import base64
import tempfile
import traceback
import time
from typing import Optional
from flask import abort, request
from google.cloud import storage
from common.rpc.ag_master import trigger_jobs
from common.rpc.ag_worker import batch_grade, ping_worker
from common.rpc.auth import get_endpoint
from common.rpc.secrets import only
from common.secrets import new_secret
from models import Assignment, Job, db
from utils import BUCKET
[docs]def create_okpy_endpoints(app):
    """Creates various RPC endpoints to interface with Okpy. See the
    following:
    - :func:`~common.rpc.ag_master.trigger_jobs`
    """
    @app.route("/api/ok/v3/grade/batch", methods=["POST"])
    def okpy_batch_grade_impl():
        data = request.json
        subm_ids = data["subm_ids"]
        assignment = data["assignment"]
        access_token = data["access_token"]
        if assignment == "test":
            return "OK"
        assignment: Optional[Assignment] = Assignment.query.get(assignment)
        if not assignment or assignment.endpoint != get_endpoint(
            course=assignment.course
        ):
            abort(404, "Unknown Assignment")
        if len(subm_ids) / assignment.batch_size > 50:
            abort(
                405,
                "Too many batches! Please set the batch_size so that there are <= 50 batches.",
            )
        job_secrets = [new_secret() for _ in subm_ids]
        queue_time = int(time.time())
        jobs = [
            Job(
                assignment_secret=assignment.assignment_secret,
                backup=backup_id,
                status="queued",
                job_secret=job_secret,
                external_job_id=new_secret(),
                access_token=access_token,
                queued_at=queue_time,
            )
            for backup_id, job_secret in zip(subm_ids, job_secrets)
        ]
        db.session.bulk_save_objects(jobs)
        db.session.commit()
        trigger_jobs(
            assignment_id=assignment.assignment_secret, jobs=job_secrets, noreply=True
        )
        return dict(jobs=[job.external_job_id for job in jobs])
    @trigger_jobs.bind(app)
    @only("ag-master")
    def trigger_jobs_impl(assignment_id, jobs):
        assignment: Assignment = Assignment.query.get(assignment_id)
        job_batches = [
            jobs[i : i + assignment.batch_size]
            for i in range(0, len(jobs), assignment.batch_size)
        ]
        bucket = storage.Client().get_bucket(BUCKET)
        blob = bucket.blob(f"zips/{assignment.endpoint}/{assignment.file}")
        with tempfile.NamedTemporaryFile() as temp:
            blob.download_to_filename(temp.name)
            with open(temp.name, "rb") as zf:
                encoded_zip = base64.b64encode(zf.read()).decode("ascii")
        for job_batch in job_batches:
            try:
                ping_worker(retries=3)
                batch_grade(
                    command=assignment.command,
                    jobs=job_batch,
                    grading_zip=encoded_zip,
                    noreply=True,
                    timeout=8,
                )
            except:
                Job.query.filter(Job.job_secret.in_(job_batch)).update(
                    {
                        Job.status: "failed",
                        Job.result: "trigger_job error\n" + traceback.format_exc(),
                    },
                    synchronize_session="fetch",
                )
                db.session.commit()
    @app.route("/results/<job_id>", methods=["GET"])
    def get_results_for(job_id):
        job = Job.query.filter_by(external_job_id=job_id).one()
        if job.status in ("finished", "failed"):
            return job.result, 200
        return "Nope!", 202
    @app.route("/results", methods=["POST"])
    def get_results_impl():
        job_ids = request.json
        jobs = Job.query.filter(Job.external_job_id.in_(job_ids)).all()
        res = {
            job.external_job_id: dict(status=job.status, result=job.result)
            for job in jobs
        }
        for job in job_ids:
            if job not in res:
                res[job] = None
        return res