import json
import traceback
from functools import wraps
from time import sleep
from typing import List
from urllib.error import HTTPError
import flask
import requests
from cachetools import TTLCache
from flask import Response, has_request_context, jsonify, request, stream_with_context
from common.rpc.auth_utils import get_token, refresh_token
from common.secrets import get_master_secret
STATUS_MARKER = "__INTERNAL_STATUS_MARKER"
GCP_INTERNAL_ERROR_CODE = 503
class Service:
def __init__(self, route):
self.route = route
def find_default_endpoints(app: str, path: str):
endpoints = []
if has_request_context():
proxied_host = request.headers.get("X-Forwarded-For-Host")
if proxied_host:
parts = proxied_host.split(".")
if "pr" in parts:
pr = parts[0]
endpoints.append(f"https://{pr}.{app}.pr.cs61a.org{path}")
endpoints.append(f"https://{app}.cs61a.org{path}")
return endpoints
def select_endpoint(endpoints: List[str], path: str, retries: int):
# try all the PR candidates
for i, endpoint in enumerate(endpoints[:-1]):
try:
for _ in range(retries + 1):
# check if the PR / endpoint exists
check_exists = requests.get(
endpoint[: -len(path)],
allow_redirects=False,
)
if check_exists.status_code == GCP_INTERNAL_ERROR_CODE:
# this error is not our fault, retry after a short pause
sleep(1)
continue
check_exists.raise_for_status()
return endpoint
else:
# if we exhaust all our retries, give up on this endpoint
continue
except (HTTPError, requests.ConnectionError):
continue
# fall back to prod
return endpoints[-1]
def stream_encode(out):
try:
for x in out:
yield bytes(x, encoding="ascii", errors="replace")
except Exception as e:
yield STATUS_MARKER
yield bytes(str(e), encoding="ascii", errors="replace")
else:
yield STATUS_MARKER
def receive_stream(resp: Response):
buffer = []
ok = True
for x in resp.iter_content():
buffer.append(x.decode("ascii"))
buff_string = "".join(buffer)
if STATUS_MARKER in buff_string:
# We are now reading in the error message
# Stop flushing the buffer
ok = False
if ok and len(buff_string) > len(STATUS_MARKER):
yield buff_string[: -len(STATUS_MARKER)]
buffer = [buff_string[-len(STATUS_MARKER) :]]
buff_string = "".join(buffer)
if not buff_string.endswith(STATUS_MARKER):
# some error occurred
pos = buff_string.index(STATUS_MARKER)
raise Exception(buff_string[pos + len(STATUS_MARKER) :])
yield from buffer[: -len(STATUS_MARKER)]
def create_service(app: str, override=None, providers=None):
app = override or app.split(".")[-1]
def route(path, *, streaming=False):
def decorator(func):
@wraps(func)
def wrapped(*, noreply=False, timeout=1, retries=0, **kwargs):
assert (
not noreply or not retries
), "Cannot retry a noreply request, use streaming instead"
if providers:
endpoints = [f"{provider}{path}" for provider in providers]
else:
endpoints = find_default_endpoints(app, path)
endpoint = select_endpoint(endpoints, path, retries)
if noreply:
try:
requests.post(endpoint, json=kwargs, timeout=timeout)
except requests.exceptions.ReadTimeout:
return
else:
for _ in range(retries + 1):
resp = requests.post(endpoint, json=kwargs, stream=streaming)
if resp.status_code == GCP_INTERNAL_ERROR_CODE:
sleep(1)
continue
break
else:
# we exhausted all our retries
resp.raise_for_status()
if resp.status_code == 401:
raise PermissionError(resp.text)
elif resp.status_code == 500:
raise Exception(resp.text)
resp.raise_for_status()
if streaming:
return receive_stream(resp)
else:
return resp.json()
def bind(app: flask.Flask):
def decorator(func):
def handler():
kwargs = request.json
try:
out = func(**kwargs)
if streaming:
return Response(stream_with_context(stream_encode(out)))
else:
return jsonify(out)
except PermissionError as e:
return "", 401
except Exception as e:
traceback.print_exc()
print(str(e))
return "", 500
app.add_url_rule(path, func.__name__, handler, methods=["POST"])
return func
return decorator
wrapped.bind = bind
return wrapped
return decorator
return Service(route)
def requires_master_secret(func):
@wraps(func)
def wrapped(*, _impersonate=None, _sudo_token=None, **kwargs):
if _sudo_token:
return func(**kwargs, _impersonate=_impersonate, _sudo_token=_sudo_token)
elif not get_master_secret() and _impersonate and not _sudo_token:
from common.rpc.secrets import (
get_secret_from_server,
) # placed here to avoid circular imports
print(f"Attempting to impersonate {_impersonate}")
try:
sudo_secret = get_secret_from_server(
secret_name="MASTER",
_impersonate=_impersonate,
_sudo_token=get_token(),
)
except PermissionError:
refresh_token()
try: # second attempt, in case the first was just an expired token
sudo_secret = get_secret_from_server(
secret_name="MASTER",
_impersonate=_impersonate,
_sudo_token=get_token(),
)
except PermissionError:
raise PermissionError(
"You must be logged in as an admin to do that."
)
master_secret = sudo_secret
else:
master_secret = get_master_secret()
return func(**kwargs, master_secret=master_secret)
return wrapped
def requires_access_token(func):
@wraps(func)
def wrapped(**kwargs):
try:
return func(**kwargs, access_token=get_token())
except PermissionError:
refresh_token()
return func(**kwargs, access_token=get_token())
return wrapped
[docs]def cached(ttl: int = 1800):
"""
Caches the return value of this RPC method for `ttl` seconds (defaults to 1800s)
"""
cache = TTLCache(1000, ttl)
def decorator(func):
@wraps(func)
def wrapped(**kwargs):
key = json.dumps(kwargs)
if key not in cache:
cache[key] = func(**kwargs)
return cache[key]
return wrapped
return decorator