Blind authentication (#675)

* auth server

* cleaning up

* auth ledger class

* class variables -> instance variables

* annotations

* add models and api route

* custom amount and api prefix

* add auth db

* blind auth token working

* jwt working

* clean up

* JWT works

* using openid connect server

* use oauth server with password flow

* new realm

* add keycloak docker

* hopefully not garbage

* auth works

* auth kinda working

* fix cli

* auth works for send and receive

* pass auth_db to Wallet

* auth in info

* refactor

* fix supported

* cache mint info

* fix settings and endpoints

* add description to .env.example

* track changes for openid connect client

* store mint in db

* store credentials

* clean up v1_api.py

* load mint info into auth wallet

* fix first login

* authenticate if refresh token fails

* clear auth also middleware

* use regex

* add cli command

* pw works

* persist keyset amounts

* add errors.py

* do not start auth server if disabled in config

* upadte poetry

* disvoery url

* fix test

* support device code flow

* adopt latest spec changes

* fix code flow

* mint max bat dynamic

* mypy ignore

* fix test

* do not serialize amount in authproof

* all auth flows working

* fix tests

* submodule

* refactor

* test

* dont sleep

* test

* add wallet auth tests

* test differently

* test only keycloak for now

* fix creds

* daemon

* fix test

* install everything

* install jinja

* delete wallet for every test

* auth: use global rate limiter

* test auth rate limit

* keycloak hostname

* move keycloak test data

* reactivate all tests

* add readme

* load proofs

* remove unused code

* remove unused code

* implement change suggestions by ok300

* add error codes

* test errors
This commit is contained in:
callebtc
2025-01-29 22:48:51 -06:00
committed by GitHub
parent b67ffd8705
commit a0ef44dba0
58 changed files with 8188 additions and 701 deletions

View File

@@ -133,3 +133,21 @@ LIGHTNING_RESERVE_FEE_MIN=2000
# MINT_GLOBAL_RATE_LIMIT_PER_MINUTE=60
# Determines the number of transactions (mint, melt, swap) allowed per minute per IP
# MINT_TRANSACTION_RATE_LIMIT_PER_MINUTE=20
# Authentication
# These settings allow you to enable blind authentication to limit the user of your mint to a group of authenticated users.
# To use this, you need to set up an OpenID Connect provider like Keycloak, Auth0, or Hydra.
# - Add the client ID "cashu-client"
# - Enable the ES256 and RS256 algorithms for this client
# - If you want to use the authorization flow, you must add the redirect URI "http://localhost:33388/callback".
# - To support other wallets, use the well-known list of allowed redirect URIs here: https://...TODO.md
#
# Turn on authentication
# MINT_REQUIRE_AUTH=TRUE
# OpenID Connect discovery URL of the authentication provider
# MINT_AUTH_OICD_DISCOVERY_URL=http://localhost:8080/realms/nutshell/.well-known/openid-configuration
# MINT_AUTH_OICD_CLIENT_ID=cashu-client
# Number of authentication attempts allowed per minute per user
# MINT_AUTH_RATE_LIMIT_PER_MINUTE=5
# Maximum number of blind auth tokens per authentication request
# MINT_AUTH_MAX_BLIND_TOKENS=100

View File

@@ -42,6 +42,21 @@ jobs:
poetry-version: ${{ matrix.poetry-version }}
mint-database: ${{ matrix.mint-database }}
tests_keycloak_auth:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.10"]
poetry-version: ["1.8.5"]
mint-database: ["./test_data/test_mint", "postgres://cashu:cashu@localhost:5432/cashu"]
uses: ./.github/workflows/tests_keycloak_auth.yml
with:
os: ${{ matrix.os }}
python-version: ${{ matrix.python-version }}
poetry-version: ${{ matrix.poetry-version }}
mint-database: ${{ matrix.mint-database }}
regtest:
uses: ./.github/workflows/regtest.yml
strategy:

View File

@@ -0,0 +1,77 @@
name: tests_keycloak
on:
workflow_call:
inputs:
python-version:
default: "3.10.4"
type: string
poetry-version:
default: "1.8.5"
type: string
mint-database:
default: ""
type: string
os:
default: "ubuntu-latest"
type: string
jobs:
poetry:
name: Auth tests with Keycloak (db ${{ inputs.mint-database }})
runs-on: ${{ inputs.os }}
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Prepare environment
uses: ./.github/actions/prepare
with:
python-version: ${{ inputs.python-version }}
poetry-version: ${{ inputs.poetry-version }}
- name: Start PostgreSQL service
if: contains(inputs.mint-database, 'postgres')
run: |
docker run -d --name postgres \
-e POSTGRES_USER=cashu \
-e POSTGRES_PASSWORD=cashu \
-e POSTGRES_DB=cashu \
-p 5432:5432 postgres:16.4
until docker exec postgres pg_isready; do sleep 1; done
- name: Prepare environment
uses: ./.github/actions/prepare
with:
python-version: ${{ inputs.python-version }}
poetry-version: ${{ inputs.poetry-version }}
- name: Start Keycloak with Backup
run: |
docker compose -f tests/keycloak_data/docker-compose-restore.yml up -d
until docker logs $(docker ps -q --filter "ancestor=quay.io/keycloak/keycloak:25.0.6") | grep "Keycloak 25.0.6 on JVM (powered by Quarkus 3.8.5) started"; do sleep 1; done
- name: Verify Keycloak Import
run: |
docker logs $(docker ps -q --filter "ancestor=quay.io/keycloak/keycloak:25.0.6") | grep "Imported"
- name: Run tests
env:
MINT_BACKEND_BOLT11_SAT: FakeWallet
WALLET_NAME: test_wallet
MINT_HOST: localhost
MINT_PORT: 3337
MINT_TEST_DATABASE: ${{ inputs.mint-database }}
TOR: false
MINT_REQUIRE_AUTH: TRUE
MINT_AUTH_OICD_DISCOVERY_URL: http://localhost:8080/realms/nutshell/.well-known/openid-configuration
MINT_AUTH_OICD_CLIENT_ID: cashu-client
run: |
poetry run pytest tests/test_wallet_auth.py -v --cov=mint --cov-report=xml
- name: Stop and clean up Docker Compose
run: |
docker compose -f tests/keycloak_data/docker-compose-restore.yml down
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3

View File

@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from sqlite3 import Row
from typing import Any, Dict, List, Optional, Union
from typing import Any, ClassVar, Dict, List, Optional, Union
import cbor2
from loguru import logger
@@ -19,7 +19,7 @@ from .crypto.aes import AESCipher
from .crypto.b_dhke import hash_to_curve
from .crypto.keys import (
derive_keys,
derive_keys_sha256,
derive_keys_deprecated_pre_0_15,
derive_keyset_id,
derive_keyset_id_deprecated,
derive_pubkeys,
@@ -173,6 +173,9 @@ class Proof(BaseModel):
return return_dict
def to_base64(self):
return base64.b64encode(cbor2.dumps(self.to_dict(include_dleq=True))).decode()
def to_dict_no_dleq(self):
# dictionary without the fields that don't need to be send to Carol
return dict(id=self.id, amount=self.amount, secret=self.secret, C=self.C)
@@ -541,6 +544,7 @@ class Unit(Enum):
usd = 2
eur = 3
btc = 4
auth = 999
def str(self, amount: int) -> str:
if self == Unit.sat:
@@ -553,6 +557,8 @@ class Unit(Enum):
return f"{amount/100:.2f} EUR"
elif self == Unit.btc:
return f"{amount/1e8:.8f} BTC"
elif self == Unit.auth:
return f"{amount} AUTH"
else:
raise Exception("Invalid unit")
@@ -724,6 +730,7 @@ class MintKeyset:
valid_to: Optional[str] = None
first_seen: Optional[str] = None
version: Optional[str] = None
amounts: List[int]
duplicate_keyset_id: Optional[str] = None # BACKWARDS COMPATIBILITY < 0.15.0
@@ -734,6 +741,7 @@ class MintKeyset:
seed: Optional[str] = None,
encrypted_seed: Optional[str] = None,
seed_encryption_method: Optional[str] = None,
amounts: Optional[List[int]] = None,
valid_from: Optional[str] = None,
valid_to: Optional[str] = None,
first_seen: Optional[str] = None,
@@ -762,6 +770,12 @@ class MintKeyset:
assert self.seed, "seed not set"
if amounts:
self.amounts = amounts
else:
# use 2^n amounts by default
self.amounts = [2**i for i in range(settings.max_order)]
self.id = id
self.valid_from = valid_from
self.valid_to = valid_to
@@ -805,6 +819,24 @@ class MintKeyset:
logger.trace(f"Loaded keyset id: {self.id} ({self.unit.name})")
@classmethod
def from_row(cls, row: Row):
return cls(
id=row["id"],
derivation_path=row["derivation_path"],
seed=row["seed"],
encrypted_seed=row["encrypted_seed"],
seed_encryption_method=row["seed_encryption_method"],
valid_from=row["valid_from"],
valid_to=row["valid_to"],
first_seen=row["first_seen"],
active=row["active"],
unit=row["unit"],
version=row["version"],
input_fee_ppk=row["input_fee_ppk"],
amounts=json.loads(row["amounts"]),
)
@property
def public_keys_hex(self) -> Dict[int, str]:
assert self.public_keys, "public keys not set"
@@ -830,23 +862,27 @@ class MintKeyset:
self.private_keys = derive_keys_backwards_compatible_insecure_pre_0_12(
self.seed, self.derivation_path
)
self.public_keys = derive_pubkeys(self.private_keys) # type: ignore
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore
logger.trace(
f"WARNING: Using weak key derivation for keyset {self.id} (backwards"
" compatibility < 0.12)"
)
self.id = id_in_db or derive_keyset_id_deprecated(self.public_keys) # type: ignore
elif self.version_tuple < (0, 15):
self.private_keys = derive_keys_sha256(self.seed, self.derivation_path)
self.private_keys = derive_keys_deprecated_pre_0_15(
self.seed, self.amounts, self.derivation_path
)
logger.trace(
f"WARNING: Using non-bip32 derivation for keyset {self.id} (backwards"
" compatibility < 0.15)"
)
self.public_keys = derive_pubkeys(self.private_keys) # type: ignore
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore
self.id = id_in_db or derive_keyset_id_deprecated(self.public_keys) # type: ignore
else:
self.private_keys = derive_keys(self.seed, self.derivation_path)
self.public_keys = derive_pubkeys(self.private_keys) # type: ignore
self.private_keys = derive_keys(
self.seed, self.derivation_path, self.amounts
)
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore
self.id = id_in_db or derive_keyset_id(self.public_keys) # type: ignore
@@ -1254,3 +1290,48 @@ class TokenV4(Token):
t=[TokenV4Token(**t) for t in token_dict["t"]],
d=token_dict.get("d", None),
)
class AuthProof(BaseModel):
"""
Blind authentication token
"""
id: str
secret: str # secret
C: str # signature
amount: int = 1 # default amount
prefix: ClassVar[str] = "authA"
@classmethod
def from_proof(cls, proof: Proof):
return cls(id=proof.id, secret=proof.secret, C=proof.C)
def to_base64(self):
serialize_dict = self.dict()
serialize_dict.pop("amount", None)
return (
self.prefix + base64.b64encode(json.dumps(serialize_dict).encode()).decode()
)
@classmethod
def from_base64(cls, base64_str: str):
assert base64_str.startswith(cls.prefix), Exception(
f"Token prefix not valid. Expected {cls.prefix}."
)
base64_str = base64_str[len(cls.prefix) :]
return cls.parse_obj(json.loads(base64.b64decode(base64_str).decode()))
def to_proof(self):
return Proof(id=self.id, secret=self.secret, C=self.C, amount=self.amount)
class WalletMint(BaseModel):
url: str
info: str
updated: Optional[str] = None
access_token: Optional[str] = None
refresh_token: Optional[str] = None
username: Optional[str] = None
password: Optional[str] = None

View File

@@ -1,54 +1,56 @@
import base64
import hashlib
import random
from typing import Dict
from typing import Dict, List
from bip32 import BIP32
from ..settings import settings
from .secp import PrivateKey, PublicKey
def derive_keys(mnemonic: str, derivation_path: str):
def derive_keys(mnemonic: str, derivation_path: str, amounts: List[int]):
"""
Deterministic derivation of keys for 2^n values.
"""
bip32 = BIP32.from_seed(mnemonic.encode())
orders_str = [f"/{i}'" for i in range(settings.max_order)]
orders_str = [f"/{a}'" for a in range(len(amounts))]
return {
2**i: PrivateKey(
a: PrivateKey(
bip32.get_privkey_from_path(derivation_path + orders_str[i]),
raw=True,
)
for i in range(settings.max_order)
for i, a in enumerate(amounts)
}
def derive_keys_sha256(seed: str, derivation_path: str = ""):
def derive_keys_deprecated_pre_0_15(
seed: str, amounts: List[int], derivation_path: str = ""
):
"""
Deterministic derivation of keys for 2^n values.
TODO: Implement BIP32.
"""
return {
2**i: PrivateKey(
a: PrivateKey(
hashlib.sha256((seed + derivation_path + str(i)).encode("utf-8")).digest()[
:32
],
raw=True,
)
for i in range(settings.max_order)
for i, a in enumerate(amounts)
}
def derive_pubkey(seed: str):
return PrivateKey(
def derive_pubkey(seed: str) -> PublicKey:
pubkey = PrivateKey(
hashlib.sha256((seed).encode("utf-8")).digest()[:32],
raw=True,
).pubkey
assert pubkey
return pubkey
def derive_pubkeys(keys: Dict[int, PrivateKey]):
return {amt: keys[amt].pubkey for amt in [2**i for i in range(settings.max_order)]}
def derive_pubkeys(keys: Dict[int, PrivateKey], amounts: List[int]):
return {amt: keys[amt].pubkey for amt in amounts}
def derive_keyset_id(keys: Dict[int, PublicKey]):

View File

@@ -339,11 +339,21 @@ class Database(Compat):
raise Exception("Timestamp is None")
return timestamp
def to_timestamp(self, timestamp_str: str) -> Union[str, datetime.datetime]:
if not timestamp_str:
timestamp_str = self.timestamp_now_str()
def to_timestamp(
self, timestamp: Union[str, datetime.datetime]
) -> Union[str, datetime.datetime]:
if not timestamp:
timestamp = self.timestamp_now_str()
if self.type in {POSTGRES, COCKROACH}:
return datetime.datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
# return datetime.datetime
if isinstance(timestamp, datetime.datetime):
return timestamp
elif isinstance(timestamp, str):
return datetime.datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
elif self.type == SQLITE:
return timestamp_str
# return str
if isinstance(timestamp, datetime.datetime):
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
elif isinstance(timestamp, str):
return timestamp
return "<nothing>"

View File

@@ -18,6 +18,7 @@ class NotAllowedError(CashuError):
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
super().__init__(detail or self.detail, code=code or self.code)
class OutputsAlreadySignedError(CashuError):
detail = "outputs have already been signed before."
code = 10002
@@ -25,6 +26,7 @@ class OutputsAlreadySignedError(CashuError):
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
super().__init__(detail or self.detail, code=code or self.code)
class InvalidProofsError(CashuError):
detail = "proofs could not be verified"
code = 10003
@@ -32,6 +34,7 @@ class InvalidProofsError(CashuError):
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
super().__init__(detail or self.detail, code=code or self.code)
class TransactionError(CashuError):
detail = "transaction error"
code = 11000
@@ -76,12 +79,14 @@ class TransactionUnitError(TransactionError):
def __init__(self, detail):
super().__init__(detail, code=self.code)
class TransactionAmountExceedsLimitError(TransactionError):
code = 11006
def __init__(self, detail):
super().__init__(detail, code=self.code)
class KeysetError(CashuError):
detail = "keyset error"
code = 12000
@@ -113,7 +118,7 @@ class QuoteNotPaidError(CashuError):
code = 20001
def __init__(self):
super().__init__(self.detail, code=2001)
super().__init__(self.detail, code=self.code)
class QuoteSignatureInvalidError(CashuError):
@@ -121,7 +126,7 @@ class QuoteSignatureInvalidError(CashuError):
code = 20008
def __init__(self):
super().__init__(self.detail, code=20008)
super().__init__(self.detail, code=self.code)
class QuoteRequiresPubkeyError(CashuError):
@@ -129,4 +134,52 @@ class QuoteRequiresPubkeyError(CashuError):
code = 20009
def __init__(self):
super().__init__(self.detail, code=20009)
super().__init__(self.detail, code=self.code)
class ClearAuthRequiredError(CashuError):
detail = "Endpoint requires clear auth"
code = 80001
def __init__(self):
super().__init__(self.detail, code=self.code)
class ClearAuthFailedError(CashuError):
detail = "Clear authentication failed"
code = 80002
def __init__(self):
super().__init__(self.detail, code=self.code)
class BlindAuthRequiredError(CashuError):
detail = "Endpoint requires blind auth"
code = 81001
def __init__(self):
super().__init__(self.detail, code=self.code)
class BlindAuthFailedError(CashuError):
detail = "Blind authentication failed"
code = 81002
def __init__(self):
super().__init__(self.detail, code=self.code)
class BlindAuthAmountExceededError(CashuError):
detail = "Maximum blind auth amount exceeded"
code = 81003
def __init__(self, detail: Optional[str] = None):
super().__init__(detail or self.detail, code=self.code)
class BlindAuthRateLimitExceededError(CashuError):
detail = "Blind auth token mint rate limit exceeded"
code = 81004
def __init__(self):
super().__init__(self.detail, code=self.code)

125
cashu/core/mint_info.py Normal file
View File

@@ -0,0 +1,125 @@
import json
import re
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from .base import Method, Unit
from .models import MintInfoContact, MintInfoProtectedEndpoint, Nut15MppSupport
from .nuts.nuts import BLIND_AUTH_NUT, CLEAR_AUTH_NUT, MPP_NUT, WEBSOCKETS_NUT
class MintInfo(BaseModel):
name: Optional[str]
pubkey: Optional[str]
version: Optional[str]
description: Optional[str]
description_long: Optional[str]
contact: Optional[List[MintInfoContact]]
motd: Optional[str]
icon_url: Optional[str]
time: Optional[int]
nuts: Dict[int, Any]
def __str__(self):
return f"{self.name} ({self.description})"
@classmethod
def from_json_str(cls, json_str: str):
return cls.parse_obj(json.loads(json_str))
def supports_nut(self, nut: int) -> bool:
if self.nuts is None:
return False
return nut in self.nuts
def supports_mpp(self, method: str, unit: Unit) -> bool:
if not self.nuts:
return False
nut_15 = self.nuts.get(MPP_NUT)
if not nut_15 or not self.supports_nut(MPP_NUT) or not nut_15.get("methods"):
return False
for entry in nut_15["methods"]:
entry_obj = Nut15MppSupport.parse_obj(entry)
if entry_obj.method == method and entry_obj.unit == unit.name:
return True
return False
def supports_websocket_mint_quote(self, method: Method, unit: Unit) -> bool:
if not self.nuts or not self.supports_nut(WEBSOCKETS_NUT):
return False
websocket_settings = self.nuts[WEBSOCKETS_NUT]
if not websocket_settings or "supported" not in websocket_settings:
return False
websocket_supported = websocket_settings["supported"]
for entry in websocket_supported:
if entry["method"] == method.name and entry["unit"] == unit.name:
if "bolt11_mint_quote" in entry["commands"]:
return True
return False
def requires_clear_auth(self) -> bool:
return self.supports_nut(CLEAR_AUTH_NUT)
def oidc_discovery_url(self) -> str:
if not self.requires_clear_auth():
raise Exception(
"Could not get OIDC discovery URL. Mint info does not support clear auth."
)
return self.nuts[CLEAR_AUTH_NUT]["openid_discovery"]
def oidc_client_id(self) -> str:
if not self.requires_clear_auth():
raise Exception(
"Could not get client_id. Mint info does not support clear auth."
)
return self.nuts[CLEAR_AUTH_NUT]["client_id"]
def required_clear_auth_endpoints(self) -> List[MintInfoProtectedEndpoint]:
if not self.requires_clear_auth():
return []
return [
MintInfoProtectedEndpoint.parse_obj(e)
for e in self.nuts[CLEAR_AUTH_NUT]["protected_endpoints"]
]
def requires_clear_auth_path(self, method: str, path: str) -> bool:
if not self.requires_clear_auth():
return False
path = "/" + path if not path.startswith("/") else path
for endpoint in self.required_clear_auth_endpoints():
if method == endpoint.method and re.match(endpoint.path, path):
return True
return False
def requires_blind_auth(self) -> bool:
return self.supports_nut(BLIND_AUTH_NUT)
@property
def bat_max_mint(self) -> int:
if not self.requires_blind_auth():
raise Exception(
"Could not get max mint. Mint info does not support blind auth."
)
if not self.nuts[BLIND_AUTH_NUT].get("bat_max_mint"):
raise Exception("Could not get max mint. bat_max_mint not set.")
return self.nuts[BLIND_AUTH_NUT]["bat_max_mint"]
def required_blind_auth_paths(self) -> List[MintInfoProtectedEndpoint]:
if not self.requires_blind_auth():
return []
return [
MintInfoProtectedEndpoint.parse_obj(e)
for e in self.nuts[BLIND_AUTH_NUT]["protected_endpoints"]
]
def requires_blind_auth_path(self, method: str, path: str) -> bool:
if not self.requires_blind_auth():
return False
path = "/" + path if not path.startswith("/") else path
for endpoint in self.required_blind_auth_paths():
if method == endpoint.method and re.match(endpoint.path, path):
return True
return False

View File

@@ -38,6 +38,11 @@ class MintInfoContact(BaseModel):
info: str
class MintInfoProtectedEndpoint(BaseModel):
method: str
path: str
class GetInfoResponse(BaseModel):
name: Optional[str] = None
pubkey: Optional[str] = None
@@ -57,7 +62,7 @@ class GetInfoResponse(BaseModel):
# BEGIN DEPRECATED: NUT-06 contact field change
# NUT-06 PR: https://github.com/cashubtc/nuts/pull/117
@root_validator(pre=True)
def preprocess_deprecated_contact_field(cls, values):
def preprocess_deprecated_contact_field(cls, values: dict):
if "contact" in values and values["contact"]:
if isinstance(values["contact"][0], list):
values["contact"] = [
@@ -346,3 +351,16 @@ class PostRestoreResponse(BaseModel):
def __init__(self, **data):
super().__init__(**data)
self.promises = self.signatures
# ------- API: BLIND AUTH -------
class PostAuthBlindMintRequest(BaseModel):
outputs: List[BlindedMessage] = Field(
...,
max_items=settings.mint_max_request_length,
description="Blinded messages for creating blind auth tokens.",
)
class PostAuthBlindMintResponse(BaseModel):
signatures: List[BlindedSignature] = []

View File

@@ -14,3 +14,5 @@ MPP_NUT = 15
WEBSOCKETS_NUT = 17
CACHE_NUT = 19
MINT_QUOTE_SIGNATURE_NUT = 20
CLEAR_AUTH_NUT = 21
BLIND_AUTH_NUT = 22

View File

@@ -68,6 +68,8 @@ class MintSettings(CashuSettings):
class MintDeprecationFlags(MintSettings):
mint_inactivate_base64_keysets: bool = Field(default=False)
auth_database: str = Field(default="data/mint")
class MintBackends(MintSettings):
mint_lightning_backend: str = Field(default="") # deprecated
@@ -231,6 +233,27 @@ class CoreLightningRestFundingSource(MintSettings):
mint_corelightning_rest_cert: Optional[str] = Field(default=None)
class AuthSettings(MintSettings):
mint_require_auth: bool = Field(default=False)
mint_auth_oicd_discovery_url: Optional[str] = Field(default=None)
mint_auth_oicd_client_id: str = Field(default="cashu-client")
mint_auth_rate_limit_per_minute: int = Field(
default=5,
title="Auth rate limit per minute",
description="Number of requests a user can authenticate per minute.",
)
mint_auth_max_blind_tokens: int = Field(default=100, gt=0)
mint_require_clear_auth_paths: List[List[str]] = [
["POST", "/v1/auth/blind/mint"],
]
mint_require_blind_auth_paths: List[List[str]] = [
["POST", "/v1/swap"],
["POST", "/v1/mint/quote/bolt11"],
["POST", "/v1/mint/bolt11"],
["POST", "/v1/melt/bolt11"],
]
class MintRedisCache(MintSettings):
mint_redis_cache_enabled: bool = Field(default=False)
mint_redis_cache_url: Optional[str] = Field(default=None)
@@ -246,6 +269,7 @@ class Settings(
FakeWalletSettings,
MintLimits,
MintBackends,
AuthSettings,
MintRedisCache,
MintDeprecationFlags,
MintSettings,

View File

@@ -13,10 +13,10 @@ from starlette.requests import Request
from ..core.errors import CashuError
from ..core.logging import configure_logger
from ..core.settings import settings
from .auth.router import auth_router
from .router import redis, router
from .router_deprecated import router_deprecated
from .startup import shutdown_mint as shutdown_mint_init
from .startup import start_mint_init
from .startup import shutdown_mint, start_auth, start_mint
if settings.debug_profiling:
pass
@@ -29,7 +29,9 @@ from .middleware import add_middlewares, request_validation_exception_handler
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
await start_mint_init()
await start_mint()
if settings.mint_require_auth:
await start_auth()
try:
yield
except asyncio.CancelledError:
@@ -38,7 +40,7 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
finally:
try:
await redis.disconnect()
await shutdown_mint_init()
await shutdown_mint()
except asyncio.CancelledError:
logger.info("CancelledError during shutdown, shutting down forcefully")
@@ -110,3 +112,6 @@ if settings.debug_mint_only_deprecated:
else:
app.include_router(router=router, tags=["Mint"])
app.include_router(router=router_deprecated, tags=["Deprecated"], deprecated=True)
if settings.mint_require_auth:
app.include_router(auth_router, tags=["Auth"])

13
cashu/mint/auth/base.py Normal file
View File

@@ -0,0 +1,13 @@
import datetime
from typing import Optional
class User:
id: str
last_access: Optional[datetime.datetime]
def __init__(self, id: str, last_access: Optional[datetime.datetime] = None):
self.id = id
if isinstance(last_access, int):
last_access = datetime.datetime.fromtimestamp(last_access)
self.last_access = last_access

669
cashu/mint/auth/crud.py Normal file
View File

@@ -0,0 +1,669 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from ...core.base import (
BlindedSignature,
MeltQuote,
MintKeyset,
MintQuote,
Proof,
)
from ...core.db import (
Connection,
Database,
)
from .base import User
class AuthLedgerCrud(ABC):
"""
Database interface for Nutshell auth ledger.
"""
@abstractmethod
async def create_user(
self,
*,
db: Database,
user: User,
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod
async def get_user(
self,
*,
db: Database,
user_id: str,
conn: Optional[Connection] = None,
) -> Optional[User]: ...
async def update_user(
self,
*,
db: Database,
user_id: str,
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod
async def get_keyset(
self,
*,
db: Database,
id: str = "",
derivation_path: str = "",
seed: str = "",
conn: Optional[Connection] = None,
) -> List[MintKeyset]: ...
@abstractmethod
async def get_proofs_used(
self,
*,
Ys: List[str],
db: Database,
conn: Optional[Connection] = None,
) -> List[Proof]: ...
@abstractmethod
async def invalidate_proof(
self,
*,
db: Database,
proof: Proof,
quote_id: Optional[str] = None,
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod
async def get_proofs_pending(
self,
*,
Ys: List[str],
db: Database,
conn: Optional[Connection] = None,
) -> List[Proof]: ...
@abstractmethod
async def set_proof_pending(
self,
*,
db: Database,
proof: Proof,
quote_id: Optional[str] = None,
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod
async def unset_proof_pending(
self,
*,
proof: Proof,
db: Database,
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod
async def store_keyset(
self,
*,
db: Database,
keyset: MintKeyset,
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod
async def store_promise(
self,
*,
db: Database,
amount: int,
b_: str,
c_: str,
id: str,
e: str = "",
s: str = "",
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod
async def get_promise(
self,
*,
db: Database,
b_: str,
conn: Optional[Connection] = None,
) -> Optional[BlindedSignature]: ...
@abstractmethod
async def get_promises(
self,
*,
db: Database,
b_s: List[str],
conn: Optional[Connection] = None,
) -> List[BlindedSignature]: ...
class AuthLedgerCrudSqlite(AuthLedgerCrud):
"""Implementation of AuthLedgerCrud for sqlite.
Args:
AuthLedgerCrud (ABC): Abstract base class for AuthLedgerCrud.
"""
async def create_user(
self,
*,
db: Database,
user: User,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
INSERT INTO {db.table_with_schema('users')}
(id)
VALUES (:id)
""",
{"id": user.id},
)
async def get_user(
self,
*,
db: Database,
user_id: str,
conn: Optional[Connection] = None,
) -> Optional[User]:
row = await (conn or db).fetchone(
f"""
SELECT * from {db.table_with_schema('users')}
WHERE id = :user_id
""",
{"user_id": user_id},
)
return User(**row) if row else None
async def update_user(
self,
*,
db: Database,
user_id: str,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
UPDATE {db.table_with_schema('users')}
SET last_access = :last_access
WHERE id = :user_id
""",
{
"last_access": db.to_timestamp(db.timestamp_now_str()),
"user_id": user_id,
},
)
async def store_promise(
self,
*,
db: Database,
amount: int,
b_: str,
c_: str,
id: str,
e: str = "",
s: str = "",
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
INSERT INTO {db.table_with_schema('promises')}
(amount, b_, c_, dleq_e, dleq_s, id, created)
VALUES (:amount, :b_, :c_, :dleq_e, :dleq_s, :id, :created)
""",
{
"amount": amount,
"b_": b_,
"c_": c_,
"dleq_e": e,
"dleq_s": s,
"id": id,
"created": db.to_timestamp(db.timestamp_now_str()),
},
)
async def get_promise(
self,
*,
db: Database,
b_: str,
conn: Optional[Connection] = None,
) -> Optional[BlindedSignature]:
row = await (conn or db).fetchone(
f"""
SELECT * from {db.table_with_schema('promises')}
WHERE b_ = :b_
""",
{"b_": str(b_)},
)
return BlindedSignature.from_row(row) if row else None
async def get_promises(
self,
*,
db: Database,
b_s: List[str],
conn: Optional[Connection] = None,
) -> List[BlindedSignature]:
rows = await (conn or db).fetchall(
f"""
SELECT * from {db.table_with_schema('promises')}
WHERE b_ IN ({','.join([':b_' + str(i) for i in range(len(b_s))])})
""",
{f"b_{i}": b_s[i] for i in range(len(b_s))},
)
return [BlindedSignature.from_row(r) for r in rows] if rows else []
async def invalidate_proof(
self,
*,
db: Database,
proof: Proof,
quote_id: Optional[str] = None,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
INSERT INTO {db.table_with_schema('proofs_used')}
(amount, c, secret, y, id, witness, created, melt_quote)
VALUES (:amount, :c, :secret, :y, :id, :witness, :created, :melt_quote)
""",
{
"amount": proof.amount,
"c": proof.C,
"secret": proof.secret,
"y": proof.Y,
"id": proof.id,
"witness": proof.witness,
"created": db.to_timestamp(db.timestamp_now_str()),
"melt_quote": quote_id,
},
)
async def get_all_melt_quotes_from_pending_proofs(
self,
*,
db: Database,
conn: Optional[Connection] = None,
) -> List[MeltQuote]:
rows = await (conn or db).fetchall(
f"""
SELECT * from {db.table_with_schema('melt_quotes')} WHERE quote in (SELECT DISTINCT melt_quote FROM {db.table_with_schema('proofs_pending')})
"""
)
return [MeltQuote.from_row(r) for r in rows]
async def get_pending_proofs_for_quote(
self,
*,
quote_id: str,
db: Database,
conn: Optional[Connection] = None,
) -> List[Proof]:
rows = await (conn or db).fetchall(
f"""
SELECT * from {db.table_with_schema('proofs_pending')}
WHERE melt_quote = :quote_id
""",
{"quote_id": quote_id},
)
return [Proof(**r) for r in rows]
async def get_proofs_pending(
self,
*,
Ys: List[str],
db: Database,
conn: Optional[Connection] = None,
) -> List[Proof]:
query = f"""
SELECT * from {db.table_with_schema('proofs_pending')}
WHERE y IN ({','.join([':y_' + str(i) for i in range(len(Ys))])})
"""
values = {f"y_{i}": Ys[i] for i in range(len(Ys))}
rows = await (conn or db).fetchall(query, values)
return [Proof(**r) for r in rows]
async def set_proof_pending(
self,
*,
db: Database,
proof: Proof,
quote_id: Optional[str] = None,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
INSERT INTO {db.table_with_schema('proofs_pending')}
(amount, c, secret, y, id, witness, created, melt_quote)
VALUES (:amount, :c, :secret, :y, :id, :witness, :created, :melt_quote)
""",
{
"amount": proof.amount,
"c": proof.C,
"secret": proof.secret,
"y": proof.Y,
"id": proof.id,
"witness": proof.witness,
"created": db.to_timestamp(db.timestamp_now_str()),
"melt_quote": quote_id,
},
)
async def unset_proof_pending(
self,
*,
proof: Proof,
db: Database,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
DELETE FROM {db.table_with_schema('proofs_pending')}
WHERE secret = :secret
""",
{"secret": proof.secret},
)
async def store_mint_quote(
self,
*,
quote: MintQuote,
db: Database,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
INSERT INTO {db.table_with_schema('mint_quotes')}
(quote, method, request, checking_id, unit, amount, issued, paid, state, created_time, paid_time)
VALUES (:quote, :method, :request, :checking_id, :unit, :amount, :issued, :paid, :state, :created_time, :paid_time)
""",
{
"quote": quote.quote,
"method": quote.method,
"request": quote.request,
"checking_id": quote.checking_id,
"unit": quote.unit,
"amount": quote.amount,
"issued": quote.issued,
"paid": quote.paid,
"state": quote.state.name,
"created_time": db.to_timestamp(
db.timestamp_from_seconds(quote.created_time) or ""
),
"paid_time": db.to_timestamp(
db.timestamp_from_seconds(quote.paid_time) or ""
),
},
)
async def get_mint_quote(
self,
*,
quote_id: Optional[str] = None,
checking_id: Optional[str] = None,
request: Optional[str] = None,
db: Database,
conn: Optional[Connection] = None,
) -> Optional[MintQuote]:
clauses = []
values: Dict[str, Any] = {}
if quote_id:
clauses.append("quote = :quote_id")
values["quote_id"] = quote_id
if checking_id:
clauses.append("checking_id = :checking_id")
values["checking_id"] = checking_id
if request:
clauses.append("request = :request")
values["request"] = request
if not any(clauses):
raise ValueError("No search criteria")
where = f"WHERE {' AND '.join(clauses)}"
row = await (conn or db).fetchone(
f"""
SELECT * from {db.table_with_schema('mint_quotes')}
{where}
""",
values,
)
if row is None:
return None
return MintQuote.from_row(row) if row else None
async def get_mint_quote_by_request(
self,
*,
request: str,
db: Database,
conn: Optional[Connection] = None,
) -> Optional[MintQuote]:
row = await (conn or db).fetchone(
f"""
SELECT * from {db.table_with_schema('mint_quotes')}
WHERE request = :request
""",
{"request": request},
)
return MintQuote.from_row(row) if row else None
async def update_mint_quote(
self,
*,
quote: MintQuote,
db: Database,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"UPDATE {db.table_with_schema('mint_quotes')} SET issued = :issued, paid = :paid, state = :state, paid_time = :paid_time WHERE quote = :quote",
{
"issued": quote.issued,
"paid": quote.paid,
"state": quote.state.name,
"paid_time": db.to_timestamp(
db.timestamp_from_seconds(quote.paid_time) or ""
),
"quote": quote.quote,
},
)
async def store_melt_quote(
self,
*,
quote: MeltQuote,
db: Database,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
INSERT INTO {db.table_with_schema('melt_quotes')}
(quote, method, request, checking_id, unit, amount, fee_reserve, paid, state, created_time, paid_time, fee_paid, proof, change, expiry)
VALUES (:quote, :method, :request, :checking_id, :unit, :amount, :fee_reserve, :paid, :state, :created_time, :paid_time, :fee_paid, :proof, :change, :expiry)
""",
{
"quote": quote.quote,
"method": quote.method,
"request": quote.request,
"checking_id": quote.checking_id,
"unit": quote.unit,
"amount": quote.amount,
"fee_reserve": quote.fee_reserve or 0,
"paid": quote.paid,
"state": quote.state.name,
"created_time": db.to_timestamp(
db.timestamp_from_seconds(quote.created_time) or ""
),
"paid_time": db.to_timestamp(
db.timestamp_from_seconds(quote.paid_time) or ""
),
"fee_paid": quote.fee_paid,
"proof": quote.payment_preimage,
"change": json.dumps(quote.change) if quote.change else None,
"expiry": db.to_timestamp(
db.timestamp_from_seconds(quote.expiry) or ""
),
},
)
async def get_melt_quote(
self,
*,
quote_id: Optional[str] = None,
checking_id: Optional[str] = None,
request: Optional[str] = None,
db: Database,
conn: Optional[Connection] = None,
) -> Optional[MeltQuote]:
clauses = []
values: Dict[str, Any] = {}
if quote_id:
clauses.append("quote = :quote_id")
values["quote_id"] = quote_id
if checking_id:
clauses.append("checking_id = :checking_id")
values["checking_id"] = checking_id
if request:
clauses.append("request = :request")
values["request"] = request
if not any(clauses):
raise ValueError("No search criteria")
where = f"WHERE {' AND '.join(clauses)}"
row = await (conn or db).fetchone(
f"""
SELECT * from {db.table_with_schema('melt_quotes')}
{where}
""",
values,
)
if row is None:
return None
return MeltQuote.from_row(row) if row else None
async def update_melt_quote(
self,
*,
quote: MeltQuote,
db: Database,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
UPDATE {db.table_with_schema('melt_quotes')} SET paid = :paid, state = :state, fee_paid = :fee_paid, paid_time = :paid_time, proof = :proof, change = :change WHERE quote = :quote
""",
{
"paid": quote.paid,
"state": quote.state.name,
"fee_paid": quote.fee_paid,
"paid_time": db.to_timestamp(
db.timestamp_from_seconds(quote.paid_time) or ""
),
"proof": quote.payment_preimage,
"change": (
json.dumps([s.dict() for s in quote.change])
if quote.change
else None
),
"quote": quote.quote,
},
)
async def store_keyset(
self,
*,
db: Database,
keyset: MintKeyset,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
INSERT INTO {db.table_with_schema('keysets')}
(id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk)
VALUES (:id, :seed, :encrypted_seed, :seed_encryption_method, :derivation_path, :valid_from, :valid_to, :first_seen, :active, :version, :unit, :input_fee_ppk)
""",
{
"id": keyset.id,
"seed": keyset.seed,
"encrypted_seed": keyset.encrypted_seed,
"seed_encryption_method": keyset.seed_encryption_method,
"derivation_path": keyset.derivation_path,
"valid_from": db.to_timestamp(
keyset.valid_from or db.timestamp_now_str()
),
"valid_to": db.to_timestamp(keyset.valid_to or db.timestamp_now_str()),
"first_seen": db.to_timestamp(
keyset.first_seen or db.timestamp_now_str()
),
"active": True,
"version": keyset.version,
"unit": keyset.unit.name,
"input_fee_ppk": keyset.input_fee_ppk,
},
)
async def get_keyset(
self,
*,
db: Database,
id: Optional[str] = None,
derivation_path: Optional[str] = None,
seed: Optional[str] = None,
unit: Optional[str] = None,
active: Optional[bool] = None,
conn: Optional[Connection] = None,
) -> List[MintKeyset]:
clauses = []
values: Dict = {}
if active is not None:
clauses.append("active = :active")
values["active"] = active
if id is not None:
clauses.append("id = :id")
values["id"] = id
if derivation_path is not None:
clauses.append("derivation_path = :derivation_path")
values["derivation_path"] = derivation_path
if seed is not None:
clauses.append("seed = :seed")
values["seed"] = seed
if unit is not None:
clauses.append("unit = :unit")
values["unit"] = unit
where = ""
if clauses:
where = f"WHERE {' AND '.join(clauses)}"
rows = await (conn or db).fetchall( # type: ignore
f"""
SELECT * from {db.table_with_schema('keysets')}
{where}
""",
values,
)
return [MintKeyset(**row) for row in rows]
async def get_proofs_used(
self,
*,
Ys: List[str],
db: Database,
conn: Optional[Connection] = None,
) -> List[Proof]:
query = f"""
SELECT * from {db.table_with_schema('proofs_used')}
WHERE y IN ({','.join([':y_' + str(i) for i in range(len(Ys))])})
"""
values = {f"y_{i}": Ys[i] for i in range(len(Ys))}
rows = await (conn or db).fetchall(query, values)
return [Proof(**r) for r in rows] if rows else []

View File

@@ -0,0 +1,100 @@
from ...core.db import Connection, Database
async def m000_create_migrations_table(conn: Connection):
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {conn.table_with_schema('dbversions')} (
db TEXT PRIMARY KEY,
version INT NOT NULL
)
"""
)
async def m001_initial(db: Database):
async with db.connect() as conn:
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {db.table_with_schema('users')} (
id TEXT PRIMARY KEY,
last_access TIMESTAMP,
UNIQUE (id)
);
"""
)
# columns: (id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk)
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {db.table_with_schema('keysets')} (
id TEXT NOT NULL,
seed TEXT NOT NULL,
encrypted_seed TEXT,
seed_encryption_method TEXT,
derivation_path TEXT NOT NULL,
valid_from TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
valid_to TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
first_seen TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
active BOOL DEFAULT TRUE,
version TEXT,
unit TEXT NOT NULL,
input_fee_ppk INT,
amounts TEXT,
UNIQUE (derivation_path)
);
"""
)
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {db.table_with_schema('promises')} (
id TEXT NOT NULL,
amount {db.big_int} NOT NULL,
b_ TEXT NOT NULL,
c_ TEXT NOT NULL,
dleq_e TEXT,
dleq_s TEXT,
created TIMESTAMP,
UNIQUE (b_)
);
"""
)
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {db.table_with_schema('proofs_used')} (
id TEXT NOT NULL,
amount {db.big_int} NOT NULL,
c TEXT NOT NULL,
secret TEXT NOT NULL,
y TEXT NOT NULL,
witness TEXT,
created TIMESTAMP,
melt_quote TEXT,
UNIQUE (secret)
);
"""
)
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {db.table_with_schema('proofs_pending')} (
id TEXT NOT NULL,
amount {db.big_int} NOT NULL,
c TEXT NOT NULL,
secret TEXT NOT NULL,
y TEXT NOT NULL,
witness TEXT,
created TIMESTAMP,
melt_quote TEXT,
UNIQUE (secret)
);
"""
)

106
cashu/mint/auth/router.py Normal file
View File

@@ -0,0 +1,106 @@
from fastapi import APIRouter, Request
from loguru import logger
from ...core.errors import KeysetNotFoundError
from ...core.models import (
KeysetsResponse,
KeysetsResponseKeyset,
KeysResponse,
KeysResponseKeyset,
PostAuthBlindMintRequest,
PostAuthBlindMintResponse,
)
from ...mint.startup import auth_ledger
auth_router: APIRouter = APIRouter()
@auth_router.get(
"/v1/auth/blind/keys",
name="Mint public keys",
summary="Get the public keys of the newest mint keyset",
response_description=(
"All supported token values their associated public keys for all active keysets"
),
response_model=KeysResponse,
)
async def keys():
"""This endpoint returns a dictionary of all supported token values of the mint and their associated public key."""
logger.trace("> GET /v1/auth/blind/keys")
keyset = auth_ledger.keyset
keyset_for_response = []
for keyset in auth_ledger.keysets.values():
if keyset.active:
keyset_for_response.append(
KeysResponseKeyset(
id=keyset.id,
unit=keyset.unit.name,
keys={k: v for k, v in keyset.public_keys_hex.items()},
)
)
return KeysResponse(keysets=keyset_for_response)
@auth_router.get(
"/v1/auth/blind/keys/{keyset_id}",
name="Keyset public keys",
summary="Public keys of a specific keyset",
response_description=(
"All supported token values of the mint and their associated"
" public key for a specific keyset."
),
response_model=KeysResponse,
)
async def keyset_keys(keyset_id: str) -> KeysResponse:
"""
Get the public keys of the mint from a specific keyset id.
"""
logger.trace(f"> GET /v1/auth/blind/keys/{keyset_id}")
keyset = auth_ledger.keysets.get(keyset_id)
if keyset is None:
raise KeysetNotFoundError(keyset_id)
keyset_for_response = KeysResponseKeyset(
id=keyset.id,
unit=keyset.unit.name,
keys={k: v for k, v in keyset.public_keys_hex.items()},
)
return KeysResponse(keysets=[keyset_for_response])
@auth_router.get(
"/v1/auth/blind/keysets",
name="Active keysets",
summary="Get all active keyset id of the mind",
response_model=KeysetsResponse,
response_description="A list of all active keyset ids of the mint.",
)
async def keysets() -> KeysetsResponse:
"""This endpoint returns a list of keysets that the mint currently supports and will accept tokens from."""
logger.trace("> GET /v1/auth/blind/keysets")
keysets = []
for id, keyset in auth_ledger.keysets.items():
keysets.append(
KeysetsResponseKeyset(
id=keyset.id,
unit=keyset.unit.name,
active=keyset.active,
)
)
return KeysetsResponse(keysets=keysets)
@auth_router.post(
"/v1/auth/blind/mint",
name="Mint blind auth tokens",
summary="Mint blind auth tokens for a user.",
response_model=PostAuthBlindMintResponse,
)
async def auth_blind_mint(
request_data: PostAuthBlindMintRequest, request: Request
) -> PostAuthBlindMintResponse:
signatures = await auth_ledger.mint_blind_auth(
outputs=request_data.outputs, user=request.state.user
)
return PostAuthBlindMintResponse(signatures=signatures)

235
cashu/mint/auth/server.py Normal file
View File

@@ -0,0 +1,235 @@
from contextlib import asynccontextmanager
from typing import Any, List, Optional
import httpx
import jwt
from loguru import logger
from ...core.base import AuthProof
from ...core.db import Database
from ...core.errors import (
BlindAuthAmountExceededError,
BlindAuthFailedError,
BlindAuthRateLimitExceededError,
ClearAuthFailedError,
)
from ...core.models import BlindedMessage, BlindedSignature
from ...core.settings import settings
from ..crud import LedgerCrudSqlite
from ..ledger import Ledger
from ..limit import assert_limit
from .base import User
from .crud import AuthLedgerCrud, AuthLedgerCrudSqlite
class AuthLedger(Ledger):
auth_crud: AuthLedgerCrud
jwks_url: str
jwks_client: jwt.PyJWKClient
issuer: str
oicd_discovery_json: dict
def __init__(
self,
db: Database,
seed: str,
seed_decryption_key: Optional[str] = None,
derivation_path="",
amounts: Optional[List[int]] = None,
crud=LedgerCrudSqlite(),
):
super().__init__(
db=db,
seed=seed,
backends=None,
seed_decryption_key=seed_decryption_key,
derivation_path=derivation_path,
crud=crud,
amounts=amounts,
)
self.oicd_discovery_url = settings.mint_auth_oicd_discovery_url or ""
async def init_auth(self):
if not self.oicd_discovery_url:
raise Exception("Missing OpenID Connect discovery URL.")
logger.info(f"Initializing OpenID Connect: {self.oicd_discovery_url}")
self.oicd_discovery_json = self._get_oicd_discovery_json()
self.jwks_url = self.oicd_discovery_json["jwks_uri"]
self.jwks_client = jwt.PyJWKClient(self.jwks_url)
logger.info(f"Getting JWKS from: {self.jwks_url}")
self.auth_crud = AuthLedgerCrudSqlite()
self.issuer: str = self.oicd_discovery_json["issuer"]
logger.info(f"Initialized OpenID Connect: {self.issuer}")
def _get_oicd_discovery_json(self) -> dict:
resp = httpx.get(self.oicd_discovery_url)
resp.raise_for_status()
return resp.json()
def _verify_oicd_issuer(self, clear_auth_token: str) -> None:
"""Verify the issuer of the clear-auth token.
Args:
clear_auth_token (str): JWT token.
Raises:
Exception: Invalid issuer.
"""
try:
decoded = jwt.decode(
clear_auth_token,
options={"verify_signature": False},
)
issuer = decoded["iss"]
if issuer != self.issuer:
raise Exception(f"Invalid issuer: {issuer}. Expected: {self.issuer}")
except Exception as e:
raise e
def _verify_decode_jwt(self, clear_auth_token: str) -> Any:
"""Verify the clear-auth JWT token.
Args:
clear_auth_token (str): JWT token.
Raises:
jwt.ExpiredSignatureError: Token has expired.
jwt.InvalidSignatureError: Invalid signature.
jwt.InvalidTokenError: Invalid token.
Returns:
Any: Decoded JWT.
"""
try:
# Use PyJWKClient to fetch the appropriate key based on the token's header
signing_key = self.jwks_client.get_signing_key_from_jwt(clear_auth_token)
decoded = jwt.decode(
clear_auth_token,
signing_key.key,
algorithms=["RS256", "ES256"],
options={"verify_aud": False},
issuer=self.issuer,
)
logger.trace(f"Decoded JWT: {decoded}")
except jwt.ExpiredSignatureError as e:
logger.error("Token has expired")
raise e
except jwt.InvalidSignatureError as e:
logger.error("Invalid signature")
raise e
except jwt.InvalidTokenError as e:
logger.error("Invalid token")
raise e
except Exception as e:
raise e
return decoded
async def _get_user(self, decoded_token: Any) -> User:
"""Get the user from the decoded token. If the user does not exist, create a new one.
Args:
decoded_token (Any): decoded JWT from PyJWT.decode
Returns:
User: User object
"""
user_id = decoded_token["sub"]
user = await self.auth_crud.get_user(user_id=user_id, db=self.db)
if not user:
logger.info(f"Creating new user: {user_id}")
user = User(id=user_id)
await self.auth_crud.create_user(user=user, db=self.db)
return user
async def verify_clear_auth(self, clear_auth_token: str) -> User:
"""Verify the clear-auth JWT token and return the user.
Checks:
- Token not expired.
- Token signature valid.
- User exists.
Args:
auth_token (str): JWT token.
Returns:
User: Authenticated user.
"""
try:
self._verify_oicd_issuer(clear_auth_token)
decoded = self._verify_decode_jwt(clear_auth_token)
user = await self._get_user(decoded)
except Exception:
raise ClearAuthFailedError()
logger.info(f"User authenticated: {user.id}")
try:
assert_limit(user.id)
except Exception:
raise BlindAuthRateLimitExceededError()
return user
async def mint_blind_auth(
self,
*,
outputs: List[BlindedMessage],
user: User,
) -> List[BlindedSignature]:
"""Mints auth tokens. Returns a list of promises.
Args:
outputs (List[BlindedMessage]): Outputs to sign.
user (User): Authenticated user.
Raises:
Exception: Invalid auth.
Exception: Output verification failed.
Exception: Output quota exceeded.
Returns:
List[BlindedSignature]: List of blinded signatures.
"""
if len(outputs) > settings.mint_auth_max_blind_tokens:
raise BlindAuthAmountExceededError(
f"Too many outputs. You can only mint {settings.mint_auth_max_blind_tokens} tokens."
)
await self._verify_outputs(outputs)
promises = await self._generate_promises(outputs)
# update last_access timestamp of the user
await self.auth_crud.update_user(user_id=user.id, db=self.db)
return promises
@asynccontextmanager
async def verify_blind_auth(self, blind_auth_token):
"""Wrapper context that puts blind auth tokens into pending list and
melts them if the wrapped call succeeds. If it fails, the blind auth
token is not invalidated.
Args:
blind_auth_token (str): Blind auth token.
Raises:
Exception: Blind auth token validation failed.
"""
try:
proof = AuthProof.from_base64(blind_auth_token).to_proof()
await self.verify_inputs_and_outputs(proofs=[proof])
await self.db_write._verify_spent_proofs_and_set_pending([proof])
except Exception as e:
logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError()
try:
yield
await self._invalidate_proofs(proofs=[proof])
except Exception as e:
logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError()
finally:
await self.db_write._unset_proofs_pending([proof])

View File

@@ -642,8 +642,8 @@ class LedgerCrudSqlite(LedgerCrud):
await (conn or db).execute(
f"""
INSERT INTO {db.table_with_schema('keysets')}
(id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk)
VALUES (:id, :seed, :encrypted_seed, :seed_encryption_method, :derivation_path, :valid_from, :valid_to, :first_seen, :active, :version, :unit, :input_fee_ppk)
(id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk, amounts)
VALUES (:id, :seed, :encrypted_seed, :seed_encryption_method, :derivation_path, :valid_from, :valid_to, :first_seen, :active, :version, :unit, :input_fee_ppk, :amounts)
""",
{
"id": keyset.id,
@@ -662,6 +662,7 @@ class LedgerCrudSqlite(LedgerCrud):
"version": keyset.version,
"unit": keyset.unit.name,
"input_fee_ppk": keyset.input_fee_ppk,
"amounts": json.dumps(keyset.amounts),
},
)
@@ -720,7 +721,7 @@ class LedgerCrudSqlite(LedgerCrud):
""",
values,
)
return [MintKeyset(**row) for row in rows]
return [MintKeyset.from_row(row) for row in rows] # type: ignore
async def update_keyset(
self,

View File

@@ -1,12 +1,17 @@
from typing import Any, Dict, List, Union
from ..core.base import Method
from ..core.mint_info import MintInfo
from ..core.models import (
MeltMethodSetting,
MintInfoContact,
MintInfoProtectedEndpoint,
MintMethodSetting,
)
from ..core.nuts.nuts import (
BLIND_AUTH_NUT,
CACHE_NUT,
CLEAR_AUTH_NUT,
DLEQ_NUT,
FEE_RETURN_NUT,
HTLC_NUT,
@@ -21,10 +26,46 @@ from ..core.nuts.nuts import (
WEBSOCKETS_NUT,
)
from ..core.settings import settings
from ..mint.protocols import SupportsBackends
from ..mint.protocols import SupportsBackends, SupportsPubkey
_VERSION_PREFIX = "Nutshell"
_SUPPORTED = "supported"
_METHOD = "method"
_UNIT = "unit"
_BOLT11 = "bolt11"
_MPP = "mpp"
_COMMANDS = "commands"
_BOLT11_MINT_QUOTE = "bolt11_mint_quote"
_BOLT11_MELT_QUOTE = "bolt11_melt_quote"
_PROOF_STATE = "proof_state"
_PROTECTED_ENDPOINTS = "protected_endpoints"
_BAT_MAX_MINT = "bat_max_mint"
_OPENID_DISCOVERY = "openid_discovery"
_CLIENT_ID = "client_id"
class LedgerFeatures(SupportsBackends):
class LedgerFeatures(SupportsBackends, SupportsPubkey):
@property
def mint_info(self) -> MintInfo:
contact_info = [
MintInfoContact(method=m, info=i)
for m, i in settings.mint_info_contact
if m and i
]
return MintInfo(
name=settings.mint_info_name,
pubkey=self.pubkey.serialize().hex() if self.pubkey else None,
version=f"{_VERSION_PREFIX}/{settings.version}",
description=settings.mint_info_description,
description_long=settings.mint_info_description_long,
contact=contact_info,
nuts=self.mint_features,
icon_url=settings.mint_info_icon_url,
motd=settings.mint_info_motd,
time=None,
)
@property
def mint_features(self) -> Dict[int, Union[List[Any], Dict[str, Any]]]:
mint_features = self.create_mint_features()
mint_features = self.add_supported_features(mint_features)
@@ -100,30 +141,62 @@ class LedgerFeatures(SupportsBackends):
# specify which websocket features are supported
# these two are supported by default
websocket_features: Dict[str, List[Dict[str, Union[str, List[str]]]]] = {
"supported": []
_SUPPORTED: []
}
# we check the backend to see if "bolt11_mint_quote" is supported as well
for method, unit_dict in self.backends.items():
if method == Method["bolt11"]:
if method == Method[_BOLT11]:
for unit in unit_dict.keys():
websocket_features["supported"].append(
websocket_features[_SUPPORTED].append(
{
"method": method.name,
"unit": unit.name,
"commands": ["bolt11_melt_quote", "proof_state"],
_METHOD: method.name,
_UNIT: unit.name,
_COMMANDS: [_BOLT11_MELT_QUOTE, _PROOF_STATE],
}
)
if unit_dict[unit].supports_incoming_payment_stream:
supported_features: List[str] = list(
websocket_features["supported"][-1]["commands"]
websocket_features[_SUPPORTED][-1][_COMMANDS]
)
websocket_features["supported"][-1]["commands"] = (
supported_features + ["bolt11_mint_quote"]
websocket_features[_SUPPORTED][-1][_COMMANDS] = (
supported_features + [_BOLT11_MINT_QUOTE]
)
if websocket_features:
mint_features[WEBSOCKETS_NUT] = websocket_features
# signal authentication features
if settings.mint_require_auth:
if not settings.mint_auth_oicd_discovery_url:
raise Exception(
"Missing OpenID Connect discovery URL: MINT_AUTH_OICD_DISCOVERY_URL"
)
clear_auth_features: Dict[str, Union[bool, str, List[str]]] = {
_OPENID_DISCOVERY: settings.mint_auth_oicd_discovery_url,
_CLIENT_ID: settings.mint_auth_oicd_client_id,
_PROTECTED_ENDPOINTS: [],
}
for endpoint in [
MintInfoProtectedEndpoint(method=e[0], path=e[1])
for e in settings.mint_require_clear_auth_paths
]:
clear_auth_features[_PROTECTED_ENDPOINTS].append(endpoint.dict()) # type: ignore
mint_features[CLEAR_AUTH_NUT] = clear_auth_features
blind_auth_features: Dict[str, Union[bool, int, str, List[str]]] = {
_BAT_MAX_MINT: settings.mint_auth_max_blind_tokens,
_PROTECTED_ENDPOINTS: [],
}
for endpoint in [
MintInfoProtectedEndpoint(method=e[0], path=e[1])
for e in settings.mint_require_blind_auth_paths
]:
blind_auth_features[_PROTECTED_ENDPOINTS].append(endpoint.dict()) # type: ignore
mint_features[BLIND_AUTH_NUT] = blind_auth_features
return mint_features
def add_cache_features(

View File

@@ -75,16 +75,26 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
db_read: DbReadHelper
invoice_listener_tasks: List[asyncio.Task] = []
disable_melt: bool = False
pubkey: PublicKey
def __init__(
self,
*,
db: Database,
seed: str,
backends: Mapping[Method, Mapping[Unit, LightningBackend]],
seed_decryption_key: Optional[str] = None,
derivation_path="",
amounts: Optional[List[int]] = None,
backends: Optional[Mapping[Method, Mapping[Unit, LightningBackend]]] = None,
seed_decryption_key: Optional[str] = None,
crud=LedgerCrudSqlite(),
):
) -> None:
self.keysets: Dict[str, MintKeyset] = {}
self.backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
self.events = LedgerEventManager()
self.db_read: DbReadHelper
self.locks: Dict[str, asyncio.Lock] = {} # holds multiprocessing locks
self.invoice_listener_tasks: List[asyncio.Task] = []
if not seed:
raise Exception("seed not set")
@@ -103,24 +113,33 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
self.db = db
self.crud = crud
self.backends = backends
if backends:
self.backends = backends
if amounts:
self.amounts = amounts
else:
self.amounts = [2**n for n in range(settings.max_order)]
self.pubkey = derive_pubkey(self.seed)
self.db_read = DbReadHelper(self.db, self.crud)
self.db_write = DbWriteHelper(self.db, self.crud, self.events, self.db_read)
# ------- STARTUP -------
async def startup_ledger(self):
await self._startup_ledger()
async def startup_ledger(self) -> None:
await self._startup_keysets()
await self._check_backends()
await self._check_pending_proofs_and_melt_quotes()
self.invoice_listener_tasks = await self.dispatch_listeners()
async def _startup_ledger(self):
async def _startup_keysets(self) -> None:
await self.init_keysets()
for derivation_path in settings.mint_derivation_path_list:
await self.activate_keyset(derivation_path=derivation_path)
async def _check_backends(self) -> None:
for method in self.backends:
for unit in self.backends[method]:
logger.info(
@@ -139,7 +158,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
logger.info(f"Data dir: {settings.cashu_dir}")
async def shutdown_ledger(self):
async def shutdown_ledger(self) -> None:
await self.db.engine.dispose()
for task in self.invoice_listener_tasks:
task.cancel()
@@ -169,57 +188,65 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
version: Optional[str] = None,
autosave=True,
) -> MintKeyset:
"""Load the keyset for a derivation path if it already exists. If not generate new one and store in the db.
"""
Load an existing keyset for the specified derivation path or generate a new one if it doesn't exist.
Optionally store the newly created keyset in the database.
Args:
derivation_path (_type_): Derivation path from which the keyset is generated.
autosave (bool, optional): Store newly-generated keyset if not already in database. Defaults to True.
derivation_path (str): Derivation path for keyset generation.
seed (Optional[str], optional): Seed value. Defaults to None.
version (Optional[str], optional): Version identifier. Defaults to None.
autosave (bool, optional): Whether to store the keyset if newly created. Defaults to True.
Returns:
MintKeyset: Keyset
MintKeyset: The activated keyset.
"""
if not derivation_path:
raise Exception("derivation path not set")
raise ValueError("Derivation path must be provided.")
seed = seed or self.seed
tmp_keyset_local = MintKeyset(
version = version or settings.version
# Initialize a temporary keyset to derive the ID
temp_keyset = MintKeyset(
seed=seed,
derivation_path=derivation_path,
version=version or settings.version,
version=version,
amounts=self.amounts,
)
logger.debug(
f"Activating keyset for derivation path {derivation_path} with id"
f" {tmp_keyset_local.id}."
f"Activating keyset for derivation path '{derivation_path}' with ID '{temp_keyset.id}'."
)
# load the keyset from db
logger.trace(f"crud: loading keyset for {derivation_path}")
tmp_keysets_local: List[MintKeyset] = await self.crud.get_keyset(
id=tmp_keyset_local.id, db=self.db
# Attempt to retrieve existing keysets from the database
existing_keysets: List[MintKeyset] = await self.crud.get_keyset(
id=temp_keyset.id, db=self.db
)
logger.trace(f"crud: loaded {len(tmp_keysets_local)} keysets")
if tmp_keysets_local:
# we have a keyset with this derivation path in the database
keyset = tmp_keysets_local[0]
logger.trace(
f"Retrieved {len(existing_keysets)} keyset(s) for derivation path '{derivation_path}'."
)
if existing_keysets:
keyset = existing_keysets[0]
else:
# no keyset for this derivation path yet
# we create a new keyset (keys will be generated at instantiation)
# Create a new keyset if none exists
keyset = MintKeyset(
seed=seed or self.seed,
seed=seed,
derivation_path=derivation_path,
version=version or settings.version,
amounts=self.amounts,
version=version,
input_fee_ppk=settings.mint_input_fee_ppk,
)
logger.debug(f"Generated new keyset {keyset.id}.")
logger.debug(f"Generated new keyset with ID '{keyset.id}'.")
if autosave:
logger.debug(f"crud: storing new keyset {keyset.id}.")
logger.debug(f"Storing new keyset with ID '{keyset.id}'.")
await self.crud.store_keyset(keyset=keyset, db=self.db)
logger.trace(f"crud: stored new keyset {keyset.id}.")
# activate this keyset
# Activate the keyset
keyset.active = True
# load the new keyset in self.keysets
self.keysets[keyset.id] = keyset
logger.debug(f"Keyset with ID '{keyset.id}' is now active.")
logger.debug(f"Loaded keyset {keyset.id}")
return keyset
async def init_keysets(self, autosave: bool = True) -> None:

View File

@@ -1,3 +1,5 @@
from typing import Optional
from fastapi import WebSocket, status
from fastapi.responses import JSONResponse
from limits import RateLimitItemPerMinute
@@ -42,26 +44,26 @@ limiter = Limiter(
)
def assert_limit(identifier: str):
def assert_limit(identifier: str, limit: Optional[int] = None):
"""Custom rate limit handler that accepts a string identifier
and raises an exception if the rate limit is exceeded. Uses the
setting `mint_transaction_rate_limit_per_minute` for the rate limit.
Args:
identifier (str): The identifier to use for the rate limit. IP address for example.
limit (Optional[int], optional): The rate limit per minute to use. Defaults to None
Raises:
Exception: If the rate limit is exceeded.
"""
global limiter
limit_per_minute = limit or settings.mint_transaction_rate_limit_per_minute
success = limiter._limiter.hit(
RateLimitItemPerMinute(settings.mint_transaction_rate_limit_per_minute),
RateLimitItemPerMinute(limit_per_minute),
identifier,
)
if not success:
logger.warning(
f"Rate limit {settings.mint_transaction_rate_limit_per_minute}/minute exceeded: {identifier}"
)
logger.warning(f"Rate limit {limit_per_minute}/minute exceeded: {identifier}")
raise Exception("Rate limit exceeded")

View File

@@ -10,7 +10,10 @@ from fastapi.exception_handlers import (
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import (
BaseHTTPMiddleware,
RequestResponseEndpoint,
)
from starlette.middleware.cors import CORSMiddleware
from ..core.settings import settings
@@ -22,6 +25,8 @@ if settings.debug_profiling:
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from .startup import auth_ledger
def add_middlewares(app: FastAPI):
app.add_middleware(
@@ -42,6 +47,52 @@ def add_middlewares(app: FastAPI):
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
if settings.mint_require_auth:
app.add_middleware(BlindAuthMiddleware)
app.add_middleware(ClearAuthMiddleware)
class ClearAuthMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if (
settings.mint_require_auth
and auth_ledger.mint_info.requires_clear_auth_path(
method=request.method, path=request.url.path
)
):
clear_auth_token = request.headers.get("clear-auth")
if not clear_auth_token:
raise Exception("Missing clear auth token.")
try:
user = await auth_ledger.verify_clear_auth(
clear_auth_token=clear_auth_token
)
request.state.user = user
except Exception as e:
raise e
return await call_next(request)
class BlindAuthMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if (
settings.mint_require_auth
and auth_ledger.mint_info.requires_blind_auth_path(
method=request.method, path=request.url.path
)
):
blind_auth_token = request.headers.get("blind-auth")
if not blind_auth_token:
raise Exception("Missing blind auth token.")
async with auth_ledger.verify_blind_auth(blind_auth_token):
return await call_next(request)
else:
return await call_next(request)
async def request_validation_exception_handler(
request: Request, exc: RequestValidationError
@@ -66,11 +117,11 @@ class CompressionMiddleware(BaseHTTPMiddleware):
response = await call_next(request)
# Handle streaming responses differently
if response.__class__.__name__ == 'StreamingResponse':
if response.__class__.__name__ == "StreamingResponse":
return response
response_body = b''
async for chunk in response.body_iterator:
response_body = b""
async for chunk in response.body_iterator: # type: ignore
response_body += chunk
accept_encoding = request.headers.get("Accept-Encoding", "")
@@ -97,5 +148,5 @@ class CompressionMiddleware(BaseHTTPMiddleware):
content=content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
media_type=response.media_type,
)

View File

@@ -858,3 +858,13 @@ async def m024_add_melt_quote_outputs(db: Database):
ADD COLUMN outputs TEXT DEFAULT NULL
"""
)
async def m025_add_amounts_to_keysets(db: Database):
async with db.connect() as conn:
await conn.execute(
f"ALTER TABLE {db.table_with_schema('keysets')} ADD COLUMN amounts TEXT"
)
await conn.execute(
f"UPDATE {db.table_with_schema('keysets')} SET amounts = '[]'"
)

View File

@@ -1,6 +1,7 @@
from typing import Dict, Mapping, Protocol
from ..core.base import Method, MintKeyset, Unit
from ..core.crypto.secp import PublicKey
from ..core.db import Database
from ..lightning.base import LightningBackend
from ..mint.crud import LedgerCrud
@@ -18,6 +19,10 @@ class SupportsBackends(Protocol):
backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
class SupportsPubkey(Protocol):
pubkey: PublicKey
class SupportsDb(Protocol):
db: Database
db_read: DbReadHelper

View File

@@ -11,7 +11,6 @@ from ..core.models import (
KeysetsResponseKeyset,
KeysResponse,
KeysResponseKeyset,
MintInfoContact,
PostCheckStateRequest,
PostCheckStateResponse,
PostMeltQuoteRequest,
@@ -44,23 +43,18 @@ redis = RedisCache()
)
async def info() -> GetInfoResponse:
logger.trace("> GET /v1/info")
mint_features = ledger.mint_features()
contact_info = [
MintInfoContact(method=m, info=i)
for m, i in settings.mint_info_contact
if m and i
]
mint_info = ledger.mint_info
return GetInfoResponse(
name=settings.mint_info_name,
pubkey=ledger.pubkey.serialize().hex() if ledger.pubkey else None,
version=f"Nutshell/{settings.version}",
description=settings.mint_info_description,
description_long=settings.mint_info_description_long,
contact=contact_info,
nuts=mint_features,
icon_url=settings.mint_info_icon_url,
name=mint_info.name,
pubkey=mint_info.pubkey,
version=mint_info.version,
description=mint_info.description,
description_long=mint_info.description_long,
contact=mint_info.contact,
nuts=mint_info.nuts,
icon_url=mint_info.icon_url,
urls=settings.mint_info_urls,
motd=settings.mint_info_motd,
motd=mint_info.motd,
time=int(time.time()),
)

View File

@@ -12,7 +12,9 @@ from ..core.db import Database
from ..core.migrations import migrate_databases
from ..core.settings import settings
from ..lightning.base import LightningBackend
from ..mint import migrations
from ..mint import migrations as mint_migrations
from ..mint.auth import migrations as auth_migrations
from ..mint.auth.server import AuthLedger
from ..mint.crud import LedgerCrudSqlite
from ..mint.ledger import Ledger
@@ -76,6 +78,15 @@ ledger = Ledger(
crud=LedgerCrudSqlite(),
)
# start auth ledger
auth_ledger = AuthLedger(
db=Database("auth", settings.auth_database),
seed="auth seed here",
amounts=[1],
derivation_path="m/0'/999'/0'",
crud=LedgerCrudSqlite(),
)
async def rotate_keys(n_seconds=60):
"""Rotate keyset epoch every n_seconds.
@@ -93,8 +104,17 @@ async def rotate_keys(n_seconds=60):
await asyncio.sleep(n_seconds)
async def start_mint_init():
await migrate_databases(ledger.db, migrations)
async def start_auth():
await migrate_databases(auth_ledger.db, auth_migrations)
logger.info("Starting auth ledger.")
await auth_ledger.init_keysets()
await auth_ledger.init_auth()
logger.info("Auth ledger started.")
async def start_mint():
await migrate_databases(ledger.db, mint_migrations)
logger.info("Starting mint ledger.")
await ledger.startup_ledger()
logger.info("Mint started.")
# asyncio.create_task(rotate_keys())

View File

@@ -1,22 +1,14 @@
import asyncio
from typing import List, Mapping
from typing import List
from loguru import logger
from ..core.base import Method, MintQuoteState, Unit
from ..core.db import Database
from ..core.base import MintQuoteState
from ..lightning.base import LightningBackend
from ..mint.crud import LedgerCrud
from .events.events import LedgerEventManager
from .protocols import SupportsBackends, SupportsDb, SupportsEvents
class LedgerTasks(SupportsDb, SupportsBackends, SupportsEvents):
backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
db: Database
crud: LedgerCrud
events: LedgerEventManager
async def dispatch_listeners(self) -> List[asyncio.Task]:
tasks = []
for method, unitbackends in self.backends.items():

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
from loguru import logger
@@ -6,14 +6,13 @@ from ..core.base import (
BlindedMessage,
BlindedSignature,
Method,
MintKeyset,
MintQuote,
Proof,
Unit,
)
from ..core.crypto import b_dhke
from ..core.crypto.secp import PublicKey
from ..core.db import Connection, Database
from ..core.db import Connection
from ..core.errors import (
InvalidProofsError,
NoSecretInProofsError,
@@ -25,11 +24,7 @@ from ..core.errors import (
)
from ..core.nuts import nut20
from ..core.settings import settings
from ..lightning.base import LightningBackend
from ..mint.crud import LedgerCrud
from .conditions import LedgerSpendingConditions
from .db.read import DbReadHelper
from .db.write import DbWriteHelper
from .protocols import SupportsBackends, SupportsDb, SupportsKeysets
@@ -38,14 +33,6 @@ class LedgerVerification(
):
"""Verification functions for the ledger."""
keyset: MintKeyset
keysets: Dict[str, MintKeyset]
crud: LedgerCrud
db: Database
db_read: DbReadHelper
db_write: DbWriteHelper
lightning: Dict[Unit, LightningBackend]
async def verify_inputs_and_outputs(
self,
*,
@@ -55,6 +42,8 @@ class LedgerVerification(
):
"""Checks all proofs and outputs for validity.
Warning: Does NOT check if the proofs were already spent. Use `db_write._verify_proofs_spendable` for that.
Args:
proofs (List[Proof]): List of proofs to check.
outputs (Optional[List[BlindedMessage]], optional): List of outputs to check.

View File

@@ -22,14 +22,11 @@ class NostrClient:
relays = [
"wss://nostr-pub.wellorder.net",
"wss://relay.damus.io",
"wss://nostr.zebedee.cloud",
"wss://relay.snort.social",
"wss://nostr.fmt.wiz.biz",
"wss://nos.lol",
"wss://nostr.oxtr.dev",
"wss://relay.current.fyi",
"wss://relay.snort.social",
] # ["wss://nostr.oxtr.dev"] # ["wss://relay.nostr.info"] "wss://nostr-pub.wellorder.net" "ws://91.237.88.218:2700", "wss://nostrrr.bublina.eu.org", ""wss://nostr-relay.freeberty.net"", , "wss://nostr.oxtr.dev", "wss://relay.nostr.info", "wss://nostr-pub.wellorder.net" , "wss://relayer.fiatjaf.com", "wss://nodestr.fmt.wiz.biz/", "wss://no.str.cr"
]
relay_manager = RelayManager()
private_key: PrivateKey
public_key: PublicKey

View File

240
cashu/wallet/auth/auth.py Normal file
View File

@@ -0,0 +1,240 @@
import hashlib
import os
from typing import List, Optional
from loguru import logger
from cashu.core.helpers import sum_proofs
from cashu.core.mint_info import MintInfo
from ...core.base import Proof
from ...core.crypto.secp import PrivateKey
from ...core.db import Database
from ..crud import get_mint_by_url, update_mint
from ..wallet import Wallet
from .openid_connect.openid_client import AuthorizationFlow, OpenIDClient
class WalletAuth(Wallet):
oidc_discovery_url: str
oidc_client: OpenIDClient
wallet_db: Database
auth_flow: AuthorizationFlow
username: str | None
password: str | None
# API prefix for all requests
api_prefix = "/v1/auth/blind"
def __init__(
self, url: str, db: str, name: str = "auth", unit: str = "auth", **kwargs
):
"""Authentication wallet.
Args:
url (str): Mint url.
db (str): Auth wallet db location.
wallet_db (str): Wallet db location.
name (str, optional): Wallet name. Defaults to "auth".
unit (str, optional): Wallet unit. Defaults to "auth".
kwargs: Additional keyword arguments.
client_id (str, optional): OpenID client id. Defaults to "cashu-client".
client_secret (str, optional): OpenID client secret. Defaults to "".
username (str, optional): OpenID username. When set, the username and
password flow will be used to authenticate. If a username is already
stored in the database, it will be used. Will be stored in the
database if not already stored.
password (str, optional): OpenID password. Used if username is set. Will
be read from the database if already stored. Will be stored in the
database if not already stored.
"""
super().__init__(url, db, name, unit)
self.client_id = kwargs.get("client_id", "cashu-client")
logger.trace(f"client_id: {self.client_id}")
self.client_secret = kwargs.get("client_secret", "")
self.username = kwargs.get("username")
self.password = kwargs.get("password")
if self.username:
if self.password is None:
raise Exception("Password must be set if username is set.")
self.auth_flow = AuthorizationFlow.PASSWORD
else:
self.auth_flow = AuthorizationFlow.AUTHORIZATION_CODE
# self.auth_flow = AuthorizationFlow.DEVICE_CODE
self.access_token = kwargs.get("access_token")
self.refresh_token = kwargs.get("refresh_token")
# overload with_db
@classmethod
async def with_db(cls, *args, **kwargs) -> "WalletAuth":
"""Create a new wallet with a database.
Keyword arguments:
url (str): Mint url.
db (str): Wallet db location.
name (str, optional): Wallet name. Defaults to "auth".
username (str, optional): OpenID username. When set, the username and
password flow will be used to authenticate. If a username is already
stored in the database, it will be used. Will be stored in the
database if not already stored.
password (str, optional): OpenID password. Used if username is set. Will
be read from the database if already stored. Will be stored in the
database if not already stored.
client_id (str, optional): OpenID client id. Defaults to "cashu-client".
client_secret (str, optional): OpenID client secret. Defaults to "".
access_token (str, optional): OpenID access token. Defaults to None.
refresh_token (str, optional): OpenID refresh token. Defaults to None.
Returns:
WalletAuth: WalletAuth instance.
"""
url: str = kwargs.get("url", "")
db = kwargs.get("db", "")
kwargs["name"] = kwargs.get("name", "auth")
name = kwargs["name"]
username = kwargs.get("username")
password = kwargs.get("password")
wallet_db = Database(name, db)
# run migrations etc
kwargs.update(dict(skip_db_read=True))
await super().with_db(*args, **kwargs)
# the wallet might not have been created yet
# if it was though, we load the username, password,
# access token and refresh token from the database
try:
mint_db = await get_mint_by_url(wallet_db, url)
if mint_db:
kwargs.update(
{
"username": username or mint_db.username,
"password": password or mint_db.password,
"access_token": mint_db.access_token,
"refresh_token": mint_db.refresh_token,
}
)
except Exception:
pass
return cls(*args, **kwargs)
async def init_auth_wallet(
self,
mint_info: Optional[MintInfo] = None,
mint_auth_proofs=True,
force_auth=False,
) -> bool:
"""Initialize authentication wallet.
Args:
mint_info (MintInfo, optional): Mint information. If not provided, we load the
info from the database or the mint directly. Defaults to None.
mint_auth_proofs (bool, optional): Whether to mint auth proofs if necessary.
Defaults to True.
force_auth (bool, optional): Whether to force authentication. Defaults to False.
Returns:
bool: False if the mint does not require clear auth. True otherwise.
"""
if mint_info:
self.mint_info = mint_info
await self.load_mint_info()
if not self.mint_info.requires_clear_auth():
return False
# Use the blind auth api_prefix for all following requests
await self.load_mint_keysets()
await self.activate_keyset()
await self.load_proofs()
self.oidc_discovery_url = self.mint_info.oidc_discovery_url()
self.client_id = self.mint_info.oidc_client_id()
# Initialize OpenIDClient
self.oidc_client = OpenIDClient(
discovery_url=self.oidc_discovery_url,
client_id=self.client_id,
client_secret=self.client_secret,
auth_flow=self.auth_flow,
username=self.username,
password=self.password,
access_token=self.access_token,
refresh_token=self.refresh_token,
)
# Authenticate using OpenIDClient
await self.oidc_client.initialize()
await self.oidc_client.authenticate(force_authenticate=force_auth)
await self.store_username_password()
await self.store_clear_auth_token()
if mint_auth_proofs:
await self.mint_blind_auth_min_balance()
return True
async def mint_blind_auth_min_balance(self) -> None:
"""Mint auth tokens if balance is too low."""
MIN_BALANCE = self.mint_info.bat_max_mint
if self.available_balance < MIN_BALANCE:
logger.debug(
f"Balance too low. Minting {self.unit.str(MIN_BALANCE)} auth tokens."
)
try:
await self.mint_blind_auth()
except Exception as e:
logger.error(f"Error minting auth proofs: {str(e)}")
async def store_username_password(self) -> None:
"""Store the username and password in the database."""
if self.username and self.password:
mint_db = await get_mint_by_url(self.db, self.url)
if not mint_db:
raise Exception("Mint not found.")
if mint_db.username != self.username or mint_db.password != self.password:
mint_db.username = self.username
mint_db.password = self.password
await update_mint(self.db, mint_db)
async def store_clear_auth_token(self) -> None:
"""Store the access and refresh tokens in the database."""
access_token = self.oidc_client.access_token
refresh_token = self.oidc_client.refresh_token
if not access_token or not refresh_token:
raise Exception("Access or refresh token not available.")
# Store the tokens in the database
mint_db = await get_mint_by_url(self.db, self.url)
if not mint_db:
raise Exception("Mint not found.")
if (
mint_db.access_token != access_token
or mint_db.refresh_token != refresh_token
):
mint_db.access_token = access_token
mint_db.refresh_token = refresh_token
await update_mint(self.db, mint_db)
async def mint_blind_auth(self) -> List[Proof]:
# Ensure access token is valid
if self.oidc_client.is_token_expired():
await self.oidc_client.refresh_access_token()
await self.store_clear_auth_token()
clear_auth_token = self.oidc_client.access_token
if not clear_auth_token:
raise Exception("No clear auth token available.")
amounts = self.mint_info.bat_max_mint * [1] # 1 AUTH tokens
secrets = [hashlib.sha256(os.urandom(32)).hexdigest() for _ in amounts]
rs = [PrivateKey(privkey=os.urandom(32), raw=True) for _ in amounts]
derivation_paths = ["" for _ in amounts]
outputs, rs = self._construct_outputs(amounts, secrets, rs)
promises = await self.blind_mint_blind_auth(clear_auth_token, outputs)
new_proofs = await self._construct_proofs(
promises, secrets, rs, derivation_paths
)
logger.debug(
f"Minted {self.unit.str(sum_proofs(new_proofs))} blind auth proofs."
)
return new_proofs

View File

@@ -0,0 +1,487 @@
import argparse
import asyncio
import base64
import secrets
import webbrowser
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Dict, Optional
from urllib.parse import urlencode
import httpx
import jwt
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from loguru import logger
class AuthorizationFlow(Enum):
AUTHORIZATION_CODE = "authorization_code"
PASSWORD = "password"
DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
class OpenIDClient:
"""OpenID Connect client for authentication."""
oidc_config: Dict[str, Any]
def __init__(
self,
discovery_url: str,
client_id: str,
client_secret: str = "",
auth_flow: Optional[AuthorizationFlow] = None,
username: Optional[str] = None,
password: Optional[str] = None,
access_token: Optional[str] = None,
refresh_token: Optional[str] = None,
token_expiration_time: Optional[datetime] = None,
device_code: Optional[str] = None,
) -> None:
self.discovery_url: str = discovery_url
self.client_id: str = client_id
self.client_secret: str = client_secret
self.auth_flow: Optional[AuthorizationFlow] = auth_flow
self.username: Optional[str] = username
self.password: Optional[str] = password
self.access_token: Optional[str] = access_token
self.refresh_token: Optional[str] = refresh_token
self.token_expiration_time: Optional[datetime] = token_expiration_time
self.device_code: Optional[str] = device_code
self.redirect_uri: str = "http://localhost:33388/callback"
self.expected_state: str = secrets.token_urlsafe(16)
self.token_response: Dict[str, Any] = {}
self.token_event: asyncio.Event = asyncio.Event()
self.token_endpoint: str = ""
self.authorization_endpoint: str = ""
self.introspection_endpoint: Optional[str] = None
self.revocation_endpoint: Optional[str] = None
self.device_authorization_endpoint: Optional[str] = None
self.templates: Jinja2Templates = Jinja2Templates(
directory="cashu/wallet/auth/openid_connect/templates"
)
self.app: FastAPI = FastAPI()
self.app.state.client = self # Store self in app state
async def initialize(self) -> None:
"""Initialize the client asynchronously."""
await self.fetch_oidc_configuration()
await self.determine_auth_flow()
async def determine_auth_flow(self) -> AuthorizationFlow:
"""Determine the authentication flow to use from the oidc configuration.
Supported flows are chosen in the following order:
- device_code
- authorization_code
- password
"""
if not hasattr(self, "oidc_config"):
raise ValueError(
"OIDC configuration not loaded. Call fetch_oidc_configuration first."
)
supported_flows = self.oidc_config.get("grant_types_supported", [])
# if self.auth_flow is already set, check if it is supported
if self.auth_flow:
if self.auth_flow.value not in supported_flows:
raise ValueError(
f"Authentication flow {self.auth_flow.value} not supported by the OIDC configuration."
)
return self.auth_flow
if AuthorizationFlow.DEVICE_CODE.value in supported_flows:
self.auth_flow = AuthorizationFlow.DEVICE_CODE
elif AuthorizationFlow.AUTHORIZATION_CODE.value in supported_flows:
self.auth_flow = AuthorizationFlow.AUTHORIZATION_CODE
elif AuthorizationFlow.PASSWORD.value in supported_flows:
self.auth_flow = AuthorizationFlow.PASSWORD
else:
raise ValueError(
"No supported authentication flows found in the OIDC configuration."
)
logger.debug(f"Determined authentication flow: {self.auth_flow.value}")
return self.auth_flow
async def fetch_oidc_configuration(self) -> None:
"""Fetch OIDC configuration from the discovery URL."""
try:
async with httpx.AsyncClient() as client:
response = await client.get(self.discovery_url)
response.raise_for_status()
self.oidc_config = response.json()
self.authorization_endpoint = self.oidc_config.get( # type: ignore
"authorization_endpoint"
)
self.token_endpoint = self.oidc_config.get("token_endpoint") # type: ignore
self.introspection_endpoint = self.oidc_config.get(
"introspection_endpoint"
)
self.revocation_endpoint = self.oidc_config.get("revocation_endpoint")
self.device_authorization_endpoint = self.oidc_config.get(
"device_authorization_endpoint"
)
except httpx.HTTPError as e:
logger.error(f"Failed to get OpenID configuration: {e}")
raise
async def handle_callback(self, request: Request) -> HTMLResponse:
"""Endpoint to handle the redirect from the OpenID provider."""
params = request.query_params
if "error" in params:
error_str = params["error"]
if "error_description" in params:
error_str += f": {params['error_description']}"
return self.templates.TemplateResponse(
"error.html",
{"request": request, "error": error_str},
)
elif "code" in params and "state" in params:
code: str = params["code"]
state: str = params["state"]
if state != self.expected_state:
raise HTTPException(status_code=400, detail="Invalid state parameter")
token_data: Dict[str, Any] = await self.exchange_code_for_token(code)
self.update_token_data(token_data)
response = self.render_success_page(request, token_data)
self.token_event.set() # Signal that the token has been received
return response
else:
return self.templates.TemplateResponse(
"error.html",
{"request": request, "error": "Missing 'code' or 'state' parameter"},
)
async def exchange_code_for_token(self, code: str) -> Dict[str, Any]:
"""Exchange the authorization code for tokens."""
data: Dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
}
headers: Dict[str, str] = {}
if self.client_secret:
# Use HTTP Basic Auth if client_secret is provided
basic_auth: str = f"{self.client_id}:{self.client_secret}"
basic_auth_bytes: bytes = basic_auth.encode("ascii")
basic_auth_b64: str = base64.b64encode(basic_auth_bytes).decode("ascii")
headers["Authorization"] = f"Basic {basic_auth_b64}"
else:
# Include client_id in the POST body for public clients
data["client_id"] = self.client_id
async with httpx.AsyncClient() as client:
try:
response = await client.post(
self.token_endpoint, data=data, headers=headers
)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
logger.error(f"HTTP error occurred during token exchange: {e}")
self.token_event.set()
return {}
async def run_server(self) -> None:
"""Run the FastAPI server."""
config = uvicorn.Config(
self.app, host="127.0.0.1", port=33388, log_level="error"
)
self.server = uvicorn.Server(config)
await self.server.serve()
async def shutdown_server(self) -> None:
"""Shut down the uvicorn server."""
self.server.should_exit = True
async def authenticate(self, force_authenticate: bool = False) -> None:
"""Start the authentication process."""
need_authenticate = force_authenticate
if self.access_token and self.refresh_token:
# We have a token and a refresh token, check if token is expired
if self.is_token_expired():
try:
await self.refresh_access_token()
except httpx.HTTPError:
logger.debug("Failed to refresh token.")
need_authenticate = True
else:
logger.debug("Using existing access token.")
else:
need_authenticate = True
if need_authenticate:
if self.auth_flow == AuthorizationFlow.AUTHORIZATION_CODE:
await self.authenticate_with_authorization_code()
elif self.auth_flow == AuthorizationFlow.PASSWORD:
await self.authenticate_with_password()
elif self.auth_flow == AuthorizationFlow.DEVICE_CODE:
await self.authenticate_with_device_code()
else:
raise ValueError(f"Unknown authentication flow: {self.auth_flow}")
def is_token_expired(self) -> bool:
"""Check if the access token is expired."""
if not self.access_token:
raise ValueError("Access token is not set.")
decoded = jwt.decode(self.access_token, options={"verify_signature": False})
exp = decoded.get("exp")
if not exp:
return False
return datetime.now() >= datetime.fromtimestamp(exp) - timedelta(minutes=1)
async def refresh_access_token(self) -> None:
"""Refresh the access token using the refresh token."""
data = {
"grant_type": "refresh_token",
"refresh_token": self.refresh_token,
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.token_endpoint, data=data)
response.raise_for_status()
token_data = response.json()
self.update_token_data(token_data)
logger.info("Token refreshed successfully.")
except httpx.HTTPError as e:
logger.debug(f"Failed to refresh token: {e}")
raise
async def authenticate_with_authorization_code(self) -> None:
"""Authenticate using the authorization code flow."""
# Set up the route handlers
@self.app.get("/callback", response_class=HTMLResponse)
async def handle_callback(request: Request) -> Any:
print("CALLBACK")
return await self.handle_callback(request)
# Build the authorization URL
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": "openid",
"state": self.expected_state,
}
auth_url = f"{self.authorization_endpoint}?{urlencode(params)}"
# Start the web server as an asyncio task
server_task = asyncio.create_task(self.run_server())
# Open the browser or print the URL for the user
logger.info("Please open the following URL in your browser to authenticate:")
logger.info(auth_url)
webbrowser.open(auth_url)
# Wait for the token response
logger.info("Waiting for authentication...")
await self.token_event.wait()
# Use the retrieved tokens
if self.token_response:
logger.info("Authentication successful!")
# logger.info("Token response:")
# logger.info(self.token_response)
else:
logger.error("Authentication failed.")
# Signal the server to shut down
await self.shutdown_server()
# Wait for the server task to finish
await server_task
def update_token_data(self, token_data: Dict[str, Any]) -> None:
self.token_response.update(token_data)
self.access_token = token_data.get("access_token")
self.refresh_token = token_data.get("refresh_token")
if not self.access_token or not self.refresh_token:
raise ValueError(
"Access token or refresh token not found in token response."
)
expires_in = token_data.get("expires_in")
if expires_in:
self.token_expiration_time = datetime.utcnow() + timedelta(
seconds=int(expires_in)
)
refresh_expires_in = token_data.get("refresh_expires_in")
if refresh_expires_in:
logger.debug(f"Refresh token expires in {refresh_expires_in} seconds.")
self.refresh_token_expiration_time = datetime.utcnow() + timedelta(
seconds=int(refresh_expires_in)
)
async def authenticate_with_password(self) -> None:
"""Authenticate using the resource owner password credentials flow."""
if not self.username or not self.password:
raise ValueError(
'Username and password must be provided. To set a password use: "cashu auth -p"'
)
data = {
"grant_type": "password",
"client_id": self.client_id,
"username": self.username,
"password": self.password,
"scope": "openid",
}
if self.client_secret:
data["client_secret"] = self.client_secret
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.token_endpoint, data=data)
response.raise_for_status()
token_data = response.json()
self.update_token_data(token_data)
logger.info("Authentication successful!")
# logger.info("Token response:")
# logger.info(self.token_response)
except httpx.HTTPError as e:
logger.error(f"Failed to obtain token: {e}")
raise
def render_success_page(
self, request: Request, token_data: Dict[str, Any]
) -> HTMLResponse:
"""Render an HTML page with a green check mark and user information."""
context = {"request": request, "token_data": token_data}
return self.templates.TemplateResponse("success.html", context)
async def authenticate_with_device_code(self) -> None:
"""Authenticate using the device code flow."""
if not self.device_authorization_endpoint:
raise ValueError("Device authorization endpoint not available.")
data = {
"client_id": self.client_id,
"scope": "openid",
}
if self.client_secret:
data["client_secret"] = self.client_secret
async with httpx.AsyncClient() as client:
try:
response = await client.post(
self.device_authorization_endpoint, data=data
)
response.raise_for_status()
device_data = response.json()
except httpx.HTTPError as e:
logger.error(f"Failed to obtain device code: {e}")
raise
# Extract device code data
device_code = device_data.get("device_code")
user_code = device_data.get("user_code")
verification_uri = device_data.get("verification_uri")
verification_uri_complete = device_data.get("verification_uri_complete")
expires_in = device_data.get("expires_in")
interval = device_data.get("interval", 5) # Default interval is 5 seconds
if not device_code or not verification_uri:
raise ValueError("Invalid response from device authorization endpoint.")
# Display instructions to the user and open the browser
if verification_uri_complete:
logger.info("Opening browser to complete authorization...")
logger.info(verification_uri_complete)
webbrowser.open(verification_uri_complete)
else:
logger.info("Please visit the following URL to authorize:")
logger.info(verification_uri)
logger.info(f"Enter the user code: {user_code}")
# Construct the URL for the user to enter the code
full_verification_uri = f"{verification_uri}?user_code={user_code}"
webbrowser.open(full_verification_uri)
# Start polling the token endpoint
start_time = datetime.now()
expires_at = start_time + timedelta(seconds=expires_in)
token_data = None
while datetime.now() < expires_at:
await asyncio.sleep(interval)
data = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code,
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.token_endpoint, data=data)
if response.status_code == 200:
# Successful response
token_data = response.json()
self.update_token_data(token_data)
logger.info("Authentication successful!")
break
else:
error_data = response.json()
error = error_data.get("error")
if error == "authorization_pending":
# Continue polling
pass
elif error == "slow_down":
# Increase interval by 5 seconds
interval += 5
elif error == "access_denied":
logger.error("Access denied by user.")
break
elif error == "expired_token":
logger.error("Device code has expired.")
break
else:
logger.error(f"Error during polling: {error}")
break
except httpx.HTTPError as e:
logger.error(f"HTTP error during token polling: {e}")
break
else:
logger.error("Device code has expired before authorization.")
raise Exception("Device code expired")
async def main() -> None:
parser = argparse.ArgumentParser(description="OpenID Connect Authentication Client")
parser.add_argument("discovery_url", help="OpenID Connect Discovery URL")
parser.add_argument("client_id", help="Client ID")
parser.add_argument("--client_secret", help="Client Secret", default="")
parser.add_argument(
"--auth_flow",
choices=["authorization_code", "password", "device_code"],
default="authorization_code",
help="Authentication flow to use",
)
parser.add_argument("--username", help="Username for password flow")
parser.add_argument("--password", help="Password for password flow")
parser.add_argument("--access_token", help="Stored access token")
parser.add_argument("--refresh_token", help="Stored refresh token")
parser.add_argument("--device_code", help="Device code for device flow")
args = parser.parse_args()
client = OpenIDClient(
discovery_url=args.discovery_url,
client_id=args.client_id,
client_secret=args.client_secret,
auth_flow=AuthorizationFlow(args.auth_flow),
username=args.username,
password=args.password,
access_token=args.access_token,
refresh_token=args.refresh_token,
device_code=args.device_code,
)
await client.initialize()
await client.authenticate()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,26 @@
<!DOCTYPE html>
<html>
<head>
<title>Authentication Error</title>
<style>
.error {
text-align: center;
margin-top: 50px;
}
.error h1 {
color: red;
}
.crossmark {
font-size: 100px;
color: red;
}
</style>
</head>
<body>
<div class="error">
<div class="crossmark">&#10006;</div>
<h1>Authentication Error</h1>
<p>{{ error }}</p>
</div>
</body>
</html>

View File

@@ -0,0 +1,33 @@
<!DOCTYPE html>
<html>
<head>
<title>Authentication Successful</title>
<style>
.success {
text-align: center;
margin-top: 50px;
}
.success h1 {
color: green;
}
.checkmark {
font-size: 100px;
color: green;
}
.token-data {
margin-top: 20px;
}
</style>
</head>
<body>
<div class="success">
<div class="checkmark">&#10004;</div>
<h1>Authentication Successful</h1>
<h3>You can close this window now.</h3>
<!-- <div class="token-data">
<h3>User Information:</h3>
<pre>{{ token_data | tojson(indent=2) }}</pre>
</div> -->
</div>
</body>
</html>

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python
import asyncio
import getpass
import os
import time
from datetime import datetime, timezone
@@ -41,6 +42,7 @@ from ...wallet.crud import (
)
from ...wallet.wallet import Wallet as Wallet
from ..api.api_server import start_api_server
from ..auth.auth import WalletAuth
from ..cli.cli_helpers import (
get_mint_wallet,
get_unit_wallet,
@@ -84,6 +86,49 @@ def coro(f):
return wrapper
def init_auth_wallet(func):
"""Decorator to pass auth_db and auth_keyset_id to the Wallet object."""
@wraps(func)
async def wrapper(*args, **kwargs):
ctx = args[0] # Assuming the first argument is 'ctx'
wallet: Wallet = ctx.obj["WALLET"]
db_location = wallet.db.db_location
auth_wallet = await WalletAuth.with_db(
url=ctx.obj["HOST"],
db=db_location,
)
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
if not requires_auth:
logger.debug("Mint does not require clear auth.")
return await func(*args, **kwargs)
# Pass auth_db and auth_keyset_id to the wallet object
wallet.auth_db = auth_wallet.db
wallet.auth_keyset_id = auth_wallet.keyset_id
# pass the mint_info so the wallet doesn't need to re-fetch it
wallet.mint_info = auth_wallet.mint_info
# Pass the auth_wallet to context
args[0].obj["AUTH_WALLET"] = auth_wallet
# Proceed to the original function
ret = await func(*args, **kwargs)
if settings.debug:
await auth_wallet.load_proofs(reload=True)
logger.debug(
f"Auth balance: {auth_wallet.unit.str(auth_wallet.available_balance)}"
)
return ret
return wrapper
@click.group(cls=NaturalOrderGroup)
@click.option(
"--host",
@@ -176,6 +221,10 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool):
ctx.obj["HOST"], db_path, name=walletname, unit=unit
)
# if we have never seen this mint before, we load its information
if not wallet.mint_info:
await wallet.load_mint()
assert wallet, "Wallet not found."
ctx.obj["WALLET"] = wallet
@@ -205,6 +254,7 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool):
)
@click.pass_context
@coro
@init_auth_wallet
async def pay(
ctx: Context, invoice: str, amount: Optional[int] = None, yes: bool = False
):
@@ -291,6 +341,7 @@ async def pay(
)
@click.pass_context
@coro
@init_auth_wallet
async def invoice(
ctx: Context,
amount: float,
@@ -451,6 +502,7 @@ async def invoice(
@cli.command("swap", help="Swap funds between mints.")
@click.pass_context
@coro
@init_auth_wallet
async def swap(ctx: Context):
print("Select the mint to swap from:")
outgoing_wallet: Wallet = await get_mint_wallet(ctx, force_select=True)
@@ -621,8 +673,9 @@ async def balance(ctx: Context, verbose):
)
@click.pass_context
@coro
@init_auth_wallet
async def send_command(
ctx,
ctx: Context,
amount: int,
memo: str,
nostr: str,
@@ -668,6 +721,7 @@ async def send_command(
)
@click.pass_context
@coro
@init_auth_wallet
async def receive_cli(
ctx: Context,
token: str,
@@ -685,6 +739,8 @@ async def receive_cli(
mint_url,
os.path.join(settings.cashu_dir, wallet.name),
unit=token_obj.unit,
auth_db=wallet.auth_db.db_location if wallet.auth_db else None,
auth_keyset_id=wallet.auth_keyset_id,
)
await verify_mint(mint_wallet, mint_url)
receive_wallet = await receive(mint_wallet, token_obj)
@@ -830,7 +886,7 @@ async def pending(ctx: Context, legacy, number: int, offset: int):
@cli.command("lock", help="Generate receiving lock.")
@click.pass_context
@coro
async def lock(ctx):
async def lock(ctx: Context):
wallet: Wallet = ctx.obj["WALLET"]
pubkey = await wallet.create_p2pk_pubkey()
@@ -851,7 +907,7 @@ async def lock(ctx):
@cli.command("locks", help="Show unused receiving locks.")
@click.pass_context
@coro
async def locks(ctx):
async def locks(ctx: Context):
wallet: Wallet = ctx.obj["WALLET"]
# P2PK lock
pubkey = await wallet.create_p2pk_pubkey()
@@ -899,7 +955,7 @@ async def locks(ctx):
)
@click.pass_context
@coro
async def invoices(ctx, paid: bool, unpaid: bool, pending: bool, mint: bool):
async def invoices(ctx: Context, paid: bool, unpaid: bool, pending: bool, mint: bool):
wallet: Wallet = ctx.obj["WALLET"]
if paid and unpaid:
@@ -1001,7 +1057,7 @@ async def invoices(ctx, paid: bool, unpaid: bool, pending: bool, mint: bool):
@cli.command("wallets", help="List of all available wallets.")
@click.pass_context
@coro
async def wallets(ctx):
async def wallets(ctx: Context):
# list all directories
wallets = [
d for d in listdir(settings.cashu_dir) if isdir(join(settings.cashu_dir, d))
@@ -1013,7 +1069,7 @@ async def wallets(ctx):
for w in wallets:
wallet = Wallet(ctx.obj["HOST"], os.path.join(settings.cashu_dir, w))
try:
await wallet.load_proofs()
await wallet.load_proofs(reload=True, all_keysets=True)
if wallet.proofs and len(wallet.proofs):
active_wallet = False
if w == ctx.obj["WALLET_NAME"]:
@@ -1031,9 +1087,10 @@ async def wallets(ctx):
@cli.command("info", help="Information about Cashu wallet.")
@click.option("--mint", default=False, is_flag=True, help="Fetch mint information.")
@click.option("--mnemonic", default=False, is_flag=True, help="Show your mnemonic.")
@click.option("--reload", default=False, is_flag=True, help="Reload mint info.")
@click.pass_context
@coro
async def info(ctx: Context, mint: bool, mnemonic: bool):
async def info(ctx: Context, mint: bool, mnemonic: bool, reload: bool):
wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_keysets_from_db(unit=None)
@@ -1057,7 +1114,11 @@ async def info(ctx: Context, mint: bool, mnemonic: bool):
if mint:
wallet.url = mint_url
try:
mint_info: dict = (await wallet.load_mint_info()).dict()
mint_info_obj = await wallet.load_mint_info(reload)
if not mint_info_obj:
print(" - Mint information not available.")
continue
mint_info = mint_info_obj.dict()
if mint_info:
print(f" - Mint name: {mint_info['name']}")
if mint_info.get("description"):
@@ -1126,6 +1187,7 @@ async def info(ctx: Context, mint: bool, mnemonic: bool):
)
@click.pass_context
@coro
@init_auth_wallet
async def restore(ctx: Context, to: int, batch: int):
wallet: Wallet = ctx.obj["WALLET"]
# check if there is already a mnemonic in the database
@@ -1160,6 +1222,7 @@ async def restore(ctx: Context, to: int, batch: int):
# @click.option("--all", default=False, is_flag=True, help="Execute on all available mints.")
@click.pass_context
@coro
@init_auth_wallet
async def selfpay(ctx: Context, all: bool = False):
wallet = await get_mint_wallet(ctx, force_select=True)
await wallet.load_mint()
@@ -1183,3 +1246,46 @@ async def selfpay(ctx: Context, all: bool = False):
print(token)
token_obj = TokenV4.deserialize(token)
await receive(wallet, token_obj)
@cli.command("auth", help="Authenticate with mint.")
@click.option("--mint", "-m", default=False, is_flag=True, help="Mint new auth tokens.")
@click.option(
"--force", "-f", default=False, is_flag=True, help="Force authentication."
)
@click.option(
"--password",
"-p",
default=False,
is_flag=True,
help="Use username and password for authentication.",
)
@click.pass_context
@coro
async def auth(ctx: Context, mint: bool, force: bool, password: bool):
# auth_wallet: WalletAuth = ctx.obj["AUTH_WALLET"]
wallet: Wallet = ctx.obj["WALLET"]
username = None
password_str = None
if password:
username = input("Enter username: ")
password_str = getpass.getpass("Enter password: ")
auth_wallet = await WalletAuth.with_db(
url=ctx.obj["HOST"],
db=wallet.db.db_location,
username=username,
password=password_str,
)
requires_auth = await auth_wallet.init_auth_wallet(
wallet.mint_info, mint_auth_proofs=False, force_auth=force
)
if not requires_auth:
print("Mint does not require authentication.")
return
if mint:
new_proofs = await auth_wallet.mint_blind_auth()
print(f"Minted {auth_wallet.unit.str(sum_proofs(new_proofs))} auth tokens.")
print(f"Auth balance: {auth_wallet.unit.str(auth_wallet.available_balance)}")

View File

@@ -9,6 +9,7 @@ from ..core.base import (
MintQuoteState,
Proof,
WalletKeyset,
WalletMint,
)
from ..core.db import Connection, Database
@@ -577,3 +578,59 @@ async def store_seed_and_mnemonic(
"mnemonic": mnemonic,
},
)
async def store_mint(
db: Database,
mint: WalletMint,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
"""
INSERT INTO mints
(url, info, updated)
VALUES (:url, :info, :updated)
""",
{
"url": mint.url,
"info": mint.info,
"updated": int(time.time()),
},
)
async def update_mint(
db: Database,
mint: WalletMint,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
"""
UPDATE mints
SET info = :info, updated = :updated, access_token = :access_token, refresh_token = :refresh_token, username = :username, password = :password
WHERE url = :url
""",
{
"url": mint.url,
"info": mint.info,
"updated": int(time.time()),
"access_token": mint.access_token,
"refresh_token": mint.refresh_token,
"username": mint.username,
"password": mint.password,
},
)
async def get_mint_by_url(
db: Database,
url: str,
conn: Optional[Connection] = None,
) -> Optional[WalletMint]:
row = await (conn or db).fetchone(
"""
SELECT * from mints WHERE url = :url
""",
{"url": url},
)
return WalletMint.parse_obj(dict(row)) if row else None

16
cashu/wallet/errors.py Normal file
View File

@@ -0,0 +1,16 @@
from typing import Optional
class WalletError(Exception):
msg: str
def __init__(self, msg):
super().__init__(msg)
self.msg = msg
class BalanceTooLowError(WalletError):
msg = "Balance too low"
def __init__(self, msg: Optional[str] = None):
super().__init__(msg or self.msg)

View File

@@ -52,6 +52,8 @@ async def redeem_TokenV3(wallet: Wallet, token: TokenV3) -> Wallet:
t.mint,
os.path.join(settings.cashu_dir, wallet.name),
unit=token.unit or wallet.unit.name,
auth_db=wallet.auth_db.db_location if wallet.auth_db else None,
auth_keyset_id=wallet.auth_keyset_id,
)
keyset_ids = mint_wallet._get_proofs_keyset_ids(t.proofs)
logger.trace(f"Keysets in tokens: {' '.join(set(keyset_ids))}")

View File

@@ -214,7 +214,6 @@ async def m010_add_ids_to_proofs_and_out_to_invoices(db: Database):
Columns that store mint and melt id for proofs and invoices.
"""
async with db.connect() as conn:
print("Running wallet migrations")
await conn.execute("ALTER TABLE proofs ADD COLUMN mint_id TEXT")
await conn.execute("ALTER TABLE proofs ADD COLUMN melt_id TEXT")
@@ -287,11 +286,30 @@ async def m013_add_mint_and_melt_quote_tables(db: Database):
"""
)
async def m013_add_key_to_mint_quote_table(db: Database):
async def m014_add_key_to_mint_quote_table(db: Database):
async with db.connect() as conn:
await conn.execute(
"""
ALTER TABLE bolt11_mint_quotes
ADD COLUMN privkey TEXT DEFAULT NULL;
"""
)
)
async def m015_add_mints_table(db: Database):
async with db.connect() as conn:
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS mints (
id INTEGER PRIMARY KEY AUTOINCREMENT,
url TEXT NOT NULL,
info TEXT NOT NULL,
updated TIMESTAMP DEFAULT {db.timestamp_now},
access_token TEXT,
refresh_token TEXT,
username TEXT,
password TEXT
);
"""
)

View File

@@ -1,10 +1,11 @@
from typing import Dict, List, Protocol
from typing import Dict, List, Optional, Protocol
import httpx
from ..core.base import Proof, Unit, WalletKeyset
from ..core.crypto.secp import PrivateKey
from ..core.db import Database
from ..core.mint_info import MintInfo
class SupportsPrivateKey(Protocol):
@@ -28,3 +29,9 @@ class SupportsHttpxClient(Protocol):
class SupportsMintURL(Protocol):
url: str
class SupportsAuth(Protocol):
auth_db: Optional[Database] = None
auth_keyset_id: Optional[str] = None
mint_info: Optional[MintInfo] = None

View File

@@ -47,7 +47,7 @@ class SubscriptionManager:
try:
msg = JSONRPCNotification.parse_raw(message)
logger.debug(f"Received notification: {msg}")
logger.trace(f"Received notification: {msg}")
except Exception as e:
logger.error(f"Error parsing notification: {e}")
return

View File

@@ -12,6 +12,7 @@ from pydantic import ValidationError
from cashu.wallet.crud import get_bolt11_melt_quote
from ..core.base import (
AuthProof,
BlindedMessage,
BlindedSignature,
MeltQuoteState,
@@ -29,6 +30,8 @@ from ..core.models import (
KeysetsResponse,
KeysetsResponseKeyset,
KeysResponse,
PostAuthBlindMintRequest,
PostAuthBlindMintResponse,
PostCheckStateRequest,
PostCheckStateResponse,
PostMeltQuoteRequest,
@@ -47,8 +50,16 @@ from ..core.models import (
)
from ..core.settings import settings
from ..tor.tor import TorProxy
from .crud import (
get_proofs,
invalidate_proof,
)
from .protocols import SupportsAuth
from .wallet_deprecated import LedgerAPIDeprecated
GET = "GET"
POST = "POST"
def async_set_httpx_client(func):
"""
@@ -78,7 +89,7 @@ def async_set_httpx_client(func):
verify=not settings.debug,
proxies=proxies_dict, # type: ignore
headers=headers_dict,
base_url=self.url,
base_url=self.url.rstrip("/"),
timeout=None if settings.debug else 60,
)
return await func(self, *args, **kwargs)
@@ -99,10 +110,10 @@ def async_ensure_mint_loaded(func):
return wrapper
class LedgerAPI(LedgerAPIDeprecated):
class LedgerAPI(LedgerAPIDeprecated, SupportsAuth):
tor: TorProxy
db: Database # we need the db for melt_deprecated
httpx: httpx.AsyncClient
api_prefix = "v1"
def __init__(self, url: str, db: Database):
self.url = url
@@ -128,7 +139,6 @@ class LedgerAPI(LedgerAPIDeprecated):
try:
resp_dict = resp.json()
except json.JSONDecodeError:
# if we can't decode the response, raise for status
resp.raise_for_status()
return
if "detail" in resp_dict:
@@ -137,9 +147,49 @@ class LedgerAPI(LedgerAPIDeprecated):
if "code" in resp_dict:
error_message += f" (Code: {resp_dict['code']})"
raise Exception(error_message)
# raise for status if no error
resp.raise_for_status()
async def _request(self, method: str, path: str, noprefix=False, **kwargs):
if not noprefix:
path = join(self.api_prefix, path)
if self.mint_info and self.mint_info.requires_blind_auth_path(method, path):
if not self.auth_db:
raise Exception(
"Mint requires blind auth, but no auth database is set."
)
if not self.auth_keyset_id:
raise Exception(
"Mint requires blind auth, but no auth keyset id is set."
)
proofs = await get_proofs(db=self.auth_db, id=self.auth_keyset_id)
if not proofs:
raise Exception(
"Mint requires blind auth, but no blind auth tokens were found."
)
# select one auth proof
proof = proofs[0]
auth_token = AuthProof.from_proof(proof).to_base64()
kwargs.setdefault("headers", {}).update(
{
"Blind-auth": f"{auth_token}",
}
)
await invalidate_proof(proof=proof, db=self.auth_db)
if self.mint_info and self.mint_info.requires_clear_auth_path(method, path):
logger.debug(f"Using clear auth token for {path}")
clear_auth_token = kwargs.pop("clear_auth_token")
if not clear_auth_token:
raise Exception(
"Mint requires clear auth, but no clear auth token is set."
)
kwargs.setdefault("headers", {}).update(
{
"Clear-auth": f"{clear_auth_token}",
}
)
return await self.httpx.request(method, path, **kwargs)
"""
ENDPOINTS
"""
@@ -157,9 +207,7 @@ class LedgerAPI(LedgerAPIDeprecated):
Raises:
Exception: If no keys are received from the mint
"""
resp = await self.httpx.get(
join(self.url, "/v1/keys"),
)
resp = await self._request(GET, "keys")
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
@@ -201,9 +249,7 @@ class LedgerAPI(LedgerAPIDeprecated):
Exception: If no keys are received from the mint
"""
keyset_id_urlsafe = keyset_id.replace("+", "-").replace("/", "_")
resp = await self.httpx.get(
join(self.url, f"/v1/keys/{keyset_id_urlsafe}"),
)
resp = await self._request(GET, f"keys/{keyset_id_urlsafe}")
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
@@ -238,9 +284,7 @@ class LedgerAPI(LedgerAPIDeprecated):
Raises:
Exception: If no keysets are received from the mint
"""
resp = await self.httpx.get(
join(self.url, "/v1/keysets"),
)
resp = await self._request(GET, "keysets")
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
@@ -265,9 +309,7 @@ class LedgerAPI(LedgerAPIDeprecated):
Raises:
Exception: If the mint info request fails
"""
resp = await self.httpx.get(
join(self.url, "/v1/info"),
)
resp = await self._request(GET, "/v1/info", noprefix=True)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
@@ -305,9 +347,12 @@ class LedgerAPI(LedgerAPIDeprecated):
payload = PostMintQuoteRequest(
unit=unit.name, amount=amount, description=memo, pubkey=pubkey
)
resp = await self.httpx.post(
join(self.url, "/v1/mint/quote/bolt11"), json=payload.dict()
resp = await self._request(
POST,
"mint/quote/bolt11",
json=payload.dict(),
)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
@@ -329,9 +374,7 @@ class LedgerAPI(LedgerAPIDeprecated):
Returns:
PostMintQuoteResponse: Mint Quote Response
"""
resp = await self.httpx.get(
join(self.url, f"/v1/mint/quote/bolt11/{quote}"),
)
resp = await self._request(GET, f"mint/quote/bolt11/{quote}")
self.raise_on_error_request(resp)
return_dict = resp.json()
return PostMintQuoteResponse.parse_obj(return_dict)
@@ -371,8 +414,9 @@ class LedgerAPI(LedgerAPIDeprecated):
return res
payload = outputs_payload.dict(include=_mintrequest_include_fields(outputs)) # type: ignore
resp = await self.httpx.post(
join(self.url, "/v1/mint/bolt11"),
resp = await self._request(
POST,
"mint/bolt11",
json=payload, # type: ignore
)
# BEGIN backwards compatibility < 0.15.0
@@ -383,7 +427,7 @@ class LedgerAPI(LedgerAPIDeprecated):
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
response_dict = resp.json()
logger.trace("Lightning invoice checked. POST /v1/mint/bolt11")
logger.trace(f"Lightning invoice checked. POST {self.api_prefix}/mint/bolt11")
promises = PostMintResponse.parse_obj(response_dict).signatures
return promises
@@ -406,8 +450,9 @@ class LedgerAPI(LedgerAPIDeprecated):
unit=unit.name, request=payment_request, options=melt_options
)
resp = await self.httpx.post(
join(self.url, "/v1/melt/quote/bolt11"),
resp = await self._request(
POST,
"melt/quote/bolt11",
json=payload.dict(),
)
# BEGIN backwards compatibility < 0.15.0
@@ -441,9 +486,7 @@ class LedgerAPI(LedgerAPIDeprecated):
Returns:
PostMeltQuoteResponse: Melt Quote Response
"""
resp = await self.httpx.get(
join(self.url, f"/v1/melt/quote/bolt11/{quote}"),
)
resp = await self._request(GET, f"melt/quote/bolt11/{quote}")
self.raise_on_error_request(resp)
return_dict = resp.json()
return PostMeltQuoteResponse.parse_obj(return_dict)
@@ -474,8 +517,9 @@ class LedgerAPI(LedgerAPIDeprecated):
"outputs": {i: outputs_include for i in range(len(outputs))},
}
resp = await self.httpx.post(
join(self.url, "/v1/melt/bolt11"),
resp = await self._request(
POST,
"melt/bolt11",
json=payload.dict(include=_meltrequest_include_fields(proofs, outputs)), # type: ignore
timeout=None,
)
@@ -523,7 +567,7 @@ class LedgerAPI(LedgerAPIDeprecated):
outputs: List[BlindedMessage],
) -> List[BlindedSignature]:
"""Consume proofs and create new promises based on amount split."""
logger.debug("Calling split. POST /v1/swap")
logger.debug(f"Calling split. POST {self.api_prefix}/swap")
split_payload = PostSwapRequest(inputs=proofs, outputs=outputs)
# construct payload
@@ -541,8 +585,9 @@ class LedgerAPI(LedgerAPIDeprecated):
"inputs": {i: proofs_include for i in range(len(proofs))},
}
resp = await self.httpx.post(
join(self.url, "/v1/swap"),
resp = await self._request(
POST,
"swap",
json=split_payload.dict(include=_splitrequest_include_fields(proofs)), # type: ignore
)
# BEGIN backwards compatibility < 0.15.0
@@ -568,8 +613,9 @@ class LedgerAPI(LedgerAPIDeprecated):
Checks whether the secrets in proofs are already spent or not and returns a list of booleans.
"""
payload = PostCheckStateRequest(Ys=[p.Y for p in proofs])
resp = await self.httpx.post(
join(self.url, "/v1/checkstate"),
resp = await self._request(
POST,
"checkstate",
json=payload.dict(),
)
# BEGIN backwards compatibility < 0.15.0
@@ -595,10 +641,7 @@ class LedgerAPI(LedgerAPIDeprecated):
"Received HTTP Error 422. Attempting state check with < 0.16.0 compatibility."
)
payload_secrets = {"secrets": [p.secret for p in proofs]}
resp_secrets = await self.httpx.post(
join(self.url, "/v1/checkstate"),
json=payload_secrets,
)
resp_secrets = await self._request(POST, "checkstate", json=payload_secrets)
self.raise_on_error(resp_secrets)
states = [
ProofState(Y=p.Y, state=ProofSpentState(s["state"]))
@@ -619,7 +662,7 @@ class LedgerAPI(LedgerAPIDeprecated):
Asks the mint to restore promises corresponding to outputs.
"""
payload = PostMintRequest(quote="restore", outputs=outputs)
resp = await self.httpx.post(join(self.url, "/v1/restore"), json=payload.dict())
resp = await self._request(POST, "restore", json=payload.dict())
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
@@ -637,3 +680,21 @@ class LedgerAPI(LedgerAPIDeprecated):
# END backwards compatibility < 0.15.1
return returnObj.outputs, returnObj.signatures
@async_set_httpx_client
async def blind_mint_blind_auth(
self, clear_auth_token: str, outputs: List[BlindedMessage]
) -> List[BlindedSignature]:
"""
Asks the mint to mint blind auth tokens. Needs to provide a clear auth token.
"""
payload = PostAuthBlindMintRequest(outputs=outputs)
resp = await self._request(
POST,
"mint",
json=payload.dict(),
clear_auth_token=clear_auth_token,
)
self.raise_on_error_request(resp)
response_dict = resp.json()
return PostAuthBlindMintResponse.parse_obj(response_dict).signatures

View File

@@ -1,4 +1,5 @@
import copy
import json
import threading
import time
from typing import Callable, Dict, List, Optional, Tuple, Union
@@ -17,6 +18,7 @@ from ..core.base import (
Proof,
Unit,
WalletKeyset,
WalletMint,
)
from ..core.crypto import b_dhke
from ..core.crypto.secp import PrivateKey, PublicKey
@@ -30,6 +32,7 @@ from ..core.helpers import (
)
from ..core.json_rpc.base import JSONRPCSubscriptionKinds
from ..core.migrations import migrate_databases
from ..core.mint_info import MintInfo
from ..core.models import (
PostCheckStateResponse,
PostMeltQuoteResponse,
@@ -43,6 +46,7 @@ from .crud import (
bump_secret_derivation,
get_bolt11_mint_quote,
get_keysets,
get_mint_by_url,
get_proofs,
invalidate_proof,
secret_used,
@@ -50,14 +54,16 @@ from .crud import (
store_bolt11_melt_quote,
store_bolt11_mint_quote,
store_keyset,
store_mint,
store_proof,
update_bolt11_melt_quote,
update_bolt11_mint_quote,
update_keyset,
update_mint,
update_proof,
)
from .errors import BalanceTooLowError
from .htlc import WalletHTLC
from .mint_info import MintInfo
from .p2pk import WalletP2PK
from .proofs import WalletProofs
from .secrets import WalletSecrets
@@ -108,21 +114,41 @@ class Wallet(
db: Database
bip32: BIP32
# private_key: Optional[PrivateKey] = None
auth_db: Optional[Database] = None
auth_keyset_id: Optional[str] = None
def __init__(self, url: str, db: str, name: str = "wallet", unit: str = "sat"):
def __init__(
self,
url: str,
db: str,
name: str = "wallet",
unit: str = "sat",
auth_db: Optional[str] = None,
auth_keyset_id: Optional[str] = None,
):
"""A Cashu wallet.
Args:
url (str): URL of the mint.
db (str): Path to the database directory.
name (str, optional): Name of the wallet database file. Defaults to "wallet".
unit (str, optional): Unit of the wallet. Defaults to "sat".
auth_db (Optional[str], optional): Path to the auth database directory. Defaults to None.
auth_keyset_id (Optional[str], optional): Keyset ID of the auth keyset. Defaults to None.
"""
self.db = Database("wallet", db)
self.db = Database(name, db)
self.proofs: List[Proof] = []
self.name = name
self.unit = Unit[unit]
url = sanitize_url(url)
# if this is an auth wallet
if (auth_db and not auth_keyset_id) or (not auth_db and auth_keyset_id):
raise Exception("Both auth_db and auth_keyset_id must be provided.")
if auth_db and auth_keyset_id:
self.auth_db = Database("auth", auth_db)
self.auth_keyset_id = auth_keyset_id
super().__init__(url=url, db=self.db)
logger.debug("Wallet initialized")
logger.debug(f"Mint URL: {url}")
@@ -137,7 +163,10 @@ class Wallet(
name: str = "wallet",
skip_db_read: bool = False,
unit: str = "sat",
auth_db: Optional[str] = None,
auth_keyset_id: Optional[str] = None,
load_all_keysets: bool = False,
**kwargs,
):
"""Initializes a wallet with a database and initializes the private key.
@@ -151,12 +180,22 @@ class Wallet(
unit (str, optional): Unit of the wallet. Defaults to "sat".
load_all_keysets (bool, optional): If true, all keysets are loaded from the database.
Defaults to False.
auth_db (Optional[str], optional): Path to the auth database directory. Defaults to None.
auth_keyset_id (Optional[str], optional): Keyset ID of the auth keyset. Defaults to None.
kwargs: Additional keyword arguments.
Returns:
Wallet: Initialized wallet.
"""
logger.trace(f"Initializing wallet with database: {db}")
self = cls(url=url, db=db, name=name, unit=unit)
self = cls(
url=url,
db=db,
name=name,
unit=unit,
auth_db=auth_db,
auth_keyset_id=auth_keyset_id,
)
await self._migrate_database()
if skip_db_read:
@@ -172,8 +211,13 @@ class Wallet(
self.keysets = {k.id: k for k in keysets_active_unit}
else:
self.keysets = {k.id: k for k in keysets_list}
keysets_str = " ".join([f"{i} {k.unit}" for i, k in self.keysets.items()])
logger.debug(f"Loaded keysets: {keysets_str}")
if self.keysets:
keysets_str = " ".join([f"{i} {k.unit}" for i, k in self.keysets.items()])
logger.debug(f"Loaded keysets: {keysets_str}")
await self.load_mint_info(offline=True)
return self
async def _migrate_database(self):
@@ -185,12 +229,63 @@ class Wallet(
# ---------- API ----------
async def load_mint_info(self) -> MintInfo:
"""Loads the mint info from the mint."""
mint_info_resp = await self._get_info()
self.mint_info = MintInfo(**mint_info_resp.dict())
logger.debug(f"Mint info: {self.mint_info}")
return self.mint_info
async def load_mint_info(self, reload=False, offline=False) -> MintInfo | None:
"""Loads the mint info from the mint.
Args:
reload (bool, optional): If True, the mint info is reloaded from the mint. Defaults to False.
offline (bool, optional): If True, the mint info is not loaded from the mint. Defaults to False.
"""
# if self.mint_info and not reload:
# return self.mint_info
# read mint info from db
if reload:
if offline:
raise Exception("Cannot reload mint info offline.")
logger.debug("Forcing reload of mint info.")
mint_info_resp = await self._get_info()
self.mint_info = MintInfo(**mint_info_resp.dict())
wallet_mint_db = await get_mint_by_url(url=self.url, db=self.db)
if not wallet_mint_db:
if self.mint_info:
logger.debug("Storing mint info in db.")
await store_mint(
db=self.db,
mint=WalletMint(
url=self.url, info=json.dumps(self.mint_info.dict())
),
)
else:
if offline:
return None
logger.debug("Loading mint info from mint.")
mint_info_resp = await self._get_info()
self.mint_info = MintInfo(**mint_info_resp.dict())
if not wallet_mint_db:
logger.debug("Storing mint info in db.")
await store_mint(
db=self.db,
mint=WalletMint(
url=self.url, info=json.dumps(self.mint_info.dict())
),
)
return self.mint_info
elif (
self.mint_info
and not json.dumps(self.mint_info.dict()) == wallet_mint_db.info
):
logger.debug("Updating mint info in db.")
await update_mint(
db=self.db,
mint=WalletMint(url=self.url, info=json.dumps(self.mint_info.dict())),
)
return self.mint_info
else:
logger.debug("Loading mint info from db.")
self.mint_info = MintInfo.from_json_str(wallet_mint_db.info)
return self.mint_info
async def load_mint_keysets(self, force_old_keysets=False):
"""Loads all keyset of the mint and makes sure we have them all in the database.
@@ -304,11 +399,11 @@ class Wallet(
force_old_keysets (bool, optional): If true, old deprecated base64 keysets are not ignored. This is necessary for restoring tokens from old base64 keysets.
Defaults to False.
"""
logger.trace("Loading mint.")
logger.trace(f"Loading mint {self.url}")
await self.load_mint_keysets(force_old_keysets)
await self.activate_keyset(keyset_id)
try:
await self.load_mint_info()
await self.load_mint_info(reload=True)
except Exception as e:
logger.debug(f"Could not load mint info: {e}")
pass
@@ -979,7 +1074,7 @@ class Wallet(
# select proofs that are not reserved and are in the active keysets of the mint
proofs = self.active_proofs(proofs)
if sum_proofs(proofs) < amount:
raise Exception("balance too low.")
raise BalanceTooLowError()
# coin selection for potentially offline sending
send_proofs = self.coinselect(proofs, amount, include_fees=include_fees)
@@ -1037,7 +1132,7 @@ class Wallet(
# select proofs that are not reserved and are in the active keysets of the mint
proofs = self.active_proofs(proofs)
if sum_proofs(proofs) < amount:
raise Exception("balance too low.")
raise BalanceTooLowError()
# coin selection for swapping, needs to include fees
swap_proofs = self.coinselect(proofs, amount, include_fees=True)

View File

@@ -144,6 +144,8 @@ class LedgerAPIDeprecated(SupportsHttpxClient, SupportsMintURL):
mint_info = GetInfoResponse(
**mint_info_deprecated.dict(exclude={"parameter", "nuts", "contact"})
)
# monkeypatch nuts
mint_info.nuts = {}
return mint_info
@async_set_httpx_client
@@ -261,7 +263,7 @@ class LedgerAPIDeprecated(SupportsHttpxClient, SupportsMintURL):
paid=False,
state=MintQuoteState.unpaid.value,
expiry=decoded_invoice.date + (decoded_invoice.expiry or 0),
pubkey=None
pubkey=None,
)
@async_set_httpx_client

7
keycloak/.env.example Normal file
View File

@@ -0,0 +1,7 @@
POSTGRES_DB=keycloak_db
POSTGRES_USER=keycloak_db_user
POSTGRES_PASSWORD=keycloak_db_user_password
KEYCLOAK_ADMIN=admin
KEYCLOAK_ADMIN_PASSWORD=password
KC_HOSTNAME=localhost
KC_HOSTNAME_PORT=8080

129
keycloak/README.md Normal file
View File

@@ -0,0 +1,129 @@
## Docker compose
This docker-compose starts a new keycloak instance. Set up the server as you wish, add realms, users etc. We will then export the data and restore an instance with the exported data.
We will modify this file later to start the server with the backup data.
```
services:
postgres:
image: postgres:16.4
volumes:
- ./postgres_data:/var/lib/postgresql/data
environment:
POSTGRES_DB: ${POSTGRES_DB}
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
networks:
- keycloak_network
keycloak:
image: quay.io/keycloak/keycloak:25.0.6
command: start
environment:
KC_HOSTNAME: localhost
KC_HOSTNAME_PORT: 8080
KC_HOSTNAME_STRICT_BACKCHANNEL: false
KC_HTTP_ENABLED: true
KC_HOSTNAME_STRICT_HTTPS: false
KC_HEALTH_ENABLED: true
KEYCLOAK_ADMIN: ${KEYCLOAK_ADMIN}
KEYCLOAK_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD}
KC_DB: postgres
KC_DB_URL: jdbc:postgresql://postgres/${POSTGRES_DB}
KC_DB_USERNAME: ${POSTGRES_USER}
KC_DB_PASSWORD: ${POSTGRES_PASSWORD}
ports:
- 8080:8080
restart: always
depends_on:
- postgres
networks:
- keycloak_network
volumes:
postgres_data:
driver: local
networks:
keycloak_network:
driver: bridge
```
## Backup
Export realm and users from running container:
```
docker exec keycloak-keycloak-1 \
/opt/keycloak/bin/kc.sh export \
--dir /opt/keycloak/data/export \
--users different_files \
--http-management-port 46566
```
Copy export out of the docker
```
docker cp keycloak-keycloak-1:/opt/keycloak/data/export ./keycloak-export
```
## Restore
Use this docker-compose.yml to start keycloak with the exported backup:
```
services:
postgres:
image: postgres:16.4
volumes:
- ./postgres_data:/var/lib/postgresql/data
environment:
POSTGRES_DB: ${POSTGRES_DB}
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
networks:
- keycloak_network
keycloak:
image: quay.io/keycloak/keycloak:25.0.6
command: start --import-realm
volumes:
- ./keycloak-export:/opt/keycloak/data/import
environment:
KC_HOSTNAME: localhost
KC_HOSTNAME_PORT: 8080
KC_HOSTNAME_STRICT_BACKCHANNEL: false
KC_HTTP_ENABLED: true
KC_HOSTNAME_STRICT_HTTPS: false
KC_HEALTH_ENABLED: true
KEYCLOAK_ADMIN: ${KEYCLOAK_ADMIN}
KEYCLOAK_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD}
KC_DB: postgres
KC_DB_URL: jdbc:postgresql://postgres/${POSTGRES_DB}
KC_DB_USERNAME: ${POSTGRES_USER}
KC_DB_PASSWORD: ${POSTGRES_PASSWORD}
ports:
- 8080:8080
restart: always
depends_on:
- postgres
networks:
- keycloak_network
volumes:
postgres_data:
driver: local
networks:
keycloak_network:
driver: bridge
```
Difference to first docker-compose is only the following part:
```
command: start --import-realm
volumes:
- ./keycloak-export:/opt/keycloak/data/import
```

View File

@@ -0,0 +1,45 @@
services:
postgres:
image: postgres:16.4
volumes:
- ./postgres_data:/var/lib/postgresql/data
environment:
POSTGRES_DB: ${POSTGRES_DB}
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
networks:
- keycloak_network
keycloak:
image: quay.io/keycloak/keycloak:25.0.6
command: start --import-realm
volumes:
- ./keycloak-export:/opt/keycloak/data/import
environment:
KC_HOSTNAME: ${KC_HOSTNAME}
KC_HOSTNAME_PORT: ${KC_HOSTNAME_PORT}
KC_HOSTNAME_STRICT_BACKCHANNEL: false
KC_HTTP_ENABLED: true
KC_HOSTNAME_STRICT_HTTPS: true
KC_HEALTH_ENABLED: true
KEYCLOAK_ADMIN: ${KEYCLOAK_ADMIN}
KEYCLOAK_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD}
KC_DB: postgres
KC_DB_URL: jdbc:postgresql://postgres/${POSTGRES_DB}
KC_DB_USERNAME: ${POSTGRES_USER}
KC_DB_PASSWORD: ${POSTGRES_PASSWORD}
ports:
- 8080:8080
restart: always
depends_on:
- postgres
networks:
- keycloak_network
volumes:
postgres_data:
driver: local
networks:
keycloak_network:
driver: bridge

1099
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -39,9 +39,11 @@ googleapis-common-protos = "^1.63.2"
mypy-protobuf = "^3.6.0"
types-protobuf = "^5.27.0.20240626"
grpcio-tools = "^1.65.1"
pyjwt = "^2.9.0"
redis = "^5.1.1"
brotli = "^1.1.0"
zstandard = "^0.23.0"
jinja2 = "^3.1.5"
[tool.poetry.group.dev.dependencies]
pytest-asyncio = "^0.24.0"

View File

@@ -51,6 +51,7 @@ settings.mint_lnd_enable_mpp = True
settings.mint_clnrest_enable_mpp = True
settings.mint_input_fee_ppk = 0
settings.db_connection_pool = True
# settings.mint_require_auth = False
assert "test" in settings.cashu_dir
shutil.rmtree(settings.cashu_dir, ignore_errors=True)

View File

@@ -0,0 +1,45 @@
services:
postgres:
image: postgres:16.4
volumes:
- ./postgres_data:/var/lib/postgresql/data
environment:
POSTGRES_DB: cashu
POSTGRES_USER: cashu
POSTGRES_PASSWORD: cashu
networks:
- keycloak_network
keycloak:
image: quay.io/keycloak/keycloak:25.0.6
command: start --import-realm
volumes:
- ./keycloak-export:/opt/keycloak/data/import
environment:
KC_HOSTNAME: localhost
KC_HOSTNAME_PORT: 8080
KC_HOSTNAME_STRICT_BACKCHANNEL: false
KC_HTTP_ENABLED: true
KC_HOSTNAME_STRICT_HTTPS: false
KC_HEALTH_ENABLED: true
KEYCLOAK_ADMIN: admin
KEYCLOAK_ADMIN_PASSWORD: admin
KC_DB: postgres
KC_DB_URL: jdbc:postgresql://postgres/cashu
KC_DB_USERNAME: cashu
KC_DB_PASSWORD: cashu
ports:
- 8080:8080
restart: always
depends_on:
- postgres
networks:
- keycloak_network
volumes:
postgres_data:
driver: local
networks:
keycloak_network:
driver: bridge

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,26 @@
{
"realm" : "master",
"users" : [ {
"id" : "0ff227f7-c163-4fca-9ae4-c8751c725421",
"username" : "admin",
"emailVerified" : false,
"createdTimestamp" : 1727128354842,
"enabled" : true,
"totp" : false,
"credentials" : [ {
"id" : "11a5f9ed-19c9-4164-be31-28ce6e23955b",
"type" : "password",
"createdDate" : 1727128354904,
"secretData" : "{\"value\":\"s/6M2/FCFd1fOyHJRMvOLvKM7e2JIOC6LZ3ovFVkGi8=\",\"salt\":\"Zjn7ChOL5688O84xf1ElGA==\",\"additionalParameters\":{}}",
"credentialData" : "{\"hashIterations\":5,\"algorithm\":\"argon2\",\"additionalParameters\":{\"hashLength\":[\"32\"],\"memory\":[\"7168\"],\"type\":[\"id\"],\"version\":[\"1.3\"],\"parallelism\":[\"1\"]}}"
} ],
"disableableCredentialTypes" : [ ],
"requiredActions" : [ ],
"realmRoles" : [ "default-roles-master", "admin" ],
"clientRoles" : {
"nutshell-realm" : [ "query-realms", "query-users", "manage-identity-providers", "manage-authorization", "view-identity-providers", "view-realm", "view-authorization", "query-clients", "manage-clients", "create-client", "view-events", "manage-events", "manage-realm", "manage-users", "view-users", "view-clients", "query-groups" ]
},
"notBefore" : 0,
"groups" : [ ]
} ]
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,53 @@
{
"realm" : "nutshell",
"users" : [ {
"id" : "c4fc742a-700f-4c83-96f2-8777c8bb56d1",
"username" : "asd@asd.com",
"firstName" : "asd",
"lastName" : "asd",
"email" : "asd@asd.com",
"emailVerified" : false,
"createdTimestamp" : 1727128876722,
"enabled" : true,
"totp" : false,
"credentials" : [ {
"id" : "23ea2b79-9c09-4133-b53b-2708258da890",
"type" : "password",
"createdDate" : 1727128876754,
"secretData" : "{\"value\":\"fDXqE3IjxS5uIYfn9eYgW5GwokWvGsg2wWY0lOgeYyE=\",\"salt\":\"Wlb5f8yPTh4QreuC99b7Zg==\",\"additionalParameters\":{}}",
"credentialData" : "{\"hashIterations\":5,\"algorithm\":\"argon2\",\"additionalParameters\":{\"hashLength\":[\"32\"],\"memory\":[\"7168\"],\"type\":[\"id\"],\"version\":[\"1.3\"],\"parallelism\":[\"1\"]}}"
} ],
"disableableCredentialTypes" : [ ],
"requiredActions" : [ ],
"realmRoles" : [ "default-roles-nutshell" ],
"clientConsents" : [ {
"clientId" : "cashu-client",
"grantedClientScopes" : [ "email", "roles", "profile" ],
"createdDate" : 1732651444894,
"lastUpdatedDate" : 1732651444908
} ],
"notBefore" : 0,
"groups" : [ ]
}, {
"id" : "43a16bd6-f5c5-4dfa-bcd4-6a5540564797",
"username" : "callebtc@protonmail.com",
"firstName" : "asdasd",
"lastName" : "asdasdasdasd",
"email" : "callebtc@protonmail.com",
"emailVerified" : false,
"createdTimestamp" : 1732639511706,
"enabled" : true,
"totp" : false,
"credentials" : [ ],
"disableableCredentialTypes" : [ ],
"requiredActions" : [ ],
"federatedIdentities" : [ {
"identityProvider" : "github",
"userId" : "93376500",
"userName" : "callebtc"
} ],
"realmRoles" : [ "default-roles-nutshell" ],
"notBefore" : 0,
"groups" : [ ]
} ]
}

View File

@@ -183,14 +183,14 @@ async def test_mint(wallet1: Wallet):
assert wallet1.balance == 64
# verify that proofs in proofs_used db have the same mint_id as the invoice in the db
mint_quote = await get_bolt11_mint_quote(db=wallet1.db, quote=mint_quote.quote)
assert mint_quote
mint_quote_2 = await get_bolt11_mint_quote(db=wallet1.db, quote=mint_quote.quote)
assert mint_quote_2
proofs_minted = await get_proofs(
db=wallet1.db, mint_id=mint_quote.quote, table="proofs"
db=wallet1.db, mint_id=mint_quote_2.quote, table="proofs"
)
assert len(proofs_minted) == len(expected_proof_amounts)
assert all([p.amount in expected_proof_amounts for p in proofs_minted])
assert all([p.mint_id == mint_quote.quote for p in proofs_minted])
assert all([p.mint_id == mint_quote_2.quote for p in proofs_minted])
@pytest.mark.asyncio
@@ -356,7 +356,7 @@ async def test_swap_to_send_more_than_balance(wallet1: Wallet):
await wallet1.mint(64, quote_id=mint_quote.quote)
await assert_err(
wallet1.swap_to_send(wallet1.proofs, 128, set_reserved=True),
"balance too low.",
"Balance too low",
)
assert wallet1.balance == 64
assert wallet1.available_balance == 64

251
tests/test_wallet_auth.py Normal file
View File

@@ -0,0 +1,251 @@
import hashlib
import os
import shutil
from pathlib import Path
import pytest
import pytest_asyncio
from cashu.core.base import Unit
from cashu.core.crypto.keys import random_hash
from cashu.core.crypto.secp import PrivateKey
from cashu.core.errors import (
BlindAuthFailedError,
BlindAuthRateLimitExceededError,
ClearAuthFailedError,
)
from cashu.core.settings import settings
from cashu.wallet.auth.auth import WalletAuth
from cashu.wallet.wallet import Wallet
from tests.conftest import SERVER_ENDPOINT
from tests.helpers import assert_err
@pytest_asyncio.fixture(scope="function")
async def wallet():
dirpath = Path("test_data/wallet")
if dirpath.exists() and dirpath.is_dir():
shutil.rmtree(dirpath)
wallet = await Wallet.with_db(
url=SERVER_ENDPOINT,
db="test_data/wallet",
name="wallet",
)
await wallet.load_mint()
yield wallet
@pytest.mark.skipif(
not settings.mint_require_auth,
reason="settings.mint_require_auth is False",
)
@pytest.mark.asyncio
async def test_wallet_auth_password(wallet: Wallet):
auth_wallet = await WalletAuth.with_db(
url=wallet.url,
db=wallet.db.db_location,
username="asd@asd.com",
password="asdasd",
)
requires_auth = await auth_wallet.init_auth_wallet(
wallet.mint_info, mint_auth_proofs=False
)
assert requires_auth
# expect JWT (CAT) with format ey*.ey*
assert auth_wallet.oidc_client.access_token
assert auth_wallet.oidc_client.access_token.split(".")[0].startswith("ey")
assert auth_wallet.oidc_client.access_token.split(".")[1].startswith("ey")
@pytest.mark.skipif(
not settings.mint_require_auth,
reason="settings.mint_require_auth is False",
)
@pytest.mark.asyncio
async def test_wallet_auth_wrong_password(wallet: Wallet):
auth_wallet = await WalletAuth.with_db(
url=wallet.url,
db=wallet.db.db_location,
username="asd@asd.com",
password="wrong_password",
)
await assert_err(auth_wallet.init_auth_wallet(wallet.mint_info), "401 Unauthorized")
@pytest.mark.skipif(
not settings.mint_require_auth,
reason="settings.mint_require_auth is False",
)
@pytest.mark.asyncio
async def test_wallet_auth_mint(wallet: Wallet):
auth_wallet = await WalletAuth.with_db(
url=wallet.url,
db=wallet.db.db_location,
username="asd@asd.com",
password="asdasd",
)
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
assert requires_auth
await auth_wallet.load_proofs()
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
@pytest.mark.skipif(
not settings.mint_require_auth,
reason="settings.mint_require_auth is False",
)
@pytest.mark.asyncio
async def test_wallet_auth_mint_manually(wallet: Wallet):
auth_wallet = await WalletAuth.with_db(
url=wallet.url,
db=wallet.db.db_location,
username="asd@asd.com",
password="asdasd",
)
requires_auth = await auth_wallet.init_auth_wallet(
wallet.mint_info, mint_auth_proofs=False
)
assert requires_auth
assert len(auth_wallet.proofs) == 0
await auth_wallet.mint_blind_auth()
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
@pytest.mark.skipif(
not settings.mint_require_auth,
reason="settings.mint_require_auth is False",
)
@pytest.mark.asyncio
async def test_wallet_auth_mint_manually_invalid_cat(wallet: Wallet):
auth_wallet = await WalletAuth.with_db(
url=wallet.url,
db=wallet.db.db_location,
username="asd@asd.com",
password="asdasd",
)
requires_auth = await auth_wallet.init_auth_wallet(
wallet.mint_info, mint_auth_proofs=False
)
assert requires_auth
assert len(auth_wallet.proofs) == 0
# invalidate CAT in the database
auth_wallet.oidc_client.access_token = random_hash()
# this is the code executed in auth_wallet.mint_blind_auth():
clear_auth_token = auth_wallet.oidc_client.access_token
if not clear_auth_token:
raise Exception("No clear auth token available.")
amounts = auth_wallet.mint_info.bat_max_mint * [1] # 1 AUTH tokens
secrets = [hashlib.sha256(os.urandom(32)).hexdigest() for _ in amounts]
rs = [PrivateKey(privkey=os.urandom(32), raw=True) for _ in amounts]
outputs, rs = auth_wallet._construct_outputs(amounts, secrets, rs)
# should fail because of invalid CAT
await assert_err(
auth_wallet.blind_mint_blind_auth(clear_auth_token, outputs),
ClearAuthFailedError.detail,
)
@pytest.mark.skipif(
not settings.mint_require_auth,
reason="settings.mint_require_auth is False",
)
@pytest.mark.asyncio
async def test_wallet_auth_invoice(wallet: Wallet):
# should fail, wallet error
await assert_err(wallet.mint_quote(10, Unit.sat), "Mint requires blind auth")
auth_wallet = await WalletAuth.with_db(
url=wallet.url,
db=wallet.db.db_location,
username="asd@asd.com",
password="asdasd",
)
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
assert requires_auth
await auth_wallet.load_proofs()
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
wallet.auth_db = auth_wallet.db
wallet.auth_keyset_id = auth_wallet.keyset_id
# should succeed
await wallet.mint_quote(10, Unit.sat)
@pytest.mark.skipif(
not settings.mint_require_auth,
reason="settings.mint_require_auth is False",
)
@pytest.mark.asyncio
async def test_wallet_auth_invoice_invalid_bat(wallet: Wallet):
# should fail, wallet error
await assert_err(wallet.mint_quote(10, Unit.sat), "Mint requires blind auth")
auth_wallet = await WalletAuth.with_db(
url=wallet.url,
db=wallet.db.db_location,
username="asd@asd.com",
password="asdasd",
)
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
assert requires_auth
await auth_wallet.load_proofs()
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
# invalidate blind auth proofs
for p in auth_wallet.proofs:
await auth_wallet.db.execute(
f"UPDATE proofs SET secret = '{random_hash()}' WHERE secret = '{p.secret}'"
)
wallet.auth_db = auth_wallet.db
wallet.auth_keyset_id = auth_wallet.keyset_id
# blind auth failed
await assert_err(wallet.mint_quote(10, Unit.sat), BlindAuthFailedError.detail)
@pytest.mark.skipif(
not settings.mint_require_auth,
reason="settings.mint_require_auth is False",
)
@pytest.mark.asyncio
async def test_wallet_auth_rate_limit(wallet: Wallet):
auth_wallet = await WalletAuth.with_db(
url=wallet.url,
db=wallet.db.db_location,
username="asd@asd.com",
password="asdasd",
)
requires_auth = await auth_wallet.init_auth_wallet(
wallet.mint_info, mint_auth_proofs=False
)
assert requires_auth
errored = False
for _ in range(100):
try:
await auth_wallet.mint_blind_auth()
except Exception as e:
assert BlindAuthRateLimitExceededError.detail in str(e)
errored = True
break
assert errored
# should have minted at least twice
assert len(auth_wallet.proofs) > auth_wallet.mint_info.bat_max_mint

View File

@@ -54,7 +54,7 @@ async def init_wallet():
wallet = await Wallet.with_db(
url=settings.mint_url,
db="test_data/test_cli_wallet",
name="wallet",
name="test_cli_wallet",
)
await wallet.load_proofs()
return wallet
@@ -411,7 +411,7 @@ def test_wallets(cli_prefix):
print("WALLETS")
# on github this is empty
if len(result.output):
assert "test_cli_wallet" in result.output
assert "wallet" in result.output
assert result.exit_code == 0
@@ -474,7 +474,7 @@ def test_send_too_much(mint, cli_prefix):
cli,
[*cli_prefix, "send", "100000"],
)
assert "balance too low" in str(result.exception)
assert "Balance too low" in str(result.exception)
def test_receive_tokenv3(mint, cli_prefix):