Source code for common.db

import os
import sys
from contextlib import contextmanager
from os import getenv
from os.path import exists
from time import sleep
from typing import List

import sqlalchemy.engine.url

from common.rpc.secrets import get_secret

use_devdb = getenv("ENV", "DEV") in ("DEV", "TEST")
use_prod_proxy = getenv("ENV") == "DEV_ON_PROD"

is_sphinx = "sphinx" in sys.argv[0]

if use_devdb:
    database_url = "sqlite:///" + os.path.join(
        os.path.abspath(os.path.dirname(__file__)), "../app.db"
    )
elif use_prod_proxy:
    curr = os.path.abspath(os.path.curdir)
    prev = None
    while True:
        if ".git" in os.listdir(curr):
            app = os.path.basename(prev)
            break
        prev = curr
        curr = os.path.dirname(curr)
    assert (
        input(
            "Are you sure? You are about to use the PROD database on a DEV build (y/N) "
        )
        .strip()
        .lower()
        == "y"
    )
    if "cloud_sql_proxy" not in os.listdir(curr):
        print(
            "Follow the instructions at https://cloud.google.com/sql/docs/mysql/sql-proxy to install the proxy "
            "in the root directory of the repository, then try again."
        )
        exit(0)
    database_url = sqlalchemy.engine.url.URL(
        drivername="mysql+pymysql",
        username="buildserver",
        password=get_secret(secret_name="DATABASE_PW"),
        host="127.0.0.1",
        port=3307,
        database=app.replace("-", "_"),
    ).__to_string__(hide_password=False)
else:
    database_url = getenv("DATABASE_URL")

engine = sqlalchemy.create_engine(
    database_url,
    **(
        dict(
            connect_args=dict(
                ssl=dict(
                    cert="/shared/client-cert.pem",
                    key="/shared/client-key.pem",
                    ca="/shared/server-ca.pem",
                ),
            )
        )
        if exists("/shared/client-cert.pem")
        else {}
    ),
    **(
        {}
        if use_devdb
        else dict(pool_size=5, max_overflow=2, pool_timeout=30, pool_recycle=1800)
    ),
)


[docs]@contextmanager def connect_db(*, retries=3): """Create a context that uses a connection to the current database. :param retries: the number of times to try connecting to the database :type retries: int :yields: a function with parameters ``(query: str, args: List[str] = [])``, where the ``query_str`` should use ``%s`` to represent sequential arguments in the ``args_list`` :example usage: .. code-block:: python with connect_db() as db: db("INSERT INTO animals VALUES %s", ["cat"]) output = db("SELECT * FROM animals") """ if is_sphinx: def no_op(*args, **kwargs): class NoOp: fetchone = lambda *args: None fetchall = lambda *args: None return NoOp() yield no_op else: for i in range(retries): try: conn = engine.connect() break except: sleep(3) continue else: raise with conn: def db(query: str, args: List[str] = []): if use_devdb: query = query.replace("%s", "?") return conn.execute(query, args) yield db
[docs]@contextmanager def transaction_db(): """Create a context that performs a transaction on current database. The difference between this and :meth:`~common.db.connect_db` is that this method batches queries in a transaction, and only runs them once the context is abandoned. :yields: a function with parameters ``(query: str, args: List[str] = [])``, where the ``query_str`` should use ``%s`` to represent sequential arguments in the ``args_list`` :example usage: .. code-block:: python with transaction_db() as db: db("INSERT INTO animals VALUES %s", ["cat"]) output = db("SELECT * FROM animals") """ with engine.begin() as conn: def db(*args): query, *rest = args if use_devdb: query = query.replace("%s", "?") return conn.execute(query, *rest) yield db