Files
nutshell/cashu/mint/auth/server.py
callebtc 9fed0f0f07 Fix duplicate blank outputs during melt (#795)
* wip blank outputs

* wip: working

* store ids for promises correctly

* tests

* fix migraiton

* revert

* fix tests

* fix auth server

* fix last tests

* retroactively change migration, initial and m017_foreign_keys_proof_tables, remove c_b and replace with c_ (same for b_)

* fix constraint

* oops

* msg stack fix

* fix test foreign key constraint

* fix postgres tests

* foreign key constraint test

* should fix psql error

* foreign key constraint sqlite

* rename to update_blinded_message_signature

* drop outputs and change columns from melt_quotes table

* switch migration order

* reorder migrations again

* fix migration

* add tests

* fix postgres migration too

* create signed_at column postgres

* foreign key constraingt promises table

* migration tool

* readme
2025-10-19 15:50:47 +02:00

242 lines
7.7 KiB
Python

from contextlib import asynccontextmanager
from typing import Any, List, Optional
import httpx
import jwt
from loguru import logger
from ...core.base import AuthProof
from ...core.db import Database
from ...core.errors import (
BlindAuthAmountExceededError,
BlindAuthFailedError,
BlindAuthRateLimitExceededError,
ClearAuthFailedError,
)
from ...core.models import BlindedMessage, BlindedSignature
from ...core.settings import settings
from ..crud import LedgerCrudSqlite
from ..ledger import Ledger
from ..limit import assert_limit
from .base import User
from .crud import AuthLedgerCrud, AuthLedgerCrudSqlite
class AuthLedger(Ledger):
auth_crud: AuthLedgerCrud
jwks_url: str
jwks_client: jwt.PyJWKClient
issuer: str
oicd_discovery_json: dict
def __init__(
self,
db: Database,
seed: str,
seed_decryption_key: Optional[str] = None,
derivation_path="",
amounts: Optional[List[int]] = None,
crud=LedgerCrudSqlite(),
):
super().__init__(
db=db,
seed=seed,
backends=None,
seed_decryption_key=seed_decryption_key,
derivation_path=derivation_path,
crud=crud,
amounts=amounts,
)
self.oicd_discovery_url = settings.mint_auth_oicd_discovery_url or ""
async def init_auth(self):
if not self.oicd_discovery_url:
raise Exception("Missing OpenID Connect discovery URL.")
logger.info(f"Initializing OpenID Connect: {self.oicd_discovery_url}")
self.oicd_discovery_json = self._get_oicd_discovery_json()
self.jwks_url = self.oicd_discovery_json["jwks_uri"]
self.jwks_client = jwt.PyJWKClient(self.jwks_url)
logger.info(f"Getting JWKS from: {self.jwks_url}")
self.auth_crud = AuthLedgerCrudSqlite()
self.issuer: str = self.oicd_discovery_json["issuer"]
logger.info(f"Initialized OpenID Connect: {self.issuer}")
def _get_oicd_discovery_json(self) -> dict:
logger.debug(
f"Getting OpenID Connect discovery JSON from: {self.oicd_discovery_url}"
)
resp = httpx.get(self.oicd_discovery_url)
resp.raise_for_status()
return resp.json()
def _verify_oicd_issuer(self, clear_auth_token: str) -> None:
"""Verify the issuer of the clear-auth token.
Args:
clear_auth_token (str): JWT token.
Raises:
Exception: Invalid issuer.
"""
try:
decoded = jwt.decode(
clear_auth_token,
options={"verify_signature": False},
)
issuer = decoded["iss"]
if issuer != self.issuer:
raise Exception(f"Invalid issuer: {issuer}. Expected: {self.issuer}")
except Exception as e:
raise e
def _verify_decode_jwt(self, clear_auth_token: str) -> Any:
"""Verify the clear-auth JWT token.
Args:
clear_auth_token (str): JWT token.
Raises:
jwt.ExpiredSignatureError: Token has expired.
jwt.InvalidSignatureError: Invalid signature.
jwt.InvalidTokenError: Invalid token.
Returns:
Any: Decoded JWT.
"""
try:
# Use PyJWKClient to fetch the appropriate key based on the token's header
signing_key = self.jwks_client.get_signing_key_from_jwt(clear_auth_token)
decoded = jwt.decode(
clear_auth_token,
signing_key.key,
algorithms=["RS256", "ES256"],
options={"verify_aud": False},
issuer=self.issuer,
)
logger.trace(f"Decoded JWT: {decoded}")
except jwt.ExpiredSignatureError as e:
logger.error("Token has expired")
raise e
except jwt.InvalidSignatureError as e:
logger.error("Invalid signature")
raise e
except jwt.InvalidTokenError as e:
logger.error("Invalid token")
raise e
except Exception as e:
raise e
return decoded
async def _get_user(self, decoded_token: Any) -> User:
"""Get the user from the decoded token. If the user does not exist, create a new one.
Args:
decoded_token (Any): decoded JWT from PyJWT.decode
Returns:
User: User object
"""
user_id = decoded_token["sub"]
user = await self.auth_crud.get_user(user_id=user_id, db=self.db)
if not user:
logger.info(f"Creating new user: {user_id}")
user = User(id=user_id)
await self.auth_crud.create_user(user=user, db=self.db)
return user
async def verify_clear_auth(self, clear_auth_token: str) -> User:
"""Verify the clear-auth JWT token and return the user.
Checks:
- Token not expired.
- Token signature valid.
- User exists.
Args:
auth_token (str): JWT token.
Returns:
User: Authenticated user.
"""
try:
self._verify_oicd_issuer(clear_auth_token)
decoded = self._verify_decode_jwt(clear_auth_token)
user = await self._get_user(decoded)
except Exception:
raise ClearAuthFailedError()
logger.info(f"User authenticated: {user.id}")
try:
assert_limit(user.id)
except Exception:
raise BlindAuthRateLimitExceededError()
return user
async def mint_blind_auth(
self,
*,
outputs: List[BlindedMessage],
user: User,
) -> List[BlindedSignature]:
"""Mints auth tokens. Returns a list of promises.
Args:
outputs (List[BlindedMessage]): Outputs to sign.
user (User): Authenticated user.
Raises:
Exception: Invalid auth.
Exception: Output verification failed.
Exception: Output quota exceeded.
Returns:
List[BlindedSignature]: List of blinded signatures.
"""
if len(outputs) > settings.mint_auth_max_blind_tokens:
raise BlindAuthAmountExceededError(
f"Too many outputs. You can only mint {settings.mint_auth_max_blind_tokens} tokens."
)
await self._verify_outputs(outputs)
await self._store_blinded_messages(outputs)
promises = await self._sign_blinded_messages(outputs)
# update last_access timestamp of the user
await self.auth_crud.update_user(user_id=user.id, db=self.db)
return promises
@asynccontextmanager
async def verify_blind_auth(self, blind_auth_token):
"""Wrapper context that puts blind auth tokens into pending list and
melts them if the wrapped call succeeds. If it fails, the blind auth
token is not invalidated.
Args:
blind_auth_token (str): Blind auth token.
Raises:
Exception: Blind auth token validation failed.
"""
try:
proof = AuthProof.from_base64(blind_auth_token).to_proof()
await self.verify_inputs_and_outputs(proofs=[proof])
await self.db_write._verify_spent_proofs_and_set_pending(
[proof], self.keysets
)
except Exception as e:
logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError()
try:
yield
await self._invalidate_proofs(proofs=[proof])
except Exception as e:
logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError()
finally:
await self.db_write._unset_proofs_pending([proof], self.keysets)