mirror of
https://github.com/aljazceru/nutshell.git
synced 2025-12-20 02:24:20 +01:00
Blind authentication (#675)
* auth server * cleaning up * auth ledger class * class variables -> instance variables * annotations * add models and api route * custom amount and api prefix * add auth db * blind auth token working * jwt working * clean up * JWT works * using openid connect server * use oauth server with password flow * new realm * add keycloak docker * hopefully not garbage * auth works * auth kinda working * fix cli * auth works for send and receive * pass auth_db to Wallet * auth in info * refactor * fix supported * cache mint info * fix settings and endpoints * add description to .env.example * track changes for openid connect client * store mint in db * store credentials * clean up v1_api.py * load mint info into auth wallet * fix first login * authenticate if refresh token fails * clear auth also middleware * use regex * add cli command * pw works * persist keyset amounts * add errors.py * do not start auth server if disabled in config * upadte poetry * disvoery url * fix test * support device code flow * adopt latest spec changes * fix code flow * mint max bat dynamic * mypy ignore * fix test * do not serialize amount in authproof * all auth flows working * fix tests * submodule * refactor * test * dont sleep * test * add wallet auth tests * test differently * test only keycloak for now * fix creds * daemon * fix test * install everything * install jinja * delete wallet for every test * auth: use global rate limiter * test auth rate limit * keycloak hostname * move keycloak test data * reactivate all tests * add readme * load proofs * remove unused code * remove unused code * implement change suggestions by ok300 * add error codes * test errors
This commit is contained in:
18
.env.example
18
.env.example
@@ -133,3 +133,21 @@ LIGHTNING_RESERVE_FEE_MIN=2000
|
||||
# MINT_GLOBAL_RATE_LIMIT_PER_MINUTE=60
|
||||
# Determines the number of transactions (mint, melt, swap) allowed per minute per IP
|
||||
# MINT_TRANSACTION_RATE_LIMIT_PER_MINUTE=20
|
||||
|
||||
# Authentication
|
||||
# These settings allow you to enable blind authentication to limit the user of your mint to a group of authenticated users.
|
||||
# To use this, you need to set up an OpenID Connect provider like Keycloak, Auth0, or Hydra.
|
||||
# - Add the client ID "cashu-client"
|
||||
# - Enable the ES256 and RS256 algorithms for this client
|
||||
# - If you want to use the authorization flow, you must add the redirect URI "http://localhost:33388/callback".
|
||||
# - To support other wallets, use the well-known list of allowed redirect URIs here: https://...TODO.md
|
||||
#
|
||||
# Turn on authentication
|
||||
# MINT_REQUIRE_AUTH=TRUE
|
||||
# OpenID Connect discovery URL of the authentication provider
|
||||
# MINT_AUTH_OICD_DISCOVERY_URL=http://localhost:8080/realms/nutshell/.well-known/openid-configuration
|
||||
# MINT_AUTH_OICD_CLIENT_ID=cashu-client
|
||||
# Number of authentication attempts allowed per minute per user
|
||||
# MINT_AUTH_RATE_LIMIT_PER_MINUTE=5
|
||||
# Maximum number of blind auth tokens per authentication request
|
||||
# MINT_AUTH_MAX_BLIND_TOKENS=100
|
||||
|
||||
15
.github/workflows/ci.yml
vendored
15
.github/workflows/ci.yml
vendored
@@ -42,6 +42,21 @@ jobs:
|
||||
poetry-version: ${{ matrix.poetry-version }}
|
||||
mint-database: ${{ matrix.mint-database }}
|
||||
|
||||
tests_keycloak_auth:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.10"]
|
||||
poetry-version: ["1.8.5"]
|
||||
mint-database: ["./test_data/test_mint", "postgres://cashu:cashu@localhost:5432/cashu"]
|
||||
uses: ./.github/workflows/tests_keycloak_auth.yml
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python-version: ${{ matrix.python-version }}
|
||||
poetry-version: ${{ matrix.poetry-version }}
|
||||
mint-database: ${{ matrix.mint-database }}
|
||||
|
||||
regtest:
|
||||
uses: ./.github/workflows/regtest.yml
|
||||
strategy:
|
||||
|
||||
77
.github/workflows/tests_keycloak_auth.yml
vendored
Normal file
77
.github/workflows/tests_keycloak_auth.yml
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
name: tests_keycloak
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
python-version:
|
||||
default: "3.10.4"
|
||||
type: string
|
||||
poetry-version:
|
||||
default: "1.8.5"
|
||||
type: string
|
||||
mint-database:
|
||||
default: ""
|
||||
type: string
|
||||
os:
|
||||
default: "ubuntu-latest"
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
poetry:
|
||||
name: Auth tests with Keycloak (db ${{ inputs.mint-database }})
|
||||
runs-on: ${{ inputs.os }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Prepare environment
|
||||
uses: ./.github/actions/prepare
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
poetry-version: ${{ inputs.poetry-version }}
|
||||
|
||||
- name: Start PostgreSQL service
|
||||
if: contains(inputs.mint-database, 'postgres')
|
||||
run: |
|
||||
docker run -d --name postgres \
|
||||
-e POSTGRES_USER=cashu \
|
||||
-e POSTGRES_PASSWORD=cashu \
|
||||
-e POSTGRES_DB=cashu \
|
||||
-p 5432:5432 postgres:16.4
|
||||
until docker exec postgres pg_isready; do sleep 1; done
|
||||
|
||||
- name: Prepare environment
|
||||
uses: ./.github/actions/prepare
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
poetry-version: ${{ inputs.poetry-version }}
|
||||
|
||||
- name: Start Keycloak with Backup
|
||||
run: |
|
||||
docker compose -f tests/keycloak_data/docker-compose-restore.yml up -d
|
||||
until docker logs $(docker ps -q --filter "ancestor=quay.io/keycloak/keycloak:25.0.6") | grep "Keycloak 25.0.6 on JVM (powered by Quarkus 3.8.5) started"; do sleep 1; done
|
||||
|
||||
- name: Verify Keycloak Import
|
||||
run: |
|
||||
docker logs $(docker ps -q --filter "ancestor=quay.io/keycloak/keycloak:25.0.6") | grep "Imported"
|
||||
|
||||
- name: Run tests
|
||||
env:
|
||||
MINT_BACKEND_BOLT11_SAT: FakeWallet
|
||||
WALLET_NAME: test_wallet
|
||||
MINT_HOST: localhost
|
||||
MINT_PORT: 3337
|
||||
MINT_TEST_DATABASE: ${{ inputs.mint-database }}
|
||||
TOR: false
|
||||
MINT_REQUIRE_AUTH: TRUE
|
||||
MINT_AUTH_OICD_DISCOVERY_URL: http://localhost:8080/realms/nutshell/.well-known/openid-configuration
|
||||
MINT_AUTH_OICD_CLIENT_ID: cashu-client
|
||||
run: |
|
||||
poetry run pytest tests/test_wallet_auth.py -v --cov=mint --cov-report=xml
|
||||
|
||||
- name: Stop and clean up Docker Compose
|
||||
run: |
|
||||
docker compose -f tests/keycloak_data/docker-compose-restore.yml down
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from sqlite3 import Row
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Union
|
||||
|
||||
import cbor2
|
||||
from loguru import logger
|
||||
@@ -19,7 +19,7 @@ from .crypto.aes import AESCipher
|
||||
from .crypto.b_dhke import hash_to_curve
|
||||
from .crypto.keys import (
|
||||
derive_keys,
|
||||
derive_keys_sha256,
|
||||
derive_keys_deprecated_pre_0_15,
|
||||
derive_keyset_id,
|
||||
derive_keyset_id_deprecated,
|
||||
derive_pubkeys,
|
||||
@@ -173,6 +173,9 @@ class Proof(BaseModel):
|
||||
|
||||
return return_dict
|
||||
|
||||
def to_base64(self):
|
||||
return base64.b64encode(cbor2.dumps(self.to_dict(include_dleq=True))).decode()
|
||||
|
||||
def to_dict_no_dleq(self):
|
||||
# dictionary without the fields that don't need to be send to Carol
|
||||
return dict(id=self.id, amount=self.amount, secret=self.secret, C=self.C)
|
||||
@@ -541,6 +544,7 @@ class Unit(Enum):
|
||||
usd = 2
|
||||
eur = 3
|
||||
btc = 4
|
||||
auth = 999
|
||||
|
||||
def str(self, amount: int) -> str:
|
||||
if self == Unit.sat:
|
||||
@@ -553,6 +557,8 @@ class Unit(Enum):
|
||||
return f"{amount/100:.2f} EUR"
|
||||
elif self == Unit.btc:
|
||||
return f"{amount/1e8:.8f} BTC"
|
||||
elif self == Unit.auth:
|
||||
return f"{amount} AUTH"
|
||||
else:
|
||||
raise Exception("Invalid unit")
|
||||
|
||||
@@ -724,6 +730,7 @@ class MintKeyset:
|
||||
valid_to: Optional[str] = None
|
||||
first_seen: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
amounts: List[int]
|
||||
|
||||
duplicate_keyset_id: Optional[str] = None # BACKWARDS COMPATIBILITY < 0.15.0
|
||||
|
||||
@@ -734,6 +741,7 @@ class MintKeyset:
|
||||
seed: Optional[str] = None,
|
||||
encrypted_seed: Optional[str] = None,
|
||||
seed_encryption_method: Optional[str] = None,
|
||||
amounts: Optional[List[int]] = None,
|
||||
valid_from: Optional[str] = None,
|
||||
valid_to: Optional[str] = None,
|
||||
first_seen: Optional[str] = None,
|
||||
@@ -762,6 +770,12 @@ class MintKeyset:
|
||||
|
||||
assert self.seed, "seed not set"
|
||||
|
||||
if amounts:
|
||||
self.amounts = amounts
|
||||
else:
|
||||
# use 2^n amounts by default
|
||||
self.amounts = [2**i for i in range(settings.max_order)]
|
||||
|
||||
self.id = id
|
||||
self.valid_from = valid_from
|
||||
self.valid_to = valid_to
|
||||
@@ -805,6 +819,24 @@ class MintKeyset:
|
||||
|
||||
logger.trace(f"Loaded keyset id: {self.id} ({self.unit.name})")
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: Row):
|
||||
return cls(
|
||||
id=row["id"],
|
||||
derivation_path=row["derivation_path"],
|
||||
seed=row["seed"],
|
||||
encrypted_seed=row["encrypted_seed"],
|
||||
seed_encryption_method=row["seed_encryption_method"],
|
||||
valid_from=row["valid_from"],
|
||||
valid_to=row["valid_to"],
|
||||
first_seen=row["first_seen"],
|
||||
active=row["active"],
|
||||
unit=row["unit"],
|
||||
version=row["version"],
|
||||
input_fee_ppk=row["input_fee_ppk"],
|
||||
amounts=json.loads(row["amounts"]),
|
||||
)
|
||||
|
||||
@property
|
||||
def public_keys_hex(self) -> Dict[int, str]:
|
||||
assert self.public_keys, "public keys not set"
|
||||
@@ -830,23 +862,27 @@ class MintKeyset:
|
||||
self.private_keys = derive_keys_backwards_compatible_insecure_pre_0_12(
|
||||
self.seed, self.derivation_path
|
||||
)
|
||||
self.public_keys = derive_pubkeys(self.private_keys) # type: ignore
|
||||
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore
|
||||
logger.trace(
|
||||
f"WARNING: Using weak key derivation for keyset {self.id} (backwards"
|
||||
" compatibility < 0.12)"
|
||||
)
|
||||
self.id = id_in_db or derive_keyset_id_deprecated(self.public_keys) # type: ignore
|
||||
elif self.version_tuple < (0, 15):
|
||||
self.private_keys = derive_keys_sha256(self.seed, self.derivation_path)
|
||||
self.private_keys = derive_keys_deprecated_pre_0_15(
|
||||
self.seed, self.amounts, self.derivation_path
|
||||
)
|
||||
logger.trace(
|
||||
f"WARNING: Using non-bip32 derivation for keyset {self.id} (backwards"
|
||||
" compatibility < 0.15)"
|
||||
)
|
||||
self.public_keys = derive_pubkeys(self.private_keys) # type: ignore
|
||||
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore
|
||||
self.id = id_in_db or derive_keyset_id_deprecated(self.public_keys) # type: ignore
|
||||
else:
|
||||
self.private_keys = derive_keys(self.seed, self.derivation_path)
|
||||
self.public_keys = derive_pubkeys(self.private_keys) # type: ignore
|
||||
self.private_keys = derive_keys(
|
||||
self.seed, self.derivation_path, self.amounts
|
||||
)
|
||||
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore
|
||||
self.id = id_in_db or derive_keyset_id(self.public_keys) # type: ignore
|
||||
|
||||
|
||||
@@ -1254,3 +1290,48 @@ class TokenV4(Token):
|
||||
t=[TokenV4Token(**t) for t in token_dict["t"]],
|
||||
d=token_dict.get("d", None),
|
||||
)
|
||||
|
||||
|
||||
class AuthProof(BaseModel):
|
||||
"""
|
||||
Blind authentication token
|
||||
"""
|
||||
|
||||
id: str
|
||||
secret: str # secret
|
||||
C: str # signature
|
||||
amount: int = 1 # default amount
|
||||
|
||||
prefix: ClassVar[str] = "authA"
|
||||
|
||||
@classmethod
|
||||
def from_proof(cls, proof: Proof):
|
||||
return cls(id=proof.id, secret=proof.secret, C=proof.C)
|
||||
|
||||
def to_base64(self):
|
||||
serialize_dict = self.dict()
|
||||
serialize_dict.pop("amount", None)
|
||||
return (
|
||||
self.prefix + base64.b64encode(json.dumps(serialize_dict).encode()).decode()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_base64(cls, base64_str: str):
|
||||
assert base64_str.startswith(cls.prefix), Exception(
|
||||
f"Token prefix not valid. Expected {cls.prefix}."
|
||||
)
|
||||
base64_str = base64_str[len(cls.prefix) :]
|
||||
return cls.parse_obj(json.loads(base64.b64decode(base64_str).decode()))
|
||||
|
||||
def to_proof(self):
|
||||
return Proof(id=self.id, secret=self.secret, C=self.C, amount=self.amount)
|
||||
|
||||
|
||||
class WalletMint(BaseModel):
|
||||
url: str
|
||||
info: str
|
||||
updated: Optional[str] = None
|
||||
access_token: Optional[str] = None
|
||||
refresh_token: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
@@ -1,54 +1,56 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import random
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
from bip32 import BIP32
|
||||
|
||||
from ..settings import settings
|
||||
from .secp import PrivateKey, PublicKey
|
||||
|
||||
|
||||
def derive_keys(mnemonic: str, derivation_path: str):
|
||||
def derive_keys(mnemonic: str, derivation_path: str, amounts: List[int]):
|
||||
"""
|
||||
Deterministic derivation of keys for 2^n values.
|
||||
"""
|
||||
bip32 = BIP32.from_seed(mnemonic.encode())
|
||||
orders_str = [f"/{i}'" for i in range(settings.max_order)]
|
||||
orders_str = [f"/{a}'" for a in range(len(amounts))]
|
||||
return {
|
||||
2**i: PrivateKey(
|
||||
a: PrivateKey(
|
||||
bip32.get_privkey_from_path(derivation_path + orders_str[i]),
|
||||
raw=True,
|
||||
)
|
||||
for i in range(settings.max_order)
|
||||
for i, a in enumerate(amounts)
|
||||
}
|
||||
|
||||
|
||||
def derive_keys_sha256(seed: str, derivation_path: str = ""):
|
||||
def derive_keys_deprecated_pre_0_15(
|
||||
seed: str, amounts: List[int], derivation_path: str = ""
|
||||
):
|
||||
"""
|
||||
Deterministic derivation of keys for 2^n values.
|
||||
TODO: Implement BIP32.
|
||||
"""
|
||||
return {
|
||||
2**i: PrivateKey(
|
||||
a: PrivateKey(
|
||||
hashlib.sha256((seed + derivation_path + str(i)).encode("utf-8")).digest()[
|
||||
:32
|
||||
],
|
||||
raw=True,
|
||||
)
|
||||
for i in range(settings.max_order)
|
||||
for i, a in enumerate(amounts)
|
||||
}
|
||||
|
||||
|
||||
def derive_pubkey(seed: str):
|
||||
return PrivateKey(
|
||||
def derive_pubkey(seed: str) -> PublicKey:
|
||||
pubkey = PrivateKey(
|
||||
hashlib.sha256((seed).encode("utf-8")).digest()[:32],
|
||||
raw=True,
|
||||
).pubkey
|
||||
assert pubkey
|
||||
return pubkey
|
||||
|
||||
|
||||
def derive_pubkeys(keys: Dict[int, PrivateKey]):
|
||||
return {amt: keys[amt].pubkey for amt in [2**i for i in range(settings.max_order)]}
|
||||
def derive_pubkeys(keys: Dict[int, PrivateKey], amounts: List[int]):
|
||||
return {amt: keys[amt].pubkey for amt in amounts}
|
||||
|
||||
|
||||
def derive_keyset_id(keys: Dict[int, PublicKey]):
|
||||
|
||||
@@ -339,11 +339,21 @@ class Database(Compat):
|
||||
raise Exception("Timestamp is None")
|
||||
return timestamp
|
||||
|
||||
def to_timestamp(self, timestamp_str: str) -> Union[str, datetime.datetime]:
|
||||
if not timestamp_str:
|
||||
timestamp_str = self.timestamp_now_str()
|
||||
def to_timestamp(
|
||||
self, timestamp: Union[str, datetime.datetime]
|
||||
) -> Union[str, datetime.datetime]:
|
||||
if not timestamp:
|
||||
timestamp = self.timestamp_now_str()
|
||||
if self.type in {POSTGRES, COCKROACH}:
|
||||
return datetime.datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
|
||||
# return datetime.datetime
|
||||
if isinstance(timestamp, datetime.datetime):
|
||||
return timestamp
|
||||
elif isinstance(timestamp, str):
|
||||
return datetime.datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
|
||||
elif self.type == SQLITE:
|
||||
return timestamp_str
|
||||
# return str
|
||||
if isinstance(timestamp, datetime.datetime):
|
||||
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
elif isinstance(timestamp, str):
|
||||
return timestamp
|
||||
return "<nothing>"
|
||||
|
||||
@@ -18,6 +18,7 @@ class NotAllowedError(CashuError):
|
||||
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
||||
super().__init__(detail or self.detail, code=code or self.code)
|
||||
|
||||
|
||||
class OutputsAlreadySignedError(CashuError):
|
||||
detail = "outputs have already been signed before."
|
||||
code = 10002
|
||||
@@ -25,6 +26,7 @@ class OutputsAlreadySignedError(CashuError):
|
||||
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
||||
super().__init__(detail or self.detail, code=code or self.code)
|
||||
|
||||
|
||||
class InvalidProofsError(CashuError):
|
||||
detail = "proofs could not be verified"
|
||||
code = 10003
|
||||
@@ -32,6 +34,7 @@ class InvalidProofsError(CashuError):
|
||||
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
||||
super().__init__(detail or self.detail, code=code or self.code)
|
||||
|
||||
|
||||
class TransactionError(CashuError):
|
||||
detail = "transaction error"
|
||||
code = 11000
|
||||
@@ -76,12 +79,14 @@ class TransactionUnitError(TransactionError):
|
||||
def __init__(self, detail):
|
||||
super().__init__(detail, code=self.code)
|
||||
|
||||
|
||||
class TransactionAmountExceedsLimitError(TransactionError):
|
||||
code = 11006
|
||||
|
||||
def __init__(self, detail):
|
||||
super().__init__(detail, code=self.code)
|
||||
|
||||
|
||||
class KeysetError(CashuError):
|
||||
detail = "keyset error"
|
||||
code = 12000
|
||||
@@ -113,7 +118,7 @@ class QuoteNotPaidError(CashuError):
|
||||
code = 20001
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.detail, code=2001)
|
||||
super().__init__(self.detail, code=self.code)
|
||||
|
||||
|
||||
class QuoteSignatureInvalidError(CashuError):
|
||||
@@ -121,7 +126,7 @@ class QuoteSignatureInvalidError(CashuError):
|
||||
code = 20008
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.detail, code=20008)
|
||||
super().__init__(self.detail, code=self.code)
|
||||
|
||||
|
||||
class QuoteRequiresPubkeyError(CashuError):
|
||||
@@ -129,4 +134,52 @@ class QuoteRequiresPubkeyError(CashuError):
|
||||
code = 20009
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.detail, code=20009)
|
||||
super().__init__(self.detail, code=self.code)
|
||||
|
||||
|
||||
class ClearAuthRequiredError(CashuError):
|
||||
detail = "Endpoint requires clear auth"
|
||||
code = 80001
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.detail, code=self.code)
|
||||
|
||||
|
||||
class ClearAuthFailedError(CashuError):
|
||||
detail = "Clear authentication failed"
|
||||
code = 80002
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.detail, code=self.code)
|
||||
|
||||
|
||||
class BlindAuthRequiredError(CashuError):
|
||||
detail = "Endpoint requires blind auth"
|
||||
code = 81001
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.detail, code=self.code)
|
||||
|
||||
|
||||
class BlindAuthFailedError(CashuError):
|
||||
detail = "Blind authentication failed"
|
||||
code = 81002
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.detail, code=self.code)
|
||||
|
||||
|
||||
class BlindAuthAmountExceededError(CashuError):
|
||||
detail = "Maximum blind auth amount exceeded"
|
||||
code = 81003
|
||||
|
||||
def __init__(self, detail: Optional[str] = None):
|
||||
super().__init__(detail or self.detail, code=self.code)
|
||||
|
||||
|
||||
class BlindAuthRateLimitExceededError(CashuError):
|
||||
detail = "Blind auth token mint rate limit exceeded"
|
||||
code = 81004
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.detail, code=self.code)
|
||||
|
||||
125
cashu/core/mint_info.py
Normal file
125
cashu/core/mint_info.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .base import Method, Unit
|
||||
from .models import MintInfoContact, MintInfoProtectedEndpoint, Nut15MppSupport
|
||||
from .nuts.nuts import BLIND_AUTH_NUT, CLEAR_AUTH_NUT, MPP_NUT, WEBSOCKETS_NUT
|
||||
|
||||
|
||||
class MintInfo(BaseModel):
|
||||
name: Optional[str]
|
||||
pubkey: Optional[str]
|
||||
version: Optional[str]
|
||||
description: Optional[str]
|
||||
description_long: Optional[str]
|
||||
contact: Optional[List[MintInfoContact]]
|
||||
motd: Optional[str]
|
||||
icon_url: Optional[str]
|
||||
time: Optional[int]
|
||||
nuts: Dict[int, Any]
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name} ({self.description})"
|
||||
|
||||
@classmethod
|
||||
def from_json_str(cls, json_str: str):
|
||||
return cls.parse_obj(json.loads(json_str))
|
||||
|
||||
def supports_nut(self, nut: int) -> bool:
|
||||
if self.nuts is None:
|
||||
return False
|
||||
return nut in self.nuts
|
||||
|
||||
def supports_mpp(self, method: str, unit: Unit) -> bool:
|
||||
if not self.nuts:
|
||||
return False
|
||||
nut_15 = self.nuts.get(MPP_NUT)
|
||||
if not nut_15 or not self.supports_nut(MPP_NUT) or not nut_15.get("methods"):
|
||||
return False
|
||||
|
||||
for entry in nut_15["methods"]:
|
||||
entry_obj = Nut15MppSupport.parse_obj(entry)
|
||||
if entry_obj.method == method and entry_obj.unit == unit.name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def supports_websocket_mint_quote(self, method: Method, unit: Unit) -> bool:
|
||||
if not self.nuts or not self.supports_nut(WEBSOCKETS_NUT):
|
||||
return False
|
||||
websocket_settings = self.nuts[WEBSOCKETS_NUT]
|
||||
if not websocket_settings or "supported" not in websocket_settings:
|
||||
return False
|
||||
websocket_supported = websocket_settings["supported"]
|
||||
for entry in websocket_supported:
|
||||
if entry["method"] == method.name and entry["unit"] == unit.name:
|
||||
if "bolt11_mint_quote" in entry["commands"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def requires_clear_auth(self) -> bool:
|
||||
return self.supports_nut(CLEAR_AUTH_NUT)
|
||||
|
||||
def oidc_discovery_url(self) -> str:
|
||||
if not self.requires_clear_auth():
|
||||
raise Exception(
|
||||
"Could not get OIDC discovery URL. Mint info does not support clear auth."
|
||||
)
|
||||
return self.nuts[CLEAR_AUTH_NUT]["openid_discovery"]
|
||||
|
||||
def oidc_client_id(self) -> str:
|
||||
if not self.requires_clear_auth():
|
||||
raise Exception(
|
||||
"Could not get client_id. Mint info does not support clear auth."
|
||||
)
|
||||
return self.nuts[CLEAR_AUTH_NUT]["client_id"]
|
||||
|
||||
def required_clear_auth_endpoints(self) -> List[MintInfoProtectedEndpoint]:
|
||||
if not self.requires_clear_auth():
|
||||
return []
|
||||
return [
|
||||
MintInfoProtectedEndpoint.parse_obj(e)
|
||||
for e in self.nuts[CLEAR_AUTH_NUT]["protected_endpoints"]
|
||||
]
|
||||
|
||||
def requires_clear_auth_path(self, method: str, path: str) -> bool:
|
||||
if not self.requires_clear_auth():
|
||||
return False
|
||||
path = "/" + path if not path.startswith("/") else path
|
||||
for endpoint in self.required_clear_auth_endpoints():
|
||||
if method == endpoint.method and re.match(endpoint.path, path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def requires_blind_auth(self) -> bool:
|
||||
return self.supports_nut(BLIND_AUTH_NUT)
|
||||
|
||||
@property
|
||||
def bat_max_mint(self) -> int:
|
||||
if not self.requires_blind_auth():
|
||||
raise Exception(
|
||||
"Could not get max mint. Mint info does not support blind auth."
|
||||
)
|
||||
if not self.nuts[BLIND_AUTH_NUT].get("bat_max_mint"):
|
||||
raise Exception("Could not get max mint. bat_max_mint not set.")
|
||||
return self.nuts[BLIND_AUTH_NUT]["bat_max_mint"]
|
||||
|
||||
def required_blind_auth_paths(self) -> List[MintInfoProtectedEndpoint]:
|
||||
if not self.requires_blind_auth():
|
||||
return []
|
||||
return [
|
||||
MintInfoProtectedEndpoint.parse_obj(e)
|
||||
for e in self.nuts[BLIND_AUTH_NUT]["protected_endpoints"]
|
||||
]
|
||||
|
||||
def requires_blind_auth_path(self, method: str, path: str) -> bool:
|
||||
if not self.requires_blind_auth():
|
||||
return False
|
||||
path = "/" + path if not path.startswith("/") else path
|
||||
for endpoint in self.required_blind_auth_paths():
|
||||
if method == endpoint.method and re.match(endpoint.path, path):
|
||||
return True
|
||||
return False
|
||||
@@ -38,6 +38,11 @@ class MintInfoContact(BaseModel):
|
||||
info: str
|
||||
|
||||
|
||||
class MintInfoProtectedEndpoint(BaseModel):
|
||||
method: str
|
||||
path: str
|
||||
|
||||
|
||||
class GetInfoResponse(BaseModel):
|
||||
name: Optional[str] = None
|
||||
pubkey: Optional[str] = None
|
||||
@@ -57,7 +62,7 @@ class GetInfoResponse(BaseModel):
|
||||
# BEGIN DEPRECATED: NUT-06 contact field change
|
||||
# NUT-06 PR: https://github.com/cashubtc/nuts/pull/117
|
||||
@root_validator(pre=True)
|
||||
def preprocess_deprecated_contact_field(cls, values):
|
||||
def preprocess_deprecated_contact_field(cls, values: dict):
|
||||
if "contact" in values and values["contact"]:
|
||||
if isinstance(values["contact"][0], list):
|
||||
values["contact"] = [
|
||||
@@ -346,3 +351,16 @@ class PostRestoreResponse(BaseModel):
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.promises = self.signatures
|
||||
|
||||
|
||||
# ------- API: BLIND AUTH -------
|
||||
class PostAuthBlindMintRequest(BaseModel):
|
||||
outputs: List[BlindedMessage] = Field(
|
||||
...,
|
||||
max_items=settings.mint_max_request_length,
|
||||
description="Blinded messages for creating blind auth tokens.",
|
||||
)
|
||||
|
||||
|
||||
class PostAuthBlindMintResponse(BaseModel):
|
||||
signatures: List[BlindedSignature] = []
|
||||
|
||||
@@ -14,3 +14,5 @@ MPP_NUT = 15
|
||||
WEBSOCKETS_NUT = 17
|
||||
CACHE_NUT = 19
|
||||
MINT_QUOTE_SIGNATURE_NUT = 20
|
||||
CLEAR_AUTH_NUT = 21
|
||||
BLIND_AUTH_NUT = 22
|
||||
|
||||
@@ -68,6 +68,8 @@ class MintSettings(CashuSettings):
|
||||
class MintDeprecationFlags(MintSettings):
|
||||
mint_inactivate_base64_keysets: bool = Field(default=False)
|
||||
|
||||
auth_database: str = Field(default="data/mint")
|
||||
|
||||
|
||||
class MintBackends(MintSettings):
|
||||
mint_lightning_backend: str = Field(default="") # deprecated
|
||||
@@ -231,6 +233,27 @@ class CoreLightningRestFundingSource(MintSettings):
|
||||
mint_corelightning_rest_cert: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class AuthSettings(MintSettings):
|
||||
mint_require_auth: bool = Field(default=False)
|
||||
mint_auth_oicd_discovery_url: Optional[str] = Field(default=None)
|
||||
mint_auth_oicd_client_id: str = Field(default="cashu-client")
|
||||
mint_auth_rate_limit_per_minute: int = Field(
|
||||
default=5,
|
||||
title="Auth rate limit per minute",
|
||||
description="Number of requests a user can authenticate per minute.",
|
||||
)
|
||||
mint_auth_max_blind_tokens: int = Field(default=100, gt=0)
|
||||
mint_require_clear_auth_paths: List[List[str]] = [
|
||||
["POST", "/v1/auth/blind/mint"],
|
||||
]
|
||||
mint_require_blind_auth_paths: List[List[str]] = [
|
||||
["POST", "/v1/swap"],
|
||||
["POST", "/v1/mint/quote/bolt11"],
|
||||
["POST", "/v1/mint/bolt11"],
|
||||
["POST", "/v1/melt/bolt11"],
|
||||
]
|
||||
|
||||
|
||||
class MintRedisCache(MintSettings):
|
||||
mint_redis_cache_enabled: bool = Field(default=False)
|
||||
mint_redis_cache_url: Optional[str] = Field(default=None)
|
||||
@@ -246,6 +269,7 @@ class Settings(
|
||||
FakeWalletSettings,
|
||||
MintLimits,
|
||||
MintBackends,
|
||||
AuthSettings,
|
||||
MintRedisCache,
|
||||
MintDeprecationFlags,
|
||||
MintSettings,
|
||||
|
||||
@@ -13,10 +13,10 @@ from starlette.requests import Request
|
||||
from ..core.errors import CashuError
|
||||
from ..core.logging import configure_logger
|
||||
from ..core.settings import settings
|
||||
from .auth.router import auth_router
|
||||
from .router import redis, router
|
||||
from .router_deprecated import router_deprecated
|
||||
from .startup import shutdown_mint as shutdown_mint_init
|
||||
from .startup import start_mint_init
|
||||
from .startup import shutdown_mint, start_auth, start_mint
|
||||
|
||||
if settings.debug_profiling:
|
||||
pass
|
||||
@@ -29,7 +29,9 @@ from .middleware import add_middlewares, request_validation_exception_handler
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
|
||||
await start_mint_init()
|
||||
await start_mint()
|
||||
if settings.mint_require_auth:
|
||||
await start_auth()
|
||||
try:
|
||||
yield
|
||||
except asyncio.CancelledError:
|
||||
@@ -38,7 +40,7 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
|
||||
finally:
|
||||
try:
|
||||
await redis.disconnect()
|
||||
await shutdown_mint_init()
|
||||
await shutdown_mint()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("CancelledError during shutdown, shutting down forcefully")
|
||||
|
||||
@@ -110,3 +112,6 @@ if settings.debug_mint_only_deprecated:
|
||||
else:
|
||||
app.include_router(router=router, tags=["Mint"])
|
||||
app.include_router(router=router_deprecated, tags=["Deprecated"], deprecated=True)
|
||||
|
||||
if settings.mint_require_auth:
|
||||
app.include_router(auth_router, tags=["Auth"])
|
||||
|
||||
13
cashu/mint/auth/base.py
Normal file
13
cashu/mint/auth/base.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class User:
|
||||
id: str
|
||||
last_access: Optional[datetime.datetime]
|
||||
|
||||
def __init__(self, id: str, last_access: Optional[datetime.datetime] = None):
|
||||
self.id = id
|
||||
if isinstance(last_access, int):
|
||||
last_access = datetime.datetime.fromtimestamp(last_access)
|
||||
self.last_access = last_access
|
||||
669
cashu/mint/auth/crud.py
Normal file
669
cashu/mint/auth/crud.py
Normal file
@@ -0,0 +1,669 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ...core.base import (
|
||||
BlindedSignature,
|
||||
MeltQuote,
|
||||
MintKeyset,
|
||||
MintQuote,
|
||||
Proof,
|
||||
)
|
||||
from ...core.db import (
|
||||
Connection,
|
||||
Database,
|
||||
)
|
||||
from .base import User
|
||||
|
||||
|
||||
class AuthLedgerCrud(ABC):
|
||||
"""
|
||||
Database interface for Nutshell auth ledger.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_user(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
user: User,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
user_id: str,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Optional[User]: ...
|
||||
|
||||
async def update_user(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
user_id: str,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_keyset(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
id: str = "",
|
||||
derivation_path: str = "",
|
||||
seed: str = "",
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[MintKeyset]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_proofs_used(
|
||||
self,
|
||||
*,
|
||||
Ys: List[str],
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[Proof]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def invalidate_proof(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
proof: Proof,
|
||||
quote_id: Optional[str] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_proofs_pending(
|
||||
self,
|
||||
*,
|
||||
Ys: List[str],
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[Proof]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def set_proof_pending(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
proof: Proof,
|
||||
quote_id: Optional[str] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def unset_proof_pending(
|
||||
self,
|
||||
*,
|
||||
proof: Proof,
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def store_keyset(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
keyset: MintKeyset,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def store_promise(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
amount: int,
|
||||
b_: str,
|
||||
c_: str,
|
||||
id: str,
|
||||
e: str = "",
|
||||
s: str = "",
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_promise(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
b_: str,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Optional[BlindedSignature]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_promises(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
b_s: List[str],
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[BlindedSignature]: ...
|
||||
|
||||
|
||||
class AuthLedgerCrudSqlite(AuthLedgerCrud):
|
||||
"""Implementation of AuthLedgerCrud for sqlite.
|
||||
|
||||
Args:
|
||||
AuthLedgerCrud (ABC): Abstract base class for AuthLedgerCrud.
|
||||
"""
|
||||
|
||||
async def create_user(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
user: User,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
INSERT INTO {db.table_with_schema('users')}
|
||||
(id)
|
||||
VALUES (:id)
|
||||
""",
|
||||
{"id": user.id},
|
||||
)
|
||||
|
||||
async def get_user(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
user_id: str,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Optional[User]:
|
||||
row = await (conn or db).fetchone(
|
||||
f"""
|
||||
SELECT * from {db.table_with_schema('users')}
|
||||
WHERE id = :user_id
|
||||
""",
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return User(**row) if row else None
|
||||
|
||||
async def update_user(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
user_id: str,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
UPDATE {db.table_with_schema('users')}
|
||||
SET last_access = :last_access
|
||||
WHERE id = :user_id
|
||||
""",
|
||||
{
|
||||
"last_access": db.to_timestamp(db.timestamp_now_str()),
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def store_promise(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
amount: int,
|
||||
b_: str,
|
||||
c_: str,
|
||||
id: str,
|
||||
e: str = "",
|
||||
s: str = "",
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
INSERT INTO {db.table_with_schema('promises')}
|
||||
(amount, b_, c_, dleq_e, dleq_s, id, created)
|
||||
VALUES (:amount, :b_, :c_, :dleq_e, :dleq_s, :id, :created)
|
||||
""",
|
||||
{
|
||||
"amount": amount,
|
||||
"b_": b_,
|
||||
"c_": c_,
|
||||
"dleq_e": e,
|
||||
"dleq_s": s,
|
||||
"id": id,
|
||||
"created": db.to_timestamp(db.timestamp_now_str()),
|
||||
},
|
||||
)
|
||||
|
||||
async def get_promise(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
b_: str,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Optional[BlindedSignature]:
|
||||
row = await (conn or db).fetchone(
|
||||
f"""
|
||||
SELECT * from {db.table_with_schema('promises')}
|
||||
WHERE b_ = :b_
|
||||
""",
|
||||
{"b_": str(b_)},
|
||||
)
|
||||
return BlindedSignature.from_row(row) if row else None
|
||||
|
||||
async def get_promises(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
b_s: List[str],
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[BlindedSignature]:
|
||||
rows = await (conn or db).fetchall(
|
||||
f"""
|
||||
SELECT * from {db.table_with_schema('promises')}
|
||||
WHERE b_ IN ({','.join([':b_' + str(i) for i in range(len(b_s))])})
|
||||
""",
|
||||
{f"b_{i}": b_s[i] for i in range(len(b_s))},
|
||||
)
|
||||
return [BlindedSignature.from_row(r) for r in rows] if rows else []
|
||||
|
||||
async def invalidate_proof(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
proof: Proof,
|
||||
quote_id: Optional[str] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
INSERT INTO {db.table_with_schema('proofs_used')}
|
||||
(amount, c, secret, y, id, witness, created, melt_quote)
|
||||
VALUES (:amount, :c, :secret, :y, :id, :witness, :created, :melt_quote)
|
||||
""",
|
||||
{
|
||||
"amount": proof.amount,
|
||||
"c": proof.C,
|
||||
"secret": proof.secret,
|
||||
"y": proof.Y,
|
||||
"id": proof.id,
|
||||
"witness": proof.witness,
|
||||
"created": db.to_timestamp(db.timestamp_now_str()),
|
||||
"melt_quote": quote_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def get_all_melt_quotes_from_pending_proofs(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[MeltQuote]:
|
||||
rows = await (conn or db).fetchall(
|
||||
f"""
|
||||
SELECT * from {db.table_with_schema('melt_quotes')} WHERE quote in (SELECT DISTINCT melt_quote FROM {db.table_with_schema('proofs_pending')})
|
||||
"""
|
||||
)
|
||||
return [MeltQuote.from_row(r) for r in rows]
|
||||
|
||||
async def get_pending_proofs_for_quote(
|
||||
self,
|
||||
*,
|
||||
quote_id: str,
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[Proof]:
|
||||
rows = await (conn or db).fetchall(
|
||||
f"""
|
||||
SELECT * from {db.table_with_schema('proofs_pending')}
|
||||
WHERE melt_quote = :quote_id
|
||||
""",
|
||||
{"quote_id": quote_id},
|
||||
)
|
||||
return [Proof(**r) for r in rows]
|
||||
|
||||
async def get_proofs_pending(
|
||||
self,
|
||||
*,
|
||||
Ys: List[str],
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[Proof]:
|
||||
query = f"""
|
||||
SELECT * from {db.table_with_schema('proofs_pending')}
|
||||
WHERE y IN ({','.join([':y_' + str(i) for i in range(len(Ys))])})
|
||||
"""
|
||||
values = {f"y_{i}": Ys[i] for i in range(len(Ys))}
|
||||
rows = await (conn or db).fetchall(query, values)
|
||||
return [Proof(**r) for r in rows]
|
||||
|
||||
async def set_proof_pending(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
proof: Proof,
|
||||
quote_id: Optional[str] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
INSERT INTO {db.table_with_schema('proofs_pending')}
|
||||
(amount, c, secret, y, id, witness, created, melt_quote)
|
||||
VALUES (:amount, :c, :secret, :y, :id, :witness, :created, :melt_quote)
|
||||
""",
|
||||
{
|
||||
"amount": proof.amount,
|
||||
"c": proof.C,
|
||||
"secret": proof.secret,
|
||||
"y": proof.Y,
|
||||
"id": proof.id,
|
||||
"witness": proof.witness,
|
||||
"created": db.to_timestamp(db.timestamp_now_str()),
|
||||
"melt_quote": quote_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def unset_proof_pending(
|
||||
self,
|
||||
*,
|
||||
proof: Proof,
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
DELETE FROM {db.table_with_schema('proofs_pending')}
|
||||
WHERE secret = :secret
|
||||
""",
|
||||
{"secret": proof.secret},
|
||||
)
|
||||
|
||||
async def store_mint_quote(
|
||||
self,
|
||||
*,
|
||||
quote: MintQuote,
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
INSERT INTO {db.table_with_schema('mint_quotes')}
|
||||
(quote, method, request, checking_id, unit, amount, issued, paid, state, created_time, paid_time)
|
||||
VALUES (:quote, :method, :request, :checking_id, :unit, :amount, :issued, :paid, :state, :created_time, :paid_time)
|
||||
""",
|
||||
{
|
||||
"quote": quote.quote,
|
||||
"method": quote.method,
|
||||
"request": quote.request,
|
||||
"checking_id": quote.checking_id,
|
||||
"unit": quote.unit,
|
||||
"amount": quote.amount,
|
||||
"issued": quote.issued,
|
||||
"paid": quote.paid,
|
||||
"state": quote.state.name,
|
||||
"created_time": db.to_timestamp(
|
||||
db.timestamp_from_seconds(quote.created_time) or ""
|
||||
),
|
||||
"paid_time": db.to_timestamp(
|
||||
db.timestamp_from_seconds(quote.paid_time) or ""
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def get_mint_quote(
|
||||
self,
|
||||
*,
|
||||
quote_id: Optional[str] = None,
|
||||
checking_id: Optional[str] = None,
|
||||
request: Optional[str] = None,
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Optional[MintQuote]:
|
||||
clauses = []
|
||||
values: Dict[str, Any] = {}
|
||||
if quote_id:
|
||||
clauses.append("quote = :quote_id")
|
||||
values["quote_id"] = quote_id
|
||||
if checking_id:
|
||||
clauses.append("checking_id = :checking_id")
|
||||
values["checking_id"] = checking_id
|
||||
if request:
|
||||
clauses.append("request = :request")
|
||||
values["request"] = request
|
||||
if not any(clauses):
|
||||
raise ValueError("No search criteria")
|
||||
where = f"WHERE {' AND '.join(clauses)}"
|
||||
row = await (conn or db).fetchone(
|
||||
f"""
|
||||
SELECT * from {db.table_with_schema('mint_quotes')}
|
||||
{where}
|
||||
""",
|
||||
values,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return MintQuote.from_row(row) if row else None
|
||||
|
||||
async def get_mint_quote_by_request(
|
||||
self,
|
||||
*,
|
||||
request: str,
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Optional[MintQuote]:
|
||||
row = await (conn or db).fetchone(
|
||||
f"""
|
||||
SELECT * from {db.table_with_schema('mint_quotes')}
|
||||
WHERE request = :request
|
||||
""",
|
||||
{"request": request},
|
||||
)
|
||||
return MintQuote.from_row(row) if row else None
|
||||
|
||||
async def update_mint_quote(
|
||||
self,
|
||||
*,
|
||||
quote: MintQuote,
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
f"UPDATE {db.table_with_schema('mint_quotes')} SET issued = :issued, paid = :paid, state = :state, paid_time = :paid_time WHERE quote = :quote",
|
||||
{
|
||||
"issued": quote.issued,
|
||||
"paid": quote.paid,
|
||||
"state": quote.state.name,
|
||||
"paid_time": db.to_timestamp(
|
||||
db.timestamp_from_seconds(quote.paid_time) or ""
|
||||
),
|
||||
"quote": quote.quote,
|
||||
},
|
||||
)
|
||||
|
||||
async def store_melt_quote(
|
||||
self,
|
||||
*,
|
||||
quote: MeltQuote,
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
INSERT INTO {db.table_with_schema('melt_quotes')}
|
||||
(quote, method, request, checking_id, unit, amount, fee_reserve, paid, state, created_time, paid_time, fee_paid, proof, change, expiry)
|
||||
VALUES (:quote, :method, :request, :checking_id, :unit, :amount, :fee_reserve, :paid, :state, :created_time, :paid_time, :fee_paid, :proof, :change, :expiry)
|
||||
""",
|
||||
{
|
||||
"quote": quote.quote,
|
||||
"method": quote.method,
|
||||
"request": quote.request,
|
||||
"checking_id": quote.checking_id,
|
||||
"unit": quote.unit,
|
||||
"amount": quote.amount,
|
||||
"fee_reserve": quote.fee_reserve or 0,
|
||||
"paid": quote.paid,
|
||||
"state": quote.state.name,
|
||||
"created_time": db.to_timestamp(
|
||||
db.timestamp_from_seconds(quote.created_time) or ""
|
||||
),
|
||||
"paid_time": db.to_timestamp(
|
||||
db.timestamp_from_seconds(quote.paid_time) or ""
|
||||
),
|
||||
"fee_paid": quote.fee_paid,
|
||||
"proof": quote.payment_preimage,
|
||||
"change": json.dumps(quote.change) if quote.change else None,
|
||||
"expiry": db.to_timestamp(
|
||||
db.timestamp_from_seconds(quote.expiry) or ""
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def get_melt_quote(
|
||||
self,
|
||||
*,
|
||||
quote_id: Optional[str] = None,
|
||||
checking_id: Optional[str] = None,
|
||||
request: Optional[str] = None,
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Optional[MeltQuote]:
|
||||
clauses = []
|
||||
values: Dict[str, Any] = {}
|
||||
if quote_id:
|
||||
clauses.append("quote = :quote_id")
|
||||
values["quote_id"] = quote_id
|
||||
if checking_id:
|
||||
clauses.append("checking_id = :checking_id")
|
||||
values["checking_id"] = checking_id
|
||||
if request:
|
||||
clauses.append("request = :request")
|
||||
values["request"] = request
|
||||
if not any(clauses):
|
||||
raise ValueError("No search criteria")
|
||||
where = f"WHERE {' AND '.join(clauses)}"
|
||||
row = await (conn or db).fetchone(
|
||||
f"""
|
||||
SELECT * from {db.table_with_schema('melt_quotes')}
|
||||
{where}
|
||||
""",
|
||||
values,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return MeltQuote.from_row(row) if row else None
|
||||
|
||||
async def update_melt_quote(
|
||||
self,
|
||||
*,
|
||||
quote: MeltQuote,
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
UPDATE {db.table_with_schema('melt_quotes')} SET paid = :paid, state = :state, fee_paid = :fee_paid, paid_time = :paid_time, proof = :proof, change = :change WHERE quote = :quote
|
||||
""",
|
||||
{
|
||||
"paid": quote.paid,
|
||||
"state": quote.state.name,
|
||||
"fee_paid": quote.fee_paid,
|
||||
"paid_time": db.to_timestamp(
|
||||
db.timestamp_from_seconds(quote.paid_time) or ""
|
||||
),
|
||||
"proof": quote.payment_preimage,
|
||||
"change": (
|
||||
json.dumps([s.dict() for s in quote.change])
|
||||
if quote.change
|
||||
else None
|
||||
),
|
||||
"quote": quote.quote,
|
||||
},
|
||||
)
|
||||
|
||||
async def store_keyset(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
keyset: MintKeyset,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
INSERT INTO {db.table_with_schema('keysets')}
|
||||
(id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk)
|
||||
VALUES (:id, :seed, :encrypted_seed, :seed_encryption_method, :derivation_path, :valid_from, :valid_to, :first_seen, :active, :version, :unit, :input_fee_ppk)
|
||||
""",
|
||||
{
|
||||
"id": keyset.id,
|
||||
"seed": keyset.seed,
|
||||
"encrypted_seed": keyset.encrypted_seed,
|
||||
"seed_encryption_method": keyset.seed_encryption_method,
|
||||
"derivation_path": keyset.derivation_path,
|
||||
"valid_from": db.to_timestamp(
|
||||
keyset.valid_from or db.timestamp_now_str()
|
||||
),
|
||||
"valid_to": db.to_timestamp(keyset.valid_to or db.timestamp_now_str()),
|
||||
"first_seen": db.to_timestamp(
|
||||
keyset.first_seen or db.timestamp_now_str()
|
||||
),
|
||||
"active": True,
|
||||
"version": keyset.version,
|
||||
"unit": keyset.unit.name,
|
||||
"input_fee_ppk": keyset.input_fee_ppk,
|
||||
},
|
||||
)
|
||||
|
||||
async def get_keyset(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
id: Optional[str] = None,
|
||||
derivation_path: Optional[str] = None,
|
||||
seed: Optional[str] = None,
|
||||
unit: Optional[str] = None,
|
||||
active: Optional[bool] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[MintKeyset]:
|
||||
clauses = []
|
||||
values: Dict = {}
|
||||
if active is not None:
|
||||
clauses.append("active = :active")
|
||||
values["active"] = active
|
||||
if id is not None:
|
||||
clauses.append("id = :id")
|
||||
values["id"] = id
|
||||
if derivation_path is not None:
|
||||
clauses.append("derivation_path = :derivation_path")
|
||||
values["derivation_path"] = derivation_path
|
||||
if seed is not None:
|
||||
clauses.append("seed = :seed")
|
||||
values["seed"] = seed
|
||||
if unit is not None:
|
||||
clauses.append("unit = :unit")
|
||||
values["unit"] = unit
|
||||
where = ""
|
||||
if clauses:
|
||||
where = f"WHERE {' AND '.join(clauses)}"
|
||||
|
||||
rows = await (conn or db).fetchall( # type: ignore
|
||||
f"""
|
||||
SELECT * from {db.table_with_schema('keysets')}
|
||||
{where}
|
||||
""",
|
||||
values,
|
||||
)
|
||||
return [MintKeyset(**row) for row in rows]
|
||||
|
||||
async def get_proofs_used(
|
||||
self,
|
||||
*,
|
||||
Ys: List[str],
|
||||
db: Database,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[Proof]:
|
||||
query = f"""
|
||||
SELECT * from {db.table_with_schema('proofs_used')}
|
||||
WHERE y IN ({','.join([':y_' + str(i) for i in range(len(Ys))])})
|
||||
"""
|
||||
values = {f"y_{i}": Ys[i] for i in range(len(Ys))}
|
||||
rows = await (conn or db).fetchall(query, values)
|
||||
return [Proof(**r) for r in rows] if rows else []
|
||||
100
cashu/mint/auth/migrations.py
Normal file
100
cashu/mint/auth/migrations.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from ...core.db import Connection, Database
|
||||
|
||||
|
||||
async def m000_create_migrations_table(conn: Connection):
|
||||
await conn.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {conn.table_with_schema('dbversions')} (
|
||||
db TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
async def m001_initial(db: Database):
|
||||
async with db.connect() as conn:
|
||||
await conn.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {db.table_with_schema('users')} (
|
||||
id TEXT PRIMARY KEY,
|
||||
last_access TIMESTAMP,
|
||||
|
||||
UNIQUE (id)
|
||||
);
|
||||
"""
|
||||
)
|
||||
# columns: (id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk)
|
||||
await conn.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {db.table_with_schema('keysets')} (
|
||||
id TEXT NOT NULL,
|
||||
seed TEXT NOT NULL,
|
||||
encrypted_seed TEXT,
|
||||
seed_encryption_method TEXT,
|
||||
derivation_path TEXT NOT NULL,
|
||||
valid_from TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
|
||||
valid_to TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
|
||||
first_seen TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
|
||||
active BOOL DEFAULT TRUE,
|
||||
version TEXT,
|
||||
unit TEXT NOT NULL,
|
||||
input_fee_ppk INT,
|
||||
amounts TEXT,
|
||||
|
||||
UNIQUE (derivation_path)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
await conn.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {db.table_with_schema('promises')} (
|
||||
id TEXT NOT NULL,
|
||||
amount {db.big_int} NOT NULL,
|
||||
b_ TEXT NOT NULL,
|
||||
c_ TEXT NOT NULL,
|
||||
dleq_e TEXT,
|
||||
dleq_s TEXT,
|
||||
created TIMESTAMP,
|
||||
|
||||
UNIQUE (b_)
|
||||
|
||||
);
|
||||
"""
|
||||
)
|
||||
await conn.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {db.table_with_schema('proofs_used')} (
|
||||
id TEXT NOT NULL,
|
||||
amount {db.big_int} NOT NULL,
|
||||
c TEXT NOT NULL,
|
||||
secret TEXT NOT NULL,
|
||||
y TEXT NOT NULL,
|
||||
witness TEXT,
|
||||
created TIMESTAMP,
|
||||
melt_quote TEXT,
|
||||
|
||||
UNIQUE (secret)
|
||||
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
await conn.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {db.table_with_schema('proofs_pending')} (
|
||||
id TEXT NOT NULL,
|
||||
amount {db.big_int} NOT NULL,
|
||||
c TEXT NOT NULL,
|
||||
secret TEXT NOT NULL,
|
||||
y TEXT NOT NULL,
|
||||
witness TEXT,
|
||||
created TIMESTAMP,
|
||||
melt_quote TEXT,
|
||||
|
||||
UNIQUE (secret)
|
||||
|
||||
);
|
||||
"""
|
||||
)
|
||||
106
cashu/mint/auth/router.py
Normal file
106
cashu/mint/auth/router.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from loguru import logger
|
||||
|
||||
from ...core.errors import KeysetNotFoundError
|
||||
from ...core.models import (
|
||||
KeysetsResponse,
|
||||
KeysetsResponseKeyset,
|
||||
KeysResponse,
|
||||
KeysResponseKeyset,
|
||||
PostAuthBlindMintRequest,
|
||||
PostAuthBlindMintResponse,
|
||||
)
|
||||
from ...mint.startup import auth_ledger
|
||||
|
||||
auth_router: APIRouter = APIRouter()
|
||||
|
||||
|
||||
@auth_router.get(
|
||||
"/v1/auth/blind/keys",
|
||||
name="Mint public keys",
|
||||
summary="Get the public keys of the newest mint keyset",
|
||||
response_description=(
|
||||
"All supported token values their associated public keys for all active keysets"
|
||||
),
|
||||
response_model=KeysResponse,
|
||||
)
|
||||
async def keys():
|
||||
"""This endpoint returns a dictionary of all supported token values of the mint and their associated public key."""
|
||||
logger.trace("> GET /v1/auth/blind/keys")
|
||||
keyset = auth_ledger.keyset
|
||||
keyset_for_response = []
|
||||
for keyset in auth_ledger.keysets.values():
|
||||
if keyset.active:
|
||||
keyset_for_response.append(
|
||||
KeysResponseKeyset(
|
||||
id=keyset.id,
|
||||
unit=keyset.unit.name,
|
||||
keys={k: v for k, v in keyset.public_keys_hex.items()},
|
||||
)
|
||||
)
|
||||
return KeysResponse(keysets=keyset_for_response)
|
||||
|
||||
|
||||
@auth_router.get(
|
||||
"/v1/auth/blind/keys/{keyset_id}",
|
||||
name="Keyset public keys",
|
||||
summary="Public keys of a specific keyset",
|
||||
response_description=(
|
||||
"All supported token values of the mint and their associated"
|
||||
" public key for a specific keyset."
|
||||
),
|
||||
response_model=KeysResponse,
|
||||
)
|
||||
async def keyset_keys(keyset_id: str) -> KeysResponse:
|
||||
"""
|
||||
Get the public keys of the mint from a specific keyset id.
|
||||
"""
|
||||
logger.trace(f"> GET /v1/auth/blind/keys/{keyset_id}")
|
||||
|
||||
keyset = auth_ledger.keysets.get(keyset_id)
|
||||
if keyset is None:
|
||||
raise KeysetNotFoundError(keyset_id)
|
||||
|
||||
keyset_for_response = KeysResponseKeyset(
|
||||
id=keyset.id,
|
||||
unit=keyset.unit.name,
|
||||
keys={k: v for k, v in keyset.public_keys_hex.items()},
|
||||
)
|
||||
return KeysResponse(keysets=[keyset_for_response])
|
||||
|
||||
|
||||
@auth_router.get(
|
||||
"/v1/auth/blind/keysets",
|
||||
name="Active keysets",
|
||||
summary="Get all active keyset id of the mind",
|
||||
response_model=KeysetsResponse,
|
||||
response_description="A list of all active keyset ids of the mint.",
|
||||
)
|
||||
async def keysets() -> KeysetsResponse:
|
||||
"""This endpoint returns a list of keysets that the mint currently supports and will accept tokens from."""
|
||||
logger.trace("> GET /v1/auth/blind/keysets")
|
||||
keysets = []
|
||||
for id, keyset in auth_ledger.keysets.items():
|
||||
keysets.append(
|
||||
KeysetsResponseKeyset(
|
||||
id=keyset.id,
|
||||
unit=keyset.unit.name,
|
||||
active=keyset.active,
|
||||
)
|
||||
)
|
||||
return KeysetsResponse(keysets=keysets)
|
||||
|
||||
|
||||
@auth_router.post(
|
||||
"/v1/auth/blind/mint",
|
||||
name="Mint blind auth tokens",
|
||||
summary="Mint blind auth tokens for a user.",
|
||||
response_model=PostAuthBlindMintResponse,
|
||||
)
|
||||
async def auth_blind_mint(
|
||||
request_data: PostAuthBlindMintRequest, request: Request
|
||||
) -> PostAuthBlindMintResponse:
|
||||
signatures = await auth_ledger.mint_blind_auth(
|
||||
outputs=request_data.outputs, user=request.state.user
|
||||
)
|
||||
return PostAuthBlindMintResponse(signatures=signatures)
|
||||
235
cashu/mint/auth/server.py
Normal file
235
cashu/mint/auth/server.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from loguru import logger
|
||||
|
||||
from ...core.base import AuthProof
|
||||
from ...core.db import Database
|
||||
from ...core.errors import (
|
||||
BlindAuthAmountExceededError,
|
||||
BlindAuthFailedError,
|
||||
BlindAuthRateLimitExceededError,
|
||||
ClearAuthFailedError,
|
||||
)
|
||||
from ...core.models import BlindedMessage, BlindedSignature
|
||||
from ...core.settings import settings
|
||||
from ..crud import LedgerCrudSqlite
|
||||
from ..ledger import Ledger
|
||||
from ..limit import assert_limit
|
||||
from .base import User
|
||||
from .crud import AuthLedgerCrud, AuthLedgerCrudSqlite
|
||||
|
||||
|
||||
class AuthLedger(Ledger):
|
||||
auth_crud: AuthLedgerCrud
|
||||
jwks_url: str
|
||||
jwks_client: jwt.PyJWKClient
|
||||
issuer: str
|
||||
oicd_discovery_json: dict
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Database,
|
||||
seed: str,
|
||||
seed_decryption_key: Optional[str] = None,
|
||||
derivation_path="",
|
||||
amounts: Optional[List[int]] = None,
|
||||
crud=LedgerCrudSqlite(),
|
||||
):
|
||||
super().__init__(
|
||||
db=db,
|
||||
seed=seed,
|
||||
backends=None,
|
||||
seed_decryption_key=seed_decryption_key,
|
||||
derivation_path=derivation_path,
|
||||
crud=crud,
|
||||
amounts=amounts,
|
||||
)
|
||||
self.oicd_discovery_url = settings.mint_auth_oicd_discovery_url or ""
|
||||
|
||||
async def init_auth(self):
|
||||
if not self.oicd_discovery_url:
|
||||
raise Exception("Missing OpenID Connect discovery URL.")
|
||||
logger.info(f"Initializing OpenID Connect: {self.oicd_discovery_url}")
|
||||
self.oicd_discovery_json = self._get_oicd_discovery_json()
|
||||
self.jwks_url = self.oicd_discovery_json["jwks_uri"]
|
||||
self.jwks_client = jwt.PyJWKClient(self.jwks_url)
|
||||
logger.info(f"Getting JWKS from: {self.jwks_url}")
|
||||
self.auth_crud = AuthLedgerCrudSqlite()
|
||||
self.issuer: str = self.oicd_discovery_json["issuer"]
|
||||
logger.info(f"Initialized OpenID Connect: {self.issuer}")
|
||||
|
||||
def _get_oicd_discovery_json(self) -> dict:
|
||||
resp = httpx.get(self.oicd_discovery_url)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
def _verify_oicd_issuer(self, clear_auth_token: str) -> None:
|
||||
"""Verify the issuer of the clear-auth token.
|
||||
|
||||
Args:
|
||||
clear_auth_token (str): JWT token.
|
||||
|
||||
Raises:
|
||||
Exception: Invalid issuer.
|
||||
"""
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
clear_auth_token,
|
||||
options={"verify_signature": False},
|
||||
)
|
||||
issuer = decoded["iss"]
|
||||
if issuer != self.issuer:
|
||||
raise Exception(f"Invalid issuer: {issuer}. Expected: {self.issuer}")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _verify_decode_jwt(self, clear_auth_token: str) -> Any:
|
||||
"""Verify the clear-auth JWT token.
|
||||
|
||||
Args:
|
||||
clear_auth_token (str): JWT token.
|
||||
|
||||
Raises:
|
||||
jwt.ExpiredSignatureError: Token has expired.
|
||||
jwt.InvalidSignatureError: Invalid signature.
|
||||
jwt.InvalidTokenError: Invalid token.
|
||||
|
||||
Returns:
|
||||
Any: Decoded JWT.
|
||||
"""
|
||||
try:
|
||||
# Use PyJWKClient to fetch the appropriate key based on the token's header
|
||||
signing_key = self.jwks_client.get_signing_key_from_jwt(clear_auth_token)
|
||||
decoded = jwt.decode(
|
||||
clear_auth_token,
|
||||
signing_key.key,
|
||||
algorithms=["RS256", "ES256"],
|
||||
options={"verify_aud": False},
|
||||
issuer=self.issuer,
|
||||
)
|
||||
logger.trace(f"Decoded JWT: {decoded}")
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
logger.error("Token has expired")
|
||||
raise e
|
||||
except jwt.InvalidSignatureError as e:
|
||||
logger.error("Invalid signature")
|
||||
raise e
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.error("Invalid token")
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return decoded
|
||||
|
||||
async def _get_user(self, decoded_token: Any) -> User:
|
||||
"""Get the user from the decoded token. If the user does not exist, create a new one.
|
||||
|
||||
Args:
|
||||
decoded_token (Any): decoded JWT from PyJWT.decode
|
||||
|
||||
Returns:
|
||||
User: User object
|
||||
"""
|
||||
user_id = decoded_token["sub"]
|
||||
user = await self.auth_crud.get_user(user_id=user_id, db=self.db)
|
||||
if not user:
|
||||
logger.info(f"Creating new user: {user_id}")
|
||||
user = User(id=user_id)
|
||||
await self.auth_crud.create_user(user=user, db=self.db)
|
||||
return user
|
||||
|
||||
async def verify_clear_auth(self, clear_auth_token: str) -> User:
|
||||
"""Verify the clear-auth JWT token and return the user.
|
||||
|
||||
Checks:
|
||||
- Token not expired.
|
||||
- Token signature valid.
|
||||
- User exists.
|
||||
|
||||
Args:
|
||||
auth_token (str): JWT token.
|
||||
|
||||
Returns:
|
||||
User: Authenticated user.
|
||||
"""
|
||||
try:
|
||||
self._verify_oicd_issuer(clear_auth_token)
|
||||
decoded = self._verify_decode_jwt(clear_auth_token)
|
||||
user = await self._get_user(decoded)
|
||||
except Exception:
|
||||
raise ClearAuthFailedError()
|
||||
|
||||
logger.info(f"User authenticated: {user.id}")
|
||||
try:
|
||||
assert_limit(user.id)
|
||||
except Exception:
|
||||
raise BlindAuthRateLimitExceededError()
|
||||
|
||||
return user
|
||||
|
||||
async def mint_blind_auth(
|
||||
self,
|
||||
*,
|
||||
outputs: List[BlindedMessage],
|
||||
user: User,
|
||||
) -> List[BlindedSignature]:
|
||||
"""Mints auth tokens. Returns a list of promises.
|
||||
|
||||
Args:
|
||||
outputs (List[BlindedMessage]): Outputs to sign.
|
||||
user (User): Authenticated user.
|
||||
|
||||
Raises:
|
||||
Exception: Invalid auth.
|
||||
Exception: Output verification failed.
|
||||
Exception: Output quota exceeded.
|
||||
|
||||
Returns:
|
||||
List[BlindedSignature]: List of blinded signatures.
|
||||
"""
|
||||
|
||||
if len(outputs) > settings.mint_auth_max_blind_tokens:
|
||||
raise BlindAuthAmountExceededError(
|
||||
f"Too many outputs. You can only mint {settings.mint_auth_max_blind_tokens} tokens."
|
||||
)
|
||||
|
||||
await self._verify_outputs(outputs)
|
||||
promises = await self._generate_promises(outputs)
|
||||
|
||||
# update last_access timestamp of the user
|
||||
await self.auth_crud.update_user(user_id=user.id, db=self.db)
|
||||
|
||||
return promises
|
||||
|
||||
@asynccontextmanager
|
||||
async def verify_blind_auth(self, blind_auth_token):
|
||||
"""Wrapper context that puts blind auth tokens into pending list and
|
||||
melts them if the wrapped call succeeds. If it fails, the blind auth
|
||||
token is not invalidated.
|
||||
|
||||
Args:
|
||||
blind_auth_token (str): Blind auth token.
|
||||
|
||||
Raises:
|
||||
Exception: Blind auth token validation failed.
|
||||
"""
|
||||
try:
|
||||
proof = AuthProof.from_base64(blind_auth_token).to_proof()
|
||||
await self.verify_inputs_and_outputs(proofs=[proof])
|
||||
await self.db_write._verify_spent_proofs_and_set_pending([proof])
|
||||
except Exception as e:
|
||||
logger.error(f"Blind auth error: {e}")
|
||||
raise BlindAuthFailedError()
|
||||
|
||||
try:
|
||||
yield
|
||||
await self._invalidate_proofs(proofs=[proof])
|
||||
except Exception as e:
|
||||
logger.error(f"Blind auth error: {e}")
|
||||
raise BlindAuthFailedError()
|
||||
finally:
|
||||
await self.db_write._unset_proofs_pending([proof])
|
||||
@@ -642,8 +642,8 @@ class LedgerCrudSqlite(LedgerCrud):
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
INSERT INTO {db.table_with_schema('keysets')}
|
||||
(id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk)
|
||||
VALUES (:id, :seed, :encrypted_seed, :seed_encryption_method, :derivation_path, :valid_from, :valid_to, :first_seen, :active, :version, :unit, :input_fee_ppk)
|
||||
(id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk, amounts)
|
||||
VALUES (:id, :seed, :encrypted_seed, :seed_encryption_method, :derivation_path, :valid_from, :valid_to, :first_seen, :active, :version, :unit, :input_fee_ppk, :amounts)
|
||||
""",
|
||||
{
|
||||
"id": keyset.id,
|
||||
@@ -662,6 +662,7 @@ class LedgerCrudSqlite(LedgerCrud):
|
||||
"version": keyset.version,
|
||||
"unit": keyset.unit.name,
|
||||
"input_fee_ppk": keyset.input_fee_ppk,
|
||||
"amounts": json.dumps(keyset.amounts),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -720,7 +721,7 @@ class LedgerCrudSqlite(LedgerCrud):
|
||||
""",
|
||||
values,
|
||||
)
|
||||
return [MintKeyset(**row) for row in rows]
|
||||
return [MintKeyset.from_row(row) for row in rows] # type: ignore
|
||||
|
||||
async def update_keyset(
|
||||
self,
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from ..core.base import Method
|
||||
from ..core.mint_info import MintInfo
|
||||
from ..core.models import (
|
||||
MeltMethodSetting,
|
||||
MintInfoContact,
|
||||
MintInfoProtectedEndpoint,
|
||||
MintMethodSetting,
|
||||
)
|
||||
from ..core.nuts.nuts import (
|
||||
BLIND_AUTH_NUT,
|
||||
CACHE_NUT,
|
||||
CLEAR_AUTH_NUT,
|
||||
DLEQ_NUT,
|
||||
FEE_RETURN_NUT,
|
||||
HTLC_NUT,
|
||||
@@ -21,10 +26,46 @@ from ..core.nuts.nuts import (
|
||||
WEBSOCKETS_NUT,
|
||||
)
|
||||
from ..core.settings import settings
|
||||
from ..mint.protocols import SupportsBackends
|
||||
from ..mint.protocols import SupportsBackends, SupportsPubkey
|
||||
|
||||
_VERSION_PREFIX = "Nutshell"
|
||||
_SUPPORTED = "supported"
|
||||
_METHOD = "method"
|
||||
_UNIT = "unit"
|
||||
_BOLT11 = "bolt11"
|
||||
_MPP = "mpp"
|
||||
_COMMANDS = "commands"
|
||||
_BOLT11_MINT_QUOTE = "bolt11_mint_quote"
|
||||
_BOLT11_MELT_QUOTE = "bolt11_melt_quote"
|
||||
_PROOF_STATE = "proof_state"
|
||||
_PROTECTED_ENDPOINTS = "protected_endpoints"
|
||||
_BAT_MAX_MINT = "bat_max_mint"
|
||||
_OPENID_DISCOVERY = "openid_discovery"
|
||||
_CLIENT_ID = "client_id"
|
||||
|
||||
|
||||
class LedgerFeatures(SupportsBackends):
|
||||
class LedgerFeatures(SupportsBackends, SupportsPubkey):
|
||||
@property
|
||||
def mint_info(self) -> MintInfo:
|
||||
contact_info = [
|
||||
MintInfoContact(method=m, info=i)
|
||||
for m, i in settings.mint_info_contact
|
||||
if m and i
|
||||
]
|
||||
return MintInfo(
|
||||
name=settings.mint_info_name,
|
||||
pubkey=self.pubkey.serialize().hex() if self.pubkey else None,
|
||||
version=f"{_VERSION_PREFIX}/{settings.version}",
|
||||
description=settings.mint_info_description,
|
||||
description_long=settings.mint_info_description_long,
|
||||
contact=contact_info,
|
||||
nuts=self.mint_features,
|
||||
icon_url=settings.mint_info_icon_url,
|
||||
motd=settings.mint_info_motd,
|
||||
time=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def mint_features(self) -> Dict[int, Union[List[Any], Dict[str, Any]]]:
|
||||
mint_features = self.create_mint_features()
|
||||
mint_features = self.add_supported_features(mint_features)
|
||||
@@ -100,30 +141,62 @@ class LedgerFeatures(SupportsBackends):
|
||||
# specify which websocket features are supported
|
||||
# these two are supported by default
|
||||
websocket_features: Dict[str, List[Dict[str, Union[str, List[str]]]]] = {
|
||||
"supported": []
|
||||
_SUPPORTED: []
|
||||
}
|
||||
# we check the backend to see if "bolt11_mint_quote" is supported as well
|
||||
for method, unit_dict in self.backends.items():
|
||||
if method == Method["bolt11"]:
|
||||
if method == Method[_BOLT11]:
|
||||
for unit in unit_dict.keys():
|
||||
websocket_features["supported"].append(
|
||||
websocket_features[_SUPPORTED].append(
|
||||
{
|
||||
"method": method.name,
|
||||
"unit": unit.name,
|
||||
"commands": ["bolt11_melt_quote", "proof_state"],
|
||||
_METHOD: method.name,
|
||||
_UNIT: unit.name,
|
||||
_COMMANDS: [_BOLT11_MELT_QUOTE, _PROOF_STATE],
|
||||
}
|
||||
)
|
||||
if unit_dict[unit].supports_incoming_payment_stream:
|
||||
supported_features: List[str] = list(
|
||||
websocket_features["supported"][-1]["commands"]
|
||||
websocket_features[_SUPPORTED][-1][_COMMANDS]
|
||||
)
|
||||
websocket_features["supported"][-1]["commands"] = (
|
||||
supported_features + ["bolt11_mint_quote"]
|
||||
websocket_features[_SUPPORTED][-1][_COMMANDS] = (
|
||||
supported_features + [_BOLT11_MINT_QUOTE]
|
||||
)
|
||||
|
||||
if websocket_features:
|
||||
mint_features[WEBSOCKETS_NUT] = websocket_features
|
||||
|
||||
# signal authentication features
|
||||
if settings.mint_require_auth:
|
||||
if not settings.mint_auth_oicd_discovery_url:
|
||||
raise Exception(
|
||||
"Missing OpenID Connect discovery URL: MINT_AUTH_OICD_DISCOVERY_URL"
|
||||
)
|
||||
clear_auth_features: Dict[str, Union[bool, str, List[str]]] = {
|
||||
_OPENID_DISCOVERY: settings.mint_auth_oicd_discovery_url,
|
||||
_CLIENT_ID: settings.mint_auth_oicd_client_id,
|
||||
_PROTECTED_ENDPOINTS: [],
|
||||
}
|
||||
|
||||
for endpoint in [
|
||||
MintInfoProtectedEndpoint(method=e[0], path=e[1])
|
||||
for e in settings.mint_require_clear_auth_paths
|
||||
]:
|
||||
clear_auth_features[_PROTECTED_ENDPOINTS].append(endpoint.dict()) # type: ignore
|
||||
|
||||
mint_features[CLEAR_AUTH_NUT] = clear_auth_features
|
||||
|
||||
blind_auth_features: Dict[str, Union[bool, int, str, List[str]]] = {
|
||||
_BAT_MAX_MINT: settings.mint_auth_max_blind_tokens,
|
||||
_PROTECTED_ENDPOINTS: [],
|
||||
}
|
||||
for endpoint in [
|
||||
MintInfoProtectedEndpoint(method=e[0], path=e[1])
|
||||
for e in settings.mint_require_blind_auth_paths
|
||||
]:
|
||||
blind_auth_features[_PROTECTED_ENDPOINTS].append(endpoint.dict()) # type: ignore
|
||||
|
||||
mint_features[BLIND_AUTH_NUT] = blind_auth_features
|
||||
|
||||
return mint_features
|
||||
|
||||
def add_cache_features(
|
||||
|
||||
@@ -75,16 +75,26 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
|
||||
db_read: DbReadHelper
|
||||
invoice_listener_tasks: List[asyncio.Task] = []
|
||||
disable_melt: bool = False
|
||||
pubkey: PublicKey
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
db: Database,
|
||||
seed: str,
|
||||
backends: Mapping[Method, Mapping[Unit, LightningBackend]],
|
||||
seed_decryption_key: Optional[str] = None,
|
||||
derivation_path="",
|
||||
amounts: Optional[List[int]] = None,
|
||||
backends: Optional[Mapping[Method, Mapping[Unit, LightningBackend]]] = None,
|
||||
seed_decryption_key: Optional[str] = None,
|
||||
crud=LedgerCrudSqlite(),
|
||||
):
|
||||
) -> None:
|
||||
self.keysets: Dict[str, MintKeyset] = {}
|
||||
self.backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
|
||||
self.events = LedgerEventManager()
|
||||
self.db_read: DbReadHelper
|
||||
self.locks: Dict[str, asyncio.Lock] = {} # holds multiprocessing locks
|
||||
self.invoice_listener_tasks: List[asyncio.Task] = []
|
||||
|
||||
if not seed:
|
||||
raise Exception("seed not set")
|
||||
|
||||
@@ -103,24 +113,33 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
|
||||
|
||||
self.db = db
|
||||
self.crud = crud
|
||||
self.backends = backends
|
||||
|
||||
if backends:
|
||||
self.backends = backends
|
||||
|
||||
if amounts:
|
||||
self.amounts = amounts
|
||||
else:
|
||||
self.amounts = [2**n for n in range(settings.max_order)]
|
||||
|
||||
self.pubkey = derive_pubkey(self.seed)
|
||||
self.db_read = DbReadHelper(self.db, self.crud)
|
||||
self.db_write = DbWriteHelper(self.db, self.crud, self.events, self.db_read)
|
||||
|
||||
# ------- STARTUP -------
|
||||
|
||||
async def startup_ledger(self):
|
||||
await self._startup_ledger()
|
||||
async def startup_ledger(self) -> None:
|
||||
await self._startup_keysets()
|
||||
await self._check_backends()
|
||||
await self._check_pending_proofs_and_melt_quotes()
|
||||
self.invoice_listener_tasks = await self.dispatch_listeners()
|
||||
|
||||
async def _startup_ledger(self):
|
||||
async def _startup_keysets(self) -> None:
|
||||
await self.init_keysets()
|
||||
|
||||
for derivation_path in settings.mint_derivation_path_list:
|
||||
await self.activate_keyset(derivation_path=derivation_path)
|
||||
|
||||
async def _check_backends(self) -> None:
|
||||
for method in self.backends:
|
||||
for unit in self.backends[method]:
|
||||
logger.info(
|
||||
@@ -139,7 +158,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
|
||||
|
||||
logger.info(f"Data dir: {settings.cashu_dir}")
|
||||
|
||||
async def shutdown_ledger(self):
|
||||
async def shutdown_ledger(self) -> None:
|
||||
await self.db.engine.dispose()
|
||||
for task in self.invoice_listener_tasks:
|
||||
task.cancel()
|
||||
@@ -169,57 +188,65 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
|
||||
version: Optional[str] = None,
|
||||
autosave=True,
|
||||
) -> MintKeyset:
|
||||
"""Load the keyset for a derivation path if it already exists. If not generate new one and store in the db.
|
||||
"""
|
||||
Load an existing keyset for the specified derivation path or generate a new one if it doesn't exist.
|
||||
Optionally store the newly created keyset in the database.
|
||||
|
||||
Args:
|
||||
derivation_path (_type_): Derivation path from which the keyset is generated.
|
||||
autosave (bool, optional): Store newly-generated keyset if not already in database. Defaults to True.
|
||||
derivation_path (str): Derivation path for keyset generation.
|
||||
seed (Optional[str], optional): Seed value. Defaults to None.
|
||||
version (Optional[str], optional): Version identifier. Defaults to None.
|
||||
autosave (bool, optional): Whether to store the keyset if newly created. Defaults to True.
|
||||
|
||||
Returns:
|
||||
MintKeyset: Keyset
|
||||
MintKeyset: The activated keyset.
|
||||
"""
|
||||
if not derivation_path:
|
||||
raise Exception("derivation path not set")
|
||||
raise ValueError("Derivation path must be provided.")
|
||||
|
||||
seed = seed or self.seed
|
||||
tmp_keyset_local = MintKeyset(
|
||||
version = version or settings.version
|
||||
# Initialize a temporary keyset to derive the ID
|
||||
temp_keyset = MintKeyset(
|
||||
seed=seed,
|
||||
derivation_path=derivation_path,
|
||||
version=version or settings.version,
|
||||
version=version,
|
||||
amounts=self.amounts,
|
||||
)
|
||||
logger.debug(
|
||||
f"Activating keyset for derivation path {derivation_path} with id"
|
||||
f" {tmp_keyset_local.id}."
|
||||
f"Activating keyset for derivation path '{derivation_path}' with ID '{temp_keyset.id}'."
|
||||
)
|
||||
# load the keyset from db
|
||||
logger.trace(f"crud: loading keyset for {derivation_path}")
|
||||
tmp_keysets_local: List[MintKeyset] = await self.crud.get_keyset(
|
||||
id=tmp_keyset_local.id, db=self.db
|
||||
|
||||
# Attempt to retrieve existing keysets from the database
|
||||
existing_keysets: List[MintKeyset] = await self.crud.get_keyset(
|
||||
id=temp_keyset.id, db=self.db
|
||||
)
|
||||
logger.trace(f"crud: loaded {len(tmp_keysets_local)} keysets")
|
||||
if tmp_keysets_local:
|
||||
# we have a keyset with this derivation path in the database
|
||||
keyset = tmp_keysets_local[0]
|
||||
logger.trace(
|
||||
f"Retrieved {len(existing_keysets)} keyset(s) for derivation path '{derivation_path}'."
|
||||
)
|
||||
|
||||
if existing_keysets:
|
||||
keyset = existing_keysets[0]
|
||||
else:
|
||||
# no keyset for this derivation path yet
|
||||
# we create a new keyset (keys will be generated at instantiation)
|
||||
# Create a new keyset if none exists
|
||||
keyset = MintKeyset(
|
||||
seed=seed or self.seed,
|
||||
seed=seed,
|
||||
derivation_path=derivation_path,
|
||||
version=version or settings.version,
|
||||
amounts=self.amounts,
|
||||
version=version,
|
||||
input_fee_ppk=settings.mint_input_fee_ppk,
|
||||
)
|
||||
logger.debug(f"Generated new keyset {keyset.id}.")
|
||||
logger.debug(f"Generated new keyset with ID '{keyset.id}'.")
|
||||
|
||||
if autosave:
|
||||
logger.debug(f"crud: storing new keyset {keyset.id}.")
|
||||
logger.debug(f"Storing new keyset with ID '{keyset.id}'.")
|
||||
await self.crud.store_keyset(keyset=keyset, db=self.db)
|
||||
logger.trace(f"crud: stored new keyset {keyset.id}.")
|
||||
|
||||
# activate this keyset
|
||||
# Activate the keyset
|
||||
keyset.active = True
|
||||
# load the new keyset in self.keysets
|
||||
self.keysets[keyset.id] = keyset
|
||||
logger.debug(f"Keyset with ID '{keyset.id}' is now active.")
|
||||
|
||||
logger.debug(f"Loaded keyset {keyset.id}")
|
||||
return keyset
|
||||
|
||||
async def init_keysets(self, autosave: bool = True) -> None:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import WebSocket, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from limits import RateLimitItemPerMinute
|
||||
@@ -42,26 +44,26 @@ limiter = Limiter(
|
||||
)
|
||||
|
||||
|
||||
def assert_limit(identifier: str):
|
||||
def assert_limit(identifier: str, limit: Optional[int] = None):
|
||||
"""Custom rate limit handler that accepts a string identifier
|
||||
and raises an exception if the rate limit is exceeded. Uses the
|
||||
setting `mint_transaction_rate_limit_per_minute` for the rate limit.
|
||||
|
||||
Args:
|
||||
identifier (str): The identifier to use for the rate limit. IP address for example.
|
||||
limit (Optional[int], optional): The rate limit per minute to use. Defaults to None
|
||||
|
||||
Raises:
|
||||
Exception: If the rate limit is exceeded.
|
||||
"""
|
||||
global limiter
|
||||
limit_per_minute = limit or settings.mint_transaction_rate_limit_per_minute
|
||||
success = limiter._limiter.hit(
|
||||
RateLimitItemPerMinute(settings.mint_transaction_rate_limit_per_minute),
|
||||
RateLimitItemPerMinute(limit_per_minute),
|
||||
identifier,
|
||||
)
|
||||
if not success:
|
||||
logger.warning(
|
||||
f"Rate limit {settings.mint_transaction_rate_limit_per_minute}/minute exceeded: {identifier}"
|
||||
)
|
||||
logger.warning(f"Rate limit {limit_per_minute}/minute exceeded: {identifier}")
|
||||
raise Exception("Rate limit exceeded")
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,10 @@ from fastapi.exception_handlers import (
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.base import (
|
||||
BaseHTTPMiddleware,
|
||||
RequestResponseEndpoint,
|
||||
)
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from ..core.settings import settings
|
||||
@@ -22,6 +25,8 @@ if settings.debug_profiling:
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
|
||||
from .startup import auth_ledger
|
||||
|
||||
|
||||
def add_middlewares(app: FastAPI):
|
||||
app.add_middleware(
|
||||
@@ -42,6 +47,52 @@ def add_middlewares(app: FastAPI):
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
app.add_middleware(SlowAPIMiddleware)
|
||||
|
||||
if settings.mint_require_auth:
|
||||
app.add_middleware(BlindAuthMiddleware)
|
||||
app.add_middleware(ClearAuthMiddleware)
|
||||
|
||||
|
||||
class ClearAuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
if (
|
||||
settings.mint_require_auth
|
||||
and auth_ledger.mint_info.requires_clear_auth_path(
|
||||
method=request.method, path=request.url.path
|
||||
)
|
||||
):
|
||||
clear_auth_token = request.headers.get("clear-auth")
|
||||
if not clear_auth_token:
|
||||
raise Exception("Missing clear auth token.")
|
||||
try:
|
||||
user = await auth_ledger.verify_clear_auth(
|
||||
clear_auth_token=clear_auth_token
|
||||
)
|
||||
request.state.user = user
|
||||
except Exception as e:
|
||||
raise e
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class BlindAuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
if (
|
||||
settings.mint_require_auth
|
||||
and auth_ledger.mint_info.requires_blind_auth_path(
|
||||
method=request.method, path=request.url.path
|
||||
)
|
||||
):
|
||||
blind_auth_token = request.headers.get("blind-auth")
|
||||
if not blind_auth_token:
|
||||
raise Exception("Missing blind auth token.")
|
||||
async with auth_ledger.verify_blind_auth(blind_auth_token):
|
||||
return await call_next(request)
|
||||
else:
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
async def request_validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
@@ -66,11 +117,11 @@ class CompressionMiddleware(BaseHTTPMiddleware):
|
||||
response = await call_next(request)
|
||||
|
||||
# Handle streaming responses differently
|
||||
if response.__class__.__name__ == 'StreamingResponse':
|
||||
if response.__class__.__name__ == "StreamingResponse":
|
||||
return response
|
||||
|
||||
response_body = b''
|
||||
async for chunk in response.body_iterator:
|
||||
response_body = b""
|
||||
async for chunk in response.body_iterator: # type: ignore
|
||||
response_body += chunk
|
||||
|
||||
accept_encoding = request.headers.get("Accept-Encoding", "")
|
||||
@@ -97,5 +148,5 @@ class CompressionMiddleware(BaseHTTPMiddleware):
|
||||
content=content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
@@ -858,3 +858,13 @@ async def m024_add_melt_quote_outputs(db: Database):
|
||||
ADD COLUMN outputs TEXT DEFAULT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
async def m025_add_amounts_to_keysets(db: Database):
|
||||
async with db.connect() as conn:
|
||||
await conn.execute(
|
||||
f"ALTER TABLE {db.table_with_schema('keysets')} ADD COLUMN amounts TEXT"
|
||||
)
|
||||
await conn.execute(
|
||||
f"UPDATE {db.table_with_schema('keysets')} SET amounts = '[]'"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Dict, Mapping, Protocol
|
||||
|
||||
from ..core.base import Method, MintKeyset, Unit
|
||||
from ..core.crypto.secp import PublicKey
|
||||
from ..core.db import Database
|
||||
from ..lightning.base import LightningBackend
|
||||
from ..mint.crud import LedgerCrud
|
||||
@@ -18,6 +19,10 @@ class SupportsBackends(Protocol):
|
||||
backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
|
||||
|
||||
|
||||
class SupportsPubkey(Protocol):
|
||||
pubkey: PublicKey
|
||||
|
||||
|
||||
class SupportsDb(Protocol):
|
||||
db: Database
|
||||
db_read: DbReadHelper
|
||||
|
||||
@@ -11,7 +11,6 @@ from ..core.models import (
|
||||
KeysetsResponseKeyset,
|
||||
KeysResponse,
|
||||
KeysResponseKeyset,
|
||||
MintInfoContact,
|
||||
PostCheckStateRequest,
|
||||
PostCheckStateResponse,
|
||||
PostMeltQuoteRequest,
|
||||
@@ -44,23 +43,18 @@ redis = RedisCache()
|
||||
)
|
||||
async def info() -> GetInfoResponse:
|
||||
logger.trace("> GET /v1/info")
|
||||
mint_features = ledger.mint_features()
|
||||
contact_info = [
|
||||
MintInfoContact(method=m, info=i)
|
||||
for m, i in settings.mint_info_contact
|
||||
if m and i
|
||||
]
|
||||
mint_info = ledger.mint_info
|
||||
return GetInfoResponse(
|
||||
name=settings.mint_info_name,
|
||||
pubkey=ledger.pubkey.serialize().hex() if ledger.pubkey else None,
|
||||
version=f"Nutshell/{settings.version}",
|
||||
description=settings.mint_info_description,
|
||||
description_long=settings.mint_info_description_long,
|
||||
contact=contact_info,
|
||||
nuts=mint_features,
|
||||
icon_url=settings.mint_info_icon_url,
|
||||
name=mint_info.name,
|
||||
pubkey=mint_info.pubkey,
|
||||
version=mint_info.version,
|
||||
description=mint_info.description,
|
||||
description_long=mint_info.description_long,
|
||||
contact=mint_info.contact,
|
||||
nuts=mint_info.nuts,
|
||||
icon_url=mint_info.icon_url,
|
||||
urls=settings.mint_info_urls,
|
||||
motd=settings.mint_info_motd,
|
||||
motd=mint_info.motd,
|
||||
time=int(time.time()),
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,9 @@ from ..core.db import Database
|
||||
from ..core.migrations import migrate_databases
|
||||
from ..core.settings import settings
|
||||
from ..lightning.base import LightningBackend
|
||||
from ..mint import migrations
|
||||
from ..mint import migrations as mint_migrations
|
||||
from ..mint.auth import migrations as auth_migrations
|
||||
from ..mint.auth.server import AuthLedger
|
||||
from ..mint.crud import LedgerCrudSqlite
|
||||
from ..mint.ledger import Ledger
|
||||
|
||||
@@ -76,6 +78,15 @@ ledger = Ledger(
|
||||
crud=LedgerCrudSqlite(),
|
||||
)
|
||||
|
||||
# start auth ledger
|
||||
auth_ledger = AuthLedger(
|
||||
db=Database("auth", settings.auth_database),
|
||||
seed="auth seed here",
|
||||
amounts=[1],
|
||||
derivation_path="m/0'/999'/0'",
|
||||
crud=LedgerCrudSqlite(),
|
||||
)
|
||||
|
||||
|
||||
async def rotate_keys(n_seconds=60):
|
||||
"""Rotate keyset epoch every n_seconds.
|
||||
@@ -93,8 +104,17 @@ async def rotate_keys(n_seconds=60):
|
||||
await asyncio.sleep(n_seconds)
|
||||
|
||||
|
||||
async def start_mint_init():
|
||||
await migrate_databases(ledger.db, migrations)
|
||||
async def start_auth():
|
||||
await migrate_databases(auth_ledger.db, auth_migrations)
|
||||
logger.info("Starting auth ledger.")
|
||||
await auth_ledger.init_keysets()
|
||||
await auth_ledger.init_auth()
|
||||
logger.info("Auth ledger started.")
|
||||
|
||||
|
||||
async def start_mint():
|
||||
await migrate_databases(ledger.db, mint_migrations)
|
||||
logger.info("Starting mint ledger.")
|
||||
await ledger.startup_ledger()
|
||||
logger.info("Mint started.")
|
||||
# asyncio.create_task(rotate_keys())
|
||||
|
||||
@@ -1,22 +1,14 @@
|
||||
import asyncio
|
||||
from typing import List, Mapping
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ..core.base import Method, MintQuoteState, Unit
|
||||
from ..core.db import Database
|
||||
from ..core.base import MintQuoteState
|
||||
from ..lightning.base import LightningBackend
|
||||
from ..mint.crud import LedgerCrud
|
||||
from .events.events import LedgerEventManager
|
||||
from .protocols import SupportsBackends, SupportsDb, SupportsEvents
|
||||
|
||||
|
||||
class LedgerTasks(SupportsDb, SupportsBackends, SupportsEvents):
|
||||
backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
|
||||
db: Database
|
||||
crud: LedgerCrud
|
||||
events: LedgerEventManager
|
||||
|
||||
async def dispatch_listeners(self) -> List[asyncio.Task]:
|
||||
tasks = []
|
||||
for method, unitbackends in self.backends.items():
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -6,14 +6,13 @@ from ..core.base import (
|
||||
BlindedMessage,
|
||||
BlindedSignature,
|
||||
Method,
|
||||
MintKeyset,
|
||||
MintQuote,
|
||||
Proof,
|
||||
Unit,
|
||||
)
|
||||
from ..core.crypto import b_dhke
|
||||
from ..core.crypto.secp import PublicKey
|
||||
from ..core.db import Connection, Database
|
||||
from ..core.db import Connection
|
||||
from ..core.errors import (
|
||||
InvalidProofsError,
|
||||
NoSecretInProofsError,
|
||||
@@ -25,11 +24,7 @@ from ..core.errors import (
|
||||
)
|
||||
from ..core.nuts import nut20
|
||||
from ..core.settings import settings
|
||||
from ..lightning.base import LightningBackend
|
||||
from ..mint.crud import LedgerCrud
|
||||
from .conditions import LedgerSpendingConditions
|
||||
from .db.read import DbReadHelper
|
||||
from .db.write import DbWriteHelper
|
||||
from .protocols import SupportsBackends, SupportsDb, SupportsKeysets
|
||||
|
||||
|
||||
@@ -38,14 +33,6 @@ class LedgerVerification(
|
||||
):
|
||||
"""Verification functions for the ledger."""
|
||||
|
||||
keyset: MintKeyset
|
||||
keysets: Dict[str, MintKeyset]
|
||||
crud: LedgerCrud
|
||||
db: Database
|
||||
db_read: DbReadHelper
|
||||
db_write: DbWriteHelper
|
||||
lightning: Dict[Unit, LightningBackend]
|
||||
|
||||
async def verify_inputs_and_outputs(
|
||||
self,
|
||||
*,
|
||||
@@ -55,6 +42,8 @@ class LedgerVerification(
|
||||
):
|
||||
"""Checks all proofs and outputs for validity.
|
||||
|
||||
Warning: Does NOT check if the proofs were already spent. Use `db_write._verify_proofs_spendable` for that.
|
||||
|
||||
Args:
|
||||
proofs (List[Proof]): List of proofs to check.
|
||||
outputs (Optional[List[BlindedMessage]], optional): List of outputs to check.
|
||||
|
||||
@@ -22,14 +22,11 @@ class NostrClient:
|
||||
relays = [
|
||||
"wss://nostr-pub.wellorder.net",
|
||||
"wss://relay.damus.io",
|
||||
"wss://nostr.zebedee.cloud",
|
||||
"wss://relay.snort.social",
|
||||
"wss://nostr.fmt.wiz.biz",
|
||||
"wss://nos.lol",
|
||||
"wss://nostr.oxtr.dev",
|
||||
"wss://relay.current.fyi",
|
||||
"wss://relay.snort.social",
|
||||
] # ["wss://nostr.oxtr.dev"] # ["wss://relay.nostr.info"] "wss://nostr-pub.wellorder.net" "ws://91.237.88.218:2700", "wss://nostrrr.bublina.eu.org", ""wss://nostr-relay.freeberty.net"", , "wss://nostr.oxtr.dev", "wss://relay.nostr.info", "wss://nostr-pub.wellorder.net" , "wss://relayer.fiatjaf.com", "wss://nodestr.fmt.wiz.biz/", "wss://no.str.cr"
|
||||
]
|
||||
relay_manager = RelayManager()
|
||||
private_key: PrivateKey
|
||||
public_key: PublicKey
|
||||
|
||||
0
cashu/wallet/auth/__init__.py
Normal file
0
cashu/wallet/auth/__init__.py
Normal file
240
cashu/wallet/auth/auth.py
Normal file
240
cashu/wallet/auth/auth.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import hashlib
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from cashu.core.helpers import sum_proofs
|
||||
from cashu.core.mint_info import MintInfo
|
||||
|
||||
from ...core.base import Proof
|
||||
from ...core.crypto.secp import PrivateKey
|
||||
from ...core.db import Database
|
||||
from ..crud import get_mint_by_url, update_mint
|
||||
from ..wallet import Wallet
|
||||
from .openid_connect.openid_client import AuthorizationFlow, OpenIDClient
|
||||
|
||||
|
||||
class WalletAuth(Wallet):
|
||||
oidc_discovery_url: str
|
||||
oidc_client: OpenIDClient
|
||||
wallet_db: Database
|
||||
auth_flow: AuthorizationFlow
|
||||
username: str | None
|
||||
password: str | None
|
||||
# API prefix for all requests
|
||||
api_prefix = "/v1/auth/blind"
|
||||
|
||||
def __init__(
|
||||
self, url: str, db: str, name: str = "auth", unit: str = "auth", **kwargs
|
||||
):
|
||||
"""Authentication wallet.
|
||||
|
||||
Args:
|
||||
url (str): Mint url.
|
||||
db (str): Auth wallet db location.
|
||||
wallet_db (str): Wallet db location.
|
||||
name (str, optional): Wallet name. Defaults to "auth".
|
||||
unit (str, optional): Wallet unit. Defaults to "auth".
|
||||
kwargs: Additional keyword arguments.
|
||||
client_id (str, optional): OpenID client id. Defaults to "cashu-client".
|
||||
client_secret (str, optional): OpenID client secret. Defaults to "".
|
||||
username (str, optional): OpenID username. When set, the username and
|
||||
password flow will be used to authenticate. If a username is already
|
||||
stored in the database, it will be used. Will be stored in the
|
||||
database if not already stored.
|
||||
password (str, optional): OpenID password. Used if username is set. Will
|
||||
be read from the database if already stored. Will be stored in the
|
||||
database if not already stored.
|
||||
"""
|
||||
super().__init__(url, db, name, unit)
|
||||
self.client_id = kwargs.get("client_id", "cashu-client")
|
||||
logger.trace(f"client_id: {self.client_id}")
|
||||
self.client_secret = kwargs.get("client_secret", "")
|
||||
self.username = kwargs.get("username")
|
||||
self.password = kwargs.get("password")
|
||||
|
||||
if self.username:
|
||||
if self.password is None:
|
||||
raise Exception("Password must be set if username is set.")
|
||||
self.auth_flow = AuthorizationFlow.PASSWORD
|
||||
else:
|
||||
self.auth_flow = AuthorizationFlow.AUTHORIZATION_CODE
|
||||
# self.auth_flow = AuthorizationFlow.DEVICE_CODE
|
||||
|
||||
self.access_token = kwargs.get("access_token")
|
||||
self.refresh_token = kwargs.get("refresh_token")
|
||||
|
||||
# overload with_db
|
||||
@classmethod
|
||||
async def with_db(cls, *args, **kwargs) -> "WalletAuth":
|
||||
"""Create a new wallet with a database.
|
||||
Keyword arguments:
|
||||
url (str): Mint url.
|
||||
db (str): Wallet db location.
|
||||
name (str, optional): Wallet name. Defaults to "auth".
|
||||
username (str, optional): OpenID username. When set, the username and
|
||||
password flow will be used to authenticate. If a username is already
|
||||
stored in the database, it will be used. Will be stored in the
|
||||
database if not already stored.
|
||||
password (str, optional): OpenID password. Used if username is set. Will
|
||||
be read from the database if already stored. Will be stored in the
|
||||
database if not already stored.
|
||||
client_id (str, optional): OpenID client id. Defaults to "cashu-client".
|
||||
client_secret (str, optional): OpenID client secret. Defaults to "".
|
||||
access_token (str, optional): OpenID access token. Defaults to None.
|
||||
refresh_token (str, optional): OpenID refresh token. Defaults to None.
|
||||
Returns:
|
||||
WalletAuth: WalletAuth instance.
|
||||
"""
|
||||
|
||||
url: str = kwargs.get("url", "")
|
||||
db = kwargs.get("db", "")
|
||||
kwargs["name"] = kwargs.get("name", "auth")
|
||||
name = kwargs["name"]
|
||||
username = kwargs.get("username")
|
||||
password = kwargs.get("password")
|
||||
wallet_db = Database(name, db)
|
||||
|
||||
# run migrations etc
|
||||
kwargs.update(dict(skip_db_read=True))
|
||||
await super().with_db(*args, **kwargs)
|
||||
|
||||
# the wallet might not have been created yet
|
||||
# if it was though, we load the username, password,
|
||||
# access token and refresh token from the database
|
||||
try:
|
||||
mint_db = await get_mint_by_url(wallet_db, url)
|
||||
if mint_db:
|
||||
kwargs.update(
|
||||
{
|
||||
"username": username or mint_db.username,
|
||||
"password": password or mint_db.password,
|
||||
"access_token": mint_db.access_token,
|
||||
"refresh_token": mint_db.refresh_token,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
async def init_auth_wallet(
|
||||
self,
|
||||
mint_info: Optional[MintInfo] = None,
|
||||
mint_auth_proofs=True,
|
||||
force_auth=False,
|
||||
) -> bool:
|
||||
"""Initialize authentication wallet.
|
||||
|
||||
Args:
|
||||
mint_info (MintInfo, optional): Mint information. If not provided, we load the
|
||||
info from the database or the mint directly. Defaults to None.
|
||||
mint_auth_proofs (bool, optional): Whether to mint auth proofs if necessary.
|
||||
Defaults to True.
|
||||
force_auth (bool, optional): Whether to force authentication. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool: False if the mint does not require clear auth. True otherwise.
|
||||
"""
|
||||
if mint_info:
|
||||
self.mint_info = mint_info
|
||||
await self.load_mint_info()
|
||||
if not self.mint_info.requires_clear_auth():
|
||||
return False
|
||||
|
||||
# Use the blind auth api_prefix for all following requests
|
||||
await self.load_mint_keysets()
|
||||
await self.activate_keyset()
|
||||
await self.load_proofs()
|
||||
|
||||
self.oidc_discovery_url = self.mint_info.oidc_discovery_url()
|
||||
self.client_id = self.mint_info.oidc_client_id()
|
||||
|
||||
# Initialize OpenIDClient
|
||||
self.oidc_client = OpenIDClient(
|
||||
discovery_url=self.oidc_discovery_url,
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
auth_flow=self.auth_flow,
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
access_token=self.access_token,
|
||||
refresh_token=self.refresh_token,
|
||||
)
|
||||
# Authenticate using OpenIDClient
|
||||
await self.oidc_client.initialize()
|
||||
await self.oidc_client.authenticate(force_authenticate=force_auth)
|
||||
|
||||
await self.store_username_password()
|
||||
await self.store_clear_auth_token()
|
||||
|
||||
if mint_auth_proofs:
|
||||
await self.mint_blind_auth_min_balance()
|
||||
|
||||
return True
|
||||
|
||||
async def mint_blind_auth_min_balance(self) -> None:
|
||||
"""Mint auth tokens if balance is too low."""
|
||||
MIN_BALANCE = self.mint_info.bat_max_mint
|
||||
|
||||
if self.available_balance < MIN_BALANCE:
|
||||
logger.debug(
|
||||
f"Balance too low. Minting {self.unit.str(MIN_BALANCE)} auth tokens."
|
||||
)
|
||||
try:
|
||||
await self.mint_blind_auth()
|
||||
except Exception as e:
|
||||
logger.error(f"Error minting auth proofs: {str(e)}")
|
||||
|
||||
async def store_username_password(self) -> None:
|
||||
"""Store the username and password in the database."""
|
||||
if self.username and self.password:
|
||||
mint_db = await get_mint_by_url(self.db, self.url)
|
||||
if not mint_db:
|
||||
raise Exception("Mint not found.")
|
||||
if mint_db.username != self.username or mint_db.password != self.password:
|
||||
mint_db.username = self.username
|
||||
mint_db.password = self.password
|
||||
await update_mint(self.db, mint_db)
|
||||
|
||||
async def store_clear_auth_token(self) -> None:
|
||||
"""Store the access and refresh tokens in the database."""
|
||||
access_token = self.oidc_client.access_token
|
||||
refresh_token = self.oidc_client.refresh_token
|
||||
if not access_token or not refresh_token:
|
||||
raise Exception("Access or refresh token not available.")
|
||||
# Store the tokens in the database
|
||||
mint_db = await get_mint_by_url(self.db, self.url)
|
||||
if not mint_db:
|
||||
raise Exception("Mint not found.")
|
||||
if (
|
||||
mint_db.access_token != access_token
|
||||
or mint_db.refresh_token != refresh_token
|
||||
):
|
||||
mint_db.access_token = access_token
|
||||
mint_db.refresh_token = refresh_token
|
||||
await update_mint(self.db, mint_db)
|
||||
|
||||
async def mint_blind_auth(self) -> List[Proof]:
|
||||
# Ensure access token is valid
|
||||
if self.oidc_client.is_token_expired():
|
||||
await self.oidc_client.refresh_access_token()
|
||||
await self.store_clear_auth_token()
|
||||
clear_auth_token = self.oidc_client.access_token
|
||||
if not clear_auth_token:
|
||||
raise Exception("No clear auth token available.")
|
||||
|
||||
amounts = self.mint_info.bat_max_mint * [1] # 1 AUTH tokens
|
||||
secrets = [hashlib.sha256(os.urandom(32)).hexdigest() for _ in amounts]
|
||||
rs = [PrivateKey(privkey=os.urandom(32), raw=True) for _ in amounts]
|
||||
derivation_paths = ["" for _ in amounts]
|
||||
outputs, rs = self._construct_outputs(amounts, secrets, rs)
|
||||
promises = await self.blind_mint_blind_auth(clear_auth_token, outputs)
|
||||
new_proofs = await self._construct_proofs(
|
||||
promises, secrets, rs, derivation_paths
|
||||
)
|
||||
logger.debug(
|
||||
f"Minted {self.unit.str(sum_proofs(new_proofs))} blind auth proofs."
|
||||
)
|
||||
return new_proofs
|
||||
487
cashu/wallet/auth/openid_connect/openid_client.py
Normal file
487
cashu/wallet/auth/openid_connect/openid_client.py
Normal file
@@ -0,0 +1,487 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import secrets
|
||||
import webbrowser
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class AuthorizationFlow(Enum):
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
PASSWORD = "password"
|
||||
DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
|
||||
|
||||
class OpenIDClient:
|
||||
"""OpenID Connect client for authentication."""
|
||||
|
||||
oidc_config: Dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
discovery_url: str,
|
||||
client_id: str,
|
||||
client_secret: str = "",
|
||||
auth_flow: Optional[AuthorizationFlow] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
refresh_token: Optional[str] = None,
|
||||
token_expiration_time: Optional[datetime] = None,
|
||||
device_code: Optional[str] = None,
|
||||
) -> None:
|
||||
self.discovery_url: str = discovery_url
|
||||
self.client_id: str = client_id
|
||||
self.client_secret: str = client_secret
|
||||
self.auth_flow: Optional[AuthorizationFlow] = auth_flow
|
||||
self.username: Optional[str] = username
|
||||
self.password: Optional[str] = password
|
||||
self.access_token: Optional[str] = access_token
|
||||
self.refresh_token: Optional[str] = refresh_token
|
||||
self.token_expiration_time: Optional[datetime] = token_expiration_time
|
||||
self.device_code: Optional[str] = device_code
|
||||
|
||||
self.redirect_uri: str = "http://localhost:33388/callback"
|
||||
self.expected_state: str = secrets.token_urlsafe(16)
|
||||
self.token_response: Dict[str, Any] = {}
|
||||
self.token_event: asyncio.Event = asyncio.Event()
|
||||
self.token_endpoint: str = ""
|
||||
self.authorization_endpoint: str = ""
|
||||
self.introspection_endpoint: Optional[str] = None
|
||||
self.revocation_endpoint: Optional[str] = None
|
||||
self.device_authorization_endpoint: Optional[str] = None
|
||||
self.templates: Jinja2Templates = Jinja2Templates(
|
||||
directory="cashu/wallet/auth/openid_connect/templates"
|
||||
)
|
||||
|
||||
self.app: FastAPI = FastAPI()
|
||||
self.app.state.client = self # Store self in app state
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the client asynchronously."""
|
||||
await self.fetch_oidc_configuration()
|
||||
await self.determine_auth_flow()
|
||||
|
||||
async def determine_auth_flow(self) -> AuthorizationFlow:
|
||||
"""Determine the authentication flow to use from the oidc configuration.
|
||||
Supported flows are chosen in the following order:
|
||||
- device_code
|
||||
- authorization_code
|
||||
- password
|
||||
"""
|
||||
if not hasattr(self, "oidc_config"):
|
||||
raise ValueError(
|
||||
"OIDC configuration not loaded. Call fetch_oidc_configuration first."
|
||||
)
|
||||
|
||||
supported_flows = self.oidc_config.get("grant_types_supported", [])
|
||||
|
||||
# if self.auth_flow is already set, check if it is supported
|
||||
if self.auth_flow:
|
||||
if self.auth_flow.value not in supported_flows:
|
||||
raise ValueError(
|
||||
f"Authentication flow {self.auth_flow.value} not supported by the OIDC configuration."
|
||||
)
|
||||
return self.auth_flow
|
||||
|
||||
if AuthorizationFlow.DEVICE_CODE.value in supported_flows:
|
||||
self.auth_flow = AuthorizationFlow.DEVICE_CODE
|
||||
elif AuthorizationFlow.AUTHORIZATION_CODE.value in supported_flows:
|
||||
self.auth_flow = AuthorizationFlow.AUTHORIZATION_CODE
|
||||
elif AuthorizationFlow.PASSWORD.value in supported_flows:
|
||||
self.auth_flow = AuthorizationFlow.PASSWORD
|
||||
else:
|
||||
raise ValueError(
|
||||
"No supported authentication flows found in the OIDC configuration."
|
||||
)
|
||||
|
||||
logger.debug(f"Determined authentication flow: {self.auth_flow.value}")
|
||||
return self.auth_flow
|
||||
|
||||
async def fetch_oidc_configuration(self) -> None:
|
||||
"""Fetch OIDC configuration from the discovery URL."""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self.discovery_url)
|
||||
response.raise_for_status()
|
||||
self.oidc_config = response.json()
|
||||
self.authorization_endpoint = self.oidc_config.get( # type: ignore
|
||||
"authorization_endpoint"
|
||||
)
|
||||
self.token_endpoint = self.oidc_config.get("token_endpoint") # type: ignore
|
||||
self.introspection_endpoint = self.oidc_config.get(
|
||||
"introspection_endpoint"
|
||||
)
|
||||
self.revocation_endpoint = self.oidc_config.get("revocation_endpoint")
|
||||
self.device_authorization_endpoint = self.oidc_config.get(
|
||||
"device_authorization_endpoint"
|
||||
)
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Failed to get OpenID configuration: {e}")
|
||||
raise
|
||||
|
||||
async def handle_callback(self, request: Request) -> HTMLResponse:
|
||||
"""Endpoint to handle the redirect from the OpenID provider."""
|
||||
params = request.query_params
|
||||
if "error" in params:
|
||||
error_str = params["error"]
|
||||
if "error_description" in params:
|
||||
error_str += f": {params['error_description']}"
|
||||
return self.templates.TemplateResponse(
|
||||
"error.html",
|
||||
{"request": request, "error": error_str},
|
||||
)
|
||||
elif "code" in params and "state" in params:
|
||||
code: str = params["code"]
|
||||
state: str = params["state"]
|
||||
if state != self.expected_state:
|
||||
raise HTTPException(status_code=400, detail="Invalid state parameter")
|
||||
token_data: Dict[str, Any] = await self.exchange_code_for_token(code)
|
||||
self.update_token_data(token_data)
|
||||
response = self.render_success_page(request, token_data)
|
||||
self.token_event.set() # Signal that the token has been received
|
||||
return response
|
||||
else:
|
||||
return self.templates.TemplateResponse(
|
||||
"error.html",
|
||||
{"request": request, "error": "Missing 'code' or 'state' parameter"},
|
||||
)
|
||||
|
||||
async def exchange_code_for_token(self, code: str) -> Dict[str, Any]:
|
||||
"""Exchange the authorization code for tokens."""
|
||||
data: Dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
headers: Dict[str, str] = {}
|
||||
if self.client_secret:
|
||||
# Use HTTP Basic Auth if client_secret is provided
|
||||
basic_auth: str = f"{self.client_id}:{self.client_secret}"
|
||||
basic_auth_bytes: bytes = basic_auth.encode("ascii")
|
||||
basic_auth_b64: str = base64.b64encode(basic_auth_bytes).decode("ascii")
|
||||
headers["Authorization"] = f"Basic {basic_auth_b64}"
|
||||
else:
|
||||
# Include client_id in the POST body for public clients
|
||||
data["client_id"] = self.client_id
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
self.token_endpoint, data=data, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error occurred during token exchange: {e}")
|
||||
self.token_event.set()
|
||||
return {}
|
||||
|
||||
async def run_server(self) -> None:
|
||||
"""Run the FastAPI server."""
|
||||
config = uvicorn.Config(
|
||||
self.app, host="127.0.0.1", port=33388, log_level="error"
|
||||
)
|
||||
self.server = uvicorn.Server(config)
|
||||
await self.server.serve()
|
||||
|
||||
async def shutdown_server(self) -> None:
|
||||
"""Shut down the uvicorn server."""
|
||||
self.server.should_exit = True
|
||||
|
||||
async def authenticate(self, force_authenticate: bool = False) -> None:
|
||||
"""Start the authentication process."""
|
||||
need_authenticate = force_authenticate
|
||||
if self.access_token and self.refresh_token:
|
||||
# We have a token and a refresh token, check if token is expired
|
||||
if self.is_token_expired():
|
||||
try:
|
||||
await self.refresh_access_token()
|
||||
except httpx.HTTPError:
|
||||
logger.debug("Failed to refresh token.")
|
||||
need_authenticate = True
|
||||
else:
|
||||
logger.debug("Using existing access token.")
|
||||
else:
|
||||
need_authenticate = True
|
||||
|
||||
if need_authenticate:
|
||||
if self.auth_flow == AuthorizationFlow.AUTHORIZATION_CODE:
|
||||
await self.authenticate_with_authorization_code()
|
||||
elif self.auth_flow == AuthorizationFlow.PASSWORD:
|
||||
await self.authenticate_with_password()
|
||||
elif self.auth_flow == AuthorizationFlow.DEVICE_CODE:
|
||||
await self.authenticate_with_device_code()
|
||||
else:
|
||||
raise ValueError(f"Unknown authentication flow: {self.auth_flow}")
|
||||
|
||||
def is_token_expired(self) -> bool:
|
||||
"""Check if the access token is expired."""
|
||||
if not self.access_token:
|
||||
raise ValueError("Access token is not set.")
|
||||
decoded = jwt.decode(self.access_token, options={"verify_signature": False})
|
||||
exp = decoded.get("exp")
|
||||
if not exp:
|
||||
return False
|
||||
return datetime.now() >= datetime.fromtimestamp(exp) - timedelta(minutes=1)
|
||||
|
||||
async def refresh_access_token(self) -> None:
|
||||
"""Refresh the access token using the refresh token."""
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": self.refresh_token,
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(self.token_endpoint, data=data)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
self.update_token_data(token_data)
|
||||
logger.info("Token refreshed successfully.")
|
||||
except httpx.HTTPError as e:
|
||||
logger.debug(f"Failed to refresh token: {e}")
|
||||
raise
|
||||
|
||||
async def authenticate_with_authorization_code(self) -> None:
|
||||
"""Authenticate using the authorization code flow."""
|
||||
|
||||
# Set up the route handlers
|
||||
@self.app.get("/callback", response_class=HTMLResponse)
|
||||
async def handle_callback(request: Request) -> Any:
|
||||
print("CALLBACK")
|
||||
return await self.handle_callback(request)
|
||||
|
||||
# Build the authorization URL
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": "openid",
|
||||
"state": self.expected_state,
|
||||
}
|
||||
auth_url = f"{self.authorization_endpoint}?{urlencode(params)}"
|
||||
|
||||
# Start the web server as an asyncio task
|
||||
server_task = asyncio.create_task(self.run_server())
|
||||
|
||||
# Open the browser or print the URL for the user
|
||||
logger.info("Please open the following URL in your browser to authenticate:")
|
||||
logger.info(auth_url)
|
||||
webbrowser.open(auth_url)
|
||||
|
||||
# Wait for the token response
|
||||
logger.info("Waiting for authentication...")
|
||||
await self.token_event.wait()
|
||||
|
||||
# Use the retrieved tokens
|
||||
if self.token_response:
|
||||
logger.info("Authentication successful!")
|
||||
# logger.info("Token response:")
|
||||
# logger.info(self.token_response)
|
||||
else:
|
||||
logger.error("Authentication failed.")
|
||||
|
||||
# Signal the server to shut down
|
||||
await self.shutdown_server()
|
||||
|
||||
# Wait for the server task to finish
|
||||
await server_task
|
||||
|
||||
def update_token_data(self, token_data: Dict[str, Any]) -> None:
|
||||
self.token_response.update(token_data)
|
||||
self.access_token = token_data.get("access_token")
|
||||
self.refresh_token = token_data.get("refresh_token")
|
||||
if not self.access_token or not self.refresh_token:
|
||||
raise ValueError(
|
||||
"Access token or refresh token not found in token response."
|
||||
)
|
||||
expires_in = token_data.get("expires_in")
|
||||
if expires_in:
|
||||
self.token_expiration_time = datetime.utcnow() + timedelta(
|
||||
seconds=int(expires_in)
|
||||
)
|
||||
refresh_expires_in = token_data.get("refresh_expires_in")
|
||||
if refresh_expires_in:
|
||||
logger.debug(f"Refresh token expires in {refresh_expires_in} seconds.")
|
||||
self.refresh_token_expiration_time = datetime.utcnow() + timedelta(
|
||||
seconds=int(refresh_expires_in)
|
||||
)
|
||||
|
||||
async def authenticate_with_password(self) -> None:
|
||||
"""Authenticate using the resource owner password credentials flow."""
|
||||
if not self.username or not self.password:
|
||||
raise ValueError(
|
||||
'Username and password must be provided. To set a password use: "cashu auth -p"'
|
||||
)
|
||||
data = {
|
||||
"grant_type": "password",
|
||||
"client_id": self.client_id,
|
||||
"username": self.username,
|
||||
"password": self.password,
|
||||
"scope": "openid",
|
||||
}
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(self.token_endpoint, data=data)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
self.update_token_data(token_data)
|
||||
logger.info("Authentication successful!")
|
||||
# logger.info("Token response:")
|
||||
# logger.info(self.token_response)
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Failed to obtain token: {e}")
|
||||
raise
|
||||
|
||||
def render_success_page(
|
||||
self, request: Request, token_data: Dict[str, Any]
|
||||
) -> HTMLResponse:
|
||||
"""Render an HTML page with a green check mark and user information."""
|
||||
context = {"request": request, "token_data": token_data}
|
||||
return self.templates.TemplateResponse("success.html", context)
|
||||
|
||||
async def authenticate_with_device_code(self) -> None:
|
||||
"""Authenticate using the device code flow."""
|
||||
if not self.device_authorization_endpoint:
|
||||
raise ValueError("Device authorization endpoint not available.")
|
||||
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"scope": "openid",
|
||||
}
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
self.device_authorization_endpoint, data=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
device_data = response.json()
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Failed to obtain device code: {e}")
|
||||
raise
|
||||
|
||||
# Extract device code data
|
||||
device_code = device_data.get("device_code")
|
||||
user_code = device_data.get("user_code")
|
||||
verification_uri = device_data.get("verification_uri")
|
||||
verification_uri_complete = device_data.get("verification_uri_complete")
|
||||
expires_in = device_data.get("expires_in")
|
||||
interval = device_data.get("interval", 5) # Default interval is 5 seconds
|
||||
|
||||
if not device_code or not verification_uri:
|
||||
raise ValueError("Invalid response from device authorization endpoint.")
|
||||
|
||||
# Display instructions to the user and open the browser
|
||||
if verification_uri_complete:
|
||||
logger.info("Opening browser to complete authorization...")
|
||||
logger.info(verification_uri_complete)
|
||||
webbrowser.open(verification_uri_complete)
|
||||
else:
|
||||
logger.info("Please visit the following URL to authorize:")
|
||||
logger.info(verification_uri)
|
||||
logger.info(f"Enter the user code: {user_code}")
|
||||
# Construct the URL for the user to enter the code
|
||||
full_verification_uri = f"{verification_uri}?user_code={user_code}"
|
||||
webbrowser.open(full_verification_uri)
|
||||
|
||||
# Start polling the token endpoint
|
||||
start_time = datetime.now()
|
||||
expires_at = start_time + timedelta(seconds=expires_in)
|
||||
token_data = None
|
||||
while datetime.now() < expires_at:
|
||||
await asyncio.sleep(interval)
|
||||
data = {
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": device_code,
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(self.token_endpoint, data=data)
|
||||
if response.status_code == 200:
|
||||
# Successful response
|
||||
token_data = response.json()
|
||||
self.update_token_data(token_data)
|
||||
logger.info("Authentication successful!")
|
||||
break
|
||||
else:
|
||||
error_data = response.json()
|
||||
error = error_data.get("error")
|
||||
if error == "authorization_pending":
|
||||
# Continue polling
|
||||
pass
|
||||
elif error == "slow_down":
|
||||
# Increase interval by 5 seconds
|
||||
interval += 5
|
||||
elif error == "access_denied":
|
||||
logger.error("Access denied by user.")
|
||||
break
|
||||
elif error == "expired_token":
|
||||
logger.error("Device code has expired.")
|
||||
break
|
||||
else:
|
||||
logger.error(f"Error during polling: {error}")
|
||||
break
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error during token polling: {e}")
|
||||
break
|
||||
else:
|
||||
logger.error("Device code has expired before authorization.")
|
||||
raise Exception("Device code expired")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="OpenID Connect Authentication Client")
|
||||
parser.add_argument("discovery_url", help="OpenID Connect Discovery URL")
|
||||
parser.add_argument("client_id", help="Client ID")
|
||||
parser.add_argument("--client_secret", help="Client Secret", default="")
|
||||
parser.add_argument(
|
||||
"--auth_flow",
|
||||
choices=["authorization_code", "password", "device_code"],
|
||||
default="authorization_code",
|
||||
help="Authentication flow to use",
|
||||
)
|
||||
parser.add_argument("--username", help="Username for password flow")
|
||||
parser.add_argument("--password", help="Password for password flow")
|
||||
parser.add_argument("--access_token", help="Stored access token")
|
||||
parser.add_argument("--refresh_token", help="Stored refresh token")
|
||||
parser.add_argument("--device_code", help="Device code for device flow")
|
||||
args = parser.parse_args()
|
||||
|
||||
client = OpenIDClient(
|
||||
discovery_url=args.discovery_url,
|
||||
client_id=args.client_id,
|
||||
client_secret=args.client_secret,
|
||||
auth_flow=AuthorizationFlow(args.auth_flow),
|
||||
username=args.username,
|
||||
password=args.password,
|
||||
access_token=args.access_token,
|
||||
refresh_token=args.refresh_token,
|
||||
device_code=args.device_code,
|
||||
)
|
||||
await client.initialize()
|
||||
await client.authenticate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
26
cashu/wallet/auth/openid_connect/templates/error.html
Normal file
26
cashu/wallet/auth/openid_connect/templates/error.html
Normal file
@@ -0,0 +1,26 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Authentication Error</title>
|
||||
<style>
|
||||
.error {
|
||||
text-align: center;
|
||||
margin-top: 50px;
|
||||
}
|
||||
.error h1 {
|
||||
color: red;
|
||||
}
|
||||
.crossmark {
|
||||
font-size: 100px;
|
||||
color: red;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="error">
|
||||
<div class="crossmark">✖</div>
|
||||
<h1>Authentication Error</h1>
|
||||
<p>{{ error }}</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
33
cashu/wallet/auth/openid_connect/templates/success.html
Normal file
33
cashu/wallet/auth/openid_connect/templates/success.html
Normal file
@@ -0,0 +1,33 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Authentication Successful</title>
|
||||
<style>
|
||||
.success {
|
||||
text-align: center;
|
||||
margin-top: 50px;
|
||||
}
|
||||
.success h1 {
|
||||
color: green;
|
||||
}
|
||||
.checkmark {
|
||||
font-size: 100px;
|
||||
color: green;
|
||||
}
|
||||
.token-data {
|
||||
margin-top: 20px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="success">
|
||||
<div class="checkmark">✔</div>
|
||||
<h1>Authentication Successful</h1>
|
||||
<h3>You can close this window now.</h3>
|
||||
<!-- <div class="token-data">
|
||||
<h3>User Information:</h3>
|
||||
<pre>{{ token_data | tojson(indent=2) }}</pre>
|
||||
</div> -->
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import asyncio
|
||||
import getpass
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
@@ -41,6 +42,7 @@ from ...wallet.crud import (
|
||||
)
|
||||
from ...wallet.wallet import Wallet as Wallet
|
||||
from ..api.api_server import start_api_server
|
||||
from ..auth.auth import WalletAuth
|
||||
from ..cli.cli_helpers import (
|
||||
get_mint_wallet,
|
||||
get_unit_wallet,
|
||||
@@ -84,6 +86,49 @@ def coro(f):
|
||||
return wrapper
|
||||
|
||||
|
||||
def init_auth_wallet(func):
|
||||
"""Decorator to pass auth_db and auth_keyset_id to the Wallet object."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
ctx = args[0] # Assuming the first argument is 'ctx'
|
||||
wallet: Wallet = ctx.obj["WALLET"]
|
||||
db_location = wallet.db.db_location
|
||||
|
||||
auth_wallet = await WalletAuth.with_db(
|
||||
url=ctx.obj["HOST"],
|
||||
db=db_location,
|
||||
)
|
||||
|
||||
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
|
||||
|
||||
if not requires_auth:
|
||||
logger.debug("Mint does not require clear auth.")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# Pass auth_db and auth_keyset_id to the wallet object
|
||||
wallet.auth_db = auth_wallet.db
|
||||
wallet.auth_keyset_id = auth_wallet.keyset_id
|
||||
# pass the mint_info so the wallet doesn't need to re-fetch it
|
||||
wallet.mint_info = auth_wallet.mint_info
|
||||
|
||||
# Pass the auth_wallet to context
|
||||
args[0].obj["AUTH_WALLET"] = auth_wallet
|
||||
|
||||
# Proceed to the original function
|
||||
ret = await func(*args, **kwargs)
|
||||
|
||||
if settings.debug:
|
||||
await auth_wallet.load_proofs(reload=True)
|
||||
logger.debug(
|
||||
f"Auth balance: {auth_wallet.unit.str(auth_wallet.available_balance)}"
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@click.group(cls=NaturalOrderGroup)
|
||||
@click.option(
|
||||
"--host",
|
||||
@@ -176,6 +221,10 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool):
|
||||
ctx.obj["HOST"], db_path, name=walletname, unit=unit
|
||||
)
|
||||
|
||||
# if we have never seen this mint before, we load its information
|
||||
if not wallet.mint_info:
|
||||
await wallet.load_mint()
|
||||
|
||||
assert wallet, "Wallet not found."
|
||||
ctx.obj["WALLET"] = wallet
|
||||
|
||||
@@ -205,6 +254,7 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool):
|
||||
)
|
||||
@click.pass_context
|
||||
@coro
|
||||
@init_auth_wallet
|
||||
async def pay(
|
||||
ctx: Context, invoice: str, amount: Optional[int] = None, yes: bool = False
|
||||
):
|
||||
@@ -291,6 +341,7 @@ async def pay(
|
||||
)
|
||||
@click.pass_context
|
||||
@coro
|
||||
@init_auth_wallet
|
||||
async def invoice(
|
||||
ctx: Context,
|
||||
amount: float,
|
||||
@@ -451,6 +502,7 @@ async def invoice(
|
||||
@cli.command("swap", help="Swap funds between mints.")
|
||||
@click.pass_context
|
||||
@coro
|
||||
@init_auth_wallet
|
||||
async def swap(ctx: Context):
|
||||
print("Select the mint to swap from:")
|
||||
outgoing_wallet: Wallet = await get_mint_wallet(ctx, force_select=True)
|
||||
@@ -621,8 +673,9 @@ async def balance(ctx: Context, verbose):
|
||||
)
|
||||
@click.pass_context
|
||||
@coro
|
||||
@init_auth_wallet
|
||||
async def send_command(
|
||||
ctx,
|
||||
ctx: Context,
|
||||
amount: int,
|
||||
memo: str,
|
||||
nostr: str,
|
||||
@@ -668,6 +721,7 @@ async def send_command(
|
||||
)
|
||||
@click.pass_context
|
||||
@coro
|
||||
@init_auth_wallet
|
||||
async def receive_cli(
|
||||
ctx: Context,
|
||||
token: str,
|
||||
@@ -685,6 +739,8 @@ async def receive_cli(
|
||||
mint_url,
|
||||
os.path.join(settings.cashu_dir, wallet.name),
|
||||
unit=token_obj.unit,
|
||||
auth_db=wallet.auth_db.db_location if wallet.auth_db else None,
|
||||
auth_keyset_id=wallet.auth_keyset_id,
|
||||
)
|
||||
await verify_mint(mint_wallet, mint_url)
|
||||
receive_wallet = await receive(mint_wallet, token_obj)
|
||||
@@ -830,7 +886,7 @@ async def pending(ctx: Context, legacy, number: int, offset: int):
|
||||
@cli.command("lock", help="Generate receiving lock.")
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def lock(ctx):
|
||||
async def lock(ctx: Context):
|
||||
wallet: Wallet = ctx.obj["WALLET"]
|
||||
|
||||
pubkey = await wallet.create_p2pk_pubkey()
|
||||
@@ -851,7 +907,7 @@ async def lock(ctx):
|
||||
@cli.command("locks", help="Show unused receiving locks.")
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def locks(ctx):
|
||||
async def locks(ctx: Context):
|
||||
wallet: Wallet = ctx.obj["WALLET"]
|
||||
# P2PK lock
|
||||
pubkey = await wallet.create_p2pk_pubkey()
|
||||
@@ -899,7 +955,7 @@ async def locks(ctx):
|
||||
)
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def invoices(ctx, paid: bool, unpaid: bool, pending: bool, mint: bool):
|
||||
async def invoices(ctx: Context, paid: bool, unpaid: bool, pending: bool, mint: bool):
|
||||
wallet: Wallet = ctx.obj["WALLET"]
|
||||
|
||||
if paid and unpaid:
|
||||
@@ -1001,7 +1057,7 @@ async def invoices(ctx, paid: bool, unpaid: bool, pending: bool, mint: bool):
|
||||
@cli.command("wallets", help="List of all available wallets.")
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def wallets(ctx):
|
||||
async def wallets(ctx: Context):
|
||||
# list all directories
|
||||
wallets = [
|
||||
d for d in listdir(settings.cashu_dir) if isdir(join(settings.cashu_dir, d))
|
||||
@@ -1013,7 +1069,7 @@ async def wallets(ctx):
|
||||
for w in wallets:
|
||||
wallet = Wallet(ctx.obj["HOST"], os.path.join(settings.cashu_dir, w))
|
||||
try:
|
||||
await wallet.load_proofs()
|
||||
await wallet.load_proofs(reload=True, all_keysets=True)
|
||||
if wallet.proofs and len(wallet.proofs):
|
||||
active_wallet = False
|
||||
if w == ctx.obj["WALLET_NAME"]:
|
||||
@@ -1031,9 +1087,10 @@ async def wallets(ctx):
|
||||
@cli.command("info", help="Information about Cashu wallet.")
|
||||
@click.option("--mint", default=False, is_flag=True, help="Fetch mint information.")
|
||||
@click.option("--mnemonic", default=False, is_flag=True, help="Show your mnemonic.")
|
||||
@click.option("--reload", default=False, is_flag=True, help="Reload mint info.")
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def info(ctx: Context, mint: bool, mnemonic: bool):
|
||||
async def info(ctx: Context, mint: bool, mnemonic: bool, reload: bool):
|
||||
wallet: Wallet = ctx.obj["WALLET"]
|
||||
await wallet.load_keysets_from_db(unit=None)
|
||||
|
||||
@@ -1057,7 +1114,11 @@ async def info(ctx: Context, mint: bool, mnemonic: bool):
|
||||
if mint:
|
||||
wallet.url = mint_url
|
||||
try:
|
||||
mint_info: dict = (await wallet.load_mint_info()).dict()
|
||||
mint_info_obj = await wallet.load_mint_info(reload)
|
||||
if not mint_info_obj:
|
||||
print(" - Mint information not available.")
|
||||
continue
|
||||
mint_info = mint_info_obj.dict()
|
||||
if mint_info:
|
||||
print(f" - Mint name: {mint_info['name']}")
|
||||
if mint_info.get("description"):
|
||||
@@ -1126,6 +1187,7 @@ async def info(ctx: Context, mint: bool, mnemonic: bool):
|
||||
)
|
||||
@click.pass_context
|
||||
@coro
|
||||
@init_auth_wallet
|
||||
async def restore(ctx: Context, to: int, batch: int):
|
||||
wallet: Wallet = ctx.obj["WALLET"]
|
||||
# check if there is already a mnemonic in the database
|
||||
@@ -1160,6 +1222,7 @@ async def restore(ctx: Context, to: int, batch: int):
|
||||
# @click.option("--all", default=False, is_flag=True, help="Execute on all available mints.")
|
||||
@click.pass_context
|
||||
@coro
|
||||
@init_auth_wallet
|
||||
async def selfpay(ctx: Context, all: bool = False):
|
||||
wallet = await get_mint_wallet(ctx, force_select=True)
|
||||
await wallet.load_mint()
|
||||
@@ -1183,3 +1246,46 @@ async def selfpay(ctx: Context, all: bool = False):
|
||||
print(token)
|
||||
token_obj = TokenV4.deserialize(token)
|
||||
await receive(wallet, token_obj)
|
||||
|
||||
|
||||
@cli.command("auth", help="Authenticate with mint.")
|
||||
@click.option("--mint", "-m", default=False, is_flag=True, help="Mint new auth tokens.")
|
||||
@click.option(
|
||||
"--force", "-f", default=False, is_flag=True, help="Force authentication."
|
||||
)
|
||||
@click.option(
|
||||
"--password",
|
||||
"-p",
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Use username and password for authentication.",
|
||||
)
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def auth(ctx: Context, mint: bool, force: bool, password: bool):
|
||||
# auth_wallet: WalletAuth = ctx.obj["AUTH_WALLET"]
|
||||
wallet: Wallet = ctx.obj["WALLET"]
|
||||
username = None
|
||||
password_str = None
|
||||
if password:
|
||||
username = input("Enter username: ")
|
||||
password_str = getpass.getpass("Enter password: ")
|
||||
auth_wallet = await WalletAuth.with_db(
|
||||
url=ctx.obj["HOST"],
|
||||
db=wallet.db.db_location,
|
||||
username=username,
|
||||
password=password_str,
|
||||
)
|
||||
|
||||
requires_auth = await auth_wallet.init_auth_wallet(
|
||||
wallet.mint_info, mint_auth_proofs=False, force_auth=force
|
||||
)
|
||||
if not requires_auth:
|
||||
print("Mint does not require authentication.")
|
||||
return
|
||||
|
||||
if mint:
|
||||
new_proofs = await auth_wallet.mint_blind_auth()
|
||||
print(f"Minted {auth_wallet.unit.str(sum_proofs(new_proofs))} auth tokens.")
|
||||
|
||||
print(f"Auth balance: {auth_wallet.unit.str(auth_wallet.available_balance)}")
|
||||
|
||||
@@ -9,6 +9,7 @@ from ..core.base import (
|
||||
MintQuoteState,
|
||||
Proof,
|
||||
WalletKeyset,
|
||||
WalletMint,
|
||||
)
|
||||
from ..core.db import Connection, Database
|
||||
|
||||
@@ -577,3 +578,59 @@ async def store_seed_and_mnemonic(
|
||||
"mnemonic": mnemonic,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def store_mint(
|
||||
db: Database,
|
||||
mint: WalletMint,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
"""
|
||||
INSERT INTO mints
|
||||
(url, info, updated)
|
||||
VALUES (:url, :info, :updated)
|
||||
""",
|
||||
{
|
||||
"url": mint.url,
|
||||
"info": mint.info,
|
||||
"updated": int(time.time()),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def update_mint(
|
||||
db: Database,
|
||||
mint: WalletMint,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
"""
|
||||
UPDATE mints
|
||||
SET info = :info, updated = :updated, access_token = :access_token, refresh_token = :refresh_token, username = :username, password = :password
|
||||
WHERE url = :url
|
||||
""",
|
||||
{
|
||||
"url": mint.url,
|
||||
"info": mint.info,
|
||||
"updated": int(time.time()),
|
||||
"access_token": mint.access_token,
|
||||
"refresh_token": mint.refresh_token,
|
||||
"username": mint.username,
|
||||
"password": mint.password,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_mint_by_url(
|
||||
db: Database,
|
||||
url: str,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Optional[WalletMint]:
|
||||
row = await (conn or db).fetchone(
|
||||
"""
|
||||
SELECT * from mints WHERE url = :url
|
||||
""",
|
||||
{"url": url},
|
||||
)
|
||||
return WalletMint.parse_obj(dict(row)) if row else None
|
||||
|
||||
16
cashu/wallet/errors.py
Normal file
16
cashu/wallet/errors.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class WalletError(Exception):
|
||||
msg: str
|
||||
|
||||
def __init__(self, msg):
|
||||
super().__init__(msg)
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class BalanceTooLowError(WalletError):
|
||||
msg = "Balance too low"
|
||||
|
||||
def __init__(self, msg: Optional[str] = None):
|
||||
super().__init__(msg or self.msg)
|
||||
@@ -52,6 +52,8 @@ async def redeem_TokenV3(wallet: Wallet, token: TokenV3) -> Wallet:
|
||||
t.mint,
|
||||
os.path.join(settings.cashu_dir, wallet.name),
|
||||
unit=token.unit or wallet.unit.name,
|
||||
auth_db=wallet.auth_db.db_location if wallet.auth_db else None,
|
||||
auth_keyset_id=wallet.auth_keyset_id,
|
||||
)
|
||||
keyset_ids = mint_wallet._get_proofs_keyset_ids(t.proofs)
|
||||
logger.trace(f"Keysets in tokens: {' '.join(set(keyset_ids))}")
|
||||
|
||||
@@ -214,7 +214,6 @@ async def m010_add_ids_to_proofs_and_out_to_invoices(db: Database):
|
||||
Columns that store mint and melt id for proofs and invoices.
|
||||
"""
|
||||
async with db.connect() as conn:
|
||||
print("Running wallet migrations")
|
||||
await conn.execute("ALTER TABLE proofs ADD COLUMN mint_id TEXT")
|
||||
await conn.execute("ALTER TABLE proofs ADD COLUMN melt_id TEXT")
|
||||
|
||||
@@ -287,7 +286,8 @@ async def m013_add_mint_and_melt_quote_tables(db: Database):
|
||||
"""
|
||||
)
|
||||
|
||||
async def m013_add_key_to_mint_quote_table(db: Database):
|
||||
|
||||
async def m014_add_key_to_mint_quote_table(db: Database):
|
||||
async with db.connect() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
@@ -295,3 +295,21 @@ async def m013_add_key_to_mint_quote_table(db: Database):
|
||||
ADD COLUMN privkey TEXT DEFAULT NULL;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
async def m015_add_mints_table(db: Database):
|
||||
async with db.connect() as conn:
|
||||
await conn.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS mints (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
url TEXT NOT NULL,
|
||||
info TEXT NOT NULL,
|
||||
updated TIMESTAMP DEFAULT {db.timestamp_now},
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
username TEXT,
|
||||
password TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Dict, List, Protocol
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core.base import Proof, Unit, WalletKeyset
|
||||
from ..core.crypto.secp import PrivateKey
|
||||
from ..core.db import Database
|
||||
from ..core.mint_info import MintInfo
|
||||
|
||||
|
||||
class SupportsPrivateKey(Protocol):
|
||||
@@ -28,3 +29,9 @@ class SupportsHttpxClient(Protocol):
|
||||
|
||||
class SupportsMintURL(Protocol):
|
||||
url: str
|
||||
|
||||
|
||||
class SupportsAuth(Protocol):
|
||||
auth_db: Optional[Database] = None
|
||||
auth_keyset_id: Optional[str] = None
|
||||
mint_info: Optional[MintInfo] = None
|
||||
|
||||
@@ -47,7 +47,7 @@ class SubscriptionManager:
|
||||
|
||||
try:
|
||||
msg = JSONRPCNotification.parse_raw(message)
|
||||
logger.debug(f"Received notification: {msg}")
|
||||
logger.trace(f"Received notification: {msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing notification: {e}")
|
||||
return
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import ValidationError
|
||||
from cashu.wallet.crud import get_bolt11_melt_quote
|
||||
|
||||
from ..core.base import (
|
||||
AuthProof,
|
||||
BlindedMessage,
|
||||
BlindedSignature,
|
||||
MeltQuoteState,
|
||||
@@ -29,6 +30,8 @@ from ..core.models import (
|
||||
KeysetsResponse,
|
||||
KeysetsResponseKeyset,
|
||||
KeysResponse,
|
||||
PostAuthBlindMintRequest,
|
||||
PostAuthBlindMintResponse,
|
||||
PostCheckStateRequest,
|
||||
PostCheckStateResponse,
|
||||
PostMeltQuoteRequest,
|
||||
@@ -47,8 +50,16 @@ from ..core.models import (
|
||||
)
|
||||
from ..core.settings import settings
|
||||
from ..tor.tor import TorProxy
|
||||
from .crud import (
|
||||
get_proofs,
|
||||
invalidate_proof,
|
||||
)
|
||||
from .protocols import SupportsAuth
|
||||
from .wallet_deprecated import LedgerAPIDeprecated
|
||||
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
|
||||
|
||||
def async_set_httpx_client(func):
|
||||
"""
|
||||
@@ -78,7 +89,7 @@ def async_set_httpx_client(func):
|
||||
verify=not settings.debug,
|
||||
proxies=proxies_dict, # type: ignore
|
||||
headers=headers_dict,
|
||||
base_url=self.url,
|
||||
base_url=self.url.rstrip("/"),
|
||||
timeout=None if settings.debug else 60,
|
||||
)
|
||||
return await func(self, *args, **kwargs)
|
||||
@@ -99,10 +110,10 @@ def async_ensure_mint_loaded(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
class LedgerAPI(LedgerAPIDeprecated):
|
||||
class LedgerAPI(LedgerAPIDeprecated, SupportsAuth):
|
||||
tor: TorProxy
|
||||
db: Database # we need the db for melt_deprecated
|
||||
httpx: httpx.AsyncClient
|
||||
api_prefix = "v1"
|
||||
|
||||
def __init__(self, url: str, db: Database):
|
||||
self.url = url
|
||||
@@ -128,7 +139,6 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
try:
|
||||
resp_dict = resp.json()
|
||||
except json.JSONDecodeError:
|
||||
# if we can't decode the response, raise for status
|
||||
resp.raise_for_status()
|
||||
return
|
||||
if "detail" in resp_dict:
|
||||
@@ -137,9 +147,49 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
if "code" in resp_dict:
|
||||
error_message += f" (Code: {resp_dict['code']})"
|
||||
raise Exception(error_message)
|
||||
# raise for status if no error
|
||||
resp.raise_for_status()
|
||||
|
||||
async def _request(self, method: str, path: str, noprefix=False, **kwargs):
|
||||
if not noprefix:
|
||||
path = join(self.api_prefix, path)
|
||||
if self.mint_info and self.mint_info.requires_blind_auth_path(method, path):
|
||||
if not self.auth_db:
|
||||
raise Exception(
|
||||
"Mint requires blind auth, but no auth database is set."
|
||||
)
|
||||
if not self.auth_keyset_id:
|
||||
raise Exception(
|
||||
"Mint requires blind auth, but no auth keyset id is set."
|
||||
)
|
||||
proofs = await get_proofs(db=self.auth_db, id=self.auth_keyset_id)
|
||||
if not proofs:
|
||||
raise Exception(
|
||||
"Mint requires blind auth, but no blind auth tokens were found."
|
||||
)
|
||||
# select one auth proof
|
||||
proof = proofs[0]
|
||||
auth_token = AuthProof.from_proof(proof).to_base64()
|
||||
kwargs.setdefault("headers", {}).update(
|
||||
{
|
||||
"Blind-auth": f"{auth_token}",
|
||||
}
|
||||
)
|
||||
await invalidate_proof(proof=proof, db=self.auth_db)
|
||||
if self.mint_info and self.mint_info.requires_clear_auth_path(method, path):
|
||||
logger.debug(f"Using clear auth token for {path}")
|
||||
clear_auth_token = kwargs.pop("clear_auth_token")
|
||||
if not clear_auth_token:
|
||||
raise Exception(
|
||||
"Mint requires clear auth, but no clear auth token is set."
|
||||
)
|
||||
kwargs.setdefault("headers", {}).update(
|
||||
{
|
||||
"Clear-auth": f"{clear_auth_token}",
|
||||
}
|
||||
)
|
||||
|
||||
return await self.httpx.request(method, path, **kwargs)
|
||||
|
||||
"""
|
||||
ENDPOINTS
|
||||
"""
|
||||
@@ -157,9 +207,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
Raises:
|
||||
Exception: If no keys are received from the mint
|
||||
"""
|
||||
resp = await self.httpx.get(
|
||||
join(self.url, "/v1/keys"),
|
||||
)
|
||||
resp = await self._request(GET, "keys")
|
||||
# BEGIN backwards compatibility < 0.15.0
|
||||
# assume the mint has not upgraded yet if we get a 404
|
||||
if resp.status_code == 404:
|
||||
@@ -201,9 +249,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
Exception: If no keys are received from the mint
|
||||
"""
|
||||
keyset_id_urlsafe = keyset_id.replace("+", "-").replace("/", "_")
|
||||
resp = await self.httpx.get(
|
||||
join(self.url, f"/v1/keys/{keyset_id_urlsafe}"),
|
||||
)
|
||||
resp = await self._request(GET, f"keys/{keyset_id_urlsafe}")
|
||||
# BEGIN backwards compatibility < 0.15.0
|
||||
# assume the mint has not upgraded yet if we get a 404
|
||||
if resp.status_code == 404:
|
||||
@@ -238,9 +284,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
Raises:
|
||||
Exception: If no keysets are received from the mint
|
||||
"""
|
||||
resp = await self.httpx.get(
|
||||
join(self.url, "/v1/keysets"),
|
||||
)
|
||||
resp = await self._request(GET, "keysets")
|
||||
# BEGIN backwards compatibility < 0.15.0
|
||||
# assume the mint has not upgraded yet if we get a 404
|
||||
if resp.status_code == 404:
|
||||
@@ -265,9 +309,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
Raises:
|
||||
Exception: If the mint info request fails
|
||||
"""
|
||||
resp = await self.httpx.get(
|
||||
join(self.url, "/v1/info"),
|
||||
)
|
||||
resp = await self._request(GET, "/v1/info", noprefix=True)
|
||||
# BEGIN backwards compatibility < 0.15.0
|
||||
# assume the mint has not upgraded yet if we get a 404
|
||||
if resp.status_code == 404:
|
||||
@@ -305,9 +347,12 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
payload = PostMintQuoteRequest(
|
||||
unit=unit.name, amount=amount, description=memo, pubkey=pubkey
|
||||
)
|
||||
resp = await self.httpx.post(
|
||||
join(self.url, "/v1/mint/quote/bolt11"), json=payload.dict()
|
||||
resp = await self._request(
|
||||
POST,
|
||||
"mint/quote/bolt11",
|
||||
json=payload.dict(),
|
||||
)
|
||||
|
||||
# BEGIN backwards compatibility < 0.15.0
|
||||
# assume the mint has not upgraded yet if we get a 404
|
||||
if resp.status_code == 404:
|
||||
@@ -329,9 +374,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
Returns:
|
||||
PostMintQuoteResponse: Mint Quote Response
|
||||
"""
|
||||
resp = await self.httpx.get(
|
||||
join(self.url, f"/v1/mint/quote/bolt11/{quote}"),
|
||||
)
|
||||
resp = await self._request(GET, f"mint/quote/bolt11/{quote}")
|
||||
self.raise_on_error_request(resp)
|
||||
return_dict = resp.json()
|
||||
return PostMintQuoteResponse.parse_obj(return_dict)
|
||||
@@ -371,8 +414,9 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
return res
|
||||
|
||||
payload = outputs_payload.dict(include=_mintrequest_include_fields(outputs)) # type: ignore
|
||||
resp = await self.httpx.post(
|
||||
join(self.url, "/v1/mint/bolt11"),
|
||||
resp = await self._request(
|
||||
POST,
|
||||
"mint/bolt11",
|
||||
json=payload, # type: ignore
|
||||
)
|
||||
# BEGIN backwards compatibility < 0.15.0
|
||||
@@ -383,7 +427,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
# END backwards compatibility < 0.15.0
|
||||
self.raise_on_error_request(resp)
|
||||
response_dict = resp.json()
|
||||
logger.trace("Lightning invoice checked. POST /v1/mint/bolt11")
|
||||
logger.trace(f"Lightning invoice checked. POST {self.api_prefix}/mint/bolt11")
|
||||
promises = PostMintResponse.parse_obj(response_dict).signatures
|
||||
return promises
|
||||
|
||||
@@ -406,8 +450,9 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
unit=unit.name, request=payment_request, options=melt_options
|
||||
)
|
||||
|
||||
resp = await self.httpx.post(
|
||||
join(self.url, "/v1/melt/quote/bolt11"),
|
||||
resp = await self._request(
|
||||
POST,
|
||||
"melt/quote/bolt11",
|
||||
json=payload.dict(),
|
||||
)
|
||||
# BEGIN backwards compatibility < 0.15.0
|
||||
@@ -441,9 +486,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
Returns:
|
||||
PostMeltQuoteResponse: Melt Quote Response
|
||||
"""
|
||||
resp = await self.httpx.get(
|
||||
join(self.url, f"/v1/melt/quote/bolt11/{quote}"),
|
||||
)
|
||||
resp = await self._request(GET, f"melt/quote/bolt11/{quote}")
|
||||
self.raise_on_error_request(resp)
|
||||
return_dict = resp.json()
|
||||
return PostMeltQuoteResponse.parse_obj(return_dict)
|
||||
@@ -474,8 +517,9 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
"outputs": {i: outputs_include for i in range(len(outputs))},
|
||||
}
|
||||
|
||||
resp = await self.httpx.post(
|
||||
join(self.url, "/v1/melt/bolt11"),
|
||||
resp = await self._request(
|
||||
POST,
|
||||
"melt/bolt11",
|
||||
json=payload.dict(include=_meltrequest_include_fields(proofs, outputs)), # type: ignore
|
||||
timeout=None,
|
||||
)
|
||||
@@ -523,7 +567,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
outputs: List[BlindedMessage],
|
||||
) -> List[BlindedSignature]:
|
||||
"""Consume proofs and create new promises based on amount split."""
|
||||
logger.debug("Calling split. POST /v1/swap")
|
||||
logger.debug(f"Calling split. POST {self.api_prefix}/swap")
|
||||
split_payload = PostSwapRequest(inputs=proofs, outputs=outputs)
|
||||
|
||||
# construct payload
|
||||
@@ -541,8 +585,9 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
"inputs": {i: proofs_include for i in range(len(proofs))},
|
||||
}
|
||||
|
||||
resp = await self.httpx.post(
|
||||
join(self.url, "/v1/swap"),
|
||||
resp = await self._request(
|
||||
POST,
|
||||
"swap",
|
||||
json=split_payload.dict(include=_splitrequest_include_fields(proofs)), # type: ignore
|
||||
)
|
||||
# BEGIN backwards compatibility < 0.15.0
|
||||
@@ -568,8 +613,9 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
Checks whether the secrets in proofs are already spent or not and returns a list of booleans.
|
||||
"""
|
||||
payload = PostCheckStateRequest(Ys=[p.Y for p in proofs])
|
||||
resp = await self.httpx.post(
|
||||
join(self.url, "/v1/checkstate"),
|
||||
resp = await self._request(
|
||||
POST,
|
||||
"checkstate",
|
||||
json=payload.dict(),
|
||||
)
|
||||
# BEGIN backwards compatibility < 0.15.0
|
||||
@@ -595,10 +641,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
"Received HTTP Error 422. Attempting state check with < 0.16.0 compatibility."
|
||||
)
|
||||
payload_secrets = {"secrets": [p.secret for p in proofs]}
|
||||
resp_secrets = await self.httpx.post(
|
||||
join(self.url, "/v1/checkstate"),
|
||||
json=payload_secrets,
|
||||
)
|
||||
resp_secrets = await self._request(POST, "checkstate", json=payload_secrets)
|
||||
self.raise_on_error(resp_secrets)
|
||||
states = [
|
||||
ProofState(Y=p.Y, state=ProofSpentState(s["state"]))
|
||||
@@ -619,7 +662,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
Asks the mint to restore promises corresponding to outputs.
|
||||
"""
|
||||
payload = PostMintRequest(quote="restore", outputs=outputs)
|
||||
resp = await self.httpx.post(join(self.url, "/v1/restore"), json=payload.dict())
|
||||
resp = await self._request(POST, "restore", json=payload.dict())
|
||||
# BEGIN backwards compatibility < 0.15.0
|
||||
# assume the mint has not upgraded yet if we get a 404
|
||||
if resp.status_code == 404:
|
||||
@@ -637,3 +680,21 @@ class LedgerAPI(LedgerAPIDeprecated):
|
||||
# END backwards compatibility < 0.15.1
|
||||
|
||||
return returnObj.outputs, returnObj.signatures
|
||||
|
||||
@async_set_httpx_client
|
||||
async def blind_mint_blind_auth(
|
||||
self, clear_auth_token: str, outputs: List[BlindedMessage]
|
||||
) -> List[BlindedSignature]:
|
||||
"""
|
||||
Asks the mint to mint blind auth tokens. Needs to provide a clear auth token.
|
||||
"""
|
||||
payload = PostAuthBlindMintRequest(outputs=outputs)
|
||||
resp = await self._request(
|
||||
POST,
|
||||
"mint",
|
||||
json=payload.dict(),
|
||||
clear_auth_token=clear_auth_token,
|
||||
)
|
||||
self.raise_on_error_request(resp)
|
||||
response_dict = resp.json()
|
||||
return PostAuthBlindMintResponse.parse_obj(response_dict).signatures
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
@@ -17,6 +18,7 @@ from ..core.base import (
|
||||
Proof,
|
||||
Unit,
|
||||
WalletKeyset,
|
||||
WalletMint,
|
||||
)
|
||||
from ..core.crypto import b_dhke
|
||||
from ..core.crypto.secp import PrivateKey, PublicKey
|
||||
@@ -30,6 +32,7 @@ from ..core.helpers import (
|
||||
)
|
||||
from ..core.json_rpc.base import JSONRPCSubscriptionKinds
|
||||
from ..core.migrations import migrate_databases
|
||||
from ..core.mint_info import MintInfo
|
||||
from ..core.models import (
|
||||
PostCheckStateResponse,
|
||||
PostMeltQuoteResponse,
|
||||
@@ -43,6 +46,7 @@ from .crud import (
|
||||
bump_secret_derivation,
|
||||
get_bolt11_mint_quote,
|
||||
get_keysets,
|
||||
get_mint_by_url,
|
||||
get_proofs,
|
||||
invalidate_proof,
|
||||
secret_used,
|
||||
@@ -50,14 +54,16 @@ from .crud import (
|
||||
store_bolt11_melt_quote,
|
||||
store_bolt11_mint_quote,
|
||||
store_keyset,
|
||||
store_mint,
|
||||
store_proof,
|
||||
update_bolt11_melt_quote,
|
||||
update_bolt11_mint_quote,
|
||||
update_keyset,
|
||||
update_mint,
|
||||
update_proof,
|
||||
)
|
||||
from .errors import BalanceTooLowError
|
||||
from .htlc import WalletHTLC
|
||||
from .mint_info import MintInfo
|
||||
from .p2pk import WalletP2PK
|
||||
from .proofs import WalletProofs
|
||||
from .secrets import WalletSecrets
|
||||
@@ -108,21 +114,41 @@ class Wallet(
|
||||
db: Database
|
||||
bip32: BIP32
|
||||
# private_key: Optional[PrivateKey] = None
|
||||
auth_db: Optional[Database] = None
|
||||
auth_keyset_id: Optional[str] = None
|
||||
|
||||
def __init__(self, url: str, db: str, name: str = "wallet", unit: str = "sat"):
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
db: str,
|
||||
name: str = "wallet",
|
||||
unit: str = "sat",
|
||||
auth_db: Optional[str] = None,
|
||||
auth_keyset_id: Optional[str] = None,
|
||||
):
|
||||
"""A Cashu wallet.
|
||||
|
||||
Args:
|
||||
url (str): URL of the mint.
|
||||
db (str): Path to the database directory.
|
||||
name (str, optional): Name of the wallet database file. Defaults to "wallet".
|
||||
unit (str, optional): Unit of the wallet. Defaults to "sat".
|
||||
auth_db (Optional[str], optional): Path to the auth database directory. Defaults to None.
|
||||
auth_keyset_id (Optional[str], optional): Keyset ID of the auth keyset. Defaults to None.
|
||||
"""
|
||||
self.db = Database("wallet", db)
|
||||
self.db = Database(name, db)
|
||||
self.proofs: List[Proof] = []
|
||||
self.name = name
|
||||
self.unit = Unit[unit]
|
||||
url = sanitize_url(url)
|
||||
|
||||
# if this is an auth wallet
|
||||
if (auth_db and not auth_keyset_id) or (not auth_db and auth_keyset_id):
|
||||
raise Exception("Both auth_db and auth_keyset_id must be provided.")
|
||||
if auth_db and auth_keyset_id:
|
||||
self.auth_db = Database("auth", auth_db)
|
||||
self.auth_keyset_id = auth_keyset_id
|
||||
|
||||
super().__init__(url=url, db=self.db)
|
||||
logger.debug("Wallet initialized")
|
||||
logger.debug(f"Mint URL: {url}")
|
||||
@@ -137,7 +163,10 @@ class Wallet(
|
||||
name: str = "wallet",
|
||||
skip_db_read: bool = False,
|
||||
unit: str = "sat",
|
||||
auth_db: Optional[str] = None,
|
||||
auth_keyset_id: Optional[str] = None,
|
||||
load_all_keysets: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes a wallet with a database and initializes the private key.
|
||||
|
||||
@@ -151,12 +180,22 @@ class Wallet(
|
||||
unit (str, optional): Unit of the wallet. Defaults to "sat".
|
||||
load_all_keysets (bool, optional): If true, all keysets are loaded from the database.
|
||||
Defaults to False.
|
||||
auth_db (Optional[str], optional): Path to the auth database directory. Defaults to None.
|
||||
auth_keyset_id (Optional[str], optional): Keyset ID of the auth keyset. Defaults to None.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Wallet: Initialized wallet.
|
||||
"""
|
||||
logger.trace(f"Initializing wallet with database: {db}")
|
||||
self = cls(url=url, db=db, name=name, unit=unit)
|
||||
self = cls(
|
||||
url=url,
|
||||
db=db,
|
||||
name=name,
|
||||
unit=unit,
|
||||
auth_db=auth_db,
|
||||
auth_keyset_id=auth_keyset_id,
|
||||
)
|
||||
await self._migrate_database()
|
||||
|
||||
if skip_db_read:
|
||||
@@ -172,8 +211,13 @@ class Wallet(
|
||||
self.keysets = {k.id: k for k in keysets_active_unit}
|
||||
else:
|
||||
self.keysets = {k.id: k for k in keysets_list}
|
||||
keysets_str = " ".join([f"{i} {k.unit}" for i, k in self.keysets.items()])
|
||||
logger.debug(f"Loaded keysets: {keysets_str}")
|
||||
|
||||
if self.keysets:
|
||||
keysets_str = " ".join([f"{i} {k.unit}" for i, k in self.keysets.items()])
|
||||
logger.debug(f"Loaded keysets: {keysets_str}")
|
||||
|
||||
await self.load_mint_info(offline=True)
|
||||
|
||||
return self
|
||||
|
||||
async def _migrate_database(self):
|
||||
@@ -185,12 +229,63 @@ class Wallet(
|
||||
|
||||
# ---------- API ----------
|
||||
|
||||
async def load_mint_info(self) -> MintInfo:
|
||||
"""Loads the mint info from the mint."""
|
||||
mint_info_resp = await self._get_info()
|
||||
self.mint_info = MintInfo(**mint_info_resp.dict())
|
||||
logger.debug(f"Mint info: {self.mint_info}")
|
||||
return self.mint_info
|
||||
async def load_mint_info(self, reload=False, offline=False) -> MintInfo | None:
|
||||
"""Loads the mint info from the mint.
|
||||
|
||||
Args:
|
||||
reload (bool, optional): If True, the mint info is reloaded from the mint. Defaults to False.
|
||||
offline (bool, optional): If True, the mint info is not loaded from the mint. Defaults to False.
|
||||
"""
|
||||
# if self.mint_info and not reload:
|
||||
# return self.mint_info
|
||||
|
||||
# read mint info from db
|
||||
if reload:
|
||||
if offline:
|
||||
raise Exception("Cannot reload mint info offline.")
|
||||
logger.debug("Forcing reload of mint info.")
|
||||
mint_info_resp = await self._get_info()
|
||||
self.mint_info = MintInfo(**mint_info_resp.dict())
|
||||
|
||||
wallet_mint_db = await get_mint_by_url(url=self.url, db=self.db)
|
||||
if not wallet_mint_db:
|
||||
if self.mint_info:
|
||||
logger.debug("Storing mint info in db.")
|
||||
await store_mint(
|
||||
db=self.db,
|
||||
mint=WalletMint(
|
||||
url=self.url, info=json.dumps(self.mint_info.dict())
|
||||
),
|
||||
)
|
||||
else:
|
||||
if offline:
|
||||
return None
|
||||
logger.debug("Loading mint info from mint.")
|
||||
mint_info_resp = await self._get_info()
|
||||
self.mint_info = MintInfo(**mint_info_resp.dict())
|
||||
if not wallet_mint_db:
|
||||
logger.debug("Storing mint info in db.")
|
||||
await store_mint(
|
||||
db=self.db,
|
||||
mint=WalletMint(
|
||||
url=self.url, info=json.dumps(self.mint_info.dict())
|
||||
),
|
||||
)
|
||||
return self.mint_info
|
||||
elif (
|
||||
self.mint_info
|
||||
and not json.dumps(self.mint_info.dict()) == wallet_mint_db.info
|
||||
):
|
||||
logger.debug("Updating mint info in db.")
|
||||
await update_mint(
|
||||
db=self.db,
|
||||
mint=WalletMint(url=self.url, info=json.dumps(self.mint_info.dict())),
|
||||
)
|
||||
return self.mint_info
|
||||
else:
|
||||
logger.debug("Loading mint info from db.")
|
||||
self.mint_info = MintInfo.from_json_str(wallet_mint_db.info)
|
||||
return self.mint_info
|
||||
|
||||
async def load_mint_keysets(self, force_old_keysets=False):
|
||||
"""Loads all keyset of the mint and makes sure we have them all in the database.
|
||||
@@ -304,11 +399,11 @@ class Wallet(
|
||||
force_old_keysets (bool, optional): If true, old deprecated base64 keysets are not ignored. This is necessary for restoring tokens from old base64 keysets.
|
||||
Defaults to False.
|
||||
"""
|
||||
logger.trace("Loading mint.")
|
||||
logger.trace(f"Loading mint {self.url}")
|
||||
await self.load_mint_keysets(force_old_keysets)
|
||||
await self.activate_keyset(keyset_id)
|
||||
try:
|
||||
await self.load_mint_info()
|
||||
await self.load_mint_info(reload=True)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load mint info: {e}")
|
||||
pass
|
||||
@@ -979,7 +1074,7 @@ class Wallet(
|
||||
# select proofs that are not reserved and are in the active keysets of the mint
|
||||
proofs = self.active_proofs(proofs)
|
||||
if sum_proofs(proofs) < amount:
|
||||
raise Exception("balance too low.")
|
||||
raise BalanceTooLowError()
|
||||
|
||||
# coin selection for potentially offline sending
|
||||
send_proofs = self.coinselect(proofs, amount, include_fees=include_fees)
|
||||
@@ -1037,7 +1132,7 @@ class Wallet(
|
||||
# select proofs that are not reserved and are in the active keysets of the mint
|
||||
proofs = self.active_proofs(proofs)
|
||||
if sum_proofs(proofs) < amount:
|
||||
raise Exception("balance too low.")
|
||||
raise BalanceTooLowError()
|
||||
|
||||
# coin selection for swapping, needs to include fees
|
||||
swap_proofs = self.coinselect(proofs, amount, include_fees=True)
|
||||
|
||||
@@ -144,6 +144,8 @@ class LedgerAPIDeprecated(SupportsHttpxClient, SupportsMintURL):
|
||||
mint_info = GetInfoResponse(
|
||||
**mint_info_deprecated.dict(exclude={"parameter", "nuts", "contact"})
|
||||
)
|
||||
# monkeypatch nuts
|
||||
mint_info.nuts = {}
|
||||
return mint_info
|
||||
|
||||
@async_set_httpx_client
|
||||
@@ -261,7 +263,7 @@ class LedgerAPIDeprecated(SupportsHttpxClient, SupportsMintURL):
|
||||
paid=False,
|
||||
state=MintQuoteState.unpaid.value,
|
||||
expiry=decoded_invoice.date + (decoded_invoice.expiry or 0),
|
||||
pubkey=None
|
||||
pubkey=None,
|
||||
)
|
||||
|
||||
@async_set_httpx_client
|
||||
|
||||
7
keycloak/.env.example
Normal file
7
keycloak/.env.example
Normal file
@@ -0,0 +1,7 @@
|
||||
POSTGRES_DB=keycloak_db
|
||||
POSTGRES_USER=keycloak_db_user
|
||||
POSTGRES_PASSWORD=keycloak_db_user_password
|
||||
KEYCLOAK_ADMIN=admin
|
||||
KEYCLOAK_ADMIN_PASSWORD=password
|
||||
KC_HOSTNAME=localhost
|
||||
KC_HOSTNAME_PORT=8080
|
||||
129
keycloak/README.md
Normal file
129
keycloak/README.md
Normal file
@@ -0,0 +1,129 @@
|
||||
## Docker compose
|
||||
|
||||
This docker-compose starts a new keycloak instance. Set up the server as you wish, add realms, users etc. We will then export the data and restore an instance with the exported data.
|
||||
|
||||
We will modify this file later to start the server with the backup data.
|
||||
|
||||
```
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16.4
|
||||
volumes:
|
||||
- ./postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
POSTGRES_DB: ${POSTGRES_DB}
|
||||
POSTGRES_USER: ${POSTGRES_USER}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
networks:
|
||||
- keycloak_network
|
||||
|
||||
keycloak:
|
||||
image: quay.io/keycloak/keycloak:25.0.6
|
||||
command: start
|
||||
environment:
|
||||
KC_HOSTNAME: localhost
|
||||
KC_HOSTNAME_PORT: 8080
|
||||
KC_HOSTNAME_STRICT_BACKCHANNEL: false
|
||||
KC_HTTP_ENABLED: true
|
||||
KC_HOSTNAME_STRICT_HTTPS: false
|
||||
KC_HEALTH_ENABLED: true
|
||||
KEYCLOAK_ADMIN: ${KEYCLOAK_ADMIN}
|
||||
KEYCLOAK_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD}
|
||||
KC_DB: postgres
|
||||
KC_DB_URL: jdbc:postgresql://postgres/${POSTGRES_DB}
|
||||
KC_DB_USERNAME: ${POSTGRES_USER}
|
||||
KC_DB_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
ports:
|
||||
- 8080:8080
|
||||
restart: always
|
||||
depends_on:
|
||||
- postgres
|
||||
networks:
|
||||
- keycloak_network
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
keycloak_network:
|
||||
driver: bridge
|
||||
```
|
||||
|
||||
## Backup
|
||||
|
||||
Export realm and users from running container:
|
||||
|
||||
```
|
||||
docker exec keycloak-keycloak-1 \
|
||||
/opt/keycloak/bin/kc.sh export \
|
||||
--dir /opt/keycloak/data/export \
|
||||
--users different_files \
|
||||
--http-management-port 46566
|
||||
```
|
||||
|
||||
Copy export out of the docker
|
||||
|
||||
```
|
||||
docker cp keycloak-keycloak-1:/opt/keycloak/data/export ./keycloak-export
|
||||
```
|
||||
|
||||
## Restore
|
||||
|
||||
Use this docker-compose.yml to start keycloak with the exported backup:
|
||||
|
||||
```
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16.4
|
||||
volumes:
|
||||
- ./postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
POSTGRES_DB: ${POSTGRES_DB}
|
||||
POSTGRES_USER: ${POSTGRES_USER}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
networks:
|
||||
- keycloak_network
|
||||
|
||||
keycloak:
|
||||
image: quay.io/keycloak/keycloak:25.0.6
|
||||
command: start --import-realm
|
||||
volumes:
|
||||
- ./keycloak-export:/opt/keycloak/data/import
|
||||
environment:
|
||||
KC_HOSTNAME: localhost
|
||||
KC_HOSTNAME_PORT: 8080
|
||||
KC_HOSTNAME_STRICT_BACKCHANNEL: false
|
||||
KC_HTTP_ENABLED: true
|
||||
KC_HOSTNAME_STRICT_HTTPS: false
|
||||
KC_HEALTH_ENABLED: true
|
||||
KEYCLOAK_ADMIN: ${KEYCLOAK_ADMIN}
|
||||
KEYCLOAK_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD}
|
||||
KC_DB: postgres
|
||||
KC_DB_URL: jdbc:postgresql://postgres/${POSTGRES_DB}
|
||||
KC_DB_USERNAME: ${POSTGRES_USER}
|
||||
KC_DB_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
ports:
|
||||
- 8080:8080
|
||||
restart: always
|
||||
depends_on:
|
||||
- postgres
|
||||
networks:
|
||||
- keycloak_network
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
keycloak_network:
|
||||
driver: bridge
|
||||
```
|
||||
|
||||
Difference to first docker-compose is only the following part:
|
||||
|
||||
```
|
||||
command: start --import-realm
|
||||
volumes:
|
||||
- ./keycloak-export:/opt/keycloak/data/import
|
||||
```
|
||||
45
keycloak/docker-compose.yml
Normal file
45
keycloak/docker-compose.yml
Normal file
@@ -0,0 +1,45 @@
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16.4
|
||||
volumes:
|
||||
- ./postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
POSTGRES_DB: ${POSTGRES_DB}
|
||||
POSTGRES_USER: ${POSTGRES_USER}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
networks:
|
||||
- keycloak_network
|
||||
|
||||
keycloak:
|
||||
image: quay.io/keycloak/keycloak:25.0.6
|
||||
command: start --import-realm
|
||||
volumes:
|
||||
- ./keycloak-export:/opt/keycloak/data/import
|
||||
environment:
|
||||
KC_HOSTNAME: ${KC_HOSTNAME}
|
||||
KC_HOSTNAME_PORT: ${KC_HOSTNAME_PORT}
|
||||
KC_HOSTNAME_STRICT_BACKCHANNEL: false
|
||||
KC_HTTP_ENABLED: true
|
||||
KC_HOSTNAME_STRICT_HTTPS: true
|
||||
KC_HEALTH_ENABLED: true
|
||||
KEYCLOAK_ADMIN: ${KEYCLOAK_ADMIN}
|
||||
KEYCLOAK_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD}
|
||||
KC_DB: postgres
|
||||
KC_DB_URL: jdbc:postgresql://postgres/${POSTGRES_DB}
|
||||
KC_DB_USERNAME: ${POSTGRES_USER}
|
||||
KC_DB_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
ports:
|
||||
- 8080:8080
|
||||
restart: always
|
||||
depends_on:
|
||||
- postgres
|
||||
networks:
|
||||
- keycloak_network
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
keycloak_network:
|
||||
driver: bridge
|
||||
1099
poetry.lock
generated
1099
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -39,9 +39,11 @@ googleapis-common-protos = "^1.63.2"
|
||||
mypy-protobuf = "^3.6.0"
|
||||
types-protobuf = "^5.27.0.20240626"
|
||||
grpcio-tools = "^1.65.1"
|
||||
pyjwt = "^2.9.0"
|
||||
redis = "^5.1.1"
|
||||
brotli = "^1.1.0"
|
||||
zstandard = "^0.23.0"
|
||||
jinja2 = "^3.1.5"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest-asyncio = "^0.24.0"
|
||||
|
||||
@@ -51,6 +51,7 @@ settings.mint_lnd_enable_mpp = True
|
||||
settings.mint_clnrest_enable_mpp = True
|
||||
settings.mint_input_fee_ppk = 0
|
||||
settings.db_connection_pool = True
|
||||
# settings.mint_require_auth = False
|
||||
|
||||
assert "test" in settings.cashu_dir
|
||||
shutil.rmtree(settings.cashu_dir, ignore_errors=True)
|
||||
|
||||
45
tests/keycloak_data/docker-compose-restore.yml
Normal file
45
tests/keycloak_data/docker-compose-restore.yml
Normal file
@@ -0,0 +1,45 @@
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16.4
|
||||
volumes:
|
||||
- ./postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
POSTGRES_DB: cashu
|
||||
POSTGRES_USER: cashu
|
||||
POSTGRES_PASSWORD: cashu
|
||||
networks:
|
||||
- keycloak_network
|
||||
|
||||
keycloak:
|
||||
image: quay.io/keycloak/keycloak:25.0.6
|
||||
command: start --import-realm
|
||||
volumes:
|
||||
- ./keycloak-export:/opt/keycloak/data/import
|
||||
environment:
|
||||
KC_HOSTNAME: localhost
|
||||
KC_HOSTNAME_PORT: 8080
|
||||
KC_HOSTNAME_STRICT_BACKCHANNEL: false
|
||||
KC_HTTP_ENABLED: true
|
||||
KC_HOSTNAME_STRICT_HTTPS: false
|
||||
KC_HEALTH_ENABLED: true
|
||||
KEYCLOAK_ADMIN: admin
|
||||
KEYCLOAK_ADMIN_PASSWORD: admin
|
||||
KC_DB: postgres
|
||||
KC_DB_URL: jdbc:postgresql://postgres/cashu
|
||||
KC_DB_USERNAME: cashu
|
||||
KC_DB_PASSWORD: cashu
|
||||
ports:
|
||||
- 8080:8080
|
||||
restart: always
|
||||
depends_on:
|
||||
- postgres
|
||||
networks:
|
||||
- keycloak_network
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
keycloak_network:
|
||||
driver: bridge
|
||||
2021
tests/keycloak_data/keycloak-export/master-realm.json
Normal file
2021
tests/keycloak_data/keycloak-export/master-realm.json
Normal file
File diff suppressed because it is too large
Load Diff
26
tests/keycloak_data/keycloak-export/master-users-0.json
Normal file
26
tests/keycloak_data/keycloak-export/master-users-0.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"realm" : "master",
|
||||
"users" : [ {
|
||||
"id" : "0ff227f7-c163-4fca-9ae4-c8751c725421",
|
||||
"username" : "admin",
|
||||
"emailVerified" : false,
|
||||
"createdTimestamp" : 1727128354842,
|
||||
"enabled" : true,
|
||||
"totp" : false,
|
||||
"credentials" : [ {
|
||||
"id" : "11a5f9ed-19c9-4164-be31-28ce6e23955b",
|
||||
"type" : "password",
|
||||
"createdDate" : 1727128354904,
|
||||
"secretData" : "{\"value\":\"s/6M2/FCFd1fOyHJRMvOLvKM7e2JIOC6LZ3ovFVkGi8=\",\"salt\":\"Zjn7ChOL5688O84xf1ElGA==\",\"additionalParameters\":{}}",
|
||||
"credentialData" : "{\"hashIterations\":5,\"algorithm\":\"argon2\",\"additionalParameters\":{\"hashLength\":[\"32\"],\"memory\":[\"7168\"],\"type\":[\"id\"],\"version\":[\"1.3\"],\"parallelism\":[\"1\"]}}"
|
||||
} ],
|
||||
"disableableCredentialTypes" : [ ],
|
||||
"requiredActions" : [ ],
|
||||
"realmRoles" : [ "default-roles-master", "admin" ],
|
||||
"clientRoles" : {
|
||||
"nutshell-realm" : [ "query-realms", "query-users", "manage-identity-providers", "manage-authorization", "view-identity-providers", "view-realm", "view-authorization", "query-clients", "manage-clients", "create-client", "view-events", "manage-events", "manage-realm", "manage-users", "view-users", "view-clients", "query-groups" ]
|
||||
},
|
||||
"notBefore" : 0,
|
||||
"groups" : [ ]
|
||||
} ]
|
||||
}
|
||||
1902
tests/keycloak_data/keycloak-export/nutshell-realm.json
Normal file
1902
tests/keycloak_data/keycloak-export/nutshell-realm.json
Normal file
File diff suppressed because it is too large
Load Diff
53
tests/keycloak_data/keycloak-export/nutshell-users-0.json
Normal file
53
tests/keycloak_data/keycloak-export/nutshell-users-0.json
Normal file
@@ -0,0 +1,53 @@
|
||||
{
|
||||
"realm" : "nutshell",
|
||||
"users" : [ {
|
||||
"id" : "c4fc742a-700f-4c83-96f2-8777c8bb56d1",
|
||||
"username" : "asd@asd.com",
|
||||
"firstName" : "asd",
|
||||
"lastName" : "asd",
|
||||
"email" : "asd@asd.com",
|
||||
"emailVerified" : false,
|
||||
"createdTimestamp" : 1727128876722,
|
||||
"enabled" : true,
|
||||
"totp" : false,
|
||||
"credentials" : [ {
|
||||
"id" : "23ea2b79-9c09-4133-b53b-2708258da890",
|
||||
"type" : "password",
|
||||
"createdDate" : 1727128876754,
|
||||
"secretData" : "{\"value\":\"fDXqE3IjxS5uIYfn9eYgW5GwokWvGsg2wWY0lOgeYyE=\",\"salt\":\"Wlb5f8yPTh4QreuC99b7Zg==\",\"additionalParameters\":{}}",
|
||||
"credentialData" : "{\"hashIterations\":5,\"algorithm\":\"argon2\",\"additionalParameters\":{\"hashLength\":[\"32\"],\"memory\":[\"7168\"],\"type\":[\"id\"],\"version\":[\"1.3\"],\"parallelism\":[\"1\"]}}"
|
||||
} ],
|
||||
"disableableCredentialTypes" : [ ],
|
||||
"requiredActions" : [ ],
|
||||
"realmRoles" : [ "default-roles-nutshell" ],
|
||||
"clientConsents" : [ {
|
||||
"clientId" : "cashu-client",
|
||||
"grantedClientScopes" : [ "email", "roles", "profile" ],
|
||||
"createdDate" : 1732651444894,
|
||||
"lastUpdatedDate" : 1732651444908
|
||||
} ],
|
||||
"notBefore" : 0,
|
||||
"groups" : [ ]
|
||||
}, {
|
||||
"id" : "43a16bd6-f5c5-4dfa-bcd4-6a5540564797",
|
||||
"username" : "callebtc@protonmail.com",
|
||||
"firstName" : "asdasd",
|
||||
"lastName" : "asdasdasdasd",
|
||||
"email" : "callebtc@protonmail.com",
|
||||
"emailVerified" : false,
|
||||
"createdTimestamp" : 1732639511706,
|
||||
"enabled" : true,
|
||||
"totp" : false,
|
||||
"credentials" : [ ],
|
||||
"disableableCredentialTypes" : [ ],
|
||||
"requiredActions" : [ ],
|
||||
"federatedIdentities" : [ {
|
||||
"identityProvider" : "github",
|
||||
"userId" : "93376500",
|
||||
"userName" : "callebtc"
|
||||
} ],
|
||||
"realmRoles" : [ "default-roles-nutshell" ],
|
||||
"notBefore" : 0,
|
||||
"groups" : [ ]
|
||||
} ]
|
||||
}
|
||||
@@ -183,14 +183,14 @@ async def test_mint(wallet1: Wallet):
|
||||
assert wallet1.balance == 64
|
||||
|
||||
# verify that proofs in proofs_used db have the same mint_id as the invoice in the db
|
||||
mint_quote = await get_bolt11_mint_quote(db=wallet1.db, quote=mint_quote.quote)
|
||||
assert mint_quote
|
||||
mint_quote_2 = await get_bolt11_mint_quote(db=wallet1.db, quote=mint_quote.quote)
|
||||
assert mint_quote_2
|
||||
proofs_minted = await get_proofs(
|
||||
db=wallet1.db, mint_id=mint_quote.quote, table="proofs"
|
||||
db=wallet1.db, mint_id=mint_quote_2.quote, table="proofs"
|
||||
)
|
||||
assert len(proofs_minted) == len(expected_proof_amounts)
|
||||
assert all([p.amount in expected_proof_amounts for p in proofs_minted])
|
||||
assert all([p.mint_id == mint_quote.quote for p in proofs_minted])
|
||||
assert all([p.mint_id == mint_quote_2.quote for p in proofs_minted])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -356,7 +356,7 @@ async def test_swap_to_send_more_than_balance(wallet1: Wallet):
|
||||
await wallet1.mint(64, quote_id=mint_quote.quote)
|
||||
await assert_err(
|
||||
wallet1.swap_to_send(wallet1.proofs, 128, set_reserved=True),
|
||||
"balance too low.",
|
||||
"Balance too low",
|
||||
)
|
||||
assert wallet1.balance == 64
|
||||
assert wallet1.available_balance == 64
|
||||
|
||||
251
tests/test_wallet_auth.py
Normal file
251
tests/test_wallet_auth.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from cashu.core.base import Unit
|
||||
from cashu.core.crypto.keys import random_hash
|
||||
from cashu.core.crypto.secp import PrivateKey
|
||||
from cashu.core.errors import (
|
||||
BlindAuthFailedError,
|
||||
BlindAuthRateLimitExceededError,
|
||||
ClearAuthFailedError,
|
||||
)
|
||||
from cashu.core.settings import settings
|
||||
from cashu.wallet.auth.auth import WalletAuth
|
||||
from cashu.wallet.wallet import Wallet
|
||||
from tests.conftest import SERVER_ENDPOINT
|
||||
from tests.helpers import assert_err
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def wallet():
|
||||
dirpath = Path("test_data/wallet")
|
||||
if dirpath.exists() and dirpath.is_dir():
|
||||
shutil.rmtree(dirpath)
|
||||
wallet = await Wallet.with_db(
|
||||
url=SERVER_ENDPOINT,
|
||||
db="test_data/wallet",
|
||||
name="wallet",
|
||||
)
|
||||
await wallet.load_mint()
|
||||
yield wallet
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not settings.mint_require_auth,
|
||||
reason="settings.mint_require_auth is False",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_auth_password(wallet: Wallet):
|
||||
auth_wallet = await WalletAuth.with_db(
|
||||
url=wallet.url,
|
||||
db=wallet.db.db_location,
|
||||
username="asd@asd.com",
|
||||
password="asdasd",
|
||||
)
|
||||
|
||||
requires_auth = await auth_wallet.init_auth_wallet(
|
||||
wallet.mint_info, mint_auth_proofs=False
|
||||
)
|
||||
assert requires_auth
|
||||
|
||||
# expect JWT (CAT) with format ey*.ey*
|
||||
assert auth_wallet.oidc_client.access_token
|
||||
assert auth_wallet.oidc_client.access_token.split(".")[0].startswith("ey")
|
||||
assert auth_wallet.oidc_client.access_token.split(".")[1].startswith("ey")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not settings.mint_require_auth,
|
||||
reason="settings.mint_require_auth is False",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_auth_wrong_password(wallet: Wallet):
|
||||
auth_wallet = await WalletAuth.with_db(
|
||||
url=wallet.url,
|
||||
db=wallet.db.db_location,
|
||||
username="asd@asd.com",
|
||||
password="wrong_password",
|
||||
)
|
||||
|
||||
await assert_err(auth_wallet.init_auth_wallet(wallet.mint_info), "401 Unauthorized")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not settings.mint_require_auth,
|
||||
reason="settings.mint_require_auth is False",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_auth_mint(wallet: Wallet):
|
||||
auth_wallet = await WalletAuth.with_db(
|
||||
url=wallet.url,
|
||||
db=wallet.db.db_location,
|
||||
username="asd@asd.com",
|
||||
password="asdasd",
|
||||
)
|
||||
|
||||
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
|
||||
assert requires_auth
|
||||
|
||||
await auth_wallet.load_proofs()
|
||||
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not settings.mint_require_auth,
|
||||
reason="settings.mint_require_auth is False",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_auth_mint_manually(wallet: Wallet):
|
||||
auth_wallet = await WalletAuth.with_db(
|
||||
url=wallet.url,
|
||||
db=wallet.db.db_location,
|
||||
username="asd@asd.com",
|
||||
password="asdasd",
|
||||
)
|
||||
|
||||
requires_auth = await auth_wallet.init_auth_wallet(
|
||||
wallet.mint_info, mint_auth_proofs=False
|
||||
)
|
||||
assert requires_auth
|
||||
assert len(auth_wallet.proofs) == 0
|
||||
|
||||
await auth_wallet.mint_blind_auth()
|
||||
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not settings.mint_require_auth,
|
||||
reason="settings.mint_require_auth is False",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_auth_mint_manually_invalid_cat(wallet: Wallet):
|
||||
auth_wallet = await WalletAuth.with_db(
|
||||
url=wallet.url,
|
||||
db=wallet.db.db_location,
|
||||
username="asd@asd.com",
|
||||
password="asdasd",
|
||||
)
|
||||
|
||||
requires_auth = await auth_wallet.init_auth_wallet(
|
||||
wallet.mint_info, mint_auth_proofs=False
|
||||
)
|
||||
assert requires_auth
|
||||
assert len(auth_wallet.proofs) == 0
|
||||
|
||||
# invalidate CAT in the database
|
||||
auth_wallet.oidc_client.access_token = random_hash()
|
||||
|
||||
# this is the code executed in auth_wallet.mint_blind_auth():
|
||||
clear_auth_token = auth_wallet.oidc_client.access_token
|
||||
if not clear_auth_token:
|
||||
raise Exception("No clear auth token available.")
|
||||
|
||||
amounts = auth_wallet.mint_info.bat_max_mint * [1] # 1 AUTH tokens
|
||||
secrets = [hashlib.sha256(os.urandom(32)).hexdigest() for _ in amounts]
|
||||
rs = [PrivateKey(privkey=os.urandom(32), raw=True) for _ in amounts]
|
||||
outputs, rs = auth_wallet._construct_outputs(amounts, secrets, rs)
|
||||
|
||||
# should fail because of invalid CAT
|
||||
await assert_err(
|
||||
auth_wallet.blind_mint_blind_auth(clear_auth_token, outputs),
|
||||
ClearAuthFailedError.detail,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not settings.mint_require_auth,
|
||||
reason="settings.mint_require_auth is False",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_auth_invoice(wallet: Wallet):
|
||||
# should fail, wallet error
|
||||
await assert_err(wallet.mint_quote(10, Unit.sat), "Mint requires blind auth")
|
||||
|
||||
auth_wallet = await WalletAuth.with_db(
|
||||
url=wallet.url,
|
||||
db=wallet.db.db_location,
|
||||
username="asd@asd.com",
|
||||
password="asdasd",
|
||||
)
|
||||
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
|
||||
assert requires_auth
|
||||
|
||||
await auth_wallet.load_proofs()
|
||||
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
|
||||
|
||||
wallet.auth_db = auth_wallet.db
|
||||
wallet.auth_keyset_id = auth_wallet.keyset_id
|
||||
|
||||
# should succeed
|
||||
await wallet.mint_quote(10, Unit.sat)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not settings.mint_require_auth,
|
||||
reason="settings.mint_require_auth is False",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_auth_invoice_invalid_bat(wallet: Wallet):
|
||||
# should fail, wallet error
|
||||
await assert_err(wallet.mint_quote(10, Unit.sat), "Mint requires blind auth")
|
||||
|
||||
auth_wallet = await WalletAuth.with_db(
|
||||
url=wallet.url,
|
||||
db=wallet.db.db_location,
|
||||
username="asd@asd.com",
|
||||
password="asdasd",
|
||||
)
|
||||
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
|
||||
assert requires_auth
|
||||
|
||||
await auth_wallet.load_proofs()
|
||||
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
|
||||
|
||||
# invalidate blind auth proofs
|
||||
for p in auth_wallet.proofs:
|
||||
await auth_wallet.db.execute(
|
||||
f"UPDATE proofs SET secret = '{random_hash()}' WHERE secret = '{p.secret}'"
|
||||
)
|
||||
|
||||
wallet.auth_db = auth_wallet.db
|
||||
wallet.auth_keyset_id = auth_wallet.keyset_id
|
||||
|
||||
# blind auth failed
|
||||
await assert_err(wallet.mint_quote(10, Unit.sat), BlindAuthFailedError.detail)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not settings.mint_require_auth,
|
||||
reason="settings.mint_require_auth is False",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_auth_rate_limit(wallet: Wallet):
|
||||
auth_wallet = await WalletAuth.with_db(
|
||||
url=wallet.url,
|
||||
db=wallet.db.db_location,
|
||||
username="asd@asd.com",
|
||||
password="asdasd",
|
||||
)
|
||||
requires_auth = await auth_wallet.init_auth_wallet(
|
||||
wallet.mint_info, mint_auth_proofs=False
|
||||
)
|
||||
assert requires_auth
|
||||
|
||||
errored = False
|
||||
for _ in range(100):
|
||||
try:
|
||||
await auth_wallet.mint_blind_auth()
|
||||
except Exception as e:
|
||||
assert BlindAuthRateLimitExceededError.detail in str(e)
|
||||
errored = True
|
||||
break
|
||||
|
||||
assert errored
|
||||
|
||||
# should have minted at least twice
|
||||
assert len(auth_wallet.proofs) > auth_wallet.mint_info.bat_max_mint
|
||||
@@ -54,7 +54,7 @@ async def init_wallet():
|
||||
wallet = await Wallet.with_db(
|
||||
url=settings.mint_url,
|
||||
db="test_data/test_cli_wallet",
|
||||
name="wallet",
|
||||
name="test_cli_wallet",
|
||||
)
|
||||
await wallet.load_proofs()
|
||||
return wallet
|
||||
@@ -411,7 +411,7 @@ def test_wallets(cli_prefix):
|
||||
print("WALLETS")
|
||||
# on github this is empty
|
||||
if len(result.output):
|
||||
assert "test_cli_wallet" in result.output
|
||||
assert "wallet" in result.output
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
@@ -474,7 +474,7 @@ def test_send_too_much(mint, cli_prefix):
|
||||
cli,
|
||||
[*cli_prefix, "send", "100000"],
|
||||
)
|
||||
assert "balance too low" in str(result.exception)
|
||||
assert "Balance too low" in str(result.exception)
|
||||
|
||||
|
||||
def test_receive_tokenv3(mint, cli_prefix):
|
||||
|
||||
Reference in New Issue
Block a user