mirror of
https://github.com/aljazceru/nutshell.git
synced 2025-12-20 18:44:20 +01:00
* wip store balance * store balances in watchdog worker * move mint_auth_database setting * auth db * balances returned as Amount (instead of int) * add test for balance change on invoice receive * fix 1 test * cancel tasks on shutdown * watchdog can now abort * remove wallet api server * fix lndgrpc * fix lnbits balance * disable watchdog * balance lnbits msat * test db watcher with its own database connection * init superclass only once * wip: log balance in keysets table * check max balance using new keyset balance * fix test * fix another test * store fees in keysets * format * cleanup * shorter * add keyset migration to auth server * fix fakewallet * fix db tests * fix postgres problems during migration 26 (mint) * fix cln * ledger * working with pending * super fast watchdog, errors * test new pipeline * delete walletapi * delete unneeded files * revert workflows
241 lines
7.6 KiB
Python
241 lines
7.6 KiB
Python
from contextlib import asynccontextmanager
|
|
from typing import Any, List, Optional
|
|
|
|
import httpx
|
|
import jwt
|
|
from loguru import logger
|
|
|
|
from ...core.base import AuthProof
|
|
from ...core.db import Database
|
|
from ...core.errors import (
|
|
BlindAuthAmountExceededError,
|
|
BlindAuthFailedError,
|
|
BlindAuthRateLimitExceededError,
|
|
ClearAuthFailedError,
|
|
)
|
|
from ...core.models import BlindedMessage, BlindedSignature
|
|
from ...core.settings import settings
|
|
from ..crud import LedgerCrudSqlite
|
|
from ..ledger import Ledger
|
|
from ..limit import assert_limit
|
|
from .base import User
|
|
from .crud import AuthLedgerCrud, AuthLedgerCrudSqlite
|
|
|
|
|
|
class AuthLedger(Ledger):
|
|
auth_crud: AuthLedgerCrud
|
|
jwks_url: str
|
|
jwks_client: jwt.PyJWKClient
|
|
issuer: str
|
|
oicd_discovery_json: dict
|
|
|
|
def __init__(
|
|
self,
|
|
db: Database,
|
|
seed: str,
|
|
seed_decryption_key: Optional[str] = None,
|
|
derivation_path="",
|
|
amounts: Optional[List[int]] = None,
|
|
crud=LedgerCrudSqlite(),
|
|
):
|
|
super().__init__(
|
|
db=db,
|
|
seed=seed,
|
|
backends=None,
|
|
seed_decryption_key=seed_decryption_key,
|
|
derivation_path=derivation_path,
|
|
crud=crud,
|
|
amounts=amounts,
|
|
)
|
|
self.oicd_discovery_url = settings.mint_auth_oicd_discovery_url or ""
|
|
|
|
async def init_auth(self):
|
|
if not self.oicd_discovery_url:
|
|
raise Exception("Missing OpenID Connect discovery URL.")
|
|
logger.info(f"Initializing OpenID Connect: {self.oicd_discovery_url}")
|
|
self.oicd_discovery_json = self._get_oicd_discovery_json()
|
|
self.jwks_url = self.oicd_discovery_json["jwks_uri"]
|
|
self.jwks_client = jwt.PyJWKClient(self.jwks_url)
|
|
logger.info(f"Getting JWKS from: {self.jwks_url}")
|
|
self.auth_crud = AuthLedgerCrudSqlite()
|
|
self.issuer: str = self.oicd_discovery_json["issuer"]
|
|
logger.info(f"Initialized OpenID Connect: {self.issuer}")
|
|
|
|
def _get_oicd_discovery_json(self) -> dict:
|
|
logger.debug(
|
|
f"Getting OpenID Connect discovery JSON from: {self.oicd_discovery_url}"
|
|
)
|
|
resp = httpx.get(self.oicd_discovery_url)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
def _verify_oicd_issuer(self, clear_auth_token: str) -> None:
|
|
"""Verify the issuer of the clear-auth token.
|
|
|
|
Args:
|
|
clear_auth_token (str): JWT token.
|
|
|
|
Raises:
|
|
Exception: Invalid issuer.
|
|
"""
|
|
try:
|
|
decoded = jwt.decode(
|
|
clear_auth_token,
|
|
options={"verify_signature": False},
|
|
)
|
|
issuer = decoded["iss"]
|
|
if issuer != self.issuer:
|
|
raise Exception(f"Invalid issuer: {issuer}. Expected: {self.issuer}")
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def _verify_decode_jwt(self, clear_auth_token: str) -> Any:
|
|
"""Verify the clear-auth JWT token.
|
|
|
|
Args:
|
|
clear_auth_token (str): JWT token.
|
|
|
|
Raises:
|
|
jwt.ExpiredSignatureError: Token has expired.
|
|
jwt.InvalidSignatureError: Invalid signature.
|
|
jwt.InvalidTokenError: Invalid token.
|
|
|
|
Returns:
|
|
Any: Decoded JWT.
|
|
"""
|
|
try:
|
|
# Use PyJWKClient to fetch the appropriate key based on the token's header
|
|
signing_key = self.jwks_client.get_signing_key_from_jwt(clear_auth_token)
|
|
decoded = jwt.decode(
|
|
clear_auth_token,
|
|
signing_key.key,
|
|
algorithms=["RS256", "ES256"],
|
|
options={"verify_aud": False},
|
|
issuer=self.issuer,
|
|
)
|
|
logger.trace(f"Decoded JWT: {decoded}")
|
|
except jwt.ExpiredSignatureError as e:
|
|
logger.error("Token has expired")
|
|
raise e
|
|
except jwt.InvalidSignatureError as e:
|
|
logger.error("Invalid signature")
|
|
raise e
|
|
except jwt.InvalidTokenError as e:
|
|
logger.error("Invalid token")
|
|
raise e
|
|
except Exception as e:
|
|
raise e
|
|
|
|
return decoded
|
|
|
|
async def _get_user(self, decoded_token: Any) -> User:
|
|
"""Get the user from the decoded token. If the user does not exist, create a new one.
|
|
|
|
Args:
|
|
decoded_token (Any): decoded JWT from PyJWT.decode
|
|
|
|
Returns:
|
|
User: User object
|
|
"""
|
|
user_id = decoded_token["sub"]
|
|
user = await self.auth_crud.get_user(user_id=user_id, db=self.db)
|
|
if not user:
|
|
logger.info(f"Creating new user: {user_id}")
|
|
user = User(id=user_id)
|
|
await self.auth_crud.create_user(user=user, db=self.db)
|
|
return user
|
|
|
|
async def verify_clear_auth(self, clear_auth_token: str) -> User:
|
|
"""Verify the clear-auth JWT token and return the user.
|
|
|
|
Checks:
|
|
- Token not expired.
|
|
- Token signature valid.
|
|
- User exists.
|
|
|
|
Args:
|
|
auth_token (str): JWT token.
|
|
|
|
Returns:
|
|
User: Authenticated user.
|
|
"""
|
|
try:
|
|
self._verify_oicd_issuer(clear_auth_token)
|
|
decoded = self._verify_decode_jwt(clear_auth_token)
|
|
user = await self._get_user(decoded)
|
|
except Exception:
|
|
raise ClearAuthFailedError()
|
|
|
|
logger.info(f"User authenticated: {user.id}")
|
|
try:
|
|
assert_limit(user.id)
|
|
except Exception:
|
|
raise BlindAuthRateLimitExceededError()
|
|
|
|
return user
|
|
|
|
async def mint_blind_auth(
|
|
self,
|
|
*,
|
|
outputs: List[BlindedMessage],
|
|
user: User,
|
|
) -> List[BlindedSignature]:
|
|
"""Mints auth tokens. Returns a list of promises.
|
|
|
|
Args:
|
|
outputs (List[BlindedMessage]): Outputs to sign.
|
|
user (User): Authenticated user.
|
|
|
|
Raises:
|
|
Exception: Invalid auth.
|
|
Exception: Output verification failed.
|
|
Exception: Output quota exceeded.
|
|
|
|
Returns:
|
|
List[BlindedSignature]: List of blinded signatures.
|
|
"""
|
|
|
|
if len(outputs) > settings.mint_auth_max_blind_tokens:
|
|
raise BlindAuthAmountExceededError(
|
|
f"Too many outputs. You can only mint {settings.mint_auth_max_blind_tokens} tokens."
|
|
)
|
|
|
|
await self._verify_outputs(outputs)
|
|
promises = await self._generate_promises(outputs)
|
|
|
|
# update last_access timestamp of the user
|
|
await self.auth_crud.update_user(user_id=user.id, db=self.db)
|
|
|
|
return promises
|
|
|
|
@asynccontextmanager
|
|
async def verify_blind_auth(self, blind_auth_token):
|
|
"""Wrapper context that puts blind auth tokens into pending list and
|
|
melts them if the wrapped call succeeds. If it fails, the blind auth
|
|
token is not invalidated.
|
|
|
|
Args:
|
|
blind_auth_token (str): Blind auth token.
|
|
|
|
Raises:
|
|
Exception: Blind auth token validation failed.
|
|
"""
|
|
try:
|
|
proof = AuthProof.from_base64(blind_auth_token).to_proof()
|
|
await self.verify_inputs_and_outputs(proofs=[proof])
|
|
await self.db_write._verify_spent_proofs_and_set_pending(
|
|
[proof], self.keysets
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Blind auth error: {e}")
|
|
raise BlindAuthFailedError()
|
|
|
|
try:
|
|
yield
|
|
await self._invalidate_proofs(proofs=[proof])
|
|
except Exception as e:
|
|
logger.error(f"Blind auth error: {e}")
|
|
raise BlindAuthFailedError()
|
|
finally:
|
|
await self.db_write._unset_proofs_pending([proof], self.keysets)
|