mirror of
https://github.com/aljazceru/nutshell.git
synced 2025-12-20 18:44:20 +01:00
* auth server * cleaning up * auth ledger class * class variables -> instance variables * annotations * add models and api route * custom amount and api prefix * add auth db * blind auth token working * jwt working * clean up * JWT works * using openid connect server * use oauth server with password flow * new realm * add keycloak docker * hopefully not garbage * auth works * auth kinda working * fix cli * auth works for send and receive * pass auth_db to Wallet * auth in info * refactor * fix supported * cache mint info * fix settings and endpoints * add description to .env.example * track changes for openid connect client * store mint in db * store credentials * clean up v1_api.py * load mint info into auth wallet * fix first login * authenticate if refresh token fails * clear auth also middleware * use regex * add cli command * pw works * persist keyset amounts * add errors.py * do not start auth server if disabled in config * upadte poetry * disvoery url * fix test * support device code flow * adopt latest spec changes * fix code flow * mint max bat dynamic * mypy ignore * fix test * do not serialize amount in authproof * all auth flows working * fix tests * submodule * refactor * test * dont sleep * test * add wallet auth tests * test differently * test only keycloak for now * fix creds * daemon * fix test * install everything * install jinja * delete wallet for every test * auth: use global rate limiter * test auth rate limit * keycloak hostname * move keycloak test data * reactivate all tests * add readme * load proofs * remove unused code * remove unused code * implement change suggestions by ok300 * add error codes * test errors
360 lines
13 KiB
Python
360 lines
13 KiB
Python
import asyncio
|
|
import datetime
|
|
import os
|
|
import re
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional, Union
|
|
|
|
from loguru import logger
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
|
|
from sqlalchemy.sql.expression import TextClause
|
|
|
|
from cashu.core.settings import settings
|
|
|
|
POSTGRES = "POSTGRES"
|
|
COCKROACH = "COCKROACH"
|
|
SQLITE = "SQLITE"
|
|
|
|
|
|
class Compat:
|
|
type: Optional[str] = "<inherited>"
|
|
schema: Optional[str] = "<inherited>"
|
|
|
|
def interval_seconds(self, seconds: int) -> str:
|
|
if self.type in {POSTGRES, COCKROACH}:
|
|
return f"interval '{seconds} seconds'"
|
|
elif self.type == SQLITE:
|
|
return f"{seconds}"
|
|
return "<nothing>"
|
|
|
|
@property
|
|
def timestamp_now(self) -> str:
|
|
if self.type in {POSTGRES, COCKROACH}:
|
|
return "now()"
|
|
elif self.type == SQLITE:
|
|
# return "(strftime('%s', 'now'))"
|
|
return str(int(time.time()))
|
|
return "<nothing>"
|
|
|
|
@property
|
|
def serial_primary_key(self) -> str:
|
|
if self.type in {POSTGRES, COCKROACH}:
|
|
return "SERIAL PRIMARY KEY"
|
|
elif self.type == SQLITE:
|
|
return "INTEGER PRIMARY KEY AUTOINCREMENT"
|
|
return "<nothing>"
|
|
|
|
@property
|
|
def references_schema(self) -> str:
|
|
if self.type in {POSTGRES, COCKROACH}:
|
|
return f"{self.schema}."
|
|
elif self.type == SQLITE:
|
|
return ""
|
|
return "<nothing>"
|
|
|
|
@property
|
|
def big_int(self) -> str:
|
|
if self.type in {POSTGRES}:
|
|
return "BIGINT"
|
|
return "INT"
|
|
|
|
def table_with_schema(self, table: str):
|
|
return f"{self.references_schema if self.schema else ''}{table}"
|
|
|
|
|
|
# https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.CursorResult
|
|
class Connection(Compat):
|
|
def __init__(self, conn: AsyncSession, txn, typ, name, schema):
|
|
self.conn = conn
|
|
self.txn = txn
|
|
self.type = typ
|
|
self.name = name
|
|
self.schema = schema
|
|
|
|
def rewrite_query(self, query) -> TextClause:
|
|
if self.type in {POSTGRES, COCKROACH}:
|
|
query = query.replace("%", "%%")
|
|
query = query.replace("?", "%s")
|
|
return text(query)
|
|
|
|
async def fetchall(self, query: str, values: dict = {}):
|
|
result = await self.conn.execute(self.rewrite_query(query), values)
|
|
return [
|
|
r._mapping for r in result.all()
|
|
] # will return [] if result list is empty
|
|
|
|
async def fetchone(self, query: str, values: dict = {}):
|
|
result = await self.conn.execute(self.rewrite_query(query), values)
|
|
r = result.fetchone()
|
|
return r._mapping if r is not None else None
|
|
|
|
async def execute(self, query: str, values: dict = {}):
|
|
return await self.conn.execute(self.rewrite_query(query), values)
|
|
|
|
|
|
class Database(Compat):
|
|
_connection: Optional[AsyncSession] = None
|
|
|
|
def __init__(self, db_name: str, db_location: str):
|
|
self.name = db_name
|
|
self.db_location = db_location
|
|
self.db_location_is_url = "://" in self.db_location
|
|
if self.db_location_is_url:
|
|
# raise Exception("Remote databases not supported. Use SQLite.")
|
|
database_uri = self.db_location
|
|
|
|
if database_uri.startswith("cockroachdb://"):
|
|
self.type = COCKROACH
|
|
else:
|
|
self.type = POSTGRES
|
|
database_uri = database_uri.replace(
|
|
"postgres://", "postgresql+asyncpg://"
|
|
)
|
|
database_uri = database_uri.replace(
|
|
"postgresql://", "postgresql+asyncpg://"
|
|
)
|
|
# Disble prepared statement cache: https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#prepared-statement-cache
|
|
database_uri += "?prepared_statement_cache_size=0"
|
|
else:
|
|
if not os.path.exists(self.db_location):
|
|
logger.info(f"Creating database directory: {self.db_location}")
|
|
os.makedirs(self.db_location)
|
|
self.path = os.path.join(self.db_location, f"{self.name}.sqlite3")
|
|
database_uri = f"sqlite+aiosqlite:///{self.path}?check_same_thread=false"
|
|
self.type = SQLITE
|
|
|
|
self.schema = self.name
|
|
if self.name.startswith("ext_"):
|
|
self.schema = self.name[4:]
|
|
else:
|
|
self.schema = None
|
|
|
|
kwargs = {}
|
|
if not settings.db_connection_pool:
|
|
kwargs["poolclass"] = NullPool
|
|
elif self.type == POSTGRES:
|
|
kwargs["poolclass"] = AsyncAdaptedQueuePool # type: ignore[assignment]
|
|
kwargs["pool_size"] = 50 # type: ignore[assignment]
|
|
kwargs["max_overflow"] = 100 # type: ignore[assignment]
|
|
|
|
self.engine = create_async_engine(database_uri, **kwargs)
|
|
self.async_session = sessionmaker(
|
|
self.engine, # type: ignore
|
|
expire_on_commit=False,
|
|
class_=AsyncSession, # type: ignore
|
|
)
|
|
|
|
@asynccontextmanager
|
|
async def get_connection(
|
|
self,
|
|
conn: Optional[Connection] = None,
|
|
lock_table: Optional[str] = None,
|
|
lock_select_statement: Optional[str] = None,
|
|
lock_timeout: Optional[float] = None,
|
|
):
|
|
"""Either yield the existing database connection (passthrough) or create a new one.
|
|
|
|
Args:
|
|
conn (Optional[Connection], optional): Connection object. Defaults to None.
|
|
lock_table (Optional[str], optional): Table to lock. Defaults to None.
|
|
lock_select_statement (Optional[str], optional): Lock select statement. Defaults to None.
|
|
lock_timeout (Optional[float], optional): Lock timeout. Defaults to None.
|
|
|
|
Yields:
|
|
Connection: Connection object.
|
|
"""
|
|
if conn is not None:
|
|
# Yield the existing connection
|
|
logger.trace("Reusing existing connection")
|
|
yield conn
|
|
else:
|
|
logger.trace("get_connection: Creating new connection")
|
|
async with self.connect(
|
|
lock_table, lock_select_statement, lock_timeout
|
|
) as new_conn:
|
|
yield new_conn
|
|
|
|
@asynccontextmanager
|
|
async def connect(
|
|
self,
|
|
lock_table: Optional[str] = None,
|
|
lock_select_statement: Optional[str] = None,
|
|
lock_timeout: Optional[float] = None,
|
|
):
|
|
async def _handle_lock_retry(retry_delay, timeout, start_time) -> float:
|
|
await asyncio.sleep(retry_delay)
|
|
retry_delay = min(retry_delay * 2, timeout - (time.time() - start_time))
|
|
return retry_delay
|
|
|
|
def _is_lock_exception(e):
|
|
if "database is locked" in str(e) or "could not obtain lock" in str(e):
|
|
logger.trace(f"Lock exception: {e}")
|
|
return True
|
|
|
|
timeout = lock_timeout or 5 # default to 5 seconds
|
|
start_time = time.time()
|
|
retry_delay = 0.1
|
|
random_int = int(time.time() * 1000)
|
|
trial = 0
|
|
|
|
while time.time() - start_time < timeout:
|
|
trial += 1
|
|
session: AsyncSession = self.async_session() # type: ignore
|
|
try:
|
|
logger.trace(f"Connecting to database trial: {trial} ({random_int})")
|
|
async with session.begin() as txn: # type: ignore
|
|
logger.trace("Connected to database. Starting transaction")
|
|
wconn = Connection(session, txn, self.type, self.name, self.schema)
|
|
if lock_table:
|
|
await self.acquire_lock(
|
|
wconn, lock_table, lock_select_statement
|
|
)
|
|
logger.trace(
|
|
f"> Yielding connection. Lock: {lock_table} - trial {trial} ({random_int})"
|
|
)
|
|
yield wconn
|
|
logger.trace(
|
|
f"< Connection yielded. Unlock: {lock_table} - trial {trial} ({random_int})"
|
|
)
|
|
return
|
|
except Exception as e:
|
|
if _is_lock_exception(e):
|
|
retry_delay = await _handle_lock_retry(
|
|
retry_delay, timeout, start_time
|
|
)
|
|
else:
|
|
logger.error(f"Error in session trial: {trial} ({random_int}): {e}")
|
|
raise
|
|
finally:
|
|
logger.trace(f"Closing session trial: {trial} ({random_int})")
|
|
await session.close()
|
|
|
|
raise Exception(
|
|
f"failed to acquire database lock on {lock_table} after {timeout}s and {trial} trials ({random_int})"
|
|
)
|
|
|
|
async def acquire_lock(
|
|
self,
|
|
wconn: Connection,
|
|
lock_table: str,
|
|
lock_select_statement: Optional[str] = None,
|
|
):
|
|
"""Acquire a lock on a table or a row in a table.
|
|
|
|
Args:
|
|
wconn (Connection): Connection object.
|
|
lock_table (str): Table to lock.
|
|
lock_select_statement (Optional[str], optional):
|
|
lock_timeout (Optional[float], optional):
|
|
|
|
Raises:
|
|
Exception: _description_
|
|
"""
|
|
if lock_select_statement:
|
|
assert (
|
|
len(re.findall(r"^[^=]+='[^']+'$", lock_select_statement)) == 1
|
|
), "lock_select_statement must have exactly one {column}='{value}' pattern."
|
|
try:
|
|
logger.trace(
|
|
f"Acquiring lock on {lock_table} with statement {self.lock_table(lock_table, lock_select_statement)}"
|
|
)
|
|
await wconn.execute(self.lock_table(lock_table, lock_select_statement))
|
|
logger.trace(f"Success: Acquired lock on {lock_table}")
|
|
return
|
|
except Exception as e:
|
|
if (
|
|
(
|
|
self.type == POSTGRES
|
|
and "could not obtain lock on relation" in str(e)
|
|
)
|
|
or (self.type == COCKROACH and "already locked" in str(e))
|
|
or (self.type == SQLITE and "database is locked" in str(e))
|
|
):
|
|
logger.trace(f"Table {lock_table} is already locked: {e}")
|
|
else:
|
|
logger.trace(f"Failed to acquire lock on {lock_table}: {e}")
|
|
|
|
raise e
|
|
|
|
async def fetchall(self, query: str, values: dict = {}) -> list:
|
|
async with self.connect() as conn:
|
|
result = await conn.execute(query, values)
|
|
return [r._mapping for r in result.all()]
|
|
|
|
async def fetchone(self, query: str, values: dict = {}):
|
|
async with self.connect() as conn:
|
|
result = await conn.execute(query, values)
|
|
r = result.fetchone()
|
|
return r._mapping if r is not None else None
|
|
|
|
async def execute(self, query: str, values: dict = {}):
|
|
async with self.connect() as conn:
|
|
return await conn.execute(query, values)
|
|
|
|
@asynccontextmanager
|
|
async def reuse_conn(self, conn: Connection):
|
|
yield conn
|
|
|
|
def lock_table(
|
|
self,
|
|
table: str,
|
|
lock_select_statement: Optional[str] = None,
|
|
) -> str:
|
|
# with postgres, we can lock a row with a SELECT statement with FOR UPDATE NOWAIT
|
|
if lock_select_statement:
|
|
if self.type == POSTGRES:
|
|
return f"SELECT 1 FROM {self.table_with_schema(table)} WHERE {lock_select_statement} FOR UPDATE NOWAIT;"
|
|
|
|
if self.type == POSTGRES:
|
|
return (
|
|
f"LOCK TABLE {self.table_with_schema(table)} IN EXCLUSIVE MODE NOWAIT;"
|
|
)
|
|
elif self.type == COCKROACH:
|
|
return f"LOCK TABLE {table};"
|
|
elif self.type == SQLITE:
|
|
return "BEGIN EXCLUSIVE TRANSACTION;"
|
|
return "<nothing>"
|
|
|
|
def timestamp_from_seconds(
|
|
self, seconds: Union[int, float, None]
|
|
) -> Union[str, None]:
|
|
if seconds is None:
|
|
return None
|
|
seconds = int(seconds)
|
|
if self.type in {POSTGRES, COCKROACH}:
|
|
return datetime.datetime.fromtimestamp(seconds).strftime(
|
|
"%Y-%m-%d %H:%M:%S"
|
|
)
|
|
elif self.type == SQLITE:
|
|
return str(seconds)
|
|
return None
|
|
|
|
def timestamp_now_str(self) -> str:
|
|
timestamp = self.timestamp_from_seconds(time.time())
|
|
if timestamp is None:
|
|
raise Exception("Timestamp is None")
|
|
return timestamp
|
|
|
|
def to_timestamp(
|
|
self, timestamp: Union[str, datetime.datetime]
|
|
) -> Union[str, datetime.datetime]:
|
|
if not timestamp:
|
|
timestamp = self.timestamp_now_str()
|
|
if self.type in {POSTGRES, COCKROACH}:
|
|
# return datetime.datetime
|
|
if isinstance(timestamp, datetime.datetime):
|
|
return timestamp
|
|
elif isinstance(timestamp, str):
|
|
return datetime.datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
|
|
elif self.type == SQLITE:
|
|
# return str
|
|
if isinstance(timestamp, datetime.datetime):
|
|
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
|
elif isinstance(timestamp, str):
|
|
return timestamp
|
|
return "<nothing>"
|