Files
nutshell/cashu/mint/auth/server.py
callebtc fc0e3fe663 Mint: watchdog balance log and killswitch (#705)
* wip store balance

* store balances in watchdog worker

* move mint_auth_database setting

* auth db

* balances returned as Amount (instead of int)

* add test for balance change on invoice receive

* fix 1 test

* cancel tasks on shutdown

* watchdog can now abort

* remove wallet api server

* fix lndgrpc

* fix lnbits balance

* disable watchdog

* balance lnbits msat

* test db watcher with its own database connection

* init superclass only once

* wip: log balance in keysets table

* check max balance using new keyset balance

* fix test

* fix another test

* store fees in keysets

* format

* cleanup

* shorter

* add keyset migration to auth server

* fix fakewallet

* fix db tests

* fix postgres problems during migration 26 (mint)

* fix cln

* ledger

* working with pending

* super fast watchdog, errors

* test new pipeline

* delete walletapi

* delete unneeded files

* revert workflows
2025-05-11 20:29:13 +02:00

241 lines
7.6 KiB
Python

from contextlib import asynccontextmanager
from typing import Any, List, Optional
import httpx
import jwt
from loguru import logger
from ...core.base import AuthProof
from ...core.db import Database
from ...core.errors import (
BlindAuthAmountExceededError,
BlindAuthFailedError,
BlindAuthRateLimitExceededError,
ClearAuthFailedError,
)
from ...core.models import BlindedMessage, BlindedSignature
from ...core.settings import settings
from ..crud import LedgerCrudSqlite
from ..ledger import Ledger
from ..limit import assert_limit
from .base import User
from .crud import AuthLedgerCrud, AuthLedgerCrudSqlite
class AuthLedger(Ledger):
auth_crud: AuthLedgerCrud
jwks_url: str
jwks_client: jwt.PyJWKClient
issuer: str
oicd_discovery_json: dict
def __init__(
self,
db: Database,
seed: str,
seed_decryption_key: Optional[str] = None,
derivation_path="",
amounts: Optional[List[int]] = None,
crud=LedgerCrudSqlite(),
):
super().__init__(
db=db,
seed=seed,
backends=None,
seed_decryption_key=seed_decryption_key,
derivation_path=derivation_path,
crud=crud,
amounts=amounts,
)
self.oicd_discovery_url = settings.mint_auth_oicd_discovery_url or ""
async def init_auth(self):
if not self.oicd_discovery_url:
raise Exception("Missing OpenID Connect discovery URL.")
logger.info(f"Initializing OpenID Connect: {self.oicd_discovery_url}")
self.oicd_discovery_json = self._get_oicd_discovery_json()
self.jwks_url = self.oicd_discovery_json["jwks_uri"]
self.jwks_client = jwt.PyJWKClient(self.jwks_url)
logger.info(f"Getting JWKS from: {self.jwks_url}")
self.auth_crud = AuthLedgerCrudSqlite()
self.issuer: str = self.oicd_discovery_json["issuer"]
logger.info(f"Initialized OpenID Connect: {self.issuer}")
def _get_oicd_discovery_json(self) -> dict:
logger.debug(
f"Getting OpenID Connect discovery JSON from: {self.oicd_discovery_url}"
)
resp = httpx.get(self.oicd_discovery_url)
resp.raise_for_status()
return resp.json()
def _verify_oicd_issuer(self, clear_auth_token: str) -> None:
"""Verify the issuer of the clear-auth token.
Args:
clear_auth_token (str): JWT token.
Raises:
Exception: Invalid issuer.
"""
try:
decoded = jwt.decode(
clear_auth_token,
options={"verify_signature": False},
)
issuer = decoded["iss"]
if issuer != self.issuer:
raise Exception(f"Invalid issuer: {issuer}. Expected: {self.issuer}")
except Exception as e:
raise e
def _verify_decode_jwt(self, clear_auth_token: str) -> Any:
"""Verify the clear-auth JWT token.
Args:
clear_auth_token (str): JWT token.
Raises:
jwt.ExpiredSignatureError: Token has expired.
jwt.InvalidSignatureError: Invalid signature.
jwt.InvalidTokenError: Invalid token.
Returns:
Any: Decoded JWT.
"""
try:
# Use PyJWKClient to fetch the appropriate key based on the token's header
signing_key = self.jwks_client.get_signing_key_from_jwt(clear_auth_token)
decoded = jwt.decode(
clear_auth_token,
signing_key.key,
algorithms=["RS256", "ES256"],
options={"verify_aud": False},
issuer=self.issuer,
)
logger.trace(f"Decoded JWT: {decoded}")
except jwt.ExpiredSignatureError as e:
logger.error("Token has expired")
raise e
except jwt.InvalidSignatureError as e:
logger.error("Invalid signature")
raise e
except jwt.InvalidTokenError as e:
logger.error("Invalid token")
raise e
except Exception as e:
raise e
return decoded
async def _get_user(self, decoded_token: Any) -> User:
"""Get the user from the decoded token. If the user does not exist, create a new one.
Args:
decoded_token (Any): decoded JWT from PyJWT.decode
Returns:
User: User object
"""
user_id = decoded_token["sub"]
user = await self.auth_crud.get_user(user_id=user_id, db=self.db)
if not user:
logger.info(f"Creating new user: {user_id}")
user = User(id=user_id)
await self.auth_crud.create_user(user=user, db=self.db)
return user
async def verify_clear_auth(self, clear_auth_token: str) -> User:
"""Verify the clear-auth JWT token and return the user.
Checks:
- Token not expired.
- Token signature valid.
- User exists.
Args:
auth_token (str): JWT token.
Returns:
User: Authenticated user.
"""
try:
self._verify_oicd_issuer(clear_auth_token)
decoded = self._verify_decode_jwt(clear_auth_token)
user = await self._get_user(decoded)
except Exception:
raise ClearAuthFailedError()
logger.info(f"User authenticated: {user.id}")
try:
assert_limit(user.id)
except Exception:
raise BlindAuthRateLimitExceededError()
return user
async def mint_blind_auth(
self,
*,
outputs: List[BlindedMessage],
user: User,
) -> List[BlindedSignature]:
"""Mints auth tokens. Returns a list of promises.
Args:
outputs (List[BlindedMessage]): Outputs to sign.
user (User): Authenticated user.
Raises:
Exception: Invalid auth.
Exception: Output verification failed.
Exception: Output quota exceeded.
Returns:
List[BlindedSignature]: List of blinded signatures.
"""
if len(outputs) > settings.mint_auth_max_blind_tokens:
raise BlindAuthAmountExceededError(
f"Too many outputs. You can only mint {settings.mint_auth_max_blind_tokens} tokens."
)
await self._verify_outputs(outputs)
promises = await self._generate_promises(outputs)
# update last_access timestamp of the user
await self.auth_crud.update_user(user_id=user.id, db=self.db)
return promises
@asynccontextmanager
async def verify_blind_auth(self, blind_auth_token):
"""Wrapper context that puts blind auth tokens into pending list and
melts them if the wrapped call succeeds. If it fails, the blind auth
token is not invalidated.
Args:
blind_auth_token (str): Blind auth token.
Raises:
Exception: Blind auth token validation failed.
"""
try:
proof = AuthProof.from_base64(blind_auth_token).to_proof()
await self.verify_inputs_and_outputs(proofs=[proof])
await self.db_write._verify_spent_proofs_and_set_pending(
[proof], self.keysets
)
except Exception as e:
logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError()
try:
yield
await self._invalidate_proofs(proofs=[proof])
except Exception as e:
logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError()
finally:
await self.db_write._unset_proofs_pending([proof], self.keysets)