Blind authentication (#675)

* auth server

* cleaning up

* auth ledger class

* class variables -> instance variables

* annotations

* add models and api route

* custom amount and api prefix

* add auth db

* blind auth token working

* jwt working

* clean up

* JWT works

* using openid connect server

* use oauth server with password flow

* new realm

* add keycloak docker

* hopefully not garbage

* auth works

* auth kinda working

* fix cli

* auth works for send and receive

* pass auth_db to Wallet

* auth in info

* refactor

* fix supported

* cache mint info

* fix settings and endpoints

* add description to .env.example

* track changes for openid connect client

* store mint in db

* store credentials

* clean up v1_api.py

* load mint info into auth wallet

* fix first login

* authenticate if refresh token fails

* clear auth also middleware

* use regex

* add cli command

* pw works

* persist keyset amounts

* add errors.py

* do not start auth server if disabled in config

* upadte poetry

* disvoery url

* fix test

* support device code flow

* adopt latest spec changes

* fix code flow

* mint max bat dynamic

* mypy ignore

* fix test

* do not serialize amount in authproof

* all auth flows working

* fix tests

* submodule

* refactor

* test

* dont sleep

* test

* add wallet auth tests

* test differently

* test only keycloak for now

* fix creds

* daemon

* fix test

* install everything

* install jinja

* delete wallet for every test

* auth: use global rate limiter

* test auth rate limit

* keycloak hostname

* move keycloak test data

* reactivate all tests

* add readme

* load proofs

* remove unused code

* remove unused code

* implement change suggestions by ok300

* add error codes

* test errors
This commit is contained in:
callebtc
2025-01-29 22:48:51 -06:00
committed by GitHub
parent b67ffd8705
commit a0ef44dba0
58 changed files with 8188 additions and 701 deletions

235
cashu/mint/auth/server.py Normal file
View File

@@ -0,0 +1,235 @@
from contextlib import asynccontextmanager
from typing import Any, List, Optional
import httpx
import jwt
from loguru import logger
from ...core.base import AuthProof
from ...core.db import Database
from ...core.errors import (
BlindAuthAmountExceededError,
BlindAuthFailedError,
BlindAuthRateLimitExceededError,
ClearAuthFailedError,
)
from ...core.models import BlindedMessage, BlindedSignature
from ...core.settings import settings
from ..crud import LedgerCrudSqlite
from ..ledger import Ledger
from ..limit import assert_limit
from .base import User
from .crud import AuthLedgerCrud, AuthLedgerCrudSqlite
class AuthLedger(Ledger):
auth_crud: AuthLedgerCrud
jwks_url: str
jwks_client: jwt.PyJWKClient
issuer: str
oicd_discovery_json: dict
def __init__(
self,
db: Database,
seed: str,
seed_decryption_key: Optional[str] = None,
derivation_path="",
amounts: Optional[List[int]] = None,
crud=LedgerCrudSqlite(),
):
super().__init__(
db=db,
seed=seed,
backends=None,
seed_decryption_key=seed_decryption_key,
derivation_path=derivation_path,
crud=crud,
amounts=amounts,
)
self.oicd_discovery_url = settings.mint_auth_oicd_discovery_url or ""
async def init_auth(self):
if not self.oicd_discovery_url:
raise Exception("Missing OpenID Connect discovery URL.")
logger.info(f"Initializing OpenID Connect: {self.oicd_discovery_url}")
self.oicd_discovery_json = self._get_oicd_discovery_json()
self.jwks_url = self.oicd_discovery_json["jwks_uri"]
self.jwks_client = jwt.PyJWKClient(self.jwks_url)
logger.info(f"Getting JWKS from: {self.jwks_url}")
self.auth_crud = AuthLedgerCrudSqlite()
self.issuer: str = self.oicd_discovery_json["issuer"]
logger.info(f"Initialized OpenID Connect: {self.issuer}")
def _get_oicd_discovery_json(self) -> dict:
resp = httpx.get(self.oicd_discovery_url)
resp.raise_for_status()
return resp.json()
def _verify_oicd_issuer(self, clear_auth_token: str) -> None:
"""Verify the issuer of the clear-auth token.
Args:
clear_auth_token (str): JWT token.
Raises:
Exception: Invalid issuer.
"""
try:
decoded = jwt.decode(
clear_auth_token,
options={"verify_signature": False},
)
issuer = decoded["iss"]
if issuer != self.issuer:
raise Exception(f"Invalid issuer: {issuer}. Expected: {self.issuer}")
except Exception as e:
raise e
def _verify_decode_jwt(self, clear_auth_token: str) -> Any:
"""Verify the clear-auth JWT token.
Args:
clear_auth_token (str): JWT token.
Raises:
jwt.ExpiredSignatureError: Token has expired.
jwt.InvalidSignatureError: Invalid signature.
jwt.InvalidTokenError: Invalid token.
Returns:
Any: Decoded JWT.
"""
try:
# Use PyJWKClient to fetch the appropriate key based on the token's header
signing_key = self.jwks_client.get_signing_key_from_jwt(clear_auth_token)
decoded = jwt.decode(
clear_auth_token,
signing_key.key,
algorithms=["RS256", "ES256"],
options={"verify_aud": False},
issuer=self.issuer,
)
logger.trace(f"Decoded JWT: {decoded}")
except jwt.ExpiredSignatureError as e:
logger.error("Token has expired")
raise e
except jwt.InvalidSignatureError as e:
logger.error("Invalid signature")
raise e
except jwt.InvalidTokenError as e:
logger.error("Invalid token")
raise e
except Exception as e:
raise e
return decoded
async def _get_user(self, decoded_token: Any) -> User:
"""Get the user from the decoded token. If the user does not exist, create a new one.
Args:
decoded_token (Any): decoded JWT from PyJWT.decode
Returns:
User: User object
"""
user_id = decoded_token["sub"]
user = await self.auth_crud.get_user(user_id=user_id, db=self.db)
if not user:
logger.info(f"Creating new user: {user_id}")
user = User(id=user_id)
await self.auth_crud.create_user(user=user, db=self.db)
return user
async def verify_clear_auth(self, clear_auth_token: str) -> User:
"""Verify the clear-auth JWT token and return the user.
Checks:
- Token not expired.
- Token signature valid.
- User exists.
Args:
auth_token (str): JWT token.
Returns:
User: Authenticated user.
"""
try:
self._verify_oicd_issuer(clear_auth_token)
decoded = self._verify_decode_jwt(clear_auth_token)
user = await self._get_user(decoded)
except Exception:
raise ClearAuthFailedError()
logger.info(f"User authenticated: {user.id}")
try:
assert_limit(user.id)
except Exception:
raise BlindAuthRateLimitExceededError()
return user
async def mint_blind_auth(
self,
*,
outputs: List[BlindedMessage],
user: User,
) -> List[BlindedSignature]:
"""Mints auth tokens. Returns a list of promises.
Args:
outputs (List[BlindedMessage]): Outputs to sign.
user (User): Authenticated user.
Raises:
Exception: Invalid auth.
Exception: Output verification failed.
Exception: Output quota exceeded.
Returns:
List[BlindedSignature]: List of blinded signatures.
"""
if len(outputs) > settings.mint_auth_max_blind_tokens:
raise BlindAuthAmountExceededError(
f"Too many outputs. You can only mint {settings.mint_auth_max_blind_tokens} tokens."
)
await self._verify_outputs(outputs)
promises = await self._generate_promises(outputs)
# update last_access timestamp of the user
await self.auth_crud.update_user(user_id=user.id, db=self.db)
return promises
@asynccontextmanager
async def verify_blind_auth(self, blind_auth_token):
"""Wrapper context that puts blind auth tokens into pending list and
melts them if the wrapped call succeeds. If it fails, the blind auth
token is not invalidated.
Args:
blind_auth_token (str): Blind auth token.
Raises:
Exception: Blind auth token validation failed.
"""
try:
proof = AuthProof.from_base64(blind_auth_token).to_proof()
await self.verify_inputs_and_outputs(proofs=[proof])
await self.db_write._verify_spent_proofs_and_set_pending([proof])
except Exception as e:
logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError()
try:
yield
await self._invalidate_proofs(proofs=[proof])
except Exception as e:
logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError()
finally:
await self.db_write._unset_proofs_pending([proof])