Files
nutshell/cashu/mint/auth/server.py
callebtc a0ef44dba0 Blind authentication (#675)
* auth server

* cleaning up

* auth ledger class

* class variables -> instance variables

* annotations

* add models and api route

* custom amount and api prefix

* add auth db

* blind auth token working

* jwt working

* clean up

* JWT works

* using openid connect server

* use oauth server with password flow

* new realm

* add keycloak docker

* hopefully not garbage

* auth works

* auth kinda working

* fix cli

* auth works for send and receive

* pass auth_db to Wallet

* auth in info

* refactor

* fix supported

* cache mint info

* fix settings and endpoints

* add description to .env.example

* track changes for openid connect client

* store mint in db

* store credentials

* clean up v1_api.py

* load mint info into auth wallet

* fix first login

* authenticate if refresh token fails

* clear auth also middleware

* use regex

* add cli command

* pw works

* persist keyset amounts

* add errors.py

* do not start auth server if disabled in config

* upadte poetry

* disvoery url

* fix test

* support device code flow

* adopt latest spec changes

* fix code flow

* mint max bat dynamic

* mypy ignore

* fix test

* do not serialize amount in authproof

* all auth flows working

* fix tests

* submodule

* refactor

* test

* dont sleep

* test

* add wallet auth tests

* test differently

* test only keycloak for now

* fix creds

* daemon

* fix test

* install everything

* install jinja

* delete wallet for every test

* auth: use global rate limiter

* test auth rate limit

* keycloak hostname

* move keycloak test data

* reactivate all tests

* add readme

* load proofs

* remove unused code

* remove unused code

* implement change suggestions by ok300

* add error codes

* test errors
2025-01-29 22:48:51 -06:00

236 lines
7.4 KiB
Python

from contextlib import asynccontextmanager
from typing import Any, List, Optional
import httpx
import jwt
from loguru import logger
from ...core.base import AuthProof
from ...core.db import Database
from ...core.errors import (
BlindAuthAmountExceededError,
BlindAuthFailedError,
BlindAuthRateLimitExceededError,
ClearAuthFailedError,
)
from ...core.models import BlindedMessage, BlindedSignature
from ...core.settings import settings
from ..crud import LedgerCrudSqlite
from ..ledger import Ledger
from ..limit import assert_limit
from .base import User
from .crud import AuthLedgerCrud, AuthLedgerCrudSqlite
class AuthLedger(Ledger):
auth_crud: AuthLedgerCrud
jwks_url: str
jwks_client: jwt.PyJWKClient
issuer: str
oicd_discovery_json: dict
def __init__(
self,
db: Database,
seed: str,
seed_decryption_key: Optional[str] = None,
derivation_path="",
amounts: Optional[List[int]] = None,
crud=LedgerCrudSqlite(),
):
super().__init__(
db=db,
seed=seed,
backends=None,
seed_decryption_key=seed_decryption_key,
derivation_path=derivation_path,
crud=crud,
amounts=amounts,
)
self.oicd_discovery_url = settings.mint_auth_oicd_discovery_url or ""
async def init_auth(self):
if not self.oicd_discovery_url:
raise Exception("Missing OpenID Connect discovery URL.")
logger.info(f"Initializing OpenID Connect: {self.oicd_discovery_url}")
self.oicd_discovery_json = self._get_oicd_discovery_json()
self.jwks_url = self.oicd_discovery_json["jwks_uri"]
self.jwks_client = jwt.PyJWKClient(self.jwks_url)
logger.info(f"Getting JWKS from: {self.jwks_url}")
self.auth_crud = AuthLedgerCrudSqlite()
self.issuer: str = self.oicd_discovery_json["issuer"]
logger.info(f"Initialized OpenID Connect: {self.issuer}")
def _get_oicd_discovery_json(self) -> dict:
resp = httpx.get(self.oicd_discovery_url)
resp.raise_for_status()
return resp.json()
def _verify_oicd_issuer(self, clear_auth_token: str) -> None:
"""Verify the issuer of the clear-auth token.
Args:
clear_auth_token (str): JWT token.
Raises:
Exception: Invalid issuer.
"""
try:
decoded = jwt.decode(
clear_auth_token,
options={"verify_signature": False},
)
issuer = decoded["iss"]
if issuer != self.issuer:
raise Exception(f"Invalid issuer: {issuer}. Expected: {self.issuer}")
except Exception as e:
raise e
def _verify_decode_jwt(self, clear_auth_token: str) -> Any:
"""Verify the clear-auth JWT token.
Args:
clear_auth_token (str): JWT token.
Raises:
jwt.ExpiredSignatureError: Token has expired.
jwt.InvalidSignatureError: Invalid signature.
jwt.InvalidTokenError: Invalid token.
Returns:
Any: Decoded JWT.
"""
try:
# Use PyJWKClient to fetch the appropriate key based on the token's header
signing_key = self.jwks_client.get_signing_key_from_jwt(clear_auth_token)
decoded = jwt.decode(
clear_auth_token,
signing_key.key,
algorithms=["RS256", "ES256"],
options={"verify_aud": False},
issuer=self.issuer,
)
logger.trace(f"Decoded JWT: {decoded}")
except jwt.ExpiredSignatureError as e:
logger.error("Token has expired")
raise e
except jwt.InvalidSignatureError as e:
logger.error("Invalid signature")
raise e
except jwt.InvalidTokenError as e:
logger.error("Invalid token")
raise e
except Exception as e:
raise e
return decoded
async def _get_user(self, decoded_token: Any) -> User:
"""Get the user from the decoded token. If the user does not exist, create a new one.
Args:
decoded_token (Any): decoded JWT from PyJWT.decode
Returns:
User: User object
"""
user_id = decoded_token["sub"]
user = await self.auth_crud.get_user(user_id=user_id, db=self.db)
if not user:
logger.info(f"Creating new user: {user_id}")
user = User(id=user_id)
await self.auth_crud.create_user(user=user, db=self.db)
return user
async def verify_clear_auth(self, clear_auth_token: str) -> User:
"""Verify the clear-auth JWT token and return the user.
Checks:
- Token not expired.
- Token signature valid.
- User exists.
Args:
auth_token (str): JWT token.
Returns:
User: Authenticated user.
"""
try:
self._verify_oicd_issuer(clear_auth_token)
decoded = self._verify_decode_jwt(clear_auth_token)
user = await self._get_user(decoded)
except Exception:
raise ClearAuthFailedError()
logger.info(f"User authenticated: {user.id}")
try:
assert_limit(user.id)
except Exception:
raise BlindAuthRateLimitExceededError()
return user
async def mint_blind_auth(
self,
*,
outputs: List[BlindedMessage],
user: User,
) -> List[BlindedSignature]:
"""Mints auth tokens. Returns a list of promises.
Args:
outputs (List[BlindedMessage]): Outputs to sign.
user (User): Authenticated user.
Raises:
Exception: Invalid auth.
Exception: Output verification failed.
Exception: Output quota exceeded.
Returns:
List[BlindedSignature]: List of blinded signatures.
"""
if len(outputs) > settings.mint_auth_max_blind_tokens:
raise BlindAuthAmountExceededError(
f"Too many outputs. You can only mint {settings.mint_auth_max_blind_tokens} tokens."
)
await self._verify_outputs(outputs)
promises = await self._generate_promises(outputs)
# update last_access timestamp of the user
await self.auth_crud.update_user(user_id=user.id, db=self.db)
return promises
@asynccontextmanager
async def verify_blind_auth(self, blind_auth_token):
"""Wrapper context that puts blind auth tokens into pending list and
melts them if the wrapped call succeeds. If it fails, the blind auth
token is not invalidated.
Args:
blind_auth_token (str): Blind auth token.
Raises:
Exception: Blind auth token validation failed.
"""
try:
proof = AuthProof.from_base64(blind_auth_token).to_proof()
await self.verify_inputs_and_outputs(proofs=[proof])
await self.db_write._verify_spent_proofs_and_set_pending([proof])
except Exception as e:
logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError()
try:
yield
await self._invalidate_proofs(proofs=[proof])
except Exception as e:
logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError()
finally:
await self.db_write._unset_proofs_pending([proof])