mirror of
https://github.com/aljazceru/nutshell.git
synced 2025-12-21 19:14:19 +01:00
Blind authentication (#675)
* auth server * cleaning up * auth ledger class * class variables -> instance variables * annotations * add models and api route * custom amount and api prefix * add auth db * blind auth token working * jwt working * clean up * JWT works * using openid connect server * use oauth server with password flow * new realm * add keycloak docker * hopefully not garbage * auth works * auth kinda working * fix cli * auth works for send and receive * pass auth_db to Wallet * auth in info * refactor * fix supported * cache mint info * fix settings and endpoints * add description to .env.example * track changes for openid connect client * store mint in db * store credentials * clean up v1_api.py * load mint info into auth wallet * fix first login * authenticate if refresh token fails * clear auth also middleware * use regex * add cli command * pw works * persist keyset amounts * add errors.py * do not start auth server if disabled in config * upadte poetry * disvoery url * fix test * support device code flow * adopt latest spec changes * fix code flow * mint max bat dynamic * mypy ignore * fix test * do not serialize amount in authproof * all auth flows working * fix tests * submodule * refactor * test * dont sleep * test * add wallet auth tests * test differently * test only keycloak for now * fix creds * daemon * fix test * install everything * install jinja * delete wallet for every test * auth: use global rate limiter * test auth rate limit * keycloak hostname * move keycloak test data * reactivate all tests * add readme * load proofs * remove unused code * remove unused code * implement change suggestions by ok300 * add error codes * test errors
This commit is contained in:
235
cashu/mint/auth/server.py
Normal file
235
cashu/mint/auth/server.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from loguru import logger
|
||||
|
||||
from ...core.base import AuthProof
|
||||
from ...core.db import Database
|
||||
from ...core.errors import (
|
||||
BlindAuthAmountExceededError,
|
||||
BlindAuthFailedError,
|
||||
BlindAuthRateLimitExceededError,
|
||||
ClearAuthFailedError,
|
||||
)
|
||||
from ...core.models import BlindedMessage, BlindedSignature
|
||||
from ...core.settings import settings
|
||||
from ..crud import LedgerCrudSqlite
|
||||
from ..ledger import Ledger
|
||||
from ..limit import assert_limit
|
||||
from .base import User
|
||||
from .crud import AuthLedgerCrud, AuthLedgerCrudSqlite
|
||||
|
||||
|
||||
class AuthLedger(Ledger):
|
||||
auth_crud: AuthLedgerCrud
|
||||
jwks_url: str
|
||||
jwks_client: jwt.PyJWKClient
|
||||
issuer: str
|
||||
oicd_discovery_json: dict
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Database,
|
||||
seed: str,
|
||||
seed_decryption_key: Optional[str] = None,
|
||||
derivation_path="",
|
||||
amounts: Optional[List[int]] = None,
|
||||
crud=LedgerCrudSqlite(),
|
||||
):
|
||||
super().__init__(
|
||||
db=db,
|
||||
seed=seed,
|
||||
backends=None,
|
||||
seed_decryption_key=seed_decryption_key,
|
||||
derivation_path=derivation_path,
|
||||
crud=crud,
|
||||
amounts=amounts,
|
||||
)
|
||||
self.oicd_discovery_url = settings.mint_auth_oicd_discovery_url or ""
|
||||
|
||||
async def init_auth(self):
|
||||
if not self.oicd_discovery_url:
|
||||
raise Exception("Missing OpenID Connect discovery URL.")
|
||||
logger.info(f"Initializing OpenID Connect: {self.oicd_discovery_url}")
|
||||
self.oicd_discovery_json = self._get_oicd_discovery_json()
|
||||
self.jwks_url = self.oicd_discovery_json["jwks_uri"]
|
||||
self.jwks_client = jwt.PyJWKClient(self.jwks_url)
|
||||
logger.info(f"Getting JWKS from: {self.jwks_url}")
|
||||
self.auth_crud = AuthLedgerCrudSqlite()
|
||||
self.issuer: str = self.oicd_discovery_json["issuer"]
|
||||
logger.info(f"Initialized OpenID Connect: {self.issuer}")
|
||||
|
||||
def _get_oicd_discovery_json(self) -> dict:
|
||||
resp = httpx.get(self.oicd_discovery_url)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
def _verify_oicd_issuer(self, clear_auth_token: str) -> None:
|
||||
"""Verify the issuer of the clear-auth token.
|
||||
|
||||
Args:
|
||||
clear_auth_token (str): JWT token.
|
||||
|
||||
Raises:
|
||||
Exception: Invalid issuer.
|
||||
"""
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
clear_auth_token,
|
||||
options={"verify_signature": False},
|
||||
)
|
||||
issuer = decoded["iss"]
|
||||
if issuer != self.issuer:
|
||||
raise Exception(f"Invalid issuer: {issuer}. Expected: {self.issuer}")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _verify_decode_jwt(self, clear_auth_token: str) -> Any:
|
||||
"""Verify the clear-auth JWT token.
|
||||
|
||||
Args:
|
||||
clear_auth_token (str): JWT token.
|
||||
|
||||
Raises:
|
||||
jwt.ExpiredSignatureError: Token has expired.
|
||||
jwt.InvalidSignatureError: Invalid signature.
|
||||
jwt.InvalidTokenError: Invalid token.
|
||||
|
||||
Returns:
|
||||
Any: Decoded JWT.
|
||||
"""
|
||||
try:
|
||||
# Use PyJWKClient to fetch the appropriate key based on the token's header
|
||||
signing_key = self.jwks_client.get_signing_key_from_jwt(clear_auth_token)
|
||||
decoded = jwt.decode(
|
||||
clear_auth_token,
|
||||
signing_key.key,
|
||||
algorithms=["RS256", "ES256"],
|
||||
options={"verify_aud": False},
|
||||
issuer=self.issuer,
|
||||
)
|
||||
logger.trace(f"Decoded JWT: {decoded}")
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
logger.error("Token has expired")
|
||||
raise e
|
||||
except jwt.InvalidSignatureError as e:
|
||||
logger.error("Invalid signature")
|
||||
raise e
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.error("Invalid token")
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return decoded
|
||||
|
||||
async def _get_user(self, decoded_token: Any) -> User:
|
||||
"""Get the user from the decoded token. If the user does not exist, create a new one.
|
||||
|
||||
Args:
|
||||
decoded_token (Any): decoded JWT from PyJWT.decode
|
||||
|
||||
Returns:
|
||||
User: User object
|
||||
"""
|
||||
user_id = decoded_token["sub"]
|
||||
user = await self.auth_crud.get_user(user_id=user_id, db=self.db)
|
||||
if not user:
|
||||
logger.info(f"Creating new user: {user_id}")
|
||||
user = User(id=user_id)
|
||||
await self.auth_crud.create_user(user=user, db=self.db)
|
||||
return user
|
||||
|
||||
async def verify_clear_auth(self, clear_auth_token: str) -> User:
|
||||
"""Verify the clear-auth JWT token and return the user.
|
||||
|
||||
Checks:
|
||||
- Token not expired.
|
||||
- Token signature valid.
|
||||
- User exists.
|
||||
|
||||
Args:
|
||||
auth_token (str): JWT token.
|
||||
|
||||
Returns:
|
||||
User: Authenticated user.
|
||||
"""
|
||||
try:
|
||||
self._verify_oicd_issuer(clear_auth_token)
|
||||
decoded = self._verify_decode_jwt(clear_auth_token)
|
||||
user = await self._get_user(decoded)
|
||||
except Exception:
|
||||
raise ClearAuthFailedError()
|
||||
|
||||
logger.info(f"User authenticated: {user.id}")
|
||||
try:
|
||||
assert_limit(user.id)
|
||||
except Exception:
|
||||
raise BlindAuthRateLimitExceededError()
|
||||
|
||||
return user
|
||||
|
||||
async def mint_blind_auth(
|
||||
self,
|
||||
*,
|
||||
outputs: List[BlindedMessage],
|
||||
user: User,
|
||||
) -> List[BlindedSignature]:
|
||||
"""Mints auth tokens. Returns a list of promises.
|
||||
|
||||
Args:
|
||||
outputs (List[BlindedMessage]): Outputs to sign.
|
||||
user (User): Authenticated user.
|
||||
|
||||
Raises:
|
||||
Exception: Invalid auth.
|
||||
Exception: Output verification failed.
|
||||
Exception: Output quota exceeded.
|
||||
|
||||
Returns:
|
||||
List[BlindedSignature]: List of blinded signatures.
|
||||
"""
|
||||
|
||||
if len(outputs) > settings.mint_auth_max_blind_tokens:
|
||||
raise BlindAuthAmountExceededError(
|
||||
f"Too many outputs. You can only mint {settings.mint_auth_max_blind_tokens} tokens."
|
||||
)
|
||||
|
||||
await self._verify_outputs(outputs)
|
||||
promises = await self._generate_promises(outputs)
|
||||
|
||||
# update last_access timestamp of the user
|
||||
await self.auth_crud.update_user(user_id=user.id, db=self.db)
|
||||
|
||||
return promises
|
||||
|
||||
@asynccontextmanager
|
||||
async def verify_blind_auth(self, blind_auth_token):
|
||||
"""Wrapper context that puts blind auth tokens into pending list and
|
||||
melts them if the wrapped call succeeds. If it fails, the blind auth
|
||||
token is not invalidated.
|
||||
|
||||
Args:
|
||||
blind_auth_token (str): Blind auth token.
|
||||
|
||||
Raises:
|
||||
Exception: Blind auth token validation failed.
|
||||
"""
|
||||
try:
|
||||
proof = AuthProof.from_base64(blind_auth_token).to_proof()
|
||||
await self.verify_inputs_and_outputs(proofs=[proof])
|
||||
await self.db_write._verify_spent_proofs_and_set_pending([proof])
|
||||
except Exception as e:
|
||||
logger.error(f"Blind auth error: {e}")
|
||||
raise BlindAuthFailedError()
|
||||
|
||||
try:
|
||||
yield
|
||||
await self._invalidate_proofs(proofs=[proof])
|
||||
except Exception as e:
|
||||
logger.error(f"Blind auth error: {e}")
|
||||
raise BlindAuthFailedError()
|
||||
finally:
|
||||
await self.db_write._unset_proofs_pending([proof])
|
||||
Reference in New Issue
Block a user