mirror of
https://github.com/aljazceru/nutshell.git
synced 2025-12-20 02:24:20 +01:00
Blind authentication (#675)
* auth server * cleaning up * auth ledger class * class variables -> instance variables * annotations * add models and api route * custom amount and api prefix * add auth db * blind auth token working * jwt working * clean up * JWT works * using openid connect server * use oauth server with password flow * new realm * add keycloak docker * hopefully not garbage * auth works * auth kinda working * fix cli * auth works for send and receive * pass auth_db to Wallet * auth in info * refactor * fix supported * cache mint info * fix settings and endpoints * add description to .env.example * track changes for openid connect client * store mint in db * store credentials * clean up v1_api.py * load mint info into auth wallet * fix first login * authenticate if refresh token fails * clear auth also middleware * use regex * add cli command * pw works * persist keyset amounts * add errors.py * do not start auth server if disabled in config * upadte poetry * disvoery url * fix test * support device code flow * adopt latest spec changes * fix code flow * mint max bat dynamic * mypy ignore * fix test * do not serialize amount in authproof * all auth flows working * fix tests * submodule * refactor * test * dont sleep * test * add wallet auth tests * test differently * test only keycloak for now * fix creds * daemon * fix test * install everything * install jinja * delete wallet for every test * auth: use global rate limiter * test auth rate limit * keycloak hostname * move keycloak test data * reactivate all tests * add readme * load proofs * remove unused code * remove unused code * implement change suggestions by ok300 * add error codes * test errors
This commit is contained in:
18
.env.example
18
.env.example
@@ -133,3 +133,21 @@ LIGHTNING_RESERVE_FEE_MIN=2000
|
|||||||
# MINT_GLOBAL_RATE_LIMIT_PER_MINUTE=60
|
# MINT_GLOBAL_RATE_LIMIT_PER_MINUTE=60
|
||||||
# Determines the number of transactions (mint, melt, swap) allowed per minute per IP
|
# Determines the number of transactions (mint, melt, swap) allowed per minute per IP
|
||||||
# MINT_TRANSACTION_RATE_LIMIT_PER_MINUTE=20
|
# MINT_TRANSACTION_RATE_LIMIT_PER_MINUTE=20
|
||||||
|
|
||||||
|
# Authentication
|
||||||
|
# These settings allow you to enable blind authentication to limit the user of your mint to a group of authenticated users.
|
||||||
|
# To use this, you need to set up an OpenID Connect provider like Keycloak, Auth0, or Hydra.
|
||||||
|
# - Add the client ID "cashu-client"
|
||||||
|
# - Enable the ES256 and RS256 algorithms for this client
|
||||||
|
# - If you want to use the authorization flow, you must add the redirect URI "http://localhost:33388/callback".
|
||||||
|
# - To support other wallets, use the well-known list of allowed redirect URIs here: https://...TODO.md
|
||||||
|
#
|
||||||
|
# Turn on authentication
|
||||||
|
# MINT_REQUIRE_AUTH=TRUE
|
||||||
|
# OpenID Connect discovery URL of the authentication provider
|
||||||
|
# MINT_AUTH_OICD_DISCOVERY_URL=http://localhost:8080/realms/nutshell/.well-known/openid-configuration
|
||||||
|
# MINT_AUTH_OICD_CLIENT_ID=cashu-client
|
||||||
|
# Number of authentication attempts allowed per minute per user
|
||||||
|
# MINT_AUTH_RATE_LIMIT_PER_MINUTE=5
|
||||||
|
# Maximum number of blind auth tokens per authentication request
|
||||||
|
# MINT_AUTH_MAX_BLIND_TOKENS=100
|
||||||
|
|||||||
15
.github/workflows/ci.yml
vendored
15
.github/workflows/ci.yml
vendored
@@ -42,6 +42,21 @@ jobs:
|
|||||||
poetry-version: ${{ matrix.poetry-version }}
|
poetry-version: ${{ matrix.poetry-version }}
|
||||||
mint-database: ${{ matrix.mint-database }}
|
mint-database: ${{ matrix.mint-database }}
|
||||||
|
|
||||||
|
tests_keycloak_auth:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: ["3.10"]
|
||||||
|
poetry-version: ["1.8.5"]
|
||||||
|
mint-database: ["./test_data/test_mint", "postgres://cashu:cashu@localhost:5432/cashu"]
|
||||||
|
uses: ./.github/workflows/tests_keycloak_auth.yml
|
||||||
|
with:
|
||||||
|
os: ${{ matrix.os }}
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
poetry-version: ${{ matrix.poetry-version }}
|
||||||
|
mint-database: ${{ matrix.mint-database }}
|
||||||
|
|
||||||
regtest:
|
regtest:
|
||||||
uses: ./.github/workflows/regtest.yml
|
uses: ./.github/workflows/regtest.yml
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
77
.github/workflows/tests_keycloak_auth.yml
vendored
Normal file
77
.github/workflows/tests_keycloak_auth.yml
vendored
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
name: tests_keycloak
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
python-version:
|
||||||
|
default: "3.10.4"
|
||||||
|
type: string
|
||||||
|
poetry-version:
|
||||||
|
default: "1.8.5"
|
||||||
|
type: string
|
||||||
|
mint-database:
|
||||||
|
default: ""
|
||||||
|
type: string
|
||||||
|
os:
|
||||||
|
default: "ubuntu-latest"
|
||||||
|
type: string
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
poetry:
|
||||||
|
name: Auth tests with Keycloak (db ${{ inputs.mint-database }})
|
||||||
|
runs-on: ${{ inputs.os }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Prepare environment
|
||||||
|
uses: ./.github/actions/prepare
|
||||||
|
with:
|
||||||
|
python-version: ${{ inputs.python-version }}
|
||||||
|
poetry-version: ${{ inputs.poetry-version }}
|
||||||
|
|
||||||
|
- name: Start PostgreSQL service
|
||||||
|
if: contains(inputs.mint-database, 'postgres')
|
||||||
|
run: |
|
||||||
|
docker run -d --name postgres \
|
||||||
|
-e POSTGRES_USER=cashu \
|
||||||
|
-e POSTGRES_PASSWORD=cashu \
|
||||||
|
-e POSTGRES_DB=cashu \
|
||||||
|
-p 5432:5432 postgres:16.4
|
||||||
|
until docker exec postgres pg_isready; do sleep 1; done
|
||||||
|
|
||||||
|
- name: Prepare environment
|
||||||
|
uses: ./.github/actions/prepare
|
||||||
|
with:
|
||||||
|
python-version: ${{ inputs.python-version }}
|
||||||
|
poetry-version: ${{ inputs.poetry-version }}
|
||||||
|
|
||||||
|
- name: Start Keycloak with Backup
|
||||||
|
run: |
|
||||||
|
docker compose -f tests/keycloak_data/docker-compose-restore.yml up -d
|
||||||
|
until docker logs $(docker ps -q --filter "ancestor=quay.io/keycloak/keycloak:25.0.6") | grep "Keycloak 25.0.6 on JVM (powered by Quarkus 3.8.5) started"; do sleep 1; done
|
||||||
|
|
||||||
|
- name: Verify Keycloak Import
|
||||||
|
run: |
|
||||||
|
docker logs $(docker ps -q --filter "ancestor=quay.io/keycloak/keycloak:25.0.6") | grep "Imported"
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
env:
|
||||||
|
MINT_BACKEND_BOLT11_SAT: FakeWallet
|
||||||
|
WALLET_NAME: test_wallet
|
||||||
|
MINT_HOST: localhost
|
||||||
|
MINT_PORT: 3337
|
||||||
|
MINT_TEST_DATABASE: ${{ inputs.mint-database }}
|
||||||
|
TOR: false
|
||||||
|
MINT_REQUIRE_AUTH: TRUE
|
||||||
|
MINT_AUTH_OICD_DISCOVERY_URL: http://localhost:8080/realms/nutshell/.well-known/openid-configuration
|
||||||
|
MINT_AUTH_OICD_CLIENT_ID: cashu-client
|
||||||
|
run: |
|
||||||
|
poetry run pytest tests/test_wallet_auth.py -v --cov=mint --cov-report=xml
|
||||||
|
|
||||||
|
- name: Stop and clean up Docker Compose
|
||||||
|
run: |
|
||||||
|
docker compose -f tests/keycloak_data/docker-compose-restore.yml down
|
||||||
|
|
||||||
|
- name: Upload coverage to Codecov
|
||||||
|
uses: codecov/codecov-action@v3
|
||||||
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from sqlite3 import Row
|
from sqlite3 import Row
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, ClassVar, Dict, List, Optional, Union
|
||||||
|
|
||||||
import cbor2
|
import cbor2
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -19,7 +19,7 @@ from .crypto.aes import AESCipher
|
|||||||
from .crypto.b_dhke import hash_to_curve
|
from .crypto.b_dhke import hash_to_curve
|
||||||
from .crypto.keys import (
|
from .crypto.keys import (
|
||||||
derive_keys,
|
derive_keys,
|
||||||
derive_keys_sha256,
|
derive_keys_deprecated_pre_0_15,
|
||||||
derive_keyset_id,
|
derive_keyset_id,
|
||||||
derive_keyset_id_deprecated,
|
derive_keyset_id_deprecated,
|
||||||
derive_pubkeys,
|
derive_pubkeys,
|
||||||
@@ -173,6 +173,9 @@ class Proof(BaseModel):
|
|||||||
|
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
def to_base64(self):
|
||||||
|
return base64.b64encode(cbor2.dumps(self.to_dict(include_dleq=True))).decode()
|
||||||
|
|
||||||
def to_dict_no_dleq(self):
|
def to_dict_no_dleq(self):
|
||||||
# dictionary without the fields that don't need to be send to Carol
|
# dictionary without the fields that don't need to be send to Carol
|
||||||
return dict(id=self.id, amount=self.amount, secret=self.secret, C=self.C)
|
return dict(id=self.id, amount=self.amount, secret=self.secret, C=self.C)
|
||||||
@@ -541,6 +544,7 @@ class Unit(Enum):
|
|||||||
usd = 2
|
usd = 2
|
||||||
eur = 3
|
eur = 3
|
||||||
btc = 4
|
btc = 4
|
||||||
|
auth = 999
|
||||||
|
|
||||||
def str(self, amount: int) -> str:
|
def str(self, amount: int) -> str:
|
||||||
if self == Unit.sat:
|
if self == Unit.sat:
|
||||||
@@ -553,6 +557,8 @@ class Unit(Enum):
|
|||||||
return f"{amount/100:.2f} EUR"
|
return f"{amount/100:.2f} EUR"
|
||||||
elif self == Unit.btc:
|
elif self == Unit.btc:
|
||||||
return f"{amount/1e8:.8f} BTC"
|
return f"{amount/1e8:.8f} BTC"
|
||||||
|
elif self == Unit.auth:
|
||||||
|
return f"{amount} AUTH"
|
||||||
else:
|
else:
|
||||||
raise Exception("Invalid unit")
|
raise Exception("Invalid unit")
|
||||||
|
|
||||||
@@ -724,6 +730,7 @@ class MintKeyset:
|
|||||||
valid_to: Optional[str] = None
|
valid_to: Optional[str] = None
|
||||||
first_seen: Optional[str] = None
|
first_seen: Optional[str] = None
|
||||||
version: Optional[str] = None
|
version: Optional[str] = None
|
||||||
|
amounts: List[int]
|
||||||
|
|
||||||
duplicate_keyset_id: Optional[str] = None # BACKWARDS COMPATIBILITY < 0.15.0
|
duplicate_keyset_id: Optional[str] = None # BACKWARDS COMPATIBILITY < 0.15.0
|
||||||
|
|
||||||
@@ -734,6 +741,7 @@ class MintKeyset:
|
|||||||
seed: Optional[str] = None,
|
seed: Optional[str] = None,
|
||||||
encrypted_seed: Optional[str] = None,
|
encrypted_seed: Optional[str] = None,
|
||||||
seed_encryption_method: Optional[str] = None,
|
seed_encryption_method: Optional[str] = None,
|
||||||
|
amounts: Optional[List[int]] = None,
|
||||||
valid_from: Optional[str] = None,
|
valid_from: Optional[str] = None,
|
||||||
valid_to: Optional[str] = None,
|
valid_to: Optional[str] = None,
|
||||||
first_seen: Optional[str] = None,
|
first_seen: Optional[str] = None,
|
||||||
@@ -762,6 +770,12 @@ class MintKeyset:
|
|||||||
|
|
||||||
assert self.seed, "seed not set"
|
assert self.seed, "seed not set"
|
||||||
|
|
||||||
|
if amounts:
|
||||||
|
self.amounts = amounts
|
||||||
|
else:
|
||||||
|
# use 2^n amounts by default
|
||||||
|
self.amounts = [2**i for i in range(settings.max_order)]
|
||||||
|
|
||||||
self.id = id
|
self.id = id
|
||||||
self.valid_from = valid_from
|
self.valid_from = valid_from
|
||||||
self.valid_to = valid_to
|
self.valid_to = valid_to
|
||||||
@@ -805,6 +819,24 @@ class MintKeyset:
|
|||||||
|
|
||||||
logger.trace(f"Loaded keyset id: {self.id} ({self.unit.name})")
|
logger.trace(f"Loaded keyset id: {self.id} ({self.unit.name})")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_row(cls, row: Row):
|
||||||
|
return cls(
|
||||||
|
id=row["id"],
|
||||||
|
derivation_path=row["derivation_path"],
|
||||||
|
seed=row["seed"],
|
||||||
|
encrypted_seed=row["encrypted_seed"],
|
||||||
|
seed_encryption_method=row["seed_encryption_method"],
|
||||||
|
valid_from=row["valid_from"],
|
||||||
|
valid_to=row["valid_to"],
|
||||||
|
first_seen=row["first_seen"],
|
||||||
|
active=row["active"],
|
||||||
|
unit=row["unit"],
|
||||||
|
version=row["version"],
|
||||||
|
input_fee_ppk=row["input_fee_ppk"],
|
||||||
|
amounts=json.loads(row["amounts"]),
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def public_keys_hex(self) -> Dict[int, str]:
|
def public_keys_hex(self) -> Dict[int, str]:
|
||||||
assert self.public_keys, "public keys not set"
|
assert self.public_keys, "public keys not set"
|
||||||
@@ -830,23 +862,27 @@ class MintKeyset:
|
|||||||
self.private_keys = derive_keys_backwards_compatible_insecure_pre_0_12(
|
self.private_keys = derive_keys_backwards_compatible_insecure_pre_0_12(
|
||||||
self.seed, self.derivation_path
|
self.seed, self.derivation_path
|
||||||
)
|
)
|
||||||
self.public_keys = derive_pubkeys(self.private_keys) # type: ignore
|
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore
|
||||||
logger.trace(
|
logger.trace(
|
||||||
f"WARNING: Using weak key derivation for keyset {self.id} (backwards"
|
f"WARNING: Using weak key derivation for keyset {self.id} (backwards"
|
||||||
" compatibility < 0.12)"
|
" compatibility < 0.12)"
|
||||||
)
|
)
|
||||||
self.id = id_in_db or derive_keyset_id_deprecated(self.public_keys) # type: ignore
|
self.id = id_in_db or derive_keyset_id_deprecated(self.public_keys) # type: ignore
|
||||||
elif self.version_tuple < (0, 15):
|
elif self.version_tuple < (0, 15):
|
||||||
self.private_keys = derive_keys_sha256(self.seed, self.derivation_path)
|
self.private_keys = derive_keys_deprecated_pre_0_15(
|
||||||
|
self.seed, self.amounts, self.derivation_path
|
||||||
|
)
|
||||||
logger.trace(
|
logger.trace(
|
||||||
f"WARNING: Using non-bip32 derivation for keyset {self.id} (backwards"
|
f"WARNING: Using non-bip32 derivation for keyset {self.id} (backwards"
|
||||||
" compatibility < 0.15)"
|
" compatibility < 0.15)"
|
||||||
)
|
)
|
||||||
self.public_keys = derive_pubkeys(self.private_keys) # type: ignore
|
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore
|
||||||
self.id = id_in_db or derive_keyset_id_deprecated(self.public_keys) # type: ignore
|
self.id = id_in_db or derive_keyset_id_deprecated(self.public_keys) # type: ignore
|
||||||
else:
|
else:
|
||||||
self.private_keys = derive_keys(self.seed, self.derivation_path)
|
self.private_keys = derive_keys(
|
||||||
self.public_keys = derive_pubkeys(self.private_keys) # type: ignore
|
self.seed, self.derivation_path, self.amounts
|
||||||
|
)
|
||||||
|
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore
|
||||||
self.id = id_in_db or derive_keyset_id(self.public_keys) # type: ignore
|
self.id = id_in_db or derive_keyset_id(self.public_keys) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@@ -1254,3 +1290,48 @@ class TokenV4(Token):
|
|||||||
t=[TokenV4Token(**t) for t in token_dict["t"]],
|
t=[TokenV4Token(**t) for t in token_dict["t"]],
|
||||||
d=token_dict.get("d", None),
|
d=token_dict.get("d", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthProof(BaseModel):
|
||||||
|
"""
|
||||||
|
Blind authentication token
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
secret: str # secret
|
||||||
|
C: str # signature
|
||||||
|
amount: int = 1 # default amount
|
||||||
|
|
||||||
|
prefix: ClassVar[str] = "authA"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_proof(cls, proof: Proof):
|
||||||
|
return cls(id=proof.id, secret=proof.secret, C=proof.C)
|
||||||
|
|
||||||
|
def to_base64(self):
|
||||||
|
serialize_dict = self.dict()
|
||||||
|
serialize_dict.pop("amount", None)
|
||||||
|
return (
|
||||||
|
self.prefix + base64.b64encode(json.dumps(serialize_dict).encode()).decode()
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_base64(cls, base64_str: str):
|
||||||
|
assert base64_str.startswith(cls.prefix), Exception(
|
||||||
|
f"Token prefix not valid. Expected {cls.prefix}."
|
||||||
|
)
|
||||||
|
base64_str = base64_str[len(cls.prefix) :]
|
||||||
|
return cls.parse_obj(json.loads(base64.b64decode(base64_str).decode()))
|
||||||
|
|
||||||
|
def to_proof(self):
|
||||||
|
return Proof(id=self.id, secret=self.secret, C=self.C, amount=self.amount)
|
||||||
|
|
||||||
|
|
||||||
|
class WalletMint(BaseModel):
|
||||||
|
url: str
|
||||||
|
info: str
|
||||||
|
updated: Optional[str] = None
|
||||||
|
access_token: Optional[str] = None
|
||||||
|
refresh_token: Optional[str] = None
|
||||||
|
username: Optional[str] = None
|
||||||
|
password: Optional[str] = None
|
||||||
|
|||||||
@@ -1,54 +1,56 @@
|
|||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import random
|
import random
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
from bip32 import BIP32
|
from bip32 import BIP32
|
||||||
|
|
||||||
from ..settings import settings
|
|
||||||
from .secp import PrivateKey, PublicKey
|
from .secp import PrivateKey, PublicKey
|
||||||
|
|
||||||
|
|
||||||
def derive_keys(mnemonic: str, derivation_path: str):
|
def derive_keys(mnemonic: str, derivation_path: str, amounts: List[int]):
|
||||||
"""
|
"""
|
||||||
Deterministic derivation of keys for 2^n values.
|
Deterministic derivation of keys for 2^n values.
|
||||||
"""
|
"""
|
||||||
bip32 = BIP32.from_seed(mnemonic.encode())
|
bip32 = BIP32.from_seed(mnemonic.encode())
|
||||||
orders_str = [f"/{i}'" for i in range(settings.max_order)]
|
orders_str = [f"/{a}'" for a in range(len(amounts))]
|
||||||
return {
|
return {
|
||||||
2**i: PrivateKey(
|
a: PrivateKey(
|
||||||
bip32.get_privkey_from_path(derivation_path + orders_str[i]),
|
bip32.get_privkey_from_path(derivation_path + orders_str[i]),
|
||||||
raw=True,
|
raw=True,
|
||||||
)
|
)
|
||||||
for i in range(settings.max_order)
|
for i, a in enumerate(amounts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def derive_keys_sha256(seed: str, derivation_path: str = ""):
|
def derive_keys_deprecated_pre_0_15(
|
||||||
|
seed: str, amounts: List[int], derivation_path: str = ""
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Deterministic derivation of keys for 2^n values.
|
Deterministic derivation of keys for 2^n values.
|
||||||
TODO: Implement BIP32.
|
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
2**i: PrivateKey(
|
a: PrivateKey(
|
||||||
hashlib.sha256((seed + derivation_path + str(i)).encode("utf-8")).digest()[
|
hashlib.sha256((seed + derivation_path + str(i)).encode("utf-8")).digest()[
|
||||||
:32
|
:32
|
||||||
],
|
],
|
||||||
raw=True,
|
raw=True,
|
||||||
)
|
)
|
||||||
for i in range(settings.max_order)
|
for i, a in enumerate(amounts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def derive_pubkey(seed: str):
|
def derive_pubkey(seed: str) -> PublicKey:
|
||||||
return PrivateKey(
|
pubkey = PrivateKey(
|
||||||
hashlib.sha256((seed).encode("utf-8")).digest()[:32],
|
hashlib.sha256((seed).encode("utf-8")).digest()[:32],
|
||||||
raw=True,
|
raw=True,
|
||||||
).pubkey
|
).pubkey
|
||||||
|
assert pubkey
|
||||||
|
return pubkey
|
||||||
|
|
||||||
|
|
||||||
def derive_pubkeys(keys: Dict[int, PrivateKey]):
|
def derive_pubkeys(keys: Dict[int, PrivateKey], amounts: List[int]):
|
||||||
return {amt: keys[amt].pubkey for amt in [2**i for i in range(settings.max_order)]}
|
return {amt: keys[amt].pubkey for amt in amounts}
|
||||||
|
|
||||||
|
|
||||||
def derive_keyset_id(keys: Dict[int, PublicKey]):
|
def derive_keyset_id(keys: Dict[int, PublicKey]):
|
||||||
|
|||||||
@@ -339,11 +339,21 @@ class Database(Compat):
|
|||||||
raise Exception("Timestamp is None")
|
raise Exception("Timestamp is None")
|
||||||
return timestamp
|
return timestamp
|
||||||
|
|
||||||
def to_timestamp(self, timestamp_str: str) -> Union[str, datetime.datetime]:
|
def to_timestamp(
|
||||||
if not timestamp_str:
|
self, timestamp: Union[str, datetime.datetime]
|
||||||
timestamp_str = self.timestamp_now_str()
|
) -> Union[str, datetime.datetime]:
|
||||||
|
if not timestamp:
|
||||||
|
timestamp = self.timestamp_now_str()
|
||||||
if self.type in {POSTGRES, COCKROACH}:
|
if self.type in {POSTGRES, COCKROACH}:
|
||||||
return datetime.datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
|
# return datetime.datetime
|
||||||
|
if isinstance(timestamp, datetime.datetime):
|
||||||
|
return timestamp
|
||||||
|
elif isinstance(timestamp, str):
|
||||||
|
return datetime.datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
|
||||||
elif self.type == SQLITE:
|
elif self.type == SQLITE:
|
||||||
return timestamp_str
|
# return str
|
||||||
|
if isinstance(timestamp, datetime.datetime):
|
||||||
|
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
elif isinstance(timestamp, str):
|
||||||
|
return timestamp
|
||||||
return "<nothing>"
|
return "<nothing>"
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class NotAllowedError(CashuError):
|
|||||||
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
||||||
super().__init__(detail or self.detail, code=code or self.code)
|
super().__init__(detail or self.detail, code=code or self.code)
|
||||||
|
|
||||||
|
|
||||||
class OutputsAlreadySignedError(CashuError):
|
class OutputsAlreadySignedError(CashuError):
|
||||||
detail = "outputs have already been signed before."
|
detail = "outputs have already been signed before."
|
||||||
code = 10002
|
code = 10002
|
||||||
@@ -25,6 +26,7 @@ class OutputsAlreadySignedError(CashuError):
|
|||||||
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
||||||
super().__init__(detail or self.detail, code=code or self.code)
|
super().__init__(detail or self.detail, code=code or self.code)
|
||||||
|
|
||||||
|
|
||||||
class InvalidProofsError(CashuError):
|
class InvalidProofsError(CashuError):
|
||||||
detail = "proofs could not be verified"
|
detail = "proofs could not be verified"
|
||||||
code = 10003
|
code = 10003
|
||||||
@@ -32,6 +34,7 @@ class InvalidProofsError(CashuError):
|
|||||||
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
||||||
super().__init__(detail or self.detail, code=code or self.code)
|
super().__init__(detail or self.detail, code=code or self.code)
|
||||||
|
|
||||||
|
|
||||||
class TransactionError(CashuError):
|
class TransactionError(CashuError):
|
||||||
detail = "transaction error"
|
detail = "transaction error"
|
||||||
code = 11000
|
code = 11000
|
||||||
@@ -76,12 +79,14 @@ class TransactionUnitError(TransactionError):
|
|||||||
def __init__(self, detail):
|
def __init__(self, detail):
|
||||||
super().__init__(detail, code=self.code)
|
super().__init__(detail, code=self.code)
|
||||||
|
|
||||||
|
|
||||||
class TransactionAmountExceedsLimitError(TransactionError):
|
class TransactionAmountExceedsLimitError(TransactionError):
|
||||||
code = 11006
|
code = 11006
|
||||||
|
|
||||||
def __init__(self, detail):
|
def __init__(self, detail):
|
||||||
super().__init__(detail, code=self.code)
|
super().__init__(detail, code=self.code)
|
||||||
|
|
||||||
|
|
||||||
class KeysetError(CashuError):
|
class KeysetError(CashuError):
|
||||||
detail = "keyset error"
|
detail = "keyset error"
|
||||||
code = 12000
|
code = 12000
|
||||||
@@ -113,7 +118,7 @@ class QuoteNotPaidError(CashuError):
|
|||||||
code = 20001
|
code = 20001
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(self.detail, code=2001)
|
super().__init__(self.detail, code=self.code)
|
||||||
|
|
||||||
|
|
||||||
class QuoteSignatureInvalidError(CashuError):
|
class QuoteSignatureInvalidError(CashuError):
|
||||||
@@ -121,7 +126,7 @@ class QuoteSignatureInvalidError(CashuError):
|
|||||||
code = 20008
|
code = 20008
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(self.detail, code=20008)
|
super().__init__(self.detail, code=self.code)
|
||||||
|
|
||||||
|
|
||||||
class QuoteRequiresPubkeyError(CashuError):
|
class QuoteRequiresPubkeyError(CashuError):
|
||||||
@@ -129,4 +134,52 @@ class QuoteRequiresPubkeyError(CashuError):
|
|||||||
code = 20009
|
code = 20009
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(self.detail, code=20009)
|
super().__init__(self.detail, code=self.code)
|
||||||
|
|
||||||
|
|
||||||
|
class ClearAuthRequiredError(CashuError):
|
||||||
|
detail = "Endpoint requires clear auth"
|
||||||
|
code = 80001
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(self.detail, code=self.code)
|
||||||
|
|
||||||
|
|
||||||
|
class ClearAuthFailedError(CashuError):
|
||||||
|
detail = "Clear authentication failed"
|
||||||
|
code = 80002
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(self.detail, code=self.code)
|
||||||
|
|
||||||
|
|
||||||
|
class BlindAuthRequiredError(CashuError):
|
||||||
|
detail = "Endpoint requires blind auth"
|
||||||
|
code = 81001
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(self.detail, code=self.code)
|
||||||
|
|
||||||
|
|
||||||
|
class BlindAuthFailedError(CashuError):
|
||||||
|
detail = "Blind authentication failed"
|
||||||
|
code = 81002
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(self.detail, code=self.code)
|
||||||
|
|
||||||
|
|
||||||
|
class BlindAuthAmountExceededError(CashuError):
|
||||||
|
detail = "Maximum blind auth amount exceeded"
|
||||||
|
code = 81003
|
||||||
|
|
||||||
|
def __init__(self, detail: Optional[str] = None):
|
||||||
|
super().__init__(detail or self.detail, code=self.code)
|
||||||
|
|
||||||
|
|
||||||
|
class BlindAuthRateLimitExceededError(CashuError):
|
||||||
|
detail = "Blind auth token mint rate limit exceeded"
|
||||||
|
code = 81004
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(self.detail, code=self.code)
|
||||||
|
|||||||
125
cashu/core/mint_info.py
Normal file
125
cashu/core/mint_info.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .base import Method, Unit
|
||||||
|
from .models import MintInfoContact, MintInfoProtectedEndpoint, Nut15MppSupport
|
||||||
|
from .nuts.nuts import BLIND_AUTH_NUT, CLEAR_AUTH_NUT, MPP_NUT, WEBSOCKETS_NUT
|
||||||
|
|
||||||
|
|
||||||
|
class MintInfo(BaseModel):
|
||||||
|
name: Optional[str]
|
||||||
|
pubkey: Optional[str]
|
||||||
|
version: Optional[str]
|
||||||
|
description: Optional[str]
|
||||||
|
description_long: Optional[str]
|
||||||
|
contact: Optional[List[MintInfoContact]]
|
||||||
|
motd: Optional[str]
|
||||||
|
icon_url: Optional[str]
|
||||||
|
time: Optional[int]
|
||||||
|
nuts: Dict[int, Any]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.name} ({self.description})"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_str(cls, json_str: str):
|
||||||
|
return cls.parse_obj(json.loads(json_str))
|
||||||
|
|
||||||
|
def supports_nut(self, nut: int) -> bool:
|
||||||
|
if self.nuts is None:
|
||||||
|
return False
|
||||||
|
return nut in self.nuts
|
||||||
|
|
||||||
|
def supports_mpp(self, method: str, unit: Unit) -> bool:
|
||||||
|
if not self.nuts:
|
||||||
|
return False
|
||||||
|
nut_15 = self.nuts.get(MPP_NUT)
|
||||||
|
if not nut_15 or not self.supports_nut(MPP_NUT) or not nut_15.get("methods"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for entry in nut_15["methods"]:
|
||||||
|
entry_obj = Nut15MppSupport.parse_obj(entry)
|
||||||
|
if entry_obj.method == method and entry_obj.unit == unit.name:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def supports_websocket_mint_quote(self, method: Method, unit: Unit) -> bool:
|
||||||
|
if not self.nuts or not self.supports_nut(WEBSOCKETS_NUT):
|
||||||
|
return False
|
||||||
|
websocket_settings = self.nuts[WEBSOCKETS_NUT]
|
||||||
|
if not websocket_settings or "supported" not in websocket_settings:
|
||||||
|
return False
|
||||||
|
websocket_supported = websocket_settings["supported"]
|
||||||
|
for entry in websocket_supported:
|
||||||
|
if entry["method"] == method.name and entry["unit"] == unit.name:
|
||||||
|
if "bolt11_mint_quote" in entry["commands"]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def requires_clear_auth(self) -> bool:
|
||||||
|
return self.supports_nut(CLEAR_AUTH_NUT)
|
||||||
|
|
||||||
|
def oidc_discovery_url(self) -> str:
|
||||||
|
if not self.requires_clear_auth():
|
||||||
|
raise Exception(
|
||||||
|
"Could not get OIDC discovery URL. Mint info does not support clear auth."
|
||||||
|
)
|
||||||
|
return self.nuts[CLEAR_AUTH_NUT]["openid_discovery"]
|
||||||
|
|
||||||
|
def oidc_client_id(self) -> str:
|
||||||
|
if not self.requires_clear_auth():
|
||||||
|
raise Exception(
|
||||||
|
"Could not get client_id. Mint info does not support clear auth."
|
||||||
|
)
|
||||||
|
return self.nuts[CLEAR_AUTH_NUT]["client_id"]
|
||||||
|
|
||||||
|
def required_clear_auth_endpoints(self) -> List[MintInfoProtectedEndpoint]:
|
||||||
|
if not self.requires_clear_auth():
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
MintInfoProtectedEndpoint.parse_obj(e)
|
||||||
|
for e in self.nuts[CLEAR_AUTH_NUT]["protected_endpoints"]
|
||||||
|
]
|
||||||
|
|
||||||
|
def requires_clear_auth_path(self, method: str, path: str) -> bool:
|
||||||
|
if not self.requires_clear_auth():
|
||||||
|
return False
|
||||||
|
path = "/" + path if not path.startswith("/") else path
|
||||||
|
for endpoint in self.required_clear_auth_endpoints():
|
||||||
|
if method == endpoint.method and re.match(endpoint.path, path):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def requires_blind_auth(self) -> bool:
|
||||||
|
return self.supports_nut(BLIND_AUTH_NUT)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bat_max_mint(self) -> int:
|
||||||
|
if not self.requires_blind_auth():
|
||||||
|
raise Exception(
|
||||||
|
"Could not get max mint. Mint info does not support blind auth."
|
||||||
|
)
|
||||||
|
if not self.nuts[BLIND_AUTH_NUT].get("bat_max_mint"):
|
||||||
|
raise Exception("Could not get max mint. bat_max_mint not set.")
|
||||||
|
return self.nuts[BLIND_AUTH_NUT]["bat_max_mint"]
|
||||||
|
|
||||||
|
def required_blind_auth_paths(self) -> List[MintInfoProtectedEndpoint]:
|
||||||
|
if not self.requires_blind_auth():
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
MintInfoProtectedEndpoint.parse_obj(e)
|
||||||
|
for e in self.nuts[BLIND_AUTH_NUT]["protected_endpoints"]
|
||||||
|
]
|
||||||
|
|
||||||
|
def requires_blind_auth_path(self, method: str, path: str) -> bool:
|
||||||
|
if not self.requires_blind_auth():
|
||||||
|
return False
|
||||||
|
path = "/" + path if not path.startswith("/") else path
|
||||||
|
for endpoint in self.required_blind_auth_paths():
|
||||||
|
if method == endpoint.method and re.match(endpoint.path, path):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
@@ -38,6 +38,11 @@ class MintInfoContact(BaseModel):
|
|||||||
info: str
|
info: str
|
||||||
|
|
||||||
|
|
||||||
|
class MintInfoProtectedEndpoint(BaseModel):
|
||||||
|
method: str
|
||||||
|
path: str
|
||||||
|
|
||||||
|
|
||||||
class GetInfoResponse(BaseModel):
|
class GetInfoResponse(BaseModel):
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
pubkey: Optional[str] = None
|
pubkey: Optional[str] = None
|
||||||
@@ -57,7 +62,7 @@ class GetInfoResponse(BaseModel):
|
|||||||
# BEGIN DEPRECATED: NUT-06 contact field change
|
# BEGIN DEPRECATED: NUT-06 contact field change
|
||||||
# NUT-06 PR: https://github.com/cashubtc/nuts/pull/117
|
# NUT-06 PR: https://github.com/cashubtc/nuts/pull/117
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def preprocess_deprecated_contact_field(cls, values):
|
def preprocess_deprecated_contact_field(cls, values: dict):
|
||||||
if "contact" in values and values["contact"]:
|
if "contact" in values and values["contact"]:
|
||||||
if isinstance(values["contact"][0], list):
|
if isinstance(values["contact"][0], list):
|
||||||
values["contact"] = [
|
values["contact"] = [
|
||||||
@@ -346,3 +351,16 @@ class PostRestoreResponse(BaseModel):
|
|||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
self.promises = self.signatures
|
self.promises = self.signatures
|
||||||
|
|
||||||
|
|
||||||
|
# ------- API: BLIND AUTH -------
|
||||||
|
class PostAuthBlindMintRequest(BaseModel):
|
||||||
|
outputs: List[BlindedMessage] = Field(
|
||||||
|
...,
|
||||||
|
max_items=settings.mint_max_request_length,
|
||||||
|
description="Blinded messages for creating blind auth tokens.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PostAuthBlindMintResponse(BaseModel):
|
||||||
|
signatures: List[BlindedSignature] = []
|
||||||
|
|||||||
@@ -14,3 +14,5 @@ MPP_NUT = 15
|
|||||||
WEBSOCKETS_NUT = 17
|
WEBSOCKETS_NUT = 17
|
||||||
CACHE_NUT = 19
|
CACHE_NUT = 19
|
||||||
MINT_QUOTE_SIGNATURE_NUT = 20
|
MINT_QUOTE_SIGNATURE_NUT = 20
|
||||||
|
CLEAR_AUTH_NUT = 21
|
||||||
|
BLIND_AUTH_NUT = 22
|
||||||
|
|||||||
@@ -68,6 +68,8 @@ class MintSettings(CashuSettings):
|
|||||||
class MintDeprecationFlags(MintSettings):
|
class MintDeprecationFlags(MintSettings):
|
||||||
mint_inactivate_base64_keysets: bool = Field(default=False)
|
mint_inactivate_base64_keysets: bool = Field(default=False)
|
||||||
|
|
||||||
|
auth_database: str = Field(default="data/mint")
|
||||||
|
|
||||||
|
|
||||||
class MintBackends(MintSettings):
|
class MintBackends(MintSettings):
|
||||||
mint_lightning_backend: str = Field(default="") # deprecated
|
mint_lightning_backend: str = Field(default="") # deprecated
|
||||||
@@ -231,6 +233,27 @@ class CoreLightningRestFundingSource(MintSettings):
|
|||||||
mint_corelightning_rest_cert: Optional[str] = Field(default=None)
|
mint_corelightning_rest_cert: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthSettings(MintSettings):
|
||||||
|
mint_require_auth: bool = Field(default=False)
|
||||||
|
mint_auth_oicd_discovery_url: Optional[str] = Field(default=None)
|
||||||
|
mint_auth_oicd_client_id: str = Field(default="cashu-client")
|
||||||
|
mint_auth_rate_limit_per_minute: int = Field(
|
||||||
|
default=5,
|
||||||
|
title="Auth rate limit per minute",
|
||||||
|
description="Number of requests a user can authenticate per minute.",
|
||||||
|
)
|
||||||
|
mint_auth_max_blind_tokens: int = Field(default=100, gt=0)
|
||||||
|
mint_require_clear_auth_paths: List[List[str]] = [
|
||||||
|
["POST", "/v1/auth/blind/mint"],
|
||||||
|
]
|
||||||
|
mint_require_blind_auth_paths: List[List[str]] = [
|
||||||
|
["POST", "/v1/swap"],
|
||||||
|
["POST", "/v1/mint/quote/bolt11"],
|
||||||
|
["POST", "/v1/mint/bolt11"],
|
||||||
|
["POST", "/v1/melt/bolt11"],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class MintRedisCache(MintSettings):
|
class MintRedisCache(MintSettings):
|
||||||
mint_redis_cache_enabled: bool = Field(default=False)
|
mint_redis_cache_enabled: bool = Field(default=False)
|
||||||
mint_redis_cache_url: Optional[str] = Field(default=None)
|
mint_redis_cache_url: Optional[str] = Field(default=None)
|
||||||
@@ -246,6 +269,7 @@ class Settings(
|
|||||||
FakeWalletSettings,
|
FakeWalletSettings,
|
||||||
MintLimits,
|
MintLimits,
|
||||||
MintBackends,
|
MintBackends,
|
||||||
|
AuthSettings,
|
||||||
MintRedisCache,
|
MintRedisCache,
|
||||||
MintDeprecationFlags,
|
MintDeprecationFlags,
|
||||||
MintSettings,
|
MintSettings,
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ from starlette.requests import Request
|
|||||||
from ..core.errors import CashuError
|
from ..core.errors import CashuError
|
||||||
from ..core.logging import configure_logger
|
from ..core.logging import configure_logger
|
||||||
from ..core.settings import settings
|
from ..core.settings import settings
|
||||||
|
from .auth.router import auth_router
|
||||||
from .router import redis, router
|
from .router import redis, router
|
||||||
from .router_deprecated import router_deprecated
|
from .router_deprecated import router_deprecated
|
||||||
from .startup import shutdown_mint as shutdown_mint_init
|
from .startup import shutdown_mint, start_auth, start_mint
|
||||||
from .startup import start_mint_init
|
|
||||||
|
|
||||||
if settings.debug_profiling:
|
if settings.debug_profiling:
|
||||||
pass
|
pass
|
||||||
@@ -29,7 +29,9 @@ from .middleware import add_middlewares, request_validation_exception_handler
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
|
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
|
||||||
await start_mint_init()
|
await start_mint()
|
||||||
|
if settings.mint_require_auth:
|
||||||
|
await start_auth()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@@ -38,7 +40,7 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
|
|||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
await redis.disconnect()
|
await redis.disconnect()
|
||||||
await shutdown_mint_init()
|
await shutdown_mint()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("CancelledError during shutdown, shutting down forcefully")
|
logger.info("CancelledError during shutdown, shutting down forcefully")
|
||||||
|
|
||||||
@@ -110,3 +112,6 @@ if settings.debug_mint_only_deprecated:
|
|||||||
else:
|
else:
|
||||||
app.include_router(router=router, tags=["Mint"])
|
app.include_router(router=router, tags=["Mint"])
|
||||||
app.include_router(router=router_deprecated, tags=["Deprecated"], deprecated=True)
|
app.include_router(router=router_deprecated, tags=["Deprecated"], deprecated=True)
|
||||||
|
|
||||||
|
if settings.mint_require_auth:
|
||||||
|
app.include_router(auth_router, tags=["Auth"])
|
||||||
|
|||||||
13
cashu/mint/auth/base.py
Normal file
13
cashu/mint/auth/base.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class User:
|
||||||
|
id: str
|
||||||
|
last_access: Optional[datetime.datetime]
|
||||||
|
|
||||||
|
def __init__(self, id: str, last_access: Optional[datetime.datetime] = None):
|
||||||
|
self.id = id
|
||||||
|
if isinstance(last_access, int):
|
||||||
|
last_access = datetime.datetime.fromtimestamp(last_access)
|
||||||
|
self.last_access = last_access
|
||||||
669
cashu/mint/auth/crud.py
Normal file
669
cashu/mint/auth/crud.py
Normal file
@@ -0,0 +1,669 @@
|
|||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from ...core.base import (
|
||||||
|
BlindedSignature,
|
||||||
|
MeltQuote,
|
||||||
|
MintKeyset,
|
||||||
|
MintQuote,
|
||||||
|
Proof,
|
||||||
|
)
|
||||||
|
from ...core.db import (
|
||||||
|
Connection,
|
||||||
|
Database,
|
||||||
|
)
|
||||||
|
from .base import User
|
||||||
|
|
||||||
|
|
||||||
|
class AuthLedgerCrud(ABC):
|
||||||
|
"""
|
||||||
|
Database interface for Nutshell auth ledger.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_user(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
user: User,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_user(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
user_id: str,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> Optional[User]: ...
|
||||||
|
|
||||||
|
async def update_user(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
user_id: str,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_keyset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
id: str = "",
|
||||||
|
derivation_path: str = "",
|
||||||
|
seed: str = "",
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> List[MintKeyset]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_proofs_used(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
Ys: List[str],
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> List[Proof]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def invalidate_proof(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
proof: Proof,
|
||||||
|
quote_id: Optional[str] = None,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_proofs_pending(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
Ys: List[str],
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> List[Proof]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def set_proof_pending(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
proof: Proof,
|
||||||
|
quote_id: Optional[str] = None,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def unset_proof_pending(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
proof: Proof,
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def store_keyset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
keyset: MintKeyset,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def store_promise(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
amount: int,
|
||||||
|
b_: str,
|
||||||
|
c_: str,
|
||||||
|
id: str,
|
||||||
|
e: str = "",
|
||||||
|
s: str = "",
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_promise(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
b_: str,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> Optional[BlindedSignature]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_promises(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
b_s: List[str],
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> List[BlindedSignature]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class AuthLedgerCrudSqlite(AuthLedgerCrud):
|
||||||
|
"""Implementation of AuthLedgerCrud for sqlite.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
AuthLedgerCrud (ABC): Abstract base class for AuthLedgerCrud.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def create_user(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
user: User,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
f"""
|
||||||
|
INSERT INTO {db.table_with_schema('users')}
|
||||||
|
(id)
|
||||||
|
VALUES (:id)
|
||||||
|
""",
|
||||||
|
{"id": user.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_user(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
user_id: str,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> Optional[User]:
|
||||||
|
row = await (conn or db).fetchone(
|
||||||
|
f"""
|
||||||
|
SELECT * from {db.table_with_schema('users')}
|
||||||
|
WHERE id = :user_id
|
||||||
|
""",
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return User(**row) if row else None
|
||||||
|
|
||||||
|
async def update_user(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
user_id: str,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
f"""
|
||||||
|
UPDATE {db.table_with_schema('users')}
|
||||||
|
SET last_access = :last_access
|
||||||
|
WHERE id = :user_id
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"last_access": db.to_timestamp(db.timestamp_now_str()),
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def store_promise(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
amount: int,
|
||||||
|
b_: str,
|
||||||
|
c_: str,
|
||||||
|
id: str,
|
||||||
|
e: str = "",
|
||||||
|
s: str = "",
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
f"""
|
||||||
|
INSERT INTO {db.table_with_schema('promises')}
|
||||||
|
(amount, b_, c_, dleq_e, dleq_s, id, created)
|
||||||
|
VALUES (:amount, :b_, :c_, :dleq_e, :dleq_s, :id, :created)
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"amount": amount,
|
||||||
|
"b_": b_,
|
||||||
|
"c_": c_,
|
||||||
|
"dleq_e": e,
|
||||||
|
"dleq_s": s,
|
||||||
|
"id": id,
|
||||||
|
"created": db.to_timestamp(db.timestamp_now_str()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_promise(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
b_: str,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> Optional[BlindedSignature]:
|
||||||
|
row = await (conn or db).fetchone(
|
||||||
|
f"""
|
||||||
|
SELECT * from {db.table_with_schema('promises')}
|
||||||
|
WHERE b_ = :b_
|
||||||
|
""",
|
||||||
|
{"b_": str(b_)},
|
||||||
|
)
|
||||||
|
return BlindedSignature.from_row(row) if row else None
|
||||||
|
|
||||||
|
async def get_promises(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
b_s: List[str],
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> List[BlindedSignature]:
|
||||||
|
rows = await (conn or db).fetchall(
|
||||||
|
f"""
|
||||||
|
SELECT * from {db.table_with_schema('promises')}
|
||||||
|
WHERE b_ IN ({','.join([':b_' + str(i) for i in range(len(b_s))])})
|
||||||
|
""",
|
||||||
|
{f"b_{i}": b_s[i] for i in range(len(b_s))},
|
||||||
|
)
|
||||||
|
return [BlindedSignature.from_row(r) for r in rows] if rows else []
|
||||||
|
|
||||||
|
async def invalidate_proof(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
proof: Proof,
|
||||||
|
quote_id: Optional[str] = None,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
f"""
|
||||||
|
INSERT INTO {db.table_with_schema('proofs_used')}
|
||||||
|
(amount, c, secret, y, id, witness, created, melt_quote)
|
||||||
|
VALUES (:amount, :c, :secret, :y, :id, :witness, :created, :melt_quote)
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"amount": proof.amount,
|
||||||
|
"c": proof.C,
|
||||||
|
"secret": proof.secret,
|
||||||
|
"y": proof.Y,
|
||||||
|
"id": proof.id,
|
||||||
|
"witness": proof.witness,
|
||||||
|
"created": db.to_timestamp(db.timestamp_now_str()),
|
||||||
|
"melt_quote": quote_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_all_melt_quotes_from_pending_proofs(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> List[MeltQuote]:
|
||||||
|
rows = await (conn or db).fetchall(
|
||||||
|
f"""
|
||||||
|
SELECT * from {db.table_with_schema('melt_quotes')} WHERE quote in (SELECT DISTINCT melt_quote FROM {db.table_with_schema('proofs_pending')})
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return [MeltQuote.from_row(r) for r in rows]
|
||||||
|
|
||||||
|
async def get_pending_proofs_for_quote(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
quote_id: str,
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> List[Proof]:
|
||||||
|
rows = await (conn or db).fetchall(
|
||||||
|
f"""
|
||||||
|
SELECT * from {db.table_with_schema('proofs_pending')}
|
||||||
|
WHERE melt_quote = :quote_id
|
||||||
|
""",
|
||||||
|
{"quote_id": quote_id},
|
||||||
|
)
|
||||||
|
return [Proof(**r) for r in rows]
|
||||||
|
|
||||||
|
async def get_proofs_pending(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
Ys: List[str],
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> List[Proof]:
|
||||||
|
query = f"""
|
||||||
|
SELECT * from {db.table_with_schema('proofs_pending')}
|
||||||
|
WHERE y IN ({','.join([':y_' + str(i) for i in range(len(Ys))])})
|
||||||
|
"""
|
||||||
|
values = {f"y_{i}": Ys[i] for i in range(len(Ys))}
|
||||||
|
rows = await (conn or db).fetchall(query, values)
|
||||||
|
return [Proof(**r) for r in rows]
|
||||||
|
|
||||||
|
async def set_proof_pending(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
proof: Proof,
|
||||||
|
quote_id: Optional[str] = None,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
f"""
|
||||||
|
INSERT INTO {db.table_with_schema('proofs_pending')}
|
||||||
|
(amount, c, secret, y, id, witness, created, melt_quote)
|
||||||
|
VALUES (:amount, :c, :secret, :y, :id, :witness, :created, :melt_quote)
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"amount": proof.amount,
|
||||||
|
"c": proof.C,
|
||||||
|
"secret": proof.secret,
|
||||||
|
"y": proof.Y,
|
||||||
|
"id": proof.id,
|
||||||
|
"witness": proof.witness,
|
||||||
|
"created": db.to_timestamp(db.timestamp_now_str()),
|
||||||
|
"melt_quote": quote_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unset_proof_pending(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
proof: Proof,
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
f"""
|
||||||
|
DELETE FROM {db.table_with_schema('proofs_pending')}
|
||||||
|
WHERE secret = :secret
|
||||||
|
""",
|
||||||
|
{"secret": proof.secret},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def store_mint_quote(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
quote: MintQuote,
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
f"""
|
||||||
|
INSERT INTO {db.table_with_schema('mint_quotes')}
|
||||||
|
(quote, method, request, checking_id, unit, amount, issued, paid, state, created_time, paid_time)
|
||||||
|
VALUES (:quote, :method, :request, :checking_id, :unit, :amount, :issued, :paid, :state, :created_time, :paid_time)
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"quote": quote.quote,
|
||||||
|
"method": quote.method,
|
||||||
|
"request": quote.request,
|
||||||
|
"checking_id": quote.checking_id,
|
||||||
|
"unit": quote.unit,
|
||||||
|
"amount": quote.amount,
|
||||||
|
"issued": quote.issued,
|
||||||
|
"paid": quote.paid,
|
||||||
|
"state": quote.state.name,
|
||||||
|
"created_time": db.to_timestamp(
|
||||||
|
db.timestamp_from_seconds(quote.created_time) or ""
|
||||||
|
),
|
||||||
|
"paid_time": db.to_timestamp(
|
||||||
|
db.timestamp_from_seconds(quote.paid_time) or ""
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_mint_quote(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
quote_id: Optional[str] = None,
|
||||||
|
checking_id: Optional[str] = None,
|
||||||
|
request: Optional[str] = None,
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> Optional[MintQuote]:
|
||||||
|
clauses = []
|
||||||
|
values: Dict[str, Any] = {}
|
||||||
|
if quote_id:
|
||||||
|
clauses.append("quote = :quote_id")
|
||||||
|
values["quote_id"] = quote_id
|
||||||
|
if checking_id:
|
||||||
|
clauses.append("checking_id = :checking_id")
|
||||||
|
values["checking_id"] = checking_id
|
||||||
|
if request:
|
||||||
|
clauses.append("request = :request")
|
||||||
|
values["request"] = request
|
||||||
|
if not any(clauses):
|
||||||
|
raise ValueError("No search criteria")
|
||||||
|
where = f"WHERE {' AND '.join(clauses)}"
|
||||||
|
row = await (conn or db).fetchone(
|
||||||
|
f"""
|
||||||
|
SELECT * from {db.table_with_schema('mint_quotes')}
|
||||||
|
{where}
|
||||||
|
""",
|
||||||
|
values,
|
||||||
|
)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return MintQuote.from_row(row) if row else None
|
||||||
|
|
||||||
|
async def get_mint_quote_by_request(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
request: str,
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> Optional[MintQuote]:
|
||||||
|
row = await (conn or db).fetchone(
|
||||||
|
f"""
|
||||||
|
SELECT * from {db.table_with_schema('mint_quotes')}
|
||||||
|
WHERE request = :request
|
||||||
|
""",
|
||||||
|
{"request": request},
|
||||||
|
)
|
||||||
|
return MintQuote.from_row(row) if row else None
|
||||||
|
|
||||||
|
async def update_mint_quote(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
quote: MintQuote,
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
f"UPDATE {db.table_with_schema('mint_quotes')} SET issued = :issued, paid = :paid, state = :state, paid_time = :paid_time WHERE quote = :quote",
|
||||||
|
{
|
||||||
|
"issued": quote.issued,
|
||||||
|
"paid": quote.paid,
|
||||||
|
"state": quote.state.name,
|
||||||
|
"paid_time": db.to_timestamp(
|
||||||
|
db.timestamp_from_seconds(quote.paid_time) or ""
|
||||||
|
),
|
||||||
|
"quote": quote.quote,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def store_melt_quote(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
quote: MeltQuote,
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
f"""
|
||||||
|
INSERT INTO {db.table_with_schema('melt_quotes')}
|
||||||
|
(quote, method, request, checking_id, unit, amount, fee_reserve, paid, state, created_time, paid_time, fee_paid, proof, change, expiry)
|
||||||
|
VALUES (:quote, :method, :request, :checking_id, :unit, :amount, :fee_reserve, :paid, :state, :created_time, :paid_time, :fee_paid, :proof, :change, :expiry)
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"quote": quote.quote,
|
||||||
|
"method": quote.method,
|
||||||
|
"request": quote.request,
|
||||||
|
"checking_id": quote.checking_id,
|
||||||
|
"unit": quote.unit,
|
||||||
|
"amount": quote.amount,
|
||||||
|
"fee_reserve": quote.fee_reserve or 0,
|
||||||
|
"paid": quote.paid,
|
||||||
|
"state": quote.state.name,
|
||||||
|
"created_time": db.to_timestamp(
|
||||||
|
db.timestamp_from_seconds(quote.created_time) or ""
|
||||||
|
),
|
||||||
|
"paid_time": db.to_timestamp(
|
||||||
|
db.timestamp_from_seconds(quote.paid_time) or ""
|
||||||
|
),
|
||||||
|
"fee_paid": quote.fee_paid,
|
||||||
|
"proof": quote.payment_preimage,
|
||||||
|
"change": json.dumps(quote.change) if quote.change else None,
|
||||||
|
"expiry": db.to_timestamp(
|
||||||
|
db.timestamp_from_seconds(quote.expiry) or ""
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_melt_quote(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
quote_id: Optional[str] = None,
|
||||||
|
checking_id: Optional[str] = None,
|
||||||
|
request: Optional[str] = None,
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> Optional[MeltQuote]:
|
||||||
|
clauses = []
|
||||||
|
values: Dict[str, Any] = {}
|
||||||
|
if quote_id:
|
||||||
|
clauses.append("quote = :quote_id")
|
||||||
|
values["quote_id"] = quote_id
|
||||||
|
if checking_id:
|
||||||
|
clauses.append("checking_id = :checking_id")
|
||||||
|
values["checking_id"] = checking_id
|
||||||
|
if request:
|
||||||
|
clauses.append("request = :request")
|
||||||
|
values["request"] = request
|
||||||
|
if not any(clauses):
|
||||||
|
raise ValueError("No search criteria")
|
||||||
|
where = f"WHERE {' AND '.join(clauses)}"
|
||||||
|
row = await (conn or db).fetchone(
|
||||||
|
f"""
|
||||||
|
SELECT * from {db.table_with_schema('melt_quotes')}
|
||||||
|
{where}
|
||||||
|
""",
|
||||||
|
values,
|
||||||
|
)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return MeltQuote.from_row(row) if row else None
|
||||||
|
|
||||||
|
async def update_melt_quote(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
quote: MeltQuote,
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
f"""
|
||||||
|
UPDATE {db.table_with_schema('melt_quotes')} SET paid = :paid, state = :state, fee_paid = :fee_paid, paid_time = :paid_time, proof = :proof, change = :change WHERE quote = :quote
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"paid": quote.paid,
|
||||||
|
"state": quote.state.name,
|
||||||
|
"fee_paid": quote.fee_paid,
|
||||||
|
"paid_time": db.to_timestamp(
|
||||||
|
db.timestamp_from_seconds(quote.paid_time) or ""
|
||||||
|
),
|
||||||
|
"proof": quote.payment_preimage,
|
||||||
|
"change": (
|
||||||
|
json.dumps([s.dict() for s in quote.change])
|
||||||
|
if quote.change
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"quote": quote.quote,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def store_keyset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
keyset: MintKeyset,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
f"""
|
||||||
|
INSERT INTO {db.table_with_schema('keysets')}
|
||||||
|
(id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk)
|
||||||
|
VALUES (:id, :seed, :encrypted_seed, :seed_encryption_method, :derivation_path, :valid_from, :valid_to, :first_seen, :active, :version, :unit, :input_fee_ppk)
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"id": keyset.id,
|
||||||
|
"seed": keyset.seed,
|
||||||
|
"encrypted_seed": keyset.encrypted_seed,
|
||||||
|
"seed_encryption_method": keyset.seed_encryption_method,
|
||||||
|
"derivation_path": keyset.derivation_path,
|
||||||
|
"valid_from": db.to_timestamp(
|
||||||
|
keyset.valid_from or db.timestamp_now_str()
|
||||||
|
),
|
||||||
|
"valid_to": db.to_timestamp(keyset.valid_to or db.timestamp_now_str()),
|
||||||
|
"first_seen": db.to_timestamp(
|
||||||
|
keyset.first_seen or db.timestamp_now_str()
|
||||||
|
),
|
||||||
|
"active": True,
|
||||||
|
"version": keyset.version,
|
||||||
|
"unit": keyset.unit.name,
|
||||||
|
"input_fee_ppk": keyset.input_fee_ppk,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_keyset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: Database,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
derivation_path: Optional[str] = None,
|
||||||
|
seed: Optional[str] = None,
|
||||||
|
unit: Optional[str] = None,
|
||||||
|
active: Optional[bool] = None,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> List[MintKeyset]:
|
||||||
|
clauses = []
|
||||||
|
values: Dict = {}
|
||||||
|
if active is not None:
|
||||||
|
clauses.append("active = :active")
|
||||||
|
values["active"] = active
|
||||||
|
if id is not None:
|
||||||
|
clauses.append("id = :id")
|
||||||
|
values["id"] = id
|
||||||
|
if derivation_path is not None:
|
||||||
|
clauses.append("derivation_path = :derivation_path")
|
||||||
|
values["derivation_path"] = derivation_path
|
||||||
|
if seed is not None:
|
||||||
|
clauses.append("seed = :seed")
|
||||||
|
values["seed"] = seed
|
||||||
|
if unit is not None:
|
||||||
|
clauses.append("unit = :unit")
|
||||||
|
values["unit"] = unit
|
||||||
|
where = ""
|
||||||
|
if clauses:
|
||||||
|
where = f"WHERE {' AND '.join(clauses)}"
|
||||||
|
|
||||||
|
rows = await (conn or db).fetchall( # type: ignore
|
||||||
|
f"""
|
||||||
|
SELECT * from {db.table_with_schema('keysets')}
|
||||||
|
{where}
|
||||||
|
""",
|
||||||
|
values,
|
||||||
|
)
|
||||||
|
return [MintKeyset(**row) for row in rows]
|
||||||
|
|
||||||
|
async def get_proofs_used(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
Ys: List[str],
|
||||||
|
db: Database,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> List[Proof]:
|
||||||
|
query = f"""
|
||||||
|
SELECT * from {db.table_with_schema('proofs_used')}
|
||||||
|
WHERE y IN ({','.join([':y_' + str(i) for i in range(len(Ys))])})
|
||||||
|
"""
|
||||||
|
values = {f"y_{i}": Ys[i] for i in range(len(Ys))}
|
||||||
|
rows = await (conn or db).fetchall(query, values)
|
||||||
|
return [Proof(**r) for r in rows] if rows else []
|
||||||
100
cashu/mint/auth/migrations.py
Normal file
100
cashu/mint/auth/migrations.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
from ...core.db import Connection, Database
|
||||||
|
|
||||||
|
|
||||||
|
async def m000_create_migrations_table(conn: Connection):
|
||||||
|
await conn.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {conn.table_with_schema('dbversions')} (
|
||||||
|
db TEXT PRIMARY KEY,
|
||||||
|
version INT NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def m001_initial(db: Database):
|
||||||
|
async with db.connect() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {db.table_with_schema('users')} (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
last_access TIMESTAMP,
|
||||||
|
|
||||||
|
UNIQUE (id)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# columns: (id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk)
|
||||||
|
await conn.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {db.table_with_schema('keysets')} (
|
||||||
|
id TEXT NOT NULL,
|
||||||
|
seed TEXT NOT NULL,
|
||||||
|
encrypted_seed TEXT,
|
||||||
|
seed_encryption_method TEXT,
|
||||||
|
derivation_path TEXT NOT NULL,
|
||||||
|
valid_from TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
|
||||||
|
valid_to TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
|
||||||
|
first_seen TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
|
||||||
|
active BOOL DEFAULT TRUE,
|
||||||
|
version TEXT,
|
||||||
|
unit TEXT NOT NULL,
|
||||||
|
input_fee_ppk INT,
|
||||||
|
amounts TEXT,
|
||||||
|
|
||||||
|
UNIQUE (derivation_path)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
await conn.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {db.table_with_schema('promises')} (
|
||||||
|
id TEXT NOT NULL,
|
||||||
|
amount {db.big_int} NOT NULL,
|
||||||
|
b_ TEXT NOT NULL,
|
||||||
|
c_ TEXT NOT NULL,
|
||||||
|
dleq_e TEXT,
|
||||||
|
dleq_s TEXT,
|
||||||
|
created TIMESTAMP,
|
||||||
|
|
||||||
|
UNIQUE (b_)
|
||||||
|
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {db.table_with_schema('proofs_used')} (
|
||||||
|
id TEXT NOT NULL,
|
||||||
|
amount {db.big_int} NOT NULL,
|
||||||
|
c TEXT NOT NULL,
|
||||||
|
secret TEXT NOT NULL,
|
||||||
|
y TEXT NOT NULL,
|
||||||
|
witness TEXT,
|
||||||
|
created TIMESTAMP,
|
||||||
|
melt_quote TEXT,
|
||||||
|
|
||||||
|
UNIQUE (secret)
|
||||||
|
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
await conn.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {db.table_with_schema('proofs_pending')} (
|
||||||
|
id TEXT NOT NULL,
|
||||||
|
amount {db.big_int} NOT NULL,
|
||||||
|
c TEXT NOT NULL,
|
||||||
|
secret TEXT NOT NULL,
|
||||||
|
y TEXT NOT NULL,
|
||||||
|
witness TEXT,
|
||||||
|
created TIMESTAMP,
|
||||||
|
melt_quote TEXT,
|
||||||
|
|
||||||
|
UNIQUE (secret)
|
||||||
|
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
106
cashu/mint/auth/router.py
Normal file
106
cashu/mint/auth/router.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
from fastapi import APIRouter, Request
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from ...core.errors import KeysetNotFoundError
|
||||||
|
from ...core.models import (
|
||||||
|
KeysetsResponse,
|
||||||
|
KeysetsResponseKeyset,
|
||||||
|
KeysResponse,
|
||||||
|
KeysResponseKeyset,
|
||||||
|
PostAuthBlindMintRequest,
|
||||||
|
PostAuthBlindMintResponse,
|
||||||
|
)
|
||||||
|
from ...mint.startup import auth_ledger
|
||||||
|
|
||||||
|
auth_router: APIRouter = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@auth_router.get(
|
||||||
|
"/v1/auth/blind/keys",
|
||||||
|
name="Mint public keys",
|
||||||
|
summary="Get the public keys of the newest mint keyset",
|
||||||
|
response_description=(
|
||||||
|
"All supported token values their associated public keys for all active keysets"
|
||||||
|
),
|
||||||
|
response_model=KeysResponse,
|
||||||
|
)
|
||||||
|
async def keys():
|
||||||
|
"""This endpoint returns a dictionary of all supported token values of the mint and their associated public key."""
|
||||||
|
logger.trace("> GET /v1/auth/blind/keys")
|
||||||
|
keyset = auth_ledger.keyset
|
||||||
|
keyset_for_response = []
|
||||||
|
for keyset in auth_ledger.keysets.values():
|
||||||
|
if keyset.active:
|
||||||
|
keyset_for_response.append(
|
||||||
|
KeysResponseKeyset(
|
||||||
|
id=keyset.id,
|
||||||
|
unit=keyset.unit.name,
|
||||||
|
keys={k: v for k, v in keyset.public_keys_hex.items()},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return KeysResponse(keysets=keyset_for_response)
|
||||||
|
|
||||||
|
|
||||||
|
@auth_router.get(
|
||||||
|
"/v1/auth/blind/keys/{keyset_id}",
|
||||||
|
name="Keyset public keys",
|
||||||
|
summary="Public keys of a specific keyset",
|
||||||
|
response_description=(
|
||||||
|
"All supported token values of the mint and their associated"
|
||||||
|
" public key for a specific keyset."
|
||||||
|
),
|
||||||
|
response_model=KeysResponse,
|
||||||
|
)
|
||||||
|
async def keyset_keys(keyset_id: str) -> KeysResponse:
|
||||||
|
"""
|
||||||
|
Get the public keys of the mint from a specific keyset id.
|
||||||
|
"""
|
||||||
|
logger.trace(f"> GET /v1/auth/blind/keys/{keyset_id}")
|
||||||
|
|
||||||
|
keyset = auth_ledger.keysets.get(keyset_id)
|
||||||
|
if keyset is None:
|
||||||
|
raise KeysetNotFoundError(keyset_id)
|
||||||
|
|
||||||
|
keyset_for_response = KeysResponseKeyset(
|
||||||
|
id=keyset.id,
|
||||||
|
unit=keyset.unit.name,
|
||||||
|
keys={k: v for k, v in keyset.public_keys_hex.items()},
|
||||||
|
)
|
||||||
|
return KeysResponse(keysets=[keyset_for_response])
|
||||||
|
|
||||||
|
|
||||||
|
@auth_router.get(
|
||||||
|
"/v1/auth/blind/keysets",
|
||||||
|
name="Active keysets",
|
||||||
|
summary="Get all active keyset id of the mind",
|
||||||
|
response_model=KeysetsResponse,
|
||||||
|
response_description="A list of all active keyset ids of the mint.",
|
||||||
|
)
|
||||||
|
async def keysets() -> KeysetsResponse:
|
||||||
|
"""This endpoint returns a list of keysets that the mint currently supports and will accept tokens from."""
|
||||||
|
logger.trace("> GET /v1/auth/blind/keysets")
|
||||||
|
keysets = []
|
||||||
|
for id, keyset in auth_ledger.keysets.items():
|
||||||
|
keysets.append(
|
||||||
|
KeysetsResponseKeyset(
|
||||||
|
id=keyset.id,
|
||||||
|
unit=keyset.unit.name,
|
||||||
|
active=keyset.active,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return KeysetsResponse(keysets=keysets)
|
||||||
|
|
||||||
|
|
||||||
|
@auth_router.post(
|
||||||
|
"/v1/auth/blind/mint",
|
||||||
|
name="Mint blind auth tokens",
|
||||||
|
summary="Mint blind auth tokens for a user.",
|
||||||
|
response_model=PostAuthBlindMintResponse,
|
||||||
|
)
|
||||||
|
async def auth_blind_mint(
|
||||||
|
request_data: PostAuthBlindMintRequest, request: Request
|
||||||
|
) -> PostAuthBlindMintResponse:
|
||||||
|
signatures = await auth_ledger.mint_blind_auth(
|
||||||
|
outputs=request_data.outputs, user=request.state.user
|
||||||
|
)
|
||||||
|
return PostAuthBlindMintResponse(signatures=signatures)
|
||||||
235
cashu/mint/auth/server.py
Normal file
235
cashu/mint/auth/server.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import jwt
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from ...core.base import AuthProof
|
||||||
|
from ...core.db import Database
|
||||||
|
from ...core.errors import (
|
||||||
|
BlindAuthAmountExceededError,
|
||||||
|
BlindAuthFailedError,
|
||||||
|
BlindAuthRateLimitExceededError,
|
||||||
|
ClearAuthFailedError,
|
||||||
|
)
|
||||||
|
from ...core.models import BlindedMessage, BlindedSignature
|
||||||
|
from ...core.settings import settings
|
||||||
|
from ..crud import LedgerCrudSqlite
|
||||||
|
from ..ledger import Ledger
|
||||||
|
from ..limit import assert_limit
|
||||||
|
from .base import User
|
||||||
|
from .crud import AuthLedgerCrud, AuthLedgerCrudSqlite
|
||||||
|
|
||||||
|
|
||||||
|
class AuthLedger(Ledger):
|
||||||
|
auth_crud: AuthLedgerCrud
|
||||||
|
jwks_url: str
|
||||||
|
jwks_client: jwt.PyJWKClient
|
||||||
|
issuer: str
|
||||||
|
oicd_discovery_json: dict
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db: Database,
|
||||||
|
seed: str,
|
||||||
|
seed_decryption_key: Optional[str] = None,
|
||||||
|
derivation_path="",
|
||||||
|
amounts: Optional[List[int]] = None,
|
||||||
|
crud=LedgerCrudSqlite(),
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
db=db,
|
||||||
|
seed=seed,
|
||||||
|
backends=None,
|
||||||
|
seed_decryption_key=seed_decryption_key,
|
||||||
|
derivation_path=derivation_path,
|
||||||
|
crud=crud,
|
||||||
|
amounts=amounts,
|
||||||
|
)
|
||||||
|
self.oicd_discovery_url = settings.mint_auth_oicd_discovery_url or ""
|
||||||
|
|
||||||
|
async def init_auth(self):
|
||||||
|
if not self.oicd_discovery_url:
|
||||||
|
raise Exception("Missing OpenID Connect discovery URL.")
|
||||||
|
logger.info(f"Initializing OpenID Connect: {self.oicd_discovery_url}")
|
||||||
|
self.oicd_discovery_json = self._get_oicd_discovery_json()
|
||||||
|
self.jwks_url = self.oicd_discovery_json["jwks_uri"]
|
||||||
|
self.jwks_client = jwt.PyJWKClient(self.jwks_url)
|
||||||
|
logger.info(f"Getting JWKS from: {self.jwks_url}")
|
||||||
|
self.auth_crud = AuthLedgerCrudSqlite()
|
||||||
|
self.issuer: str = self.oicd_discovery_json["issuer"]
|
||||||
|
logger.info(f"Initialized OpenID Connect: {self.issuer}")
|
||||||
|
|
||||||
|
def _get_oicd_discovery_json(self) -> dict:
|
||||||
|
resp = httpx.get(self.oicd_discovery_url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
def _verify_oicd_issuer(self, clear_auth_token: str) -> None:
|
||||||
|
"""Verify the issuer of the clear-auth token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clear_auth_token (str): JWT token.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: Invalid issuer.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
decoded = jwt.decode(
|
||||||
|
clear_auth_token,
|
||||||
|
options={"verify_signature": False},
|
||||||
|
)
|
||||||
|
issuer = decoded["iss"]
|
||||||
|
if issuer != self.issuer:
|
||||||
|
raise Exception(f"Invalid issuer: {issuer}. Expected: {self.issuer}")
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _verify_decode_jwt(self, clear_auth_token: str) -> Any:
|
||||||
|
"""Verify the clear-auth JWT token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clear_auth_token (str): JWT token.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
jwt.ExpiredSignatureError: Token has expired.
|
||||||
|
jwt.InvalidSignatureError: Invalid signature.
|
||||||
|
jwt.InvalidTokenError: Invalid token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: Decoded JWT.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use PyJWKClient to fetch the appropriate key based on the token's header
|
||||||
|
signing_key = self.jwks_client.get_signing_key_from_jwt(clear_auth_token)
|
||||||
|
decoded = jwt.decode(
|
||||||
|
clear_auth_token,
|
||||||
|
signing_key.key,
|
||||||
|
algorithms=["RS256", "ES256"],
|
||||||
|
options={"verify_aud": False},
|
||||||
|
issuer=self.issuer,
|
||||||
|
)
|
||||||
|
logger.trace(f"Decoded JWT: {decoded}")
|
||||||
|
except jwt.ExpiredSignatureError as e:
|
||||||
|
logger.error("Token has expired")
|
||||||
|
raise e
|
||||||
|
except jwt.InvalidSignatureError as e:
|
||||||
|
logger.error("Invalid signature")
|
||||||
|
raise e
|
||||||
|
except jwt.InvalidTokenError as e:
|
||||||
|
logger.error("Invalid token")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
async def _get_user(self, decoded_token: Any) -> User:
|
||||||
|
"""Get the user from the decoded token. If the user does not exist, create a new one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoded_token (Any): decoded JWT from PyJWT.decode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User: User object
|
||||||
|
"""
|
||||||
|
user_id = decoded_token["sub"]
|
||||||
|
user = await self.auth_crud.get_user(user_id=user_id, db=self.db)
|
||||||
|
if not user:
|
||||||
|
logger.info(f"Creating new user: {user_id}")
|
||||||
|
user = User(id=user_id)
|
||||||
|
await self.auth_crud.create_user(user=user, db=self.db)
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def verify_clear_auth(self, clear_auth_token: str) -> User:
|
||||||
|
"""Verify the clear-auth JWT token and return the user.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
- Token not expired.
|
||||||
|
- Token signature valid.
|
||||||
|
- User exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_token (str): JWT token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User: Authenticated user.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._verify_oicd_issuer(clear_auth_token)
|
||||||
|
decoded = self._verify_decode_jwt(clear_auth_token)
|
||||||
|
user = await self._get_user(decoded)
|
||||||
|
except Exception:
|
||||||
|
raise ClearAuthFailedError()
|
||||||
|
|
||||||
|
logger.info(f"User authenticated: {user.id}")
|
||||||
|
try:
|
||||||
|
assert_limit(user.id)
|
||||||
|
except Exception:
|
||||||
|
raise BlindAuthRateLimitExceededError()
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def mint_blind_auth(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
outputs: List[BlindedMessage],
|
||||||
|
user: User,
|
||||||
|
) -> List[BlindedSignature]:
|
||||||
|
"""Mints auth tokens. Returns a list of promises.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (List[BlindedMessage]): Outputs to sign.
|
||||||
|
user (User): Authenticated user.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: Invalid auth.
|
||||||
|
Exception: Output verification failed.
|
||||||
|
Exception: Output quota exceeded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BlindedSignature]: List of blinded signatures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(outputs) > settings.mint_auth_max_blind_tokens:
|
||||||
|
raise BlindAuthAmountExceededError(
|
||||||
|
f"Too many outputs. You can only mint {settings.mint_auth_max_blind_tokens} tokens."
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._verify_outputs(outputs)
|
||||||
|
promises = await self._generate_promises(outputs)
|
||||||
|
|
||||||
|
# update last_access timestamp of the user
|
||||||
|
await self.auth_crud.update_user(user_id=user.id, db=self.db)
|
||||||
|
|
||||||
|
return promises
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def verify_blind_auth(self, blind_auth_token):
|
||||||
|
"""Wrapper context that puts blind auth tokens into pending list and
|
||||||
|
melts them if the wrapped call succeeds. If it fails, the blind auth
|
||||||
|
token is not invalidated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blind_auth_token (str): Blind auth token.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: Blind auth token validation failed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
proof = AuthProof.from_base64(blind_auth_token).to_proof()
|
||||||
|
await self.verify_inputs_and_outputs(proofs=[proof])
|
||||||
|
await self.db_write._verify_spent_proofs_and_set_pending([proof])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Blind auth error: {e}")
|
||||||
|
raise BlindAuthFailedError()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
await self._invalidate_proofs(proofs=[proof])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Blind auth error: {e}")
|
||||||
|
raise BlindAuthFailedError()
|
||||||
|
finally:
|
||||||
|
await self.db_write._unset_proofs_pending([proof])
|
||||||
@@ -642,8 +642,8 @@ class LedgerCrudSqlite(LedgerCrud):
|
|||||||
await (conn or db).execute(
|
await (conn or db).execute(
|
||||||
f"""
|
f"""
|
||||||
INSERT INTO {db.table_with_schema('keysets')}
|
INSERT INTO {db.table_with_schema('keysets')}
|
||||||
(id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk)
|
(id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk, amounts)
|
||||||
VALUES (:id, :seed, :encrypted_seed, :seed_encryption_method, :derivation_path, :valid_from, :valid_to, :first_seen, :active, :version, :unit, :input_fee_ppk)
|
VALUES (:id, :seed, :encrypted_seed, :seed_encryption_method, :derivation_path, :valid_from, :valid_to, :first_seen, :active, :version, :unit, :input_fee_ppk, :amounts)
|
||||||
""",
|
""",
|
||||||
{
|
{
|
||||||
"id": keyset.id,
|
"id": keyset.id,
|
||||||
@@ -662,6 +662,7 @@ class LedgerCrudSqlite(LedgerCrud):
|
|||||||
"version": keyset.version,
|
"version": keyset.version,
|
||||||
"unit": keyset.unit.name,
|
"unit": keyset.unit.name,
|
||||||
"input_fee_ppk": keyset.input_fee_ppk,
|
"input_fee_ppk": keyset.input_fee_ppk,
|
||||||
|
"amounts": json.dumps(keyset.amounts),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -720,7 +721,7 @@ class LedgerCrudSqlite(LedgerCrud):
|
|||||||
""",
|
""",
|
||||||
values,
|
values,
|
||||||
)
|
)
|
||||||
return [MintKeyset(**row) for row in rows]
|
return [MintKeyset.from_row(row) for row in rows] # type: ignore
|
||||||
|
|
||||||
async def update_keyset(
|
async def update_keyset(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,12 +1,17 @@
|
|||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from ..core.base import Method
|
from ..core.base import Method
|
||||||
|
from ..core.mint_info import MintInfo
|
||||||
from ..core.models import (
|
from ..core.models import (
|
||||||
MeltMethodSetting,
|
MeltMethodSetting,
|
||||||
|
MintInfoContact,
|
||||||
|
MintInfoProtectedEndpoint,
|
||||||
MintMethodSetting,
|
MintMethodSetting,
|
||||||
)
|
)
|
||||||
from ..core.nuts.nuts import (
|
from ..core.nuts.nuts import (
|
||||||
|
BLIND_AUTH_NUT,
|
||||||
CACHE_NUT,
|
CACHE_NUT,
|
||||||
|
CLEAR_AUTH_NUT,
|
||||||
DLEQ_NUT,
|
DLEQ_NUT,
|
||||||
FEE_RETURN_NUT,
|
FEE_RETURN_NUT,
|
||||||
HTLC_NUT,
|
HTLC_NUT,
|
||||||
@@ -21,10 +26,46 @@ from ..core.nuts.nuts import (
|
|||||||
WEBSOCKETS_NUT,
|
WEBSOCKETS_NUT,
|
||||||
)
|
)
|
||||||
from ..core.settings import settings
|
from ..core.settings import settings
|
||||||
from ..mint.protocols import SupportsBackends
|
from ..mint.protocols import SupportsBackends, SupportsPubkey
|
||||||
|
|
||||||
|
_VERSION_PREFIX = "Nutshell"
|
||||||
|
_SUPPORTED = "supported"
|
||||||
|
_METHOD = "method"
|
||||||
|
_UNIT = "unit"
|
||||||
|
_BOLT11 = "bolt11"
|
||||||
|
_MPP = "mpp"
|
||||||
|
_COMMANDS = "commands"
|
||||||
|
_BOLT11_MINT_QUOTE = "bolt11_mint_quote"
|
||||||
|
_BOLT11_MELT_QUOTE = "bolt11_melt_quote"
|
||||||
|
_PROOF_STATE = "proof_state"
|
||||||
|
_PROTECTED_ENDPOINTS = "protected_endpoints"
|
||||||
|
_BAT_MAX_MINT = "bat_max_mint"
|
||||||
|
_OPENID_DISCOVERY = "openid_discovery"
|
||||||
|
_CLIENT_ID = "client_id"
|
||||||
|
|
||||||
|
|
||||||
class LedgerFeatures(SupportsBackends):
|
class LedgerFeatures(SupportsBackends, SupportsPubkey):
|
||||||
|
@property
|
||||||
|
def mint_info(self) -> MintInfo:
|
||||||
|
contact_info = [
|
||||||
|
MintInfoContact(method=m, info=i)
|
||||||
|
for m, i in settings.mint_info_contact
|
||||||
|
if m and i
|
||||||
|
]
|
||||||
|
return MintInfo(
|
||||||
|
name=settings.mint_info_name,
|
||||||
|
pubkey=self.pubkey.serialize().hex() if self.pubkey else None,
|
||||||
|
version=f"{_VERSION_PREFIX}/{settings.version}",
|
||||||
|
description=settings.mint_info_description,
|
||||||
|
description_long=settings.mint_info_description_long,
|
||||||
|
contact=contact_info,
|
||||||
|
nuts=self.mint_features,
|
||||||
|
icon_url=settings.mint_info_icon_url,
|
||||||
|
motd=settings.mint_info_motd,
|
||||||
|
time=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
def mint_features(self) -> Dict[int, Union[List[Any], Dict[str, Any]]]:
|
def mint_features(self) -> Dict[int, Union[List[Any], Dict[str, Any]]]:
|
||||||
mint_features = self.create_mint_features()
|
mint_features = self.create_mint_features()
|
||||||
mint_features = self.add_supported_features(mint_features)
|
mint_features = self.add_supported_features(mint_features)
|
||||||
@@ -100,30 +141,62 @@ class LedgerFeatures(SupportsBackends):
|
|||||||
# specify which websocket features are supported
|
# specify which websocket features are supported
|
||||||
# these two are supported by default
|
# these two are supported by default
|
||||||
websocket_features: Dict[str, List[Dict[str, Union[str, List[str]]]]] = {
|
websocket_features: Dict[str, List[Dict[str, Union[str, List[str]]]]] = {
|
||||||
"supported": []
|
_SUPPORTED: []
|
||||||
}
|
}
|
||||||
# we check the backend to see if "bolt11_mint_quote" is supported as well
|
# we check the backend to see if "bolt11_mint_quote" is supported as well
|
||||||
for method, unit_dict in self.backends.items():
|
for method, unit_dict in self.backends.items():
|
||||||
if method == Method["bolt11"]:
|
if method == Method[_BOLT11]:
|
||||||
for unit in unit_dict.keys():
|
for unit in unit_dict.keys():
|
||||||
websocket_features["supported"].append(
|
websocket_features[_SUPPORTED].append(
|
||||||
{
|
{
|
||||||
"method": method.name,
|
_METHOD: method.name,
|
||||||
"unit": unit.name,
|
_UNIT: unit.name,
|
||||||
"commands": ["bolt11_melt_quote", "proof_state"],
|
_COMMANDS: [_BOLT11_MELT_QUOTE, _PROOF_STATE],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if unit_dict[unit].supports_incoming_payment_stream:
|
if unit_dict[unit].supports_incoming_payment_stream:
|
||||||
supported_features: List[str] = list(
|
supported_features: List[str] = list(
|
||||||
websocket_features["supported"][-1]["commands"]
|
websocket_features[_SUPPORTED][-1][_COMMANDS]
|
||||||
)
|
)
|
||||||
websocket_features["supported"][-1]["commands"] = (
|
websocket_features[_SUPPORTED][-1][_COMMANDS] = (
|
||||||
supported_features + ["bolt11_mint_quote"]
|
supported_features + [_BOLT11_MINT_QUOTE]
|
||||||
)
|
)
|
||||||
|
|
||||||
if websocket_features:
|
if websocket_features:
|
||||||
mint_features[WEBSOCKETS_NUT] = websocket_features
|
mint_features[WEBSOCKETS_NUT] = websocket_features
|
||||||
|
|
||||||
|
# signal authentication features
|
||||||
|
if settings.mint_require_auth:
|
||||||
|
if not settings.mint_auth_oicd_discovery_url:
|
||||||
|
raise Exception(
|
||||||
|
"Missing OpenID Connect discovery URL: MINT_AUTH_OICD_DISCOVERY_URL"
|
||||||
|
)
|
||||||
|
clear_auth_features: Dict[str, Union[bool, str, List[str]]] = {
|
||||||
|
_OPENID_DISCOVERY: settings.mint_auth_oicd_discovery_url,
|
||||||
|
_CLIENT_ID: settings.mint_auth_oicd_client_id,
|
||||||
|
_PROTECTED_ENDPOINTS: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for endpoint in [
|
||||||
|
MintInfoProtectedEndpoint(method=e[0], path=e[1])
|
||||||
|
for e in settings.mint_require_clear_auth_paths
|
||||||
|
]:
|
||||||
|
clear_auth_features[_PROTECTED_ENDPOINTS].append(endpoint.dict()) # type: ignore
|
||||||
|
|
||||||
|
mint_features[CLEAR_AUTH_NUT] = clear_auth_features
|
||||||
|
|
||||||
|
blind_auth_features: Dict[str, Union[bool, int, str, List[str]]] = {
|
||||||
|
_BAT_MAX_MINT: settings.mint_auth_max_blind_tokens,
|
||||||
|
_PROTECTED_ENDPOINTS: [],
|
||||||
|
}
|
||||||
|
for endpoint in [
|
||||||
|
MintInfoProtectedEndpoint(method=e[0], path=e[1])
|
||||||
|
for e in settings.mint_require_blind_auth_paths
|
||||||
|
]:
|
||||||
|
blind_auth_features[_PROTECTED_ENDPOINTS].append(endpoint.dict()) # type: ignore
|
||||||
|
|
||||||
|
mint_features[BLIND_AUTH_NUT] = blind_auth_features
|
||||||
|
|
||||||
return mint_features
|
return mint_features
|
||||||
|
|
||||||
def add_cache_features(
|
def add_cache_features(
|
||||||
|
|||||||
@@ -75,16 +75,26 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
|
|||||||
db_read: DbReadHelper
|
db_read: DbReadHelper
|
||||||
invoice_listener_tasks: List[asyncio.Task] = []
|
invoice_listener_tasks: List[asyncio.Task] = []
|
||||||
disable_melt: bool = False
|
disable_melt: bool = False
|
||||||
|
pubkey: PublicKey
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
db: Database,
|
db: Database,
|
||||||
seed: str,
|
seed: str,
|
||||||
backends: Mapping[Method, Mapping[Unit, LightningBackend]],
|
|
||||||
seed_decryption_key: Optional[str] = None,
|
|
||||||
derivation_path="",
|
derivation_path="",
|
||||||
|
amounts: Optional[List[int]] = None,
|
||||||
|
backends: Optional[Mapping[Method, Mapping[Unit, LightningBackend]]] = None,
|
||||||
|
seed_decryption_key: Optional[str] = None,
|
||||||
crud=LedgerCrudSqlite(),
|
crud=LedgerCrudSqlite(),
|
||||||
):
|
) -> None:
|
||||||
|
self.keysets: Dict[str, MintKeyset] = {}
|
||||||
|
self.backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
|
||||||
|
self.events = LedgerEventManager()
|
||||||
|
self.db_read: DbReadHelper
|
||||||
|
self.locks: Dict[str, asyncio.Lock] = {} # holds multiprocessing locks
|
||||||
|
self.invoice_listener_tasks: List[asyncio.Task] = []
|
||||||
|
|
||||||
if not seed:
|
if not seed:
|
||||||
raise Exception("seed not set")
|
raise Exception("seed not set")
|
||||||
|
|
||||||
@@ -103,24 +113,33 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
|
|||||||
|
|
||||||
self.db = db
|
self.db = db
|
||||||
self.crud = crud
|
self.crud = crud
|
||||||
self.backends = backends
|
|
||||||
|
if backends:
|
||||||
|
self.backends = backends
|
||||||
|
|
||||||
|
if amounts:
|
||||||
|
self.amounts = amounts
|
||||||
|
else:
|
||||||
|
self.amounts = [2**n for n in range(settings.max_order)]
|
||||||
|
|
||||||
self.pubkey = derive_pubkey(self.seed)
|
self.pubkey = derive_pubkey(self.seed)
|
||||||
self.db_read = DbReadHelper(self.db, self.crud)
|
self.db_read = DbReadHelper(self.db, self.crud)
|
||||||
self.db_write = DbWriteHelper(self.db, self.crud, self.events, self.db_read)
|
self.db_write = DbWriteHelper(self.db, self.crud, self.events, self.db_read)
|
||||||
|
|
||||||
# ------- STARTUP -------
|
# ------- STARTUP -------
|
||||||
|
|
||||||
async def startup_ledger(self):
|
async def startup_ledger(self) -> None:
|
||||||
await self._startup_ledger()
|
await self._startup_keysets()
|
||||||
|
await self._check_backends()
|
||||||
await self._check_pending_proofs_and_melt_quotes()
|
await self._check_pending_proofs_and_melt_quotes()
|
||||||
self.invoice_listener_tasks = await self.dispatch_listeners()
|
self.invoice_listener_tasks = await self.dispatch_listeners()
|
||||||
|
|
||||||
async def _startup_ledger(self):
|
async def _startup_keysets(self) -> None:
|
||||||
await self.init_keysets()
|
await self.init_keysets()
|
||||||
|
|
||||||
for derivation_path in settings.mint_derivation_path_list:
|
for derivation_path in settings.mint_derivation_path_list:
|
||||||
await self.activate_keyset(derivation_path=derivation_path)
|
await self.activate_keyset(derivation_path=derivation_path)
|
||||||
|
|
||||||
|
async def _check_backends(self) -> None:
|
||||||
for method in self.backends:
|
for method in self.backends:
|
||||||
for unit in self.backends[method]:
|
for unit in self.backends[method]:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -139,7 +158,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
|
|||||||
|
|
||||||
logger.info(f"Data dir: {settings.cashu_dir}")
|
logger.info(f"Data dir: {settings.cashu_dir}")
|
||||||
|
|
||||||
async def shutdown_ledger(self):
|
async def shutdown_ledger(self) -> None:
|
||||||
await self.db.engine.dispose()
|
await self.db.engine.dispose()
|
||||||
for task in self.invoice_listener_tasks:
|
for task in self.invoice_listener_tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
@@ -169,57 +188,65 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe
|
|||||||
version: Optional[str] = None,
|
version: Optional[str] = None,
|
||||||
autosave=True,
|
autosave=True,
|
||||||
) -> MintKeyset:
|
) -> MintKeyset:
|
||||||
"""Load the keyset for a derivation path if it already exists. If not generate new one and store in the db.
|
"""
|
||||||
|
Load an existing keyset for the specified derivation path or generate a new one if it doesn't exist.
|
||||||
|
Optionally store the newly created keyset in the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
derivation_path (_type_): Derivation path from which the keyset is generated.
|
derivation_path (str): Derivation path for keyset generation.
|
||||||
autosave (bool, optional): Store newly-generated keyset if not already in database. Defaults to True.
|
seed (Optional[str], optional): Seed value. Defaults to None.
|
||||||
|
version (Optional[str], optional): Version identifier. Defaults to None.
|
||||||
|
autosave (bool, optional): Whether to store the keyset if newly created. Defaults to True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MintKeyset: Keyset
|
MintKeyset: The activated keyset.
|
||||||
"""
|
"""
|
||||||
if not derivation_path:
|
if not derivation_path:
|
||||||
raise Exception("derivation path not set")
|
raise ValueError("Derivation path must be provided.")
|
||||||
|
|
||||||
seed = seed or self.seed
|
seed = seed or self.seed
|
||||||
tmp_keyset_local = MintKeyset(
|
version = version or settings.version
|
||||||
|
# Initialize a temporary keyset to derive the ID
|
||||||
|
temp_keyset = MintKeyset(
|
||||||
seed=seed,
|
seed=seed,
|
||||||
derivation_path=derivation_path,
|
derivation_path=derivation_path,
|
||||||
version=version or settings.version,
|
version=version,
|
||||||
|
amounts=self.amounts,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Activating keyset for derivation path {derivation_path} with id"
|
f"Activating keyset for derivation path '{derivation_path}' with ID '{temp_keyset.id}'."
|
||||||
f" {tmp_keyset_local.id}."
|
|
||||||
)
|
)
|
||||||
# load the keyset from db
|
|
||||||
logger.trace(f"crud: loading keyset for {derivation_path}")
|
# Attempt to retrieve existing keysets from the database
|
||||||
tmp_keysets_local: List[MintKeyset] = await self.crud.get_keyset(
|
existing_keysets: List[MintKeyset] = await self.crud.get_keyset(
|
||||||
id=tmp_keyset_local.id, db=self.db
|
id=temp_keyset.id, db=self.db
|
||||||
)
|
)
|
||||||
logger.trace(f"crud: loaded {len(tmp_keysets_local)} keysets")
|
logger.trace(
|
||||||
if tmp_keysets_local:
|
f"Retrieved {len(existing_keysets)} keyset(s) for derivation path '{derivation_path}'."
|
||||||
# we have a keyset with this derivation path in the database
|
)
|
||||||
keyset = tmp_keysets_local[0]
|
|
||||||
|
if existing_keysets:
|
||||||
|
keyset = existing_keysets[0]
|
||||||
else:
|
else:
|
||||||
# no keyset for this derivation path yet
|
# Create a new keyset if none exists
|
||||||
# we create a new keyset (keys will be generated at instantiation)
|
|
||||||
keyset = MintKeyset(
|
keyset = MintKeyset(
|
||||||
seed=seed or self.seed,
|
seed=seed,
|
||||||
derivation_path=derivation_path,
|
derivation_path=derivation_path,
|
||||||
version=version or settings.version,
|
amounts=self.amounts,
|
||||||
|
version=version,
|
||||||
input_fee_ppk=settings.mint_input_fee_ppk,
|
input_fee_ppk=settings.mint_input_fee_ppk,
|
||||||
)
|
)
|
||||||
logger.debug(f"Generated new keyset {keyset.id}.")
|
logger.debug(f"Generated new keyset with ID '{keyset.id}'.")
|
||||||
|
|
||||||
if autosave:
|
if autosave:
|
||||||
logger.debug(f"crud: storing new keyset {keyset.id}.")
|
logger.debug(f"Storing new keyset with ID '{keyset.id}'.")
|
||||||
await self.crud.store_keyset(keyset=keyset, db=self.db)
|
await self.crud.store_keyset(keyset=keyset, db=self.db)
|
||||||
logger.trace(f"crud: stored new keyset {keyset.id}.")
|
|
||||||
|
|
||||||
# activate this keyset
|
# Activate the keyset
|
||||||
keyset.active = True
|
keyset.active = True
|
||||||
# load the new keyset in self.keysets
|
|
||||||
self.keysets[keyset.id] = keyset
|
self.keysets[keyset.id] = keyset
|
||||||
|
logger.debug(f"Keyset with ID '{keyset.id}' is now active.")
|
||||||
|
|
||||||
logger.debug(f"Loaded keyset {keyset.id}")
|
|
||||||
return keyset
|
return keyset
|
||||||
|
|
||||||
async def init_keysets(self, autosave: bool = True) -> None:
|
async def init_keysets(self, autosave: bool = True) -> None:
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import WebSocket, status
|
from fastapi import WebSocket, status
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from limits import RateLimitItemPerMinute
|
from limits import RateLimitItemPerMinute
|
||||||
@@ -42,26 +44,26 @@ limiter = Limiter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def assert_limit(identifier: str):
|
def assert_limit(identifier: str, limit: Optional[int] = None):
|
||||||
"""Custom rate limit handler that accepts a string identifier
|
"""Custom rate limit handler that accepts a string identifier
|
||||||
and raises an exception if the rate limit is exceeded. Uses the
|
and raises an exception if the rate limit is exceeded. Uses the
|
||||||
setting `mint_transaction_rate_limit_per_minute` for the rate limit.
|
setting `mint_transaction_rate_limit_per_minute` for the rate limit.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identifier (str): The identifier to use for the rate limit. IP address for example.
|
identifier (str): The identifier to use for the rate limit. IP address for example.
|
||||||
|
limit (Optional[int], optional): The rate limit per minute to use. Defaults to None
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If the rate limit is exceeded.
|
Exception: If the rate limit is exceeded.
|
||||||
"""
|
"""
|
||||||
global limiter
|
global limiter
|
||||||
|
limit_per_minute = limit or settings.mint_transaction_rate_limit_per_minute
|
||||||
success = limiter._limiter.hit(
|
success = limiter._limiter.hit(
|
||||||
RateLimitItemPerMinute(settings.mint_transaction_rate_limit_per_minute),
|
RateLimitItemPerMinute(limit_per_minute),
|
||||||
identifier,
|
identifier,
|
||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
logger.warning(
|
logger.warning(f"Rate limit {limit_per_minute}/minute exceeded: {identifier}")
|
||||||
f"Rate limit {settings.mint_transaction_rate_limit_per_minute}/minute exceeded: {identifier}"
|
|
||||||
)
|
|
||||||
raise Exception("Rate limit exceeded")
|
raise Exception("Rate limit exceeded")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ from fastapi.exception_handlers import (
|
|||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import (
|
||||||
|
BaseHTTPMiddleware,
|
||||||
|
RequestResponseEndpoint,
|
||||||
|
)
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from ..core.settings import settings
|
from ..core.settings import settings
|
||||||
@@ -22,6 +25,8 @@ if settings.debug_profiling:
|
|||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
from slowapi.middleware import SlowAPIMiddleware
|
from slowapi.middleware import SlowAPIMiddleware
|
||||||
|
|
||||||
|
from .startup import auth_ledger
|
||||||
|
|
||||||
|
|
||||||
def add_middlewares(app: FastAPI):
|
def add_middlewares(app: FastAPI):
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@@ -42,6 +47,52 @@ def add_middlewares(app: FastAPI):
|
|||||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
app.add_middleware(SlowAPIMiddleware)
|
app.add_middleware(SlowAPIMiddleware)
|
||||||
|
|
||||||
|
if settings.mint_require_auth:
|
||||||
|
app.add_middleware(BlindAuthMiddleware)
|
||||||
|
app.add_middleware(ClearAuthMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
class ClearAuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(
|
||||||
|
self, request: Request, call_next: RequestResponseEndpoint
|
||||||
|
) -> Response:
|
||||||
|
if (
|
||||||
|
settings.mint_require_auth
|
||||||
|
and auth_ledger.mint_info.requires_clear_auth_path(
|
||||||
|
method=request.method, path=request.url.path
|
||||||
|
)
|
||||||
|
):
|
||||||
|
clear_auth_token = request.headers.get("clear-auth")
|
||||||
|
if not clear_auth_token:
|
||||||
|
raise Exception("Missing clear auth token.")
|
||||||
|
try:
|
||||||
|
user = await auth_ledger.verify_clear_auth(
|
||||||
|
clear_auth_token=clear_auth_token
|
||||||
|
)
|
||||||
|
request.state.user = user
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
class BlindAuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(
|
||||||
|
self, request: Request, call_next: RequestResponseEndpoint
|
||||||
|
) -> Response:
|
||||||
|
if (
|
||||||
|
settings.mint_require_auth
|
||||||
|
and auth_ledger.mint_info.requires_blind_auth_path(
|
||||||
|
method=request.method, path=request.url.path
|
||||||
|
)
|
||||||
|
):
|
||||||
|
blind_auth_token = request.headers.get("blind-auth")
|
||||||
|
if not blind_auth_token:
|
||||||
|
raise Exception("Missing blind auth token.")
|
||||||
|
async with auth_ledger.verify_blind_auth(blind_auth_token):
|
||||||
|
return await call_next(request)
|
||||||
|
else:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
async def request_validation_exception_handler(
|
async def request_validation_exception_handler(
|
||||||
request: Request, exc: RequestValidationError
|
request: Request, exc: RequestValidationError
|
||||||
@@ -66,11 +117,11 @@ class CompressionMiddleware(BaseHTTPMiddleware):
|
|||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
||||||
# Handle streaming responses differently
|
# Handle streaming responses differently
|
||||||
if response.__class__.__name__ == 'StreamingResponse':
|
if response.__class__.__name__ == "StreamingResponse":
|
||||||
return response
|
return response
|
||||||
|
|
||||||
response_body = b''
|
response_body = b""
|
||||||
async for chunk in response.body_iterator:
|
async for chunk in response.body_iterator: # type: ignore
|
||||||
response_body += chunk
|
response_body += chunk
|
||||||
|
|
||||||
accept_encoding = request.headers.get("Accept-Encoding", "")
|
accept_encoding = request.headers.get("Accept-Encoding", "")
|
||||||
@@ -97,5 +148,5 @@ class CompressionMiddleware(BaseHTTPMiddleware):
|
|||||||
content=content,
|
content=content,
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
headers=dict(response.headers),
|
headers=dict(response.headers),
|
||||||
media_type=response.media_type
|
media_type=response.media_type,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -858,3 +858,13 @@ async def m024_add_melt_quote_outputs(db: Database):
|
|||||||
ADD COLUMN outputs TEXT DEFAULT NULL
|
ADD COLUMN outputs TEXT DEFAULT NULL
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def m025_add_amounts_to_keysets(db: Database):
|
||||||
|
async with db.connect() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
f"ALTER TABLE {db.table_with_schema('keysets')} ADD COLUMN amounts TEXT"
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
f"UPDATE {db.table_with_schema('keysets')} SET amounts = '[]'"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import Dict, Mapping, Protocol
|
from typing import Dict, Mapping, Protocol
|
||||||
|
|
||||||
from ..core.base import Method, MintKeyset, Unit
|
from ..core.base import Method, MintKeyset, Unit
|
||||||
|
from ..core.crypto.secp import PublicKey
|
||||||
from ..core.db import Database
|
from ..core.db import Database
|
||||||
from ..lightning.base import LightningBackend
|
from ..lightning.base import LightningBackend
|
||||||
from ..mint.crud import LedgerCrud
|
from ..mint.crud import LedgerCrud
|
||||||
@@ -18,6 +19,10 @@ class SupportsBackends(Protocol):
|
|||||||
backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
|
backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class SupportsPubkey(Protocol):
|
||||||
|
pubkey: PublicKey
|
||||||
|
|
||||||
|
|
||||||
class SupportsDb(Protocol):
|
class SupportsDb(Protocol):
|
||||||
db: Database
|
db: Database
|
||||||
db_read: DbReadHelper
|
db_read: DbReadHelper
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from ..core.models import (
|
|||||||
KeysetsResponseKeyset,
|
KeysetsResponseKeyset,
|
||||||
KeysResponse,
|
KeysResponse,
|
||||||
KeysResponseKeyset,
|
KeysResponseKeyset,
|
||||||
MintInfoContact,
|
|
||||||
PostCheckStateRequest,
|
PostCheckStateRequest,
|
||||||
PostCheckStateResponse,
|
PostCheckStateResponse,
|
||||||
PostMeltQuoteRequest,
|
PostMeltQuoteRequest,
|
||||||
@@ -44,23 +43,18 @@ redis = RedisCache()
|
|||||||
)
|
)
|
||||||
async def info() -> GetInfoResponse:
|
async def info() -> GetInfoResponse:
|
||||||
logger.trace("> GET /v1/info")
|
logger.trace("> GET /v1/info")
|
||||||
mint_features = ledger.mint_features()
|
mint_info = ledger.mint_info
|
||||||
contact_info = [
|
|
||||||
MintInfoContact(method=m, info=i)
|
|
||||||
for m, i in settings.mint_info_contact
|
|
||||||
if m and i
|
|
||||||
]
|
|
||||||
return GetInfoResponse(
|
return GetInfoResponse(
|
||||||
name=settings.mint_info_name,
|
name=mint_info.name,
|
||||||
pubkey=ledger.pubkey.serialize().hex() if ledger.pubkey else None,
|
pubkey=mint_info.pubkey,
|
||||||
version=f"Nutshell/{settings.version}",
|
version=mint_info.version,
|
||||||
description=settings.mint_info_description,
|
description=mint_info.description,
|
||||||
description_long=settings.mint_info_description_long,
|
description_long=mint_info.description_long,
|
||||||
contact=contact_info,
|
contact=mint_info.contact,
|
||||||
nuts=mint_features,
|
nuts=mint_info.nuts,
|
||||||
icon_url=settings.mint_info_icon_url,
|
icon_url=mint_info.icon_url,
|
||||||
urls=settings.mint_info_urls,
|
urls=settings.mint_info_urls,
|
||||||
motd=settings.mint_info_motd,
|
motd=mint_info.motd,
|
||||||
time=int(time.time()),
|
time=int(time.time()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ from ..core.db import Database
|
|||||||
from ..core.migrations import migrate_databases
|
from ..core.migrations import migrate_databases
|
||||||
from ..core.settings import settings
|
from ..core.settings import settings
|
||||||
from ..lightning.base import LightningBackend
|
from ..lightning.base import LightningBackend
|
||||||
from ..mint import migrations
|
from ..mint import migrations as mint_migrations
|
||||||
|
from ..mint.auth import migrations as auth_migrations
|
||||||
|
from ..mint.auth.server import AuthLedger
|
||||||
from ..mint.crud import LedgerCrudSqlite
|
from ..mint.crud import LedgerCrudSqlite
|
||||||
from ..mint.ledger import Ledger
|
from ..mint.ledger import Ledger
|
||||||
|
|
||||||
@@ -76,6 +78,15 @@ ledger = Ledger(
|
|||||||
crud=LedgerCrudSqlite(),
|
crud=LedgerCrudSqlite(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# start auth ledger
|
||||||
|
auth_ledger = AuthLedger(
|
||||||
|
db=Database("auth", settings.auth_database),
|
||||||
|
seed="auth seed here",
|
||||||
|
amounts=[1],
|
||||||
|
derivation_path="m/0'/999'/0'",
|
||||||
|
crud=LedgerCrudSqlite(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def rotate_keys(n_seconds=60):
|
async def rotate_keys(n_seconds=60):
|
||||||
"""Rotate keyset epoch every n_seconds.
|
"""Rotate keyset epoch every n_seconds.
|
||||||
@@ -93,8 +104,17 @@ async def rotate_keys(n_seconds=60):
|
|||||||
await asyncio.sleep(n_seconds)
|
await asyncio.sleep(n_seconds)
|
||||||
|
|
||||||
|
|
||||||
async def start_mint_init():
|
async def start_auth():
|
||||||
await migrate_databases(ledger.db, migrations)
|
await migrate_databases(auth_ledger.db, auth_migrations)
|
||||||
|
logger.info("Starting auth ledger.")
|
||||||
|
await auth_ledger.init_keysets()
|
||||||
|
await auth_ledger.init_auth()
|
||||||
|
logger.info("Auth ledger started.")
|
||||||
|
|
||||||
|
|
||||||
|
async def start_mint():
|
||||||
|
await migrate_databases(ledger.db, mint_migrations)
|
||||||
|
logger.info("Starting mint ledger.")
|
||||||
await ledger.startup_ledger()
|
await ledger.startup_ledger()
|
||||||
logger.info("Mint started.")
|
logger.info("Mint started.")
|
||||||
# asyncio.create_task(rotate_keys())
|
# asyncio.create_task(rotate_keys())
|
||||||
|
|||||||
@@ -1,22 +1,14 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Mapping
|
from typing import List
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..core.base import Method, MintQuoteState, Unit
|
from ..core.base import MintQuoteState
|
||||||
from ..core.db import Database
|
|
||||||
from ..lightning.base import LightningBackend
|
from ..lightning.base import LightningBackend
|
||||||
from ..mint.crud import LedgerCrud
|
|
||||||
from .events.events import LedgerEventManager
|
|
||||||
from .protocols import SupportsBackends, SupportsDb, SupportsEvents
|
from .protocols import SupportsBackends, SupportsDb, SupportsEvents
|
||||||
|
|
||||||
|
|
||||||
class LedgerTasks(SupportsDb, SupportsBackends, SupportsEvents):
|
class LedgerTasks(SupportsDb, SupportsBackends, SupportsEvents):
|
||||||
backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
|
|
||||||
db: Database
|
|
||||||
crud: LedgerCrud
|
|
||||||
events: LedgerEventManager
|
|
||||||
|
|
||||||
async def dispatch_listeners(self) -> List[asyncio.Task]:
|
async def dispatch_listeners(self) -> List[asyncio.Task]:
|
||||||
tasks = []
|
tasks = []
|
||||||
for method, unitbackends in self.backends.items():
|
for method, unitbackends in self.backends.items():
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
from typing import List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@@ -6,14 +6,13 @@ from ..core.base import (
|
|||||||
BlindedMessage,
|
BlindedMessage,
|
||||||
BlindedSignature,
|
BlindedSignature,
|
||||||
Method,
|
Method,
|
||||||
MintKeyset,
|
|
||||||
MintQuote,
|
MintQuote,
|
||||||
Proof,
|
Proof,
|
||||||
Unit,
|
Unit,
|
||||||
)
|
)
|
||||||
from ..core.crypto import b_dhke
|
from ..core.crypto import b_dhke
|
||||||
from ..core.crypto.secp import PublicKey
|
from ..core.crypto.secp import PublicKey
|
||||||
from ..core.db import Connection, Database
|
from ..core.db import Connection
|
||||||
from ..core.errors import (
|
from ..core.errors import (
|
||||||
InvalidProofsError,
|
InvalidProofsError,
|
||||||
NoSecretInProofsError,
|
NoSecretInProofsError,
|
||||||
@@ -25,11 +24,7 @@ from ..core.errors import (
|
|||||||
)
|
)
|
||||||
from ..core.nuts import nut20
|
from ..core.nuts import nut20
|
||||||
from ..core.settings import settings
|
from ..core.settings import settings
|
||||||
from ..lightning.base import LightningBackend
|
|
||||||
from ..mint.crud import LedgerCrud
|
|
||||||
from .conditions import LedgerSpendingConditions
|
from .conditions import LedgerSpendingConditions
|
||||||
from .db.read import DbReadHelper
|
|
||||||
from .db.write import DbWriteHelper
|
|
||||||
from .protocols import SupportsBackends, SupportsDb, SupportsKeysets
|
from .protocols import SupportsBackends, SupportsDb, SupportsKeysets
|
||||||
|
|
||||||
|
|
||||||
@@ -38,14 +33,6 @@ class LedgerVerification(
|
|||||||
):
|
):
|
||||||
"""Verification functions for the ledger."""
|
"""Verification functions for the ledger."""
|
||||||
|
|
||||||
keyset: MintKeyset
|
|
||||||
keysets: Dict[str, MintKeyset]
|
|
||||||
crud: LedgerCrud
|
|
||||||
db: Database
|
|
||||||
db_read: DbReadHelper
|
|
||||||
db_write: DbWriteHelper
|
|
||||||
lightning: Dict[Unit, LightningBackend]
|
|
||||||
|
|
||||||
async def verify_inputs_and_outputs(
|
async def verify_inputs_and_outputs(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -55,6 +42,8 @@ class LedgerVerification(
|
|||||||
):
|
):
|
||||||
"""Checks all proofs and outputs for validity.
|
"""Checks all proofs and outputs for validity.
|
||||||
|
|
||||||
|
Warning: Does NOT check if the proofs were already spent. Use `db_write._verify_proofs_spendable` for that.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
proofs (List[Proof]): List of proofs to check.
|
proofs (List[Proof]): List of proofs to check.
|
||||||
outputs (Optional[List[BlindedMessage]], optional): List of outputs to check.
|
outputs (Optional[List[BlindedMessage]], optional): List of outputs to check.
|
||||||
|
|||||||
@@ -22,14 +22,11 @@ class NostrClient:
|
|||||||
relays = [
|
relays = [
|
||||||
"wss://nostr-pub.wellorder.net",
|
"wss://nostr-pub.wellorder.net",
|
||||||
"wss://relay.damus.io",
|
"wss://relay.damus.io",
|
||||||
"wss://nostr.zebedee.cloud",
|
|
||||||
"wss://relay.snort.social",
|
|
||||||
"wss://nostr.fmt.wiz.biz",
|
"wss://nostr.fmt.wiz.biz",
|
||||||
"wss://nos.lol",
|
"wss://nos.lol",
|
||||||
"wss://nostr.oxtr.dev",
|
"wss://nostr.oxtr.dev",
|
||||||
"wss://relay.current.fyi",
|
"wss://relay.current.fyi",
|
||||||
"wss://relay.snort.social",
|
]
|
||||||
] # ["wss://nostr.oxtr.dev"] # ["wss://relay.nostr.info"] "wss://nostr-pub.wellorder.net" "ws://91.237.88.218:2700", "wss://nostrrr.bublina.eu.org", ""wss://nostr-relay.freeberty.net"", , "wss://nostr.oxtr.dev", "wss://relay.nostr.info", "wss://nostr-pub.wellorder.net" , "wss://relayer.fiatjaf.com", "wss://nodestr.fmt.wiz.biz/", "wss://no.str.cr"
|
|
||||||
relay_manager = RelayManager()
|
relay_manager = RelayManager()
|
||||||
private_key: PrivateKey
|
private_key: PrivateKey
|
||||||
public_key: PublicKey
|
public_key: PublicKey
|
||||||
|
|||||||
0
cashu/wallet/auth/__init__.py
Normal file
0
cashu/wallet/auth/__init__.py
Normal file
240
cashu/wallet/auth/auth.py
Normal file
240
cashu/wallet/auth/auth.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from cashu.core.helpers import sum_proofs
|
||||||
|
from cashu.core.mint_info import MintInfo
|
||||||
|
|
||||||
|
from ...core.base import Proof
|
||||||
|
from ...core.crypto.secp import PrivateKey
|
||||||
|
from ...core.db import Database
|
||||||
|
from ..crud import get_mint_by_url, update_mint
|
||||||
|
from ..wallet import Wallet
|
||||||
|
from .openid_connect.openid_client import AuthorizationFlow, OpenIDClient
|
||||||
|
|
||||||
|
|
||||||
|
class WalletAuth(Wallet):
|
||||||
|
oidc_discovery_url: str
|
||||||
|
oidc_client: OpenIDClient
|
||||||
|
wallet_db: Database
|
||||||
|
auth_flow: AuthorizationFlow
|
||||||
|
username: str | None
|
||||||
|
password: str | None
|
||||||
|
# API prefix for all requests
|
||||||
|
api_prefix = "/v1/auth/blind"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, url: str, db: str, name: str = "auth", unit: str = "auth", **kwargs
|
||||||
|
):
|
||||||
|
"""Authentication wallet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): Mint url.
|
||||||
|
db (str): Auth wallet db location.
|
||||||
|
wallet_db (str): Wallet db location.
|
||||||
|
name (str, optional): Wallet name. Defaults to "auth".
|
||||||
|
unit (str, optional): Wallet unit. Defaults to "auth".
|
||||||
|
kwargs: Additional keyword arguments.
|
||||||
|
client_id (str, optional): OpenID client id. Defaults to "cashu-client".
|
||||||
|
client_secret (str, optional): OpenID client secret. Defaults to "".
|
||||||
|
username (str, optional): OpenID username. When set, the username and
|
||||||
|
password flow will be used to authenticate. If a username is already
|
||||||
|
stored in the database, it will be used. Will be stored in the
|
||||||
|
database if not already stored.
|
||||||
|
password (str, optional): OpenID password. Used if username is set. Will
|
||||||
|
be read from the database if already stored. Will be stored in the
|
||||||
|
database if not already stored.
|
||||||
|
"""
|
||||||
|
super().__init__(url, db, name, unit)
|
||||||
|
self.client_id = kwargs.get("client_id", "cashu-client")
|
||||||
|
logger.trace(f"client_id: {self.client_id}")
|
||||||
|
self.client_secret = kwargs.get("client_secret", "")
|
||||||
|
self.username = kwargs.get("username")
|
||||||
|
self.password = kwargs.get("password")
|
||||||
|
|
||||||
|
if self.username:
|
||||||
|
if self.password is None:
|
||||||
|
raise Exception("Password must be set if username is set.")
|
||||||
|
self.auth_flow = AuthorizationFlow.PASSWORD
|
||||||
|
else:
|
||||||
|
self.auth_flow = AuthorizationFlow.AUTHORIZATION_CODE
|
||||||
|
# self.auth_flow = AuthorizationFlow.DEVICE_CODE
|
||||||
|
|
||||||
|
self.access_token = kwargs.get("access_token")
|
||||||
|
self.refresh_token = kwargs.get("refresh_token")
|
||||||
|
|
||||||
|
# overload with_db
|
||||||
|
@classmethod
|
||||||
|
async def with_db(cls, *args, **kwargs) -> "WalletAuth":
|
||||||
|
"""Create a new wallet with a database.
|
||||||
|
Keyword arguments:
|
||||||
|
url (str): Mint url.
|
||||||
|
db (str): Wallet db location.
|
||||||
|
name (str, optional): Wallet name. Defaults to "auth".
|
||||||
|
username (str, optional): OpenID username. When set, the username and
|
||||||
|
password flow will be used to authenticate. If a username is already
|
||||||
|
stored in the database, it will be used. Will be stored in the
|
||||||
|
database if not already stored.
|
||||||
|
password (str, optional): OpenID password. Used if username is set. Will
|
||||||
|
be read from the database if already stored. Will be stored in the
|
||||||
|
database if not already stored.
|
||||||
|
client_id (str, optional): OpenID client id. Defaults to "cashu-client".
|
||||||
|
client_secret (str, optional): OpenID client secret. Defaults to "".
|
||||||
|
access_token (str, optional): OpenID access token. Defaults to None.
|
||||||
|
refresh_token (str, optional): OpenID refresh token. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
WalletAuth: WalletAuth instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
url: str = kwargs.get("url", "")
|
||||||
|
db = kwargs.get("db", "")
|
||||||
|
kwargs["name"] = kwargs.get("name", "auth")
|
||||||
|
name = kwargs["name"]
|
||||||
|
username = kwargs.get("username")
|
||||||
|
password = kwargs.get("password")
|
||||||
|
wallet_db = Database(name, db)
|
||||||
|
|
||||||
|
# run migrations etc
|
||||||
|
kwargs.update(dict(skip_db_read=True))
|
||||||
|
await super().with_db(*args, **kwargs)
|
||||||
|
|
||||||
|
# the wallet might not have been created yet
|
||||||
|
# if it was though, we load the username, password,
|
||||||
|
# access token and refresh token from the database
|
||||||
|
try:
|
||||||
|
mint_db = await get_mint_by_url(wallet_db, url)
|
||||||
|
if mint_db:
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"username": username or mint_db.username,
|
||||||
|
"password": password or mint_db.password,
|
||||||
|
"access_token": mint_db.access_token,
|
||||||
|
"refresh_token": mint_db.refresh_token,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
|
async def init_auth_wallet(
|
||||||
|
self,
|
||||||
|
mint_info: Optional[MintInfo] = None,
|
||||||
|
mint_auth_proofs=True,
|
||||||
|
force_auth=False,
|
||||||
|
) -> bool:
|
||||||
|
"""Initialize authentication wallet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mint_info (MintInfo, optional): Mint information. If not provided, we load the
|
||||||
|
info from the database or the mint directly. Defaults to None.
|
||||||
|
mint_auth_proofs (bool, optional): Whether to mint auth proofs if necessary.
|
||||||
|
Defaults to True.
|
||||||
|
force_auth (bool, optional): Whether to force authentication. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: False if the mint does not require clear auth. True otherwise.
|
||||||
|
"""
|
||||||
|
if mint_info:
|
||||||
|
self.mint_info = mint_info
|
||||||
|
await self.load_mint_info()
|
||||||
|
if not self.mint_info.requires_clear_auth():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Use the blind auth api_prefix for all following requests
|
||||||
|
await self.load_mint_keysets()
|
||||||
|
await self.activate_keyset()
|
||||||
|
await self.load_proofs()
|
||||||
|
|
||||||
|
self.oidc_discovery_url = self.mint_info.oidc_discovery_url()
|
||||||
|
self.client_id = self.mint_info.oidc_client_id()
|
||||||
|
|
||||||
|
# Initialize OpenIDClient
|
||||||
|
self.oidc_client = OpenIDClient(
|
||||||
|
discovery_url=self.oidc_discovery_url,
|
||||||
|
client_id=self.client_id,
|
||||||
|
client_secret=self.client_secret,
|
||||||
|
auth_flow=self.auth_flow,
|
||||||
|
username=self.username,
|
||||||
|
password=self.password,
|
||||||
|
access_token=self.access_token,
|
||||||
|
refresh_token=self.refresh_token,
|
||||||
|
)
|
||||||
|
# Authenticate using OpenIDClient
|
||||||
|
await self.oidc_client.initialize()
|
||||||
|
await self.oidc_client.authenticate(force_authenticate=force_auth)
|
||||||
|
|
||||||
|
await self.store_username_password()
|
||||||
|
await self.store_clear_auth_token()
|
||||||
|
|
||||||
|
if mint_auth_proofs:
|
||||||
|
await self.mint_blind_auth_min_balance()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def mint_blind_auth_min_balance(self) -> None:
|
||||||
|
"""Mint auth tokens if balance is too low."""
|
||||||
|
MIN_BALANCE = self.mint_info.bat_max_mint
|
||||||
|
|
||||||
|
if self.available_balance < MIN_BALANCE:
|
||||||
|
logger.debug(
|
||||||
|
f"Balance too low. Minting {self.unit.str(MIN_BALANCE)} auth tokens."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await self.mint_blind_auth()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error minting auth proofs: {str(e)}")
|
||||||
|
|
||||||
|
async def store_username_password(self) -> None:
|
||||||
|
"""Store the username and password in the database."""
|
||||||
|
if self.username and self.password:
|
||||||
|
mint_db = await get_mint_by_url(self.db, self.url)
|
||||||
|
if not mint_db:
|
||||||
|
raise Exception("Mint not found.")
|
||||||
|
if mint_db.username != self.username or mint_db.password != self.password:
|
||||||
|
mint_db.username = self.username
|
||||||
|
mint_db.password = self.password
|
||||||
|
await update_mint(self.db, mint_db)
|
||||||
|
|
||||||
|
async def store_clear_auth_token(self) -> None:
|
||||||
|
"""Store the access and refresh tokens in the database."""
|
||||||
|
access_token = self.oidc_client.access_token
|
||||||
|
refresh_token = self.oidc_client.refresh_token
|
||||||
|
if not access_token or not refresh_token:
|
||||||
|
raise Exception("Access or refresh token not available.")
|
||||||
|
# Store the tokens in the database
|
||||||
|
mint_db = await get_mint_by_url(self.db, self.url)
|
||||||
|
if not mint_db:
|
||||||
|
raise Exception("Mint not found.")
|
||||||
|
if (
|
||||||
|
mint_db.access_token != access_token
|
||||||
|
or mint_db.refresh_token != refresh_token
|
||||||
|
):
|
||||||
|
mint_db.access_token = access_token
|
||||||
|
mint_db.refresh_token = refresh_token
|
||||||
|
await update_mint(self.db, mint_db)
|
||||||
|
|
||||||
|
async def mint_blind_auth(self) -> List[Proof]:
|
||||||
|
# Ensure access token is valid
|
||||||
|
if self.oidc_client.is_token_expired():
|
||||||
|
await self.oidc_client.refresh_access_token()
|
||||||
|
await self.store_clear_auth_token()
|
||||||
|
clear_auth_token = self.oidc_client.access_token
|
||||||
|
if not clear_auth_token:
|
||||||
|
raise Exception("No clear auth token available.")
|
||||||
|
|
||||||
|
amounts = self.mint_info.bat_max_mint * [1] # 1 AUTH tokens
|
||||||
|
secrets = [hashlib.sha256(os.urandom(32)).hexdigest() for _ in amounts]
|
||||||
|
rs = [PrivateKey(privkey=os.urandom(32), raw=True) for _ in amounts]
|
||||||
|
derivation_paths = ["" for _ in amounts]
|
||||||
|
outputs, rs = self._construct_outputs(amounts, secrets, rs)
|
||||||
|
promises = await self.blind_mint_blind_auth(clear_auth_token, outputs)
|
||||||
|
new_proofs = await self._construct_proofs(
|
||||||
|
promises, secrets, rs, derivation_paths
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Minted {self.unit.str(sum_proofs(new_proofs))} blind auth proofs."
|
||||||
|
)
|
||||||
|
return new_proofs
|
||||||
487
cashu/wallet/auth/openid_connect/openid_client.py
Normal file
487
cashu/wallet/auth/openid_connect/openid_client.py
Normal file
@@ -0,0 +1,487 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import secrets
|
||||||
|
import webbrowser
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import jwt
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationFlow(Enum):
|
||||||
|
AUTHORIZATION_CODE = "authorization_code"
|
||||||
|
PASSWORD = "password"
|
||||||
|
DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenIDClient:
|
||||||
|
"""OpenID Connect client for authentication."""
|
||||||
|
|
||||||
|
oidc_config: Dict[str, Any]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
discovery_url: str,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str = "",
|
||||||
|
auth_flow: Optional[AuthorizationFlow] = None,
|
||||||
|
username: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
refresh_token: Optional[str] = None,
|
||||||
|
token_expiration_time: Optional[datetime] = None,
|
||||||
|
device_code: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
self.discovery_url: str = discovery_url
|
||||||
|
self.client_id: str = client_id
|
||||||
|
self.client_secret: str = client_secret
|
||||||
|
self.auth_flow: Optional[AuthorizationFlow] = auth_flow
|
||||||
|
self.username: Optional[str] = username
|
||||||
|
self.password: Optional[str] = password
|
||||||
|
self.access_token: Optional[str] = access_token
|
||||||
|
self.refresh_token: Optional[str] = refresh_token
|
||||||
|
self.token_expiration_time: Optional[datetime] = token_expiration_time
|
||||||
|
self.device_code: Optional[str] = device_code
|
||||||
|
|
||||||
|
self.redirect_uri: str = "http://localhost:33388/callback"
|
||||||
|
self.expected_state: str = secrets.token_urlsafe(16)
|
||||||
|
self.token_response: Dict[str, Any] = {}
|
||||||
|
self.token_event: asyncio.Event = asyncio.Event()
|
||||||
|
self.token_endpoint: str = ""
|
||||||
|
self.authorization_endpoint: str = ""
|
||||||
|
self.introspection_endpoint: Optional[str] = None
|
||||||
|
self.revocation_endpoint: Optional[str] = None
|
||||||
|
self.device_authorization_endpoint: Optional[str] = None
|
||||||
|
self.templates: Jinja2Templates = Jinja2Templates(
|
||||||
|
directory="cashu/wallet/auth/openid_connect/templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.app: FastAPI = FastAPI()
|
||||||
|
self.app.state.client = self # Store self in app state
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the client asynchronously."""
|
||||||
|
await self.fetch_oidc_configuration()
|
||||||
|
await self.determine_auth_flow()
|
||||||
|
|
||||||
|
async def determine_auth_flow(self) -> AuthorizationFlow:
|
||||||
|
"""Determine the authentication flow to use from the oidc configuration.
|
||||||
|
Supported flows are chosen in the following order:
|
||||||
|
- device_code
|
||||||
|
- authorization_code
|
||||||
|
- password
|
||||||
|
"""
|
||||||
|
if not hasattr(self, "oidc_config"):
|
||||||
|
raise ValueError(
|
||||||
|
"OIDC configuration not loaded. Call fetch_oidc_configuration first."
|
||||||
|
)
|
||||||
|
|
||||||
|
supported_flows = self.oidc_config.get("grant_types_supported", [])
|
||||||
|
|
||||||
|
# if self.auth_flow is already set, check if it is supported
|
||||||
|
if self.auth_flow:
|
||||||
|
if self.auth_flow.value not in supported_flows:
|
||||||
|
raise ValueError(
|
||||||
|
f"Authentication flow {self.auth_flow.value} not supported by the OIDC configuration."
|
||||||
|
)
|
||||||
|
return self.auth_flow
|
||||||
|
|
||||||
|
if AuthorizationFlow.DEVICE_CODE.value in supported_flows:
|
||||||
|
self.auth_flow = AuthorizationFlow.DEVICE_CODE
|
||||||
|
elif AuthorizationFlow.AUTHORIZATION_CODE.value in supported_flows:
|
||||||
|
self.auth_flow = AuthorizationFlow.AUTHORIZATION_CODE
|
||||||
|
elif AuthorizationFlow.PASSWORD.value in supported_flows:
|
||||||
|
self.auth_flow = AuthorizationFlow.PASSWORD
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No supported authentication flows found in the OIDC configuration."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Determined authentication flow: {self.auth_flow.value}")
|
||||||
|
return self.auth_flow
|
||||||
|
|
||||||
|
async def fetch_oidc_configuration(self) -> None:
|
||||||
|
"""Fetch OIDC configuration from the discovery URL."""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(self.discovery_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
self.oidc_config = response.json()
|
||||||
|
self.authorization_endpoint = self.oidc_config.get( # type: ignore
|
||||||
|
"authorization_endpoint"
|
||||||
|
)
|
||||||
|
self.token_endpoint = self.oidc_config.get("token_endpoint") # type: ignore
|
||||||
|
self.introspection_endpoint = self.oidc_config.get(
|
||||||
|
"introspection_endpoint"
|
||||||
|
)
|
||||||
|
self.revocation_endpoint = self.oidc_config.get("revocation_endpoint")
|
||||||
|
self.device_authorization_endpoint = self.oidc_config.get(
|
||||||
|
"device_authorization_endpoint"
|
||||||
|
)
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error(f"Failed to get OpenID configuration: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def handle_callback(self, request: Request) -> HTMLResponse:
|
||||||
|
"""Endpoint to handle the redirect from the OpenID provider."""
|
||||||
|
params = request.query_params
|
||||||
|
if "error" in params:
|
||||||
|
error_str = params["error"]
|
||||||
|
if "error_description" in params:
|
||||||
|
error_str += f": {params['error_description']}"
|
||||||
|
return self.templates.TemplateResponse(
|
||||||
|
"error.html",
|
||||||
|
{"request": request, "error": error_str},
|
||||||
|
)
|
||||||
|
elif "code" in params and "state" in params:
|
||||||
|
code: str = params["code"]
|
||||||
|
state: str = params["state"]
|
||||||
|
if state != self.expected_state:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid state parameter")
|
||||||
|
token_data: Dict[str, Any] = await self.exchange_code_for_token(code)
|
||||||
|
self.update_token_data(token_data)
|
||||||
|
response = self.render_success_page(request, token_data)
|
||||||
|
self.token_event.set() # Signal that the token has been received
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
return self.templates.TemplateResponse(
|
||||||
|
"error.html",
|
||||||
|
{"request": request, "error": "Missing 'code' or 'state' parameter"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def exchange_code_for_token(self, code: str) -> Dict[str, Any]:
|
||||||
|
"""Exchange the authorization code for tokens."""
|
||||||
|
data: Dict[str, str] = {
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
}
|
||||||
|
headers: Dict[str, str] = {}
|
||||||
|
if self.client_secret:
|
||||||
|
# Use HTTP Basic Auth if client_secret is provided
|
||||||
|
basic_auth: str = f"{self.client_id}:{self.client_secret}"
|
||||||
|
basic_auth_bytes: bytes = basic_auth.encode("ascii")
|
||||||
|
basic_auth_b64: str = base64.b64encode(basic_auth_bytes).decode("ascii")
|
||||||
|
headers["Authorization"] = f"Basic {basic_auth_b64}"
|
||||||
|
else:
|
||||||
|
# Include client_id in the POST body for public clients
|
||||||
|
data["client_id"] = self.client_id
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
self.token_endpoint, data=data, headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error(f"HTTP error occurred during token exchange: {e}")
|
||||||
|
self.token_event.set()
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def run_server(self) -> None:
|
||||||
|
"""Run the FastAPI server."""
|
||||||
|
config = uvicorn.Config(
|
||||||
|
self.app, host="127.0.0.1", port=33388, log_level="error"
|
||||||
|
)
|
||||||
|
self.server = uvicorn.Server(config)
|
||||||
|
await self.server.serve()
|
||||||
|
|
||||||
|
async def shutdown_server(self) -> None:
|
||||||
|
"""Shut down the uvicorn server."""
|
||||||
|
self.server.should_exit = True
|
||||||
|
|
||||||
|
async def authenticate(self, force_authenticate: bool = False) -> None:
|
||||||
|
"""Start the authentication process."""
|
||||||
|
need_authenticate = force_authenticate
|
||||||
|
if self.access_token and self.refresh_token:
|
||||||
|
# We have a token and a refresh token, check if token is expired
|
||||||
|
if self.is_token_expired():
|
||||||
|
try:
|
||||||
|
await self.refresh_access_token()
|
||||||
|
except httpx.HTTPError:
|
||||||
|
logger.debug("Failed to refresh token.")
|
||||||
|
need_authenticate = True
|
||||||
|
else:
|
||||||
|
logger.debug("Using existing access token.")
|
||||||
|
else:
|
||||||
|
need_authenticate = True
|
||||||
|
|
||||||
|
if need_authenticate:
|
||||||
|
if self.auth_flow == AuthorizationFlow.AUTHORIZATION_CODE:
|
||||||
|
await self.authenticate_with_authorization_code()
|
||||||
|
elif self.auth_flow == AuthorizationFlow.PASSWORD:
|
||||||
|
await self.authenticate_with_password()
|
||||||
|
elif self.auth_flow == AuthorizationFlow.DEVICE_CODE:
|
||||||
|
await self.authenticate_with_device_code()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown authentication flow: {self.auth_flow}")
|
||||||
|
|
||||||
|
def is_token_expired(self) -> bool:
|
||||||
|
"""Check if the access token is expired."""
|
||||||
|
if not self.access_token:
|
||||||
|
raise ValueError("Access token is not set.")
|
||||||
|
decoded = jwt.decode(self.access_token, options={"verify_signature": False})
|
||||||
|
exp = decoded.get("exp")
|
||||||
|
if not exp:
|
||||||
|
return False
|
||||||
|
return datetime.now() >= datetime.fromtimestamp(exp) - timedelta(minutes=1)
|
||||||
|
|
||||||
|
async def refresh_access_token(self) -> None:
|
||||||
|
"""Refresh the access token using the refresh token."""
|
||||||
|
data = {
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": self.refresh_token,
|
||||||
|
"client_id": self.client_id,
|
||||||
|
}
|
||||||
|
if self.client_secret:
|
||||||
|
data["client_secret"] = self.client_secret
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(self.token_endpoint, data=data)
|
||||||
|
response.raise_for_status()
|
||||||
|
token_data = response.json()
|
||||||
|
self.update_token_data(token_data)
|
||||||
|
logger.info("Token refreshed successfully.")
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.debug(f"Failed to refresh token: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def authenticate_with_authorization_code(self) -> None:
|
||||||
|
"""Authenticate using the authorization code flow."""
|
||||||
|
|
||||||
|
# Set up the route handlers
|
||||||
|
@self.app.get("/callback", response_class=HTMLResponse)
|
||||||
|
async def handle_callback(request: Request) -> Any:
|
||||||
|
print("CALLBACK")
|
||||||
|
return await self.handle_callback(request)
|
||||||
|
|
||||||
|
# Build the authorization URL
|
||||||
|
params = {
|
||||||
|
"response_type": "code",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"scope": "openid",
|
||||||
|
"state": self.expected_state,
|
||||||
|
}
|
||||||
|
auth_url = f"{self.authorization_endpoint}?{urlencode(params)}"
|
||||||
|
|
||||||
|
# Start the web server as an asyncio task
|
||||||
|
server_task = asyncio.create_task(self.run_server())
|
||||||
|
|
||||||
|
# Open the browser or print the URL for the user
|
||||||
|
logger.info("Please open the following URL in your browser to authenticate:")
|
||||||
|
logger.info(auth_url)
|
||||||
|
webbrowser.open(auth_url)
|
||||||
|
|
||||||
|
# Wait for the token response
|
||||||
|
logger.info("Waiting for authentication...")
|
||||||
|
await self.token_event.wait()
|
||||||
|
|
||||||
|
# Use the retrieved tokens
|
||||||
|
if self.token_response:
|
||||||
|
logger.info("Authentication successful!")
|
||||||
|
# logger.info("Token response:")
|
||||||
|
# logger.info(self.token_response)
|
||||||
|
else:
|
||||||
|
logger.error("Authentication failed.")
|
||||||
|
|
||||||
|
# Signal the server to shut down
|
||||||
|
await self.shutdown_server()
|
||||||
|
|
||||||
|
# Wait for the server task to finish
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
def update_token_data(self, token_data: Dict[str, Any]) -> None:
|
||||||
|
self.token_response.update(token_data)
|
||||||
|
self.access_token = token_data.get("access_token")
|
||||||
|
self.refresh_token = token_data.get("refresh_token")
|
||||||
|
if not self.access_token or not self.refresh_token:
|
||||||
|
raise ValueError(
|
||||||
|
"Access token or refresh token not found in token response."
|
||||||
|
)
|
||||||
|
expires_in = token_data.get("expires_in")
|
||||||
|
if expires_in:
|
||||||
|
self.token_expiration_time = datetime.utcnow() + timedelta(
|
||||||
|
seconds=int(expires_in)
|
||||||
|
)
|
||||||
|
refresh_expires_in = token_data.get("refresh_expires_in")
|
||||||
|
if refresh_expires_in:
|
||||||
|
logger.debug(f"Refresh token expires in {refresh_expires_in} seconds.")
|
||||||
|
self.refresh_token_expiration_time = datetime.utcnow() + timedelta(
|
||||||
|
seconds=int(refresh_expires_in)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def authenticate_with_password(self) -> None:
|
||||||
|
"""Authenticate using the resource owner password credentials flow."""
|
||||||
|
if not self.username or not self.password:
|
||||||
|
raise ValueError(
|
||||||
|
'Username and password must be provided. To set a password use: "cashu auth -p"'
|
||||||
|
)
|
||||||
|
data = {
|
||||||
|
"grant_type": "password",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"username": self.username,
|
||||||
|
"password": self.password,
|
||||||
|
"scope": "openid",
|
||||||
|
}
|
||||||
|
if self.client_secret:
|
||||||
|
data["client_secret"] = self.client_secret
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(self.token_endpoint, data=data)
|
||||||
|
response.raise_for_status()
|
||||||
|
token_data = response.json()
|
||||||
|
self.update_token_data(token_data)
|
||||||
|
logger.info("Authentication successful!")
|
||||||
|
# logger.info("Token response:")
|
||||||
|
# logger.info(self.token_response)
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error(f"Failed to obtain token: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def render_success_page(
|
||||||
|
self, request: Request, token_data: Dict[str, Any]
|
||||||
|
) -> HTMLResponse:
|
||||||
|
"""Render an HTML page with a green check mark and user information."""
|
||||||
|
context = {"request": request, "token_data": token_data}
|
||||||
|
return self.templates.TemplateResponse("success.html", context)
|
||||||
|
|
||||||
|
async def authenticate_with_device_code(self) -> None:
|
||||||
|
"""Authenticate using the device code flow."""
|
||||||
|
if not self.device_authorization_endpoint:
|
||||||
|
raise ValueError("Device authorization endpoint not available.")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"scope": "openid",
|
||||||
|
}
|
||||||
|
if self.client_secret:
|
||||||
|
data["client_secret"] = self.client_secret
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
self.device_authorization_endpoint, data=data
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
device_data = response.json()
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error(f"Failed to obtain device code: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Extract device code data
|
||||||
|
device_code = device_data.get("device_code")
|
||||||
|
user_code = device_data.get("user_code")
|
||||||
|
verification_uri = device_data.get("verification_uri")
|
||||||
|
verification_uri_complete = device_data.get("verification_uri_complete")
|
||||||
|
expires_in = device_data.get("expires_in")
|
||||||
|
interval = device_data.get("interval", 5) # Default interval is 5 seconds
|
||||||
|
|
||||||
|
if not device_code or not verification_uri:
|
||||||
|
raise ValueError("Invalid response from device authorization endpoint.")
|
||||||
|
|
||||||
|
# Display instructions to the user and open the browser
|
||||||
|
if verification_uri_complete:
|
||||||
|
logger.info("Opening browser to complete authorization...")
|
||||||
|
logger.info(verification_uri_complete)
|
||||||
|
webbrowser.open(verification_uri_complete)
|
||||||
|
else:
|
||||||
|
logger.info("Please visit the following URL to authorize:")
|
||||||
|
logger.info(verification_uri)
|
||||||
|
logger.info(f"Enter the user code: {user_code}")
|
||||||
|
# Construct the URL for the user to enter the code
|
||||||
|
full_verification_uri = f"{verification_uri}?user_code={user_code}"
|
||||||
|
webbrowser.open(full_verification_uri)
|
||||||
|
|
||||||
|
# Start polling the token endpoint
|
||||||
|
start_time = datetime.now()
|
||||||
|
expires_at = start_time + timedelta(seconds=expires_in)
|
||||||
|
token_data = None
|
||||||
|
while datetime.now() < expires_at:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
data = {
|
||||||
|
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||||
|
"device_code": device_code,
|
||||||
|
"client_id": self.client_id,
|
||||||
|
}
|
||||||
|
if self.client_secret:
|
||||||
|
data["client_secret"] = self.client_secret
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(self.token_endpoint, data=data)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# Successful response
|
||||||
|
token_data = response.json()
|
||||||
|
self.update_token_data(token_data)
|
||||||
|
logger.info("Authentication successful!")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
error_data = response.json()
|
||||||
|
error = error_data.get("error")
|
||||||
|
if error == "authorization_pending":
|
||||||
|
# Continue polling
|
||||||
|
pass
|
||||||
|
elif error == "slow_down":
|
||||||
|
# Increase interval by 5 seconds
|
||||||
|
interval += 5
|
||||||
|
elif error == "access_denied":
|
||||||
|
logger.error("Access denied by user.")
|
||||||
|
break
|
||||||
|
elif error == "expired_token":
|
||||||
|
logger.error("Device code has expired.")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.error(f"Error during polling: {error}")
|
||||||
|
break
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error(f"HTTP error during token polling: {e}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.error("Device code has expired before authorization.")
|
||||||
|
raise Exception("Device code expired")
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="OpenID Connect Authentication Client")
|
||||||
|
parser.add_argument("discovery_url", help="OpenID Connect Discovery URL")
|
||||||
|
parser.add_argument("client_id", help="Client ID")
|
||||||
|
parser.add_argument("--client_secret", help="Client Secret", default="")
|
||||||
|
parser.add_argument(
|
||||||
|
"--auth_flow",
|
||||||
|
choices=["authorization_code", "password", "device_code"],
|
||||||
|
default="authorization_code",
|
||||||
|
help="Authentication flow to use",
|
||||||
|
)
|
||||||
|
parser.add_argument("--username", help="Username for password flow")
|
||||||
|
parser.add_argument("--password", help="Password for password flow")
|
||||||
|
parser.add_argument("--access_token", help="Stored access token")
|
||||||
|
parser.add_argument("--refresh_token", help="Stored refresh token")
|
||||||
|
parser.add_argument("--device_code", help="Device code for device flow")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
client = OpenIDClient(
|
||||||
|
discovery_url=args.discovery_url,
|
||||||
|
client_id=args.client_id,
|
||||||
|
client_secret=args.client_secret,
|
||||||
|
auth_flow=AuthorizationFlow(args.auth_flow),
|
||||||
|
username=args.username,
|
||||||
|
password=args.password,
|
||||||
|
access_token=args.access_token,
|
||||||
|
refresh_token=args.refresh_token,
|
||||||
|
device_code=args.device_code,
|
||||||
|
)
|
||||||
|
await client.initialize()
|
||||||
|
await client.authenticate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
26
cashu/wallet/auth/openid_connect/templates/error.html
Normal file
26
cashu/wallet/auth/openid_connect/templates/error.html
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Authentication Error</title>
|
||||||
|
<style>
|
||||||
|
.error {
|
||||||
|
text-align: center;
|
||||||
|
margin-top: 50px;
|
||||||
|
}
|
||||||
|
.error h1 {
|
||||||
|
color: red;
|
||||||
|
}
|
||||||
|
.crossmark {
|
||||||
|
font-size: 100px;
|
||||||
|
color: red;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="error">
|
||||||
|
<div class="crossmark">✖</div>
|
||||||
|
<h1>Authentication Error</h1>
|
||||||
|
<p>{{ error }}</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
33
cashu/wallet/auth/openid_connect/templates/success.html
Normal file
33
cashu/wallet/auth/openid_connect/templates/success.html
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Authentication Successful</title>
|
||||||
|
<style>
|
||||||
|
.success {
|
||||||
|
text-align: center;
|
||||||
|
margin-top: 50px;
|
||||||
|
}
|
||||||
|
.success h1 {
|
||||||
|
color: green;
|
||||||
|
}
|
||||||
|
.checkmark {
|
||||||
|
font-size: 100px;
|
||||||
|
color: green;
|
||||||
|
}
|
||||||
|
.token-data {
|
||||||
|
margin-top: 20px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="success">
|
||||||
|
<div class="checkmark">✔</div>
|
||||||
|
<h1>Authentication Successful</h1>
|
||||||
|
<h3>You can close this window now.</h3>
|
||||||
|
<!-- <div class="token-data">
|
||||||
|
<h3>User Information:</h3>
|
||||||
|
<pre>{{ token_data | tojson(indent=2) }}</pre>
|
||||||
|
</div> -->
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import getpass
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
@@ -41,6 +42,7 @@ from ...wallet.crud import (
|
|||||||
)
|
)
|
||||||
from ...wallet.wallet import Wallet as Wallet
|
from ...wallet.wallet import Wallet as Wallet
|
||||||
from ..api.api_server import start_api_server
|
from ..api.api_server import start_api_server
|
||||||
|
from ..auth.auth import WalletAuth
|
||||||
from ..cli.cli_helpers import (
|
from ..cli.cli_helpers import (
|
||||||
get_mint_wallet,
|
get_mint_wallet,
|
||||||
get_unit_wallet,
|
get_unit_wallet,
|
||||||
@@ -84,6 +86,49 @@ def coro(f):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def init_auth_wallet(func):
|
||||||
|
"""Decorator to pass auth_db and auth_keyset_id to the Wallet object."""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
ctx = args[0] # Assuming the first argument is 'ctx'
|
||||||
|
wallet: Wallet = ctx.obj["WALLET"]
|
||||||
|
db_location = wallet.db.db_location
|
||||||
|
|
||||||
|
auth_wallet = await WalletAuth.with_db(
|
||||||
|
url=ctx.obj["HOST"],
|
||||||
|
db=db_location,
|
||||||
|
)
|
||||||
|
|
||||||
|
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
|
||||||
|
|
||||||
|
if not requires_auth:
|
||||||
|
logger.debug("Mint does not require clear auth.")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Pass auth_db and auth_keyset_id to the wallet object
|
||||||
|
wallet.auth_db = auth_wallet.db
|
||||||
|
wallet.auth_keyset_id = auth_wallet.keyset_id
|
||||||
|
# pass the mint_info so the wallet doesn't need to re-fetch it
|
||||||
|
wallet.mint_info = auth_wallet.mint_info
|
||||||
|
|
||||||
|
# Pass the auth_wallet to context
|
||||||
|
args[0].obj["AUTH_WALLET"] = auth_wallet
|
||||||
|
|
||||||
|
# Proceed to the original function
|
||||||
|
ret = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
if settings.debug:
|
||||||
|
await auth_wallet.load_proofs(reload=True)
|
||||||
|
logger.debug(
|
||||||
|
f"Auth balance: {auth_wallet.unit.str(auth_wallet.available_balance)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@click.group(cls=NaturalOrderGroup)
|
@click.group(cls=NaturalOrderGroup)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--host",
|
"--host",
|
||||||
@@ -176,6 +221,10 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool):
|
|||||||
ctx.obj["HOST"], db_path, name=walletname, unit=unit
|
ctx.obj["HOST"], db_path, name=walletname, unit=unit
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if we have never seen this mint before, we load its information
|
||||||
|
if not wallet.mint_info:
|
||||||
|
await wallet.load_mint()
|
||||||
|
|
||||||
assert wallet, "Wallet not found."
|
assert wallet, "Wallet not found."
|
||||||
ctx.obj["WALLET"] = wallet
|
ctx.obj["WALLET"] = wallet
|
||||||
|
|
||||||
@@ -205,6 +254,7 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool):
|
|||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
|
@init_auth_wallet
|
||||||
async def pay(
|
async def pay(
|
||||||
ctx: Context, invoice: str, amount: Optional[int] = None, yes: bool = False
|
ctx: Context, invoice: str, amount: Optional[int] = None, yes: bool = False
|
||||||
):
|
):
|
||||||
@@ -291,6 +341,7 @@ async def pay(
|
|||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
|
@init_auth_wallet
|
||||||
async def invoice(
|
async def invoice(
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
amount: float,
|
amount: float,
|
||||||
@@ -451,6 +502,7 @@ async def invoice(
|
|||||||
@cli.command("swap", help="Swap funds between mints.")
|
@cli.command("swap", help="Swap funds between mints.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
|
@init_auth_wallet
|
||||||
async def swap(ctx: Context):
|
async def swap(ctx: Context):
|
||||||
print("Select the mint to swap from:")
|
print("Select the mint to swap from:")
|
||||||
outgoing_wallet: Wallet = await get_mint_wallet(ctx, force_select=True)
|
outgoing_wallet: Wallet = await get_mint_wallet(ctx, force_select=True)
|
||||||
@@ -621,8 +673,9 @@ async def balance(ctx: Context, verbose):
|
|||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
|
@init_auth_wallet
|
||||||
async def send_command(
|
async def send_command(
|
||||||
ctx,
|
ctx: Context,
|
||||||
amount: int,
|
amount: int,
|
||||||
memo: str,
|
memo: str,
|
||||||
nostr: str,
|
nostr: str,
|
||||||
@@ -668,6 +721,7 @@ async def send_command(
|
|||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
|
@init_auth_wallet
|
||||||
async def receive_cli(
|
async def receive_cli(
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
token: str,
|
token: str,
|
||||||
@@ -685,6 +739,8 @@ async def receive_cli(
|
|||||||
mint_url,
|
mint_url,
|
||||||
os.path.join(settings.cashu_dir, wallet.name),
|
os.path.join(settings.cashu_dir, wallet.name),
|
||||||
unit=token_obj.unit,
|
unit=token_obj.unit,
|
||||||
|
auth_db=wallet.auth_db.db_location if wallet.auth_db else None,
|
||||||
|
auth_keyset_id=wallet.auth_keyset_id,
|
||||||
)
|
)
|
||||||
await verify_mint(mint_wallet, mint_url)
|
await verify_mint(mint_wallet, mint_url)
|
||||||
receive_wallet = await receive(mint_wallet, token_obj)
|
receive_wallet = await receive(mint_wallet, token_obj)
|
||||||
@@ -830,7 +886,7 @@ async def pending(ctx: Context, legacy, number: int, offset: int):
|
|||||||
@cli.command("lock", help="Generate receiving lock.")
|
@cli.command("lock", help="Generate receiving lock.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def lock(ctx):
|
async def lock(ctx: Context):
|
||||||
wallet: Wallet = ctx.obj["WALLET"]
|
wallet: Wallet = ctx.obj["WALLET"]
|
||||||
|
|
||||||
pubkey = await wallet.create_p2pk_pubkey()
|
pubkey = await wallet.create_p2pk_pubkey()
|
||||||
@@ -851,7 +907,7 @@ async def lock(ctx):
|
|||||||
@cli.command("locks", help="Show unused receiving locks.")
|
@cli.command("locks", help="Show unused receiving locks.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def locks(ctx):
|
async def locks(ctx: Context):
|
||||||
wallet: Wallet = ctx.obj["WALLET"]
|
wallet: Wallet = ctx.obj["WALLET"]
|
||||||
# P2PK lock
|
# P2PK lock
|
||||||
pubkey = await wallet.create_p2pk_pubkey()
|
pubkey = await wallet.create_p2pk_pubkey()
|
||||||
@@ -899,7 +955,7 @@ async def locks(ctx):
|
|||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def invoices(ctx, paid: bool, unpaid: bool, pending: bool, mint: bool):
|
async def invoices(ctx: Context, paid: bool, unpaid: bool, pending: bool, mint: bool):
|
||||||
wallet: Wallet = ctx.obj["WALLET"]
|
wallet: Wallet = ctx.obj["WALLET"]
|
||||||
|
|
||||||
if paid and unpaid:
|
if paid and unpaid:
|
||||||
@@ -1001,7 +1057,7 @@ async def invoices(ctx, paid: bool, unpaid: bool, pending: bool, mint: bool):
|
|||||||
@cli.command("wallets", help="List of all available wallets.")
|
@cli.command("wallets", help="List of all available wallets.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def wallets(ctx):
|
async def wallets(ctx: Context):
|
||||||
# list all directories
|
# list all directories
|
||||||
wallets = [
|
wallets = [
|
||||||
d for d in listdir(settings.cashu_dir) if isdir(join(settings.cashu_dir, d))
|
d for d in listdir(settings.cashu_dir) if isdir(join(settings.cashu_dir, d))
|
||||||
@@ -1013,7 +1069,7 @@ async def wallets(ctx):
|
|||||||
for w in wallets:
|
for w in wallets:
|
||||||
wallet = Wallet(ctx.obj["HOST"], os.path.join(settings.cashu_dir, w))
|
wallet = Wallet(ctx.obj["HOST"], os.path.join(settings.cashu_dir, w))
|
||||||
try:
|
try:
|
||||||
await wallet.load_proofs()
|
await wallet.load_proofs(reload=True, all_keysets=True)
|
||||||
if wallet.proofs and len(wallet.proofs):
|
if wallet.proofs and len(wallet.proofs):
|
||||||
active_wallet = False
|
active_wallet = False
|
||||||
if w == ctx.obj["WALLET_NAME"]:
|
if w == ctx.obj["WALLET_NAME"]:
|
||||||
@@ -1031,9 +1087,10 @@ async def wallets(ctx):
|
|||||||
@cli.command("info", help="Information about Cashu wallet.")
|
@cli.command("info", help="Information about Cashu wallet.")
|
||||||
@click.option("--mint", default=False, is_flag=True, help="Fetch mint information.")
|
@click.option("--mint", default=False, is_flag=True, help="Fetch mint information.")
|
||||||
@click.option("--mnemonic", default=False, is_flag=True, help="Show your mnemonic.")
|
@click.option("--mnemonic", default=False, is_flag=True, help="Show your mnemonic.")
|
||||||
|
@click.option("--reload", default=False, is_flag=True, help="Reload mint info.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def info(ctx: Context, mint: bool, mnemonic: bool):
|
async def info(ctx: Context, mint: bool, mnemonic: bool, reload: bool):
|
||||||
wallet: Wallet = ctx.obj["WALLET"]
|
wallet: Wallet = ctx.obj["WALLET"]
|
||||||
await wallet.load_keysets_from_db(unit=None)
|
await wallet.load_keysets_from_db(unit=None)
|
||||||
|
|
||||||
@@ -1057,7 +1114,11 @@ async def info(ctx: Context, mint: bool, mnemonic: bool):
|
|||||||
if mint:
|
if mint:
|
||||||
wallet.url = mint_url
|
wallet.url = mint_url
|
||||||
try:
|
try:
|
||||||
mint_info: dict = (await wallet.load_mint_info()).dict()
|
mint_info_obj = await wallet.load_mint_info(reload)
|
||||||
|
if not mint_info_obj:
|
||||||
|
print(" - Mint information not available.")
|
||||||
|
continue
|
||||||
|
mint_info = mint_info_obj.dict()
|
||||||
if mint_info:
|
if mint_info:
|
||||||
print(f" - Mint name: {mint_info['name']}")
|
print(f" - Mint name: {mint_info['name']}")
|
||||||
if mint_info.get("description"):
|
if mint_info.get("description"):
|
||||||
@@ -1126,6 +1187,7 @@ async def info(ctx: Context, mint: bool, mnemonic: bool):
|
|||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
|
@init_auth_wallet
|
||||||
async def restore(ctx: Context, to: int, batch: int):
|
async def restore(ctx: Context, to: int, batch: int):
|
||||||
wallet: Wallet = ctx.obj["WALLET"]
|
wallet: Wallet = ctx.obj["WALLET"]
|
||||||
# check if there is already a mnemonic in the database
|
# check if there is already a mnemonic in the database
|
||||||
@@ -1160,6 +1222,7 @@ async def restore(ctx: Context, to: int, batch: int):
|
|||||||
# @click.option("--all", default=False, is_flag=True, help="Execute on all available mints.")
|
# @click.option("--all", default=False, is_flag=True, help="Execute on all available mints.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
|
@init_auth_wallet
|
||||||
async def selfpay(ctx: Context, all: bool = False):
|
async def selfpay(ctx: Context, all: bool = False):
|
||||||
wallet = await get_mint_wallet(ctx, force_select=True)
|
wallet = await get_mint_wallet(ctx, force_select=True)
|
||||||
await wallet.load_mint()
|
await wallet.load_mint()
|
||||||
@@ -1183,3 +1246,46 @@ async def selfpay(ctx: Context, all: bool = False):
|
|||||||
print(token)
|
print(token)
|
||||||
token_obj = TokenV4.deserialize(token)
|
token_obj = TokenV4.deserialize(token)
|
||||||
await receive(wallet, token_obj)
|
await receive(wallet, token_obj)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command("auth", help="Authenticate with mint.")
|
||||||
|
@click.option("--mint", "-m", default=False, is_flag=True, help="Mint new auth tokens.")
|
||||||
|
@click.option(
|
||||||
|
"--force", "-f", default=False, is_flag=True, help="Force authentication."
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--password",
|
||||||
|
"-p",
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Use username and password for authentication.",
|
||||||
|
)
|
||||||
|
@click.pass_context
|
||||||
|
@coro
|
||||||
|
async def auth(ctx: Context, mint: bool, force: bool, password: bool):
|
||||||
|
# auth_wallet: WalletAuth = ctx.obj["AUTH_WALLET"]
|
||||||
|
wallet: Wallet = ctx.obj["WALLET"]
|
||||||
|
username = None
|
||||||
|
password_str = None
|
||||||
|
if password:
|
||||||
|
username = input("Enter username: ")
|
||||||
|
password_str = getpass.getpass("Enter password: ")
|
||||||
|
auth_wallet = await WalletAuth.with_db(
|
||||||
|
url=ctx.obj["HOST"],
|
||||||
|
db=wallet.db.db_location,
|
||||||
|
username=username,
|
||||||
|
password=password_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
requires_auth = await auth_wallet.init_auth_wallet(
|
||||||
|
wallet.mint_info, mint_auth_proofs=False, force_auth=force
|
||||||
|
)
|
||||||
|
if not requires_auth:
|
||||||
|
print("Mint does not require authentication.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if mint:
|
||||||
|
new_proofs = await auth_wallet.mint_blind_auth()
|
||||||
|
print(f"Minted {auth_wallet.unit.str(sum_proofs(new_proofs))} auth tokens.")
|
||||||
|
|
||||||
|
print(f"Auth balance: {auth_wallet.unit.str(auth_wallet.available_balance)}")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from ..core.base import (
|
|||||||
MintQuoteState,
|
MintQuoteState,
|
||||||
Proof,
|
Proof,
|
||||||
WalletKeyset,
|
WalletKeyset,
|
||||||
|
WalletMint,
|
||||||
)
|
)
|
||||||
from ..core.db import Connection, Database
|
from ..core.db import Connection, Database
|
||||||
|
|
||||||
@@ -577,3 +578,59 @@ async def store_seed_and_mnemonic(
|
|||||||
"mnemonic": mnemonic,
|
"mnemonic": mnemonic,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def store_mint(
|
||||||
|
db: Database,
|
||||||
|
mint: WalletMint,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO mints
|
||||||
|
(url, info, updated)
|
||||||
|
VALUES (:url, :info, :updated)
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"url": mint.url,
|
||||||
|
"info": mint.info,
|
||||||
|
"updated": int(time.time()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_mint(
|
||||||
|
db: Database,
|
||||||
|
mint: WalletMint,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> None:
|
||||||
|
await (conn or db).execute(
|
||||||
|
"""
|
||||||
|
UPDATE mints
|
||||||
|
SET info = :info, updated = :updated, access_token = :access_token, refresh_token = :refresh_token, username = :username, password = :password
|
||||||
|
WHERE url = :url
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"url": mint.url,
|
||||||
|
"info": mint.info,
|
||||||
|
"updated": int(time.time()),
|
||||||
|
"access_token": mint.access_token,
|
||||||
|
"refresh_token": mint.refresh_token,
|
||||||
|
"username": mint.username,
|
||||||
|
"password": mint.password,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_mint_by_url(
|
||||||
|
db: Database,
|
||||||
|
url: str,
|
||||||
|
conn: Optional[Connection] = None,
|
||||||
|
) -> Optional[WalletMint]:
|
||||||
|
row = await (conn or db).fetchone(
|
||||||
|
"""
|
||||||
|
SELECT * from mints WHERE url = :url
|
||||||
|
""",
|
||||||
|
{"url": url},
|
||||||
|
)
|
||||||
|
return WalletMint.parse_obj(dict(row)) if row else None
|
||||||
|
|||||||
16
cashu/wallet/errors.py
Normal file
16
cashu/wallet/errors.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class WalletError(Exception):
|
||||||
|
msg: str
|
||||||
|
|
||||||
|
def __init__(self, msg):
|
||||||
|
super().__init__(msg)
|
||||||
|
self.msg = msg
|
||||||
|
|
||||||
|
|
||||||
|
class BalanceTooLowError(WalletError):
|
||||||
|
msg = "Balance too low"
|
||||||
|
|
||||||
|
def __init__(self, msg: Optional[str] = None):
|
||||||
|
super().__init__(msg or self.msg)
|
||||||
@@ -52,6 +52,8 @@ async def redeem_TokenV3(wallet: Wallet, token: TokenV3) -> Wallet:
|
|||||||
t.mint,
|
t.mint,
|
||||||
os.path.join(settings.cashu_dir, wallet.name),
|
os.path.join(settings.cashu_dir, wallet.name),
|
||||||
unit=token.unit or wallet.unit.name,
|
unit=token.unit or wallet.unit.name,
|
||||||
|
auth_db=wallet.auth_db.db_location if wallet.auth_db else None,
|
||||||
|
auth_keyset_id=wallet.auth_keyset_id,
|
||||||
)
|
)
|
||||||
keyset_ids = mint_wallet._get_proofs_keyset_ids(t.proofs)
|
keyset_ids = mint_wallet._get_proofs_keyset_ids(t.proofs)
|
||||||
logger.trace(f"Keysets in tokens: {' '.join(set(keyset_ids))}")
|
logger.trace(f"Keysets in tokens: {' '.join(set(keyset_ids))}")
|
||||||
|
|||||||
@@ -214,7 +214,6 @@ async def m010_add_ids_to_proofs_and_out_to_invoices(db: Database):
|
|||||||
Columns that store mint and melt id for proofs and invoices.
|
Columns that store mint and melt id for proofs and invoices.
|
||||||
"""
|
"""
|
||||||
async with db.connect() as conn:
|
async with db.connect() as conn:
|
||||||
print("Running wallet migrations")
|
|
||||||
await conn.execute("ALTER TABLE proofs ADD COLUMN mint_id TEXT")
|
await conn.execute("ALTER TABLE proofs ADD COLUMN mint_id TEXT")
|
||||||
await conn.execute("ALTER TABLE proofs ADD COLUMN melt_id TEXT")
|
await conn.execute("ALTER TABLE proofs ADD COLUMN melt_id TEXT")
|
||||||
|
|
||||||
@@ -287,7 +286,8 @@ async def m013_add_mint_and_melt_quote_tables(db: Database):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
async def m013_add_key_to_mint_quote_table(db: Database):
|
|
||||||
|
async def m014_add_key_to_mint_quote_table(db: Database):
|
||||||
async with db.connect() as conn:
|
async with db.connect() as conn:
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"""
|
"""
|
||||||
@@ -295,3 +295,21 @@ async def m013_add_key_to_mint_quote_table(db: Database):
|
|||||||
ADD COLUMN privkey TEXT DEFAULT NULL;
|
ADD COLUMN privkey TEXT DEFAULT NULL;
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def m015_add_mints_table(db: Database):
|
||||||
|
async with db.connect() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS mints (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
url TEXT NOT NULL,
|
||||||
|
info TEXT NOT NULL,
|
||||||
|
updated TIMESTAMP DEFAULT {db.timestamp_now},
|
||||||
|
access_token TEXT,
|
||||||
|
refresh_token TEXT,
|
||||||
|
username TEXT,
|
||||||
|
password TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from typing import Dict, List, Protocol
|
from typing import Dict, List, Optional, Protocol
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from ..core.base import Proof, Unit, WalletKeyset
|
from ..core.base import Proof, Unit, WalletKeyset
|
||||||
from ..core.crypto.secp import PrivateKey
|
from ..core.crypto.secp import PrivateKey
|
||||||
from ..core.db import Database
|
from ..core.db import Database
|
||||||
|
from ..core.mint_info import MintInfo
|
||||||
|
|
||||||
|
|
||||||
class SupportsPrivateKey(Protocol):
|
class SupportsPrivateKey(Protocol):
|
||||||
@@ -28,3 +29,9 @@ class SupportsHttpxClient(Protocol):
|
|||||||
|
|
||||||
class SupportsMintURL(Protocol):
|
class SupportsMintURL(Protocol):
|
||||||
url: str
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
class SupportsAuth(Protocol):
|
||||||
|
auth_db: Optional[Database] = None
|
||||||
|
auth_keyset_id: Optional[str] = None
|
||||||
|
mint_info: Optional[MintInfo] = None
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class SubscriptionManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
msg = JSONRPCNotification.parse_raw(message)
|
msg = JSONRPCNotification.parse_raw(message)
|
||||||
logger.debug(f"Received notification: {msg}")
|
logger.trace(f"Received notification: {msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error parsing notification: {e}")
|
logger.error(f"Error parsing notification: {e}")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from pydantic import ValidationError
|
|||||||
from cashu.wallet.crud import get_bolt11_melt_quote
|
from cashu.wallet.crud import get_bolt11_melt_quote
|
||||||
|
|
||||||
from ..core.base import (
|
from ..core.base import (
|
||||||
|
AuthProof,
|
||||||
BlindedMessage,
|
BlindedMessage,
|
||||||
BlindedSignature,
|
BlindedSignature,
|
||||||
MeltQuoteState,
|
MeltQuoteState,
|
||||||
@@ -29,6 +30,8 @@ from ..core.models import (
|
|||||||
KeysetsResponse,
|
KeysetsResponse,
|
||||||
KeysetsResponseKeyset,
|
KeysetsResponseKeyset,
|
||||||
KeysResponse,
|
KeysResponse,
|
||||||
|
PostAuthBlindMintRequest,
|
||||||
|
PostAuthBlindMintResponse,
|
||||||
PostCheckStateRequest,
|
PostCheckStateRequest,
|
||||||
PostCheckStateResponse,
|
PostCheckStateResponse,
|
||||||
PostMeltQuoteRequest,
|
PostMeltQuoteRequest,
|
||||||
@@ -47,8 +50,16 @@ from ..core.models import (
|
|||||||
)
|
)
|
||||||
from ..core.settings import settings
|
from ..core.settings import settings
|
||||||
from ..tor.tor import TorProxy
|
from ..tor.tor import TorProxy
|
||||||
|
from .crud import (
|
||||||
|
get_proofs,
|
||||||
|
invalidate_proof,
|
||||||
|
)
|
||||||
|
from .protocols import SupportsAuth
|
||||||
from .wallet_deprecated import LedgerAPIDeprecated
|
from .wallet_deprecated import LedgerAPIDeprecated
|
||||||
|
|
||||||
|
GET = "GET"
|
||||||
|
POST = "POST"
|
||||||
|
|
||||||
|
|
||||||
def async_set_httpx_client(func):
|
def async_set_httpx_client(func):
|
||||||
"""
|
"""
|
||||||
@@ -78,7 +89,7 @@ def async_set_httpx_client(func):
|
|||||||
verify=not settings.debug,
|
verify=not settings.debug,
|
||||||
proxies=proxies_dict, # type: ignore
|
proxies=proxies_dict, # type: ignore
|
||||||
headers=headers_dict,
|
headers=headers_dict,
|
||||||
base_url=self.url,
|
base_url=self.url.rstrip("/"),
|
||||||
timeout=None if settings.debug else 60,
|
timeout=None if settings.debug else 60,
|
||||||
)
|
)
|
||||||
return await func(self, *args, **kwargs)
|
return await func(self, *args, **kwargs)
|
||||||
@@ -99,10 +110,10 @@ def async_ensure_mint_loaded(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class LedgerAPI(LedgerAPIDeprecated):
|
class LedgerAPI(LedgerAPIDeprecated, SupportsAuth):
|
||||||
tor: TorProxy
|
tor: TorProxy
|
||||||
db: Database # we need the db for melt_deprecated
|
|
||||||
httpx: httpx.AsyncClient
|
httpx: httpx.AsyncClient
|
||||||
|
api_prefix = "v1"
|
||||||
|
|
||||||
def __init__(self, url: str, db: Database):
|
def __init__(self, url: str, db: Database):
|
||||||
self.url = url
|
self.url = url
|
||||||
@@ -128,7 +139,6 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
try:
|
try:
|
||||||
resp_dict = resp.json()
|
resp_dict = resp.json()
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# if we can't decode the response, raise for status
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return
|
return
|
||||||
if "detail" in resp_dict:
|
if "detail" in resp_dict:
|
||||||
@@ -137,9 +147,49 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
if "code" in resp_dict:
|
if "code" in resp_dict:
|
||||||
error_message += f" (Code: {resp_dict['code']})"
|
error_message += f" (Code: {resp_dict['code']})"
|
||||||
raise Exception(error_message)
|
raise Exception(error_message)
|
||||||
# raise for status if no error
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
async def _request(self, method: str, path: str, noprefix=False, **kwargs):
|
||||||
|
if not noprefix:
|
||||||
|
path = join(self.api_prefix, path)
|
||||||
|
if self.mint_info and self.mint_info.requires_blind_auth_path(method, path):
|
||||||
|
if not self.auth_db:
|
||||||
|
raise Exception(
|
||||||
|
"Mint requires blind auth, but no auth database is set."
|
||||||
|
)
|
||||||
|
if not self.auth_keyset_id:
|
||||||
|
raise Exception(
|
||||||
|
"Mint requires blind auth, but no auth keyset id is set."
|
||||||
|
)
|
||||||
|
proofs = await get_proofs(db=self.auth_db, id=self.auth_keyset_id)
|
||||||
|
if not proofs:
|
||||||
|
raise Exception(
|
||||||
|
"Mint requires blind auth, but no blind auth tokens were found."
|
||||||
|
)
|
||||||
|
# select one auth proof
|
||||||
|
proof = proofs[0]
|
||||||
|
auth_token = AuthProof.from_proof(proof).to_base64()
|
||||||
|
kwargs.setdefault("headers", {}).update(
|
||||||
|
{
|
||||||
|
"Blind-auth": f"{auth_token}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await invalidate_proof(proof=proof, db=self.auth_db)
|
||||||
|
if self.mint_info and self.mint_info.requires_clear_auth_path(method, path):
|
||||||
|
logger.debug(f"Using clear auth token for {path}")
|
||||||
|
clear_auth_token = kwargs.pop("clear_auth_token")
|
||||||
|
if not clear_auth_token:
|
||||||
|
raise Exception(
|
||||||
|
"Mint requires clear auth, but no clear auth token is set."
|
||||||
|
)
|
||||||
|
kwargs.setdefault("headers", {}).update(
|
||||||
|
{
|
||||||
|
"Clear-auth": f"{clear_auth_token}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.httpx.request(method, path, **kwargs)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ENDPOINTS
|
ENDPOINTS
|
||||||
"""
|
"""
|
||||||
@@ -157,9 +207,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If no keys are received from the mint
|
Exception: If no keys are received from the mint
|
||||||
"""
|
"""
|
||||||
resp = await self.httpx.get(
|
resp = await self._request(GET, "keys")
|
||||||
join(self.url, "/v1/keys"),
|
|
||||||
)
|
|
||||||
# BEGIN backwards compatibility < 0.15.0
|
# BEGIN backwards compatibility < 0.15.0
|
||||||
# assume the mint has not upgraded yet if we get a 404
|
# assume the mint has not upgraded yet if we get a 404
|
||||||
if resp.status_code == 404:
|
if resp.status_code == 404:
|
||||||
@@ -201,9 +249,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
Exception: If no keys are received from the mint
|
Exception: If no keys are received from the mint
|
||||||
"""
|
"""
|
||||||
keyset_id_urlsafe = keyset_id.replace("+", "-").replace("/", "_")
|
keyset_id_urlsafe = keyset_id.replace("+", "-").replace("/", "_")
|
||||||
resp = await self.httpx.get(
|
resp = await self._request(GET, f"keys/{keyset_id_urlsafe}")
|
||||||
join(self.url, f"/v1/keys/{keyset_id_urlsafe}"),
|
|
||||||
)
|
|
||||||
# BEGIN backwards compatibility < 0.15.0
|
# BEGIN backwards compatibility < 0.15.0
|
||||||
# assume the mint has not upgraded yet if we get a 404
|
# assume the mint has not upgraded yet if we get a 404
|
||||||
if resp.status_code == 404:
|
if resp.status_code == 404:
|
||||||
@@ -238,9 +284,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If no keysets are received from the mint
|
Exception: If no keysets are received from the mint
|
||||||
"""
|
"""
|
||||||
resp = await self.httpx.get(
|
resp = await self._request(GET, "keysets")
|
||||||
join(self.url, "/v1/keysets"),
|
|
||||||
)
|
|
||||||
# BEGIN backwards compatibility < 0.15.0
|
# BEGIN backwards compatibility < 0.15.0
|
||||||
# assume the mint has not upgraded yet if we get a 404
|
# assume the mint has not upgraded yet if we get a 404
|
||||||
if resp.status_code == 404:
|
if resp.status_code == 404:
|
||||||
@@ -265,9 +309,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If the mint info request fails
|
Exception: If the mint info request fails
|
||||||
"""
|
"""
|
||||||
resp = await self.httpx.get(
|
resp = await self._request(GET, "/v1/info", noprefix=True)
|
||||||
join(self.url, "/v1/info"),
|
|
||||||
)
|
|
||||||
# BEGIN backwards compatibility < 0.15.0
|
# BEGIN backwards compatibility < 0.15.0
|
||||||
# assume the mint has not upgraded yet if we get a 404
|
# assume the mint has not upgraded yet if we get a 404
|
||||||
if resp.status_code == 404:
|
if resp.status_code == 404:
|
||||||
@@ -305,9 +347,12 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
payload = PostMintQuoteRequest(
|
payload = PostMintQuoteRequest(
|
||||||
unit=unit.name, amount=amount, description=memo, pubkey=pubkey
|
unit=unit.name, amount=amount, description=memo, pubkey=pubkey
|
||||||
)
|
)
|
||||||
resp = await self.httpx.post(
|
resp = await self._request(
|
||||||
join(self.url, "/v1/mint/quote/bolt11"), json=payload.dict()
|
POST,
|
||||||
|
"mint/quote/bolt11",
|
||||||
|
json=payload.dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# BEGIN backwards compatibility < 0.15.0
|
# BEGIN backwards compatibility < 0.15.0
|
||||||
# assume the mint has not upgraded yet if we get a 404
|
# assume the mint has not upgraded yet if we get a 404
|
||||||
if resp.status_code == 404:
|
if resp.status_code == 404:
|
||||||
@@ -329,9 +374,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
Returns:
|
Returns:
|
||||||
PostMintQuoteResponse: Mint Quote Response
|
PostMintQuoteResponse: Mint Quote Response
|
||||||
"""
|
"""
|
||||||
resp = await self.httpx.get(
|
resp = await self._request(GET, f"mint/quote/bolt11/{quote}")
|
||||||
join(self.url, f"/v1/mint/quote/bolt11/{quote}"),
|
|
||||||
)
|
|
||||||
self.raise_on_error_request(resp)
|
self.raise_on_error_request(resp)
|
||||||
return_dict = resp.json()
|
return_dict = resp.json()
|
||||||
return PostMintQuoteResponse.parse_obj(return_dict)
|
return PostMintQuoteResponse.parse_obj(return_dict)
|
||||||
@@ -371,8 +414,9 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
payload = outputs_payload.dict(include=_mintrequest_include_fields(outputs)) # type: ignore
|
payload = outputs_payload.dict(include=_mintrequest_include_fields(outputs)) # type: ignore
|
||||||
resp = await self.httpx.post(
|
resp = await self._request(
|
||||||
join(self.url, "/v1/mint/bolt11"),
|
POST,
|
||||||
|
"mint/bolt11",
|
||||||
json=payload, # type: ignore
|
json=payload, # type: ignore
|
||||||
)
|
)
|
||||||
# BEGIN backwards compatibility < 0.15.0
|
# BEGIN backwards compatibility < 0.15.0
|
||||||
@@ -383,7 +427,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
# END backwards compatibility < 0.15.0
|
# END backwards compatibility < 0.15.0
|
||||||
self.raise_on_error_request(resp)
|
self.raise_on_error_request(resp)
|
||||||
response_dict = resp.json()
|
response_dict = resp.json()
|
||||||
logger.trace("Lightning invoice checked. POST /v1/mint/bolt11")
|
logger.trace(f"Lightning invoice checked. POST {self.api_prefix}/mint/bolt11")
|
||||||
promises = PostMintResponse.parse_obj(response_dict).signatures
|
promises = PostMintResponse.parse_obj(response_dict).signatures
|
||||||
return promises
|
return promises
|
||||||
|
|
||||||
@@ -406,8 +450,9 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
unit=unit.name, request=payment_request, options=melt_options
|
unit=unit.name, request=payment_request, options=melt_options
|
||||||
)
|
)
|
||||||
|
|
||||||
resp = await self.httpx.post(
|
resp = await self._request(
|
||||||
join(self.url, "/v1/melt/quote/bolt11"),
|
POST,
|
||||||
|
"melt/quote/bolt11",
|
||||||
json=payload.dict(),
|
json=payload.dict(),
|
||||||
)
|
)
|
||||||
# BEGIN backwards compatibility < 0.15.0
|
# BEGIN backwards compatibility < 0.15.0
|
||||||
@@ -441,9 +486,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
Returns:
|
Returns:
|
||||||
PostMeltQuoteResponse: Melt Quote Response
|
PostMeltQuoteResponse: Melt Quote Response
|
||||||
"""
|
"""
|
||||||
resp = await self.httpx.get(
|
resp = await self._request(GET, f"melt/quote/bolt11/{quote}")
|
||||||
join(self.url, f"/v1/melt/quote/bolt11/{quote}"),
|
|
||||||
)
|
|
||||||
self.raise_on_error_request(resp)
|
self.raise_on_error_request(resp)
|
||||||
return_dict = resp.json()
|
return_dict = resp.json()
|
||||||
return PostMeltQuoteResponse.parse_obj(return_dict)
|
return PostMeltQuoteResponse.parse_obj(return_dict)
|
||||||
@@ -474,8 +517,9 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
"outputs": {i: outputs_include for i in range(len(outputs))},
|
"outputs": {i: outputs_include for i in range(len(outputs))},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = await self.httpx.post(
|
resp = await self._request(
|
||||||
join(self.url, "/v1/melt/bolt11"),
|
POST,
|
||||||
|
"melt/bolt11",
|
||||||
json=payload.dict(include=_meltrequest_include_fields(proofs, outputs)), # type: ignore
|
json=payload.dict(include=_meltrequest_include_fields(proofs, outputs)), # type: ignore
|
||||||
timeout=None,
|
timeout=None,
|
||||||
)
|
)
|
||||||
@@ -523,7 +567,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
outputs: List[BlindedMessage],
|
outputs: List[BlindedMessage],
|
||||||
) -> List[BlindedSignature]:
|
) -> List[BlindedSignature]:
|
||||||
"""Consume proofs and create new promises based on amount split."""
|
"""Consume proofs and create new promises based on amount split."""
|
||||||
logger.debug("Calling split. POST /v1/swap")
|
logger.debug(f"Calling split. POST {self.api_prefix}/swap")
|
||||||
split_payload = PostSwapRequest(inputs=proofs, outputs=outputs)
|
split_payload = PostSwapRequest(inputs=proofs, outputs=outputs)
|
||||||
|
|
||||||
# construct payload
|
# construct payload
|
||||||
@@ -541,8 +585,9 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
"inputs": {i: proofs_include for i in range(len(proofs))},
|
"inputs": {i: proofs_include for i in range(len(proofs))},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = await self.httpx.post(
|
resp = await self._request(
|
||||||
join(self.url, "/v1/swap"),
|
POST,
|
||||||
|
"swap",
|
||||||
json=split_payload.dict(include=_splitrequest_include_fields(proofs)), # type: ignore
|
json=split_payload.dict(include=_splitrequest_include_fields(proofs)), # type: ignore
|
||||||
)
|
)
|
||||||
# BEGIN backwards compatibility < 0.15.0
|
# BEGIN backwards compatibility < 0.15.0
|
||||||
@@ -568,8 +613,9 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
Checks whether the secrets in proofs are already spent or not and returns a list of booleans.
|
Checks whether the secrets in proofs are already spent or not and returns a list of booleans.
|
||||||
"""
|
"""
|
||||||
payload = PostCheckStateRequest(Ys=[p.Y for p in proofs])
|
payload = PostCheckStateRequest(Ys=[p.Y for p in proofs])
|
||||||
resp = await self.httpx.post(
|
resp = await self._request(
|
||||||
join(self.url, "/v1/checkstate"),
|
POST,
|
||||||
|
"checkstate",
|
||||||
json=payload.dict(),
|
json=payload.dict(),
|
||||||
)
|
)
|
||||||
# BEGIN backwards compatibility < 0.15.0
|
# BEGIN backwards compatibility < 0.15.0
|
||||||
@@ -595,10 +641,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
"Received HTTP Error 422. Attempting state check with < 0.16.0 compatibility."
|
"Received HTTP Error 422. Attempting state check with < 0.16.0 compatibility."
|
||||||
)
|
)
|
||||||
payload_secrets = {"secrets": [p.secret for p in proofs]}
|
payload_secrets = {"secrets": [p.secret for p in proofs]}
|
||||||
resp_secrets = await self.httpx.post(
|
resp_secrets = await self._request(POST, "checkstate", json=payload_secrets)
|
||||||
join(self.url, "/v1/checkstate"),
|
|
||||||
json=payload_secrets,
|
|
||||||
)
|
|
||||||
self.raise_on_error(resp_secrets)
|
self.raise_on_error(resp_secrets)
|
||||||
states = [
|
states = [
|
||||||
ProofState(Y=p.Y, state=ProofSpentState(s["state"]))
|
ProofState(Y=p.Y, state=ProofSpentState(s["state"]))
|
||||||
@@ -619,7 +662,7 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
Asks the mint to restore promises corresponding to outputs.
|
Asks the mint to restore promises corresponding to outputs.
|
||||||
"""
|
"""
|
||||||
payload = PostMintRequest(quote="restore", outputs=outputs)
|
payload = PostMintRequest(quote="restore", outputs=outputs)
|
||||||
resp = await self.httpx.post(join(self.url, "/v1/restore"), json=payload.dict())
|
resp = await self._request(POST, "restore", json=payload.dict())
|
||||||
# BEGIN backwards compatibility < 0.15.0
|
# BEGIN backwards compatibility < 0.15.0
|
||||||
# assume the mint has not upgraded yet if we get a 404
|
# assume the mint has not upgraded yet if we get a 404
|
||||||
if resp.status_code == 404:
|
if resp.status_code == 404:
|
||||||
@@ -637,3 +680,21 @@ class LedgerAPI(LedgerAPIDeprecated):
|
|||||||
# END backwards compatibility < 0.15.1
|
# END backwards compatibility < 0.15.1
|
||||||
|
|
||||||
return returnObj.outputs, returnObj.signatures
|
return returnObj.outputs, returnObj.signatures
|
||||||
|
|
||||||
|
@async_set_httpx_client
|
||||||
|
async def blind_mint_blind_auth(
|
||||||
|
self, clear_auth_token: str, outputs: List[BlindedMessage]
|
||||||
|
) -> List[BlindedSignature]:
|
||||||
|
"""
|
||||||
|
Asks the mint to mint blind auth tokens. Needs to provide a clear auth token.
|
||||||
|
"""
|
||||||
|
payload = PostAuthBlindMintRequest(outputs=outputs)
|
||||||
|
resp = await self._request(
|
||||||
|
POST,
|
||||||
|
"mint",
|
||||||
|
json=payload.dict(),
|
||||||
|
clear_auth_token=clear_auth_token,
|
||||||
|
)
|
||||||
|
self.raise_on_error_request(resp)
|
||||||
|
response_dict = resp.json()
|
||||||
|
return PostAuthBlindMintResponse.parse_obj(response_dict).signatures
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import json
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
@@ -17,6 +18,7 @@ from ..core.base import (
|
|||||||
Proof,
|
Proof,
|
||||||
Unit,
|
Unit,
|
||||||
WalletKeyset,
|
WalletKeyset,
|
||||||
|
WalletMint,
|
||||||
)
|
)
|
||||||
from ..core.crypto import b_dhke
|
from ..core.crypto import b_dhke
|
||||||
from ..core.crypto.secp import PrivateKey, PublicKey
|
from ..core.crypto.secp import PrivateKey, PublicKey
|
||||||
@@ -30,6 +32,7 @@ from ..core.helpers import (
|
|||||||
)
|
)
|
||||||
from ..core.json_rpc.base import JSONRPCSubscriptionKinds
|
from ..core.json_rpc.base import JSONRPCSubscriptionKinds
|
||||||
from ..core.migrations import migrate_databases
|
from ..core.migrations import migrate_databases
|
||||||
|
from ..core.mint_info import MintInfo
|
||||||
from ..core.models import (
|
from ..core.models import (
|
||||||
PostCheckStateResponse,
|
PostCheckStateResponse,
|
||||||
PostMeltQuoteResponse,
|
PostMeltQuoteResponse,
|
||||||
@@ -43,6 +46,7 @@ from .crud import (
|
|||||||
bump_secret_derivation,
|
bump_secret_derivation,
|
||||||
get_bolt11_mint_quote,
|
get_bolt11_mint_quote,
|
||||||
get_keysets,
|
get_keysets,
|
||||||
|
get_mint_by_url,
|
||||||
get_proofs,
|
get_proofs,
|
||||||
invalidate_proof,
|
invalidate_proof,
|
||||||
secret_used,
|
secret_used,
|
||||||
@@ -50,14 +54,16 @@ from .crud import (
|
|||||||
store_bolt11_melt_quote,
|
store_bolt11_melt_quote,
|
||||||
store_bolt11_mint_quote,
|
store_bolt11_mint_quote,
|
||||||
store_keyset,
|
store_keyset,
|
||||||
|
store_mint,
|
||||||
store_proof,
|
store_proof,
|
||||||
update_bolt11_melt_quote,
|
update_bolt11_melt_quote,
|
||||||
update_bolt11_mint_quote,
|
update_bolt11_mint_quote,
|
||||||
update_keyset,
|
update_keyset,
|
||||||
|
update_mint,
|
||||||
update_proof,
|
update_proof,
|
||||||
)
|
)
|
||||||
|
from .errors import BalanceTooLowError
|
||||||
from .htlc import WalletHTLC
|
from .htlc import WalletHTLC
|
||||||
from .mint_info import MintInfo
|
|
||||||
from .p2pk import WalletP2PK
|
from .p2pk import WalletP2PK
|
||||||
from .proofs import WalletProofs
|
from .proofs import WalletProofs
|
||||||
from .secrets import WalletSecrets
|
from .secrets import WalletSecrets
|
||||||
@@ -108,21 +114,41 @@ class Wallet(
|
|||||||
db: Database
|
db: Database
|
||||||
bip32: BIP32
|
bip32: BIP32
|
||||||
# private_key: Optional[PrivateKey] = None
|
# private_key: Optional[PrivateKey] = None
|
||||||
|
auth_db: Optional[Database] = None
|
||||||
|
auth_keyset_id: Optional[str] = None
|
||||||
|
|
||||||
def __init__(self, url: str, db: str, name: str = "wallet", unit: str = "sat"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
db: str,
|
||||||
|
name: str = "wallet",
|
||||||
|
unit: str = "sat",
|
||||||
|
auth_db: Optional[str] = None,
|
||||||
|
auth_keyset_id: Optional[str] = None,
|
||||||
|
):
|
||||||
"""A Cashu wallet.
|
"""A Cashu wallet.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url (str): URL of the mint.
|
url (str): URL of the mint.
|
||||||
db (str): Path to the database directory.
|
db (str): Path to the database directory.
|
||||||
name (str, optional): Name of the wallet database file. Defaults to "wallet".
|
name (str, optional): Name of the wallet database file. Defaults to "wallet".
|
||||||
|
unit (str, optional): Unit of the wallet. Defaults to "sat".
|
||||||
|
auth_db (Optional[str], optional): Path to the auth database directory. Defaults to None.
|
||||||
|
auth_keyset_id (Optional[str], optional): Keyset ID of the auth keyset. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.db = Database("wallet", db)
|
self.db = Database(name, db)
|
||||||
self.proofs: List[Proof] = []
|
self.proofs: List[Proof] = []
|
||||||
self.name = name
|
self.name = name
|
||||||
self.unit = Unit[unit]
|
self.unit = Unit[unit]
|
||||||
url = sanitize_url(url)
|
url = sanitize_url(url)
|
||||||
|
|
||||||
|
# if this is an auth wallet
|
||||||
|
if (auth_db and not auth_keyset_id) or (not auth_db and auth_keyset_id):
|
||||||
|
raise Exception("Both auth_db and auth_keyset_id must be provided.")
|
||||||
|
if auth_db and auth_keyset_id:
|
||||||
|
self.auth_db = Database("auth", auth_db)
|
||||||
|
self.auth_keyset_id = auth_keyset_id
|
||||||
|
|
||||||
super().__init__(url=url, db=self.db)
|
super().__init__(url=url, db=self.db)
|
||||||
logger.debug("Wallet initialized")
|
logger.debug("Wallet initialized")
|
||||||
logger.debug(f"Mint URL: {url}")
|
logger.debug(f"Mint URL: {url}")
|
||||||
@@ -137,7 +163,10 @@ class Wallet(
|
|||||||
name: str = "wallet",
|
name: str = "wallet",
|
||||||
skip_db_read: bool = False,
|
skip_db_read: bool = False,
|
||||||
unit: str = "sat",
|
unit: str = "sat",
|
||||||
|
auth_db: Optional[str] = None,
|
||||||
|
auth_keyset_id: Optional[str] = None,
|
||||||
load_all_keysets: bool = False,
|
load_all_keysets: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Initializes a wallet with a database and initializes the private key.
|
"""Initializes a wallet with a database and initializes the private key.
|
||||||
|
|
||||||
@@ -151,12 +180,22 @@ class Wallet(
|
|||||||
unit (str, optional): Unit of the wallet. Defaults to "sat".
|
unit (str, optional): Unit of the wallet. Defaults to "sat".
|
||||||
load_all_keysets (bool, optional): If true, all keysets are loaded from the database.
|
load_all_keysets (bool, optional): If true, all keysets are loaded from the database.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
|
auth_db (Optional[str], optional): Path to the auth database directory. Defaults to None.
|
||||||
|
auth_keyset_id (Optional[str], optional): Keyset ID of the auth keyset. Defaults to None.
|
||||||
|
kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Wallet: Initialized wallet.
|
Wallet: Initialized wallet.
|
||||||
"""
|
"""
|
||||||
logger.trace(f"Initializing wallet with database: {db}")
|
logger.trace(f"Initializing wallet with database: {db}")
|
||||||
self = cls(url=url, db=db, name=name, unit=unit)
|
self = cls(
|
||||||
|
url=url,
|
||||||
|
db=db,
|
||||||
|
name=name,
|
||||||
|
unit=unit,
|
||||||
|
auth_db=auth_db,
|
||||||
|
auth_keyset_id=auth_keyset_id,
|
||||||
|
)
|
||||||
await self._migrate_database()
|
await self._migrate_database()
|
||||||
|
|
||||||
if skip_db_read:
|
if skip_db_read:
|
||||||
@@ -172,8 +211,13 @@ class Wallet(
|
|||||||
self.keysets = {k.id: k for k in keysets_active_unit}
|
self.keysets = {k.id: k for k in keysets_active_unit}
|
||||||
else:
|
else:
|
||||||
self.keysets = {k.id: k for k in keysets_list}
|
self.keysets = {k.id: k for k in keysets_list}
|
||||||
keysets_str = " ".join([f"{i} {k.unit}" for i, k in self.keysets.items()])
|
|
||||||
logger.debug(f"Loaded keysets: {keysets_str}")
|
if self.keysets:
|
||||||
|
keysets_str = " ".join([f"{i} {k.unit}" for i, k in self.keysets.items()])
|
||||||
|
logger.debug(f"Loaded keysets: {keysets_str}")
|
||||||
|
|
||||||
|
await self.load_mint_info(offline=True)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def _migrate_database(self):
|
async def _migrate_database(self):
|
||||||
@@ -185,12 +229,63 @@ class Wallet(
|
|||||||
|
|
||||||
# ---------- API ----------
|
# ---------- API ----------
|
||||||
|
|
||||||
async def load_mint_info(self) -> MintInfo:
|
async def load_mint_info(self, reload=False, offline=False) -> MintInfo | None:
|
||||||
"""Loads the mint info from the mint."""
|
"""Loads the mint info from the mint.
|
||||||
mint_info_resp = await self._get_info()
|
|
||||||
self.mint_info = MintInfo(**mint_info_resp.dict())
|
Args:
|
||||||
logger.debug(f"Mint info: {self.mint_info}")
|
reload (bool, optional): If True, the mint info is reloaded from the mint. Defaults to False.
|
||||||
return self.mint_info
|
offline (bool, optional): If True, the mint info is not loaded from the mint. Defaults to False.
|
||||||
|
"""
|
||||||
|
# if self.mint_info and not reload:
|
||||||
|
# return self.mint_info
|
||||||
|
|
||||||
|
# read mint info from db
|
||||||
|
if reload:
|
||||||
|
if offline:
|
||||||
|
raise Exception("Cannot reload mint info offline.")
|
||||||
|
logger.debug("Forcing reload of mint info.")
|
||||||
|
mint_info_resp = await self._get_info()
|
||||||
|
self.mint_info = MintInfo(**mint_info_resp.dict())
|
||||||
|
|
||||||
|
wallet_mint_db = await get_mint_by_url(url=self.url, db=self.db)
|
||||||
|
if not wallet_mint_db:
|
||||||
|
if self.mint_info:
|
||||||
|
logger.debug("Storing mint info in db.")
|
||||||
|
await store_mint(
|
||||||
|
db=self.db,
|
||||||
|
mint=WalletMint(
|
||||||
|
url=self.url, info=json.dumps(self.mint_info.dict())
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if offline:
|
||||||
|
return None
|
||||||
|
logger.debug("Loading mint info from mint.")
|
||||||
|
mint_info_resp = await self._get_info()
|
||||||
|
self.mint_info = MintInfo(**mint_info_resp.dict())
|
||||||
|
if not wallet_mint_db:
|
||||||
|
logger.debug("Storing mint info in db.")
|
||||||
|
await store_mint(
|
||||||
|
db=self.db,
|
||||||
|
mint=WalletMint(
|
||||||
|
url=self.url, info=json.dumps(self.mint_info.dict())
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return self.mint_info
|
||||||
|
elif (
|
||||||
|
self.mint_info
|
||||||
|
and not json.dumps(self.mint_info.dict()) == wallet_mint_db.info
|
||||||
|
):
|
||||||
|
logger.debug("Updating mint info in db.")
|
||||||
|
await update_mint(
|
||||||
|
db=self.db,
|
||||||
|
mint=WalletMint(url=self.url, info=json.dumps(self.mint_info.dict())),
|
||||||
|
)
|
||||||
|
return self.mint_info
|
||||||
|
else:
|
||||||
|
logger.debug("Loading mint info from db.")
|
||||||
|
self.mint_info = MintInfo.from_json_str(wallet_mint_db.info)
|
||||||
|
return self.mint_info
|
||||||
|
|
||||||
async def load_mint_keysets(self, force_old_keysets=False):
|
async def load_mint_keysets(self, force_old_keysets=False):
|
||||||
"""Loads all keyset of the mint and makes sure we have them all in the database.
|
"""Loads all keyset of the mint and makes sure we have them all in the database.
|
||||||
@@ -304,11 +399,11 @@ class Wallet(
|
|||||||
force_old_keysets (bool, optional): If true, old deprecated base64 keysets are not ignored. This is necessary for restoring tokens from old base64 keysets.
|
force_old_keysets (bool, optional): If true, old deprecated base64 keysets are not ignored. This is necessary for restoring tokens from old base64 keysets.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
"""
|
"""
|
||||||
logger.trace("Loading mint.")
|
logger.trace(f"Loading mint {self.url}")
|
||||||
await self.load_mint_keysets(force_old_keysets)
|
await self.load_mint_keysets(force_old_keysets)
|
||||||
await self.activate_keyset(keyset_id)
|
await self.activate_keyset(keyset_id)
|
||||||
try:
|
try:
|
||||||
await self.load_mint_info()
|
await self.load_mint_info(reload=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Could not load mint info: {e}")
|
logger.debug(f"Could not load mint info: {e}")
|
||||||
pass
|
pass
|
||||||
@@ -979,7 +1074,7 @@ class Wallet(
|
|||||||
# select proofs that are not reserved and are in the active keysets of the mint
|
# select proofs that are not reserved and are in the active keysets of the mint
|
||||||
proofs = self.active_proofs(proofs)
|
proofs = self.active_proofs(proofs)
|
||||||
if sum_proofs(proofs) < amount:
|
if sum_proofs(proofs) < amount:
|
||||||
raise Exception("balance too low.")
|
raise BalanceTooLowError()
|
||||||
|
|
||||||
# coin selection for potentially offline sending
|
# coin selection for potentially offline sending
|
||||||
send_proofs = self.coinselect(proofs, amount, include_fees=include_fees)
|
send_proofs = self.coinselect(proofs, amount, include_fees=include_fees)
|
||||||
@@ -1037,7 +1132,7 @@ class Wallet(
|
|||||||
# select proofs that are not reserved and are in the active keysets of the mint
|
# select proofs that are not reserved and are in the active keysets of the mint
|
||||||
proofs = self.active_proofs(proofs)
|
proofs = self.active_proofs(proofs)
|
||||||
if sum_proofs(proofs) < amount:
|
if sum_proofs(proofs) < amount:
|
||||||
raise Exception("balance too low.")
|
raise BalanceTooLowError()
|
||||||
|
|
||||||
# coin selection for swapping, needs to include fees
|
# coin selection for swapping, needs to include fees
|
||||||
swap_proofs = self.coinselect(proofs, amount, include_fees=True)
|
swap_proofs = self.coinselect(proofs, amount, include_fees=True)
|
||||||
|
|||||||
@@ -144,6 +144,8 @@ class LedgerAPIDeprecated(SupportsHttpxClient, SupportsMintURL):
|
|||||||
mint_info = GetInfoResponse(
|
mint_info = GetInfoResponse(
|
||||||
**mint_info_deprecated.dict(exclude={"parameter", "nuts", "contact"})
|
**mint_info_deprecated.dict(exclude={"parameter", "nuts", "contact"})
|
||||||
)
|
)
|
||||||
|
# monkeypatch nuts
|
||||||
|
mint_info.nuts = {}
|
||||||
return mint_info
|
return mint_info
|
||||||
|
|
||||||
@async_set_httpx_client
|
@async_set_httpx_client
|
||||||
@@ -261,7 +263,7 @@ class LedgerAPIDeprecated(SupportsHttpxClient, SupportsMintURL):
|
|||||||
paid=False,
|
paid=False,
|
||||||
state=MintQuoteState.unpaid.value,
|
state=MintQuoteState.unpaid.value,
|
||||||
expiry=decoded_invoice.date + (decoded_invoice.expiry or 0),
|
expiry=decoded_invoice.date + (decoded_invoice.expiry or 0),
|
||||||
pubkey=None
|
pubkey=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@async_set_httpx_client
|
@async_set_httpx_client
|
||||||
|
|||||||
7
keycloak/.env.example
Normal file
7
keycloak/.env.example
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
POSTGRES_DB=keycloak_db
|
||||||
|
POSTGRES_USER=keycloak_db_user
|
||||||
|
POSTGRES_PASSWORD=keycloak_db_user_password
|
||||||
|
KEYCLOAK_ADMIN=admin
|
||||||
|
KEYCLOAK_ADMIN_PASSWORD=password
|
||||||
|
KC_HOSTNAME=localhost
|
||||||
|
KC_HOSTNAME_PORT=8080
|
||||||
129
keycloak/README.md
Normal file
129
keycloak/README.md
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
## Docker compose
|
||||||
|
|
||||||
|
This docker-compose starts a new keycloak instance. Set up the server as you wish, add realms, users etc. We will then export the data and restore an instance with the exported data.
|
||||||
|
|
||||||
|
We will modify this file later to start the server with the backup data.
|
||||||
|
|
||||||
|
```
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:16.4
|
||||||
|
volumes:
|
||||||
|
- ./postgres_data:/var/lib/postgresql/data
|
||||||
|
environment:
|
||||||
|
POSTGRES_DB: ${POSTGRES_DB}
|
||||||
|
POSTGRES_USER: ${POSTGRES_USER}
|
||||||
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||||
|
networks:
|
||||||
|
- keycloak_network
|
||||||
|
|
||||||
|
keycloak:
|
||||||
|
image: quay.io/keycloak/keycloak:25.0.6
|
||||||
|
command: start
|
||||||
|
environment:
|
||||||
|
KC_HOSTNAME: localhost
|
||||||
|
KC_HOSTNAME_PORT: 8080
|
||||||
|
KC_HOSTNAME_STRICT_BACKCHANNEL: false
|
||||||
|
KC_HTTP_ENABLED: true
|
||||||
|
KC_HOSTNAME_STRICT_HTTPS: false
|
||||||
|
KC_HEALTH_ENABLED: true
|
||||||
|
KEYCLOAK_ADMIN: ${KEYCLOAK_ADMIN}
|
||||||
|
KEYCLOAK_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD}
|
||||||
|
KC_DB: postgres
|
||||||
|
KC_DB_URL: jdbc:postgresql://postgres/${POSTGRES_DB}
|
||||||
|
KC_DB_USERNAME: ${POSTGRES_USER}
|
||||||
|
KC_DB_PASSWORD: ${POSTGRES_PASSWORD}
|
||||||
|
ports:
|
||||||
|
- 8080:8080
|
||||||
|
restart: always
|
||||||
|
depends_on:
|
||||||
|
- postgres
|
||||||
|
networks:
|
||||||
|
- keycloak_network
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
driver: local
|
||||||
|
|
||||||
|
networks:
|
||||||
|
keycloak_network:
|
||||||
|
driver: bridge
|
||||||
|
```
|
||||||
|
|
||||||
|
## Backup
|
||||||
|
|
||||||
|
Export realm and users from running container:
|
||||||
|
|
||||||
|
```
|
||||||
|
docker exec keycloak-keycloak-1 \
|
||||||
|
/opt/keycloak/bin/kc.sh export \
|
||||||
|
--dir /opt/keycloak/data/export \
|
||||||
|
--users different_files \
|
||||||
|
--http-management-port 46566
|
||||||
|
```
|
||||||
|
|
||||||
|
Copy export out of the docker
|
||||||
|
|
||||||
|
```
|
||||||
|
docker cp keycloak-keycloak-1:/opt/keycloak/data/export ./keycloak-export
|
||||||
|
```
|
||||||
|
|
||||||
|
## Restore
|
||||||
|
|
||||||
|
Use this docker-compose.yml to start keycloak with the exported backup:
|
||||||
|
|
||||||
|
```
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:16.4
|
||||||
|
volumes:
|
||||||
|
- ./postgres_data:/var/lib/postgresql/data
|
||||||
|
environment:
|
||||||
|
POSTGRES_DB: ${POSTGRES_DB}
|
||||||
|
POSTGRES_USER: ${POSTGRES_USER}
|
||||||
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||||
|
networks:
|
||||||
|
- keycloak_network
|
||||||
|
|
||||||
|
keycloak:
|
||||||
|
image: quay.io/keycloak/keycloak:25.0.6
|
||||||
|
command: start --import-realm
|
||||||
|
volumes:
|
||||||
|
- ./keycloak-export:/opt/keycloak/data/import
|
||||||
|
environment:
|
||||||
|
KC_HOSTNAME: localhost
|
||||||
|
KC_HOSTNAME_PORT: 8080
|
||||||
|
KC_HOSTNAME_STRICT_BACKCHANNEL: false
|
||||||
|
KC_HTTP_ENABLED: true
|
||||||
|
KC_HOSTNAME_STRICT_HTTPS: false
|
||||||
|
KC_HEALTH_ENABLED: true
|
||||||
|
KEYCLOAK_ADMIN: ${KEYCLOAK_ADMIN}
|
||||||
|
KEYCLOAK_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD}
|
||||||
|
KC_DB: postgres
|
||||||
|
KC_DB_URL: jdbc:postgresql://postgres/${POSTGRES_DB}
|
||||||
|
KC_DB_USERNAME: ${POSTGRES_USER}
|
||||||
|
KC_DB_PASSWORD: ${POSTGRES_PASSWORD}
|
||||||
|
ports:
|
||||||
|
- 8080:8080
|
||||||
|
restart: always
|
||||||
|
depends_on:
|
||||||
|
- postgres
|
||||||
|
networks:
|
||||||
|
- keycloak_network
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
driver: local
|
||||||
|
|
||||||
|
networks:
|
||||||
|
keycloak_network:
|
||||||
|
driver: bridge
|
||||||
|
```
|
||||||
|
|
||||||
|
Difference to first docker-compose is only the following part:
|
||||||
|
|
||||||
|
```
|
||||||
|
command: start --import-realm
|
||||||
|
volumes:
|
||||||
|
- ./keycloak-export:/opt/keycloak/data/import
|
||||||
|
```
|
||||||
45
keycloak/docker-compose.yml
Normal file
45
keycloak/docker-compose.yml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:16.4
|
||||||
|
volumes:
|
||||||
|
- ./postgres_data:/var/lib/postgresql/data
|
||||||
|
environment:
|
||||||
|
POSTGRES_DB: ${POSTGRES_DB}
|
||||||
|
POSTGRES_USER: ${POSTGRES_USER}
|
||||||
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||||
|
networks:
|
||||||
|
- keycloak_network
|
||||||
|
|
||||||
|
keycloak:
|
||||||
|
image: quay.io/keycloak/keycloak:25.0.6
|
||||||
|
command: start --import-realm
|
||||||
|
volumes:
|
||||||
|
- ./keycloak-export:/opt/keycloak/data/import
|
||||||
|
environment:
|
||||||
|
KC_HOSTNAME: ${KC_HOSTNAME}
|
||||||
|
KC_HOSTNAME_PORT: ${KC_HOSTNAME_PORT}
|
||||||
|
KC_HOSTNAME_STRICT_BACKCHANNEL: false
|
||||||
|
KC_HTTP_ENABLED: true
|
||||||
|
KC_HOSTNAME_STRICT_HTTPS: true
|
||||||
|
KC_HEALTH_ENABLED: true
|
||||||
|
KEYCLOAK_ADMIN: ${KEYCLOAK_ADMIN}
|
||||||
|
KEYCLOAK_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD}
|
||||||
|
KC_DB: postgres
|
||||||
|
KC_DB_URL: jdbc:postgresql://postgres/${POSTGRES_DB}
|
||||||
|
KC_DB_USERNAME: ${POSTGRES_USER}
|
||||||
|
KC_DB_PASSWORD: ${POSTGRES_PASSWORD}
|
||||||
|
ports:
|
||||||
|
- 8080:8080
|
||||||
|
restart: always
|
||||||
|
depends_on:
|
||||||
|
- postgres
|
||||||
|
networks:
|
||||||
|
- keycloak_network
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
driver: local
|
||||||
|
|
||||||
|
networks:
|
||||||
|
keycloak_network:
|
||||||
|
driver: bridge
|
||||||
1099
poetry.lock
generated
1099
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -39,9 +39,11 @@ googleapis-common-protos = "^1.63.2"
|
|||||||
mypy-protobuf = "^3.6.0"
|
mypy-protobuf = "^3.6.0"
|
||||||
types-protobuf = "^5.27.0.20240626"
|
types-protobuf = "^5.27.0.20240626"
|
||||||
grpcio-tools = "^1.65.1"
|
grpcio-tools = "^1.65.1"
|
||||||
|
pyjwt = "^2.9.0"
|
||||||
redis = "^5.1.1"
|
redis = "^5.1.1"
|
||||||
brotli = "^1.1.0"
|
brotli = "^1.1.0"
|
||||||
zstandard = "^0.23.0"
|
zstandard = "^0.23.0"
|
||||||
|
jinja2 = "^3.1.5"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pytest-asyncio = "^0.24.0"
|
pytest-asyncio = "^0.24.0"
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ settings.mint_lnd_enable_mpp = True
|
|||||||
settings.mint_clnrest_enable_mpp = True
|
settings.mint_clnrest_enable_mpp = True
|
||||||
settings.mint_input_fee_ppk = 0
|
settings.mint_input_fee_ppk = 0
|
||||||
settings.db_connection_pool = True
|
settings.db_connection_pool = True
|
||||||
|
# settings.mint_require_auth = False
|
||||||
|
|
||||||
assert "test" in settings.cashu_dir
|
assert "test" in settings.cashu_dir
|
||||||
shutil.rmtree(settings.cashu_dir, ignore_errors=True)
|
shutil.rmtree(settings.cashu_dir, ignore_errors=True)
|
||||||
|
|||||||
45
tests/keycloak_data/docker-compose-restore.yml
Normal file
45
tests/keycloak_data/docker-compose-restore.yml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:16.4
|
||||||
|
volumes:
|
||||||
|
- ./postgres_data:/var/lib/postgresql/data
|
||||||
|
environment:
|
||||||
|
POSTGRES_DB: cashu
|
||||||
|
POSTGRES_USER: cashu
|
||||||
|
POSTGRES_PASSWORD: cashu
|
||||||
|
networks:
|
||||||
|
- keycloak_network
|
||||||
|
|
||||||
|
keycloak:
|
||||||
|
image: quay.io/keycloak/keycloak:25.0.6
|
||||||
|
command: start --import-realm
|
||||||
|
volumes:
|
||||||
|
- ./keycloak-export:/opt/keycloak/data/import
|
||||||
|
environment:
|
||||||
|
KC_HOSTNAME: localhost
|
||||||
|
KC_HOSTNAME_PORT: 8080
|
||||||
|
KC_HOSTNAME_STRICT_BACKCHANNEL: false
|
||||||
|
KC_HTTP_ENABLED: true
|
||||||
|
KC_HOSTNAME_STRICT_HTTPS: false
|
||||||
|
KC_HEALTH_ENABLED: true
|
||||||
|
KEYCLOAK_ADMIN: admin
|
||||||
|
KEYCLOAK_ADMIN_PASSWORD: admin
|
||||||
|
KC_DB: postgres
|
||||||
|
KC_DB_URL: jdbc:postgresql://postgres/cashu
|
||||||
|
KC_DB_USERNAME: cashu
|
||||||
|
KC_DB_PASSWORD: cashu
|
||||||
|
ports:
|
||||||
|
- 8080:8080
|
||||||
|
restart: always
|
||||||
|
depends_on:
|
||||||
|
- postgres
|
||||||
|
networks:
|
||||||
|
- keycloak_network
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
driver: local
|
||||||
|
|
||||||
|
networks:
|
||||||
|
keycloak_network:
|
||||||
|
driver: bridge
|
||||||
2021
tests/keycloak_data/keycloak-export/master-realm.json
Normal file
2021
tests/keycloak_data/keycloak-export/master-realm.json
Normal file
File diff suppressed because it is too large
Load Diff
26
tests/keycloak_data/keycloak-export/master-users-0.json
Normal file
26
tests/keycloak_data/keycloak-export/master-users-0.json
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"realm" : "master",
|
||||||
|
"users" : [ {
|
||||||
|
"id" : "0ff227f7-c163-4fca-9ae4-c8751c725421",
|
||||||
|
"username" : "admin",
|
||||||
|
"emailVerified" : false,
|
||||||
|
"createdTimestamp" : 1727128354842,
|
||||||
|
"enabled" : true,
|
||||||
|
"totp" : false,
|
||||||
|
"credentials" : [ {
|
||||||
|
"id" : "11a5f9ed-19c9-4164-be31-28ce6e23955b",
|
||||||
|
"type" : "password",
|
||||||
|
"createdDate" : 1727128354904,
|
||||||
|
"secretData" : "{\"value\":\"s/6M2/FCFd1fOyHJRMvOLvKM7e2JIOC6LZ3ovFVkGi8=\",\"salt\":\"Zjn7ChOL5688O84xf1ElGA==\",\"additionalParameters\":{}}",
|
||||||
|
"credentialData" : "{\"hashIterations\":5,\"algorithm\":\"argon2\",\"additionalParameters\":{\"hashLength\":[\"32\"],\"memory\":[\"7168\"],\"type\":[\"id\"],\"version\":[\"1.3\"],\"parallelism\":[\"1\"]}}"
|
||||||
|
} ],
|
||||||
|
"disableableCredentialTypes" : [ ],
|
||||||
|
"requiredActions" : [ ],
|
||||||
|
"realmRoles" : [ "default-roles-master", "admin" ],
|
||||||
|
"clientRoles" : {
|
||||||
|
"nutshell-realm" : [ "query-realms", "query-users", "manage-identity-providers", "manage-authorization", "view-identity-providers", "view-realm", "view-authorization", "query-clients", "manage-clients", "create-client", "view-events", "manage-events", "manage-realm", "manage-users", "view-users", "view-clients", "query-groups" ]
|
||||||
|
},
|
||||||
|
"notBefore" : 0,
|
||||||
|
"groups" : [ ]
|
||||||
|
} ]
|
||||||
|
}
|
||||||
1902
tests/keycloak_data/keycloak-export/nutshell-realm.json
Normal file
1902
tests/keycloak_data/keycloak-export/nutshell-realm.json
Normal file
File diff suppressed because it is too large
Load Diff
53
tests/keycloak_data/keycloak-export/nutshell-users-0.json
Normal file
53
tests/keycloak_data/keycloak-export/nutshell-users-0.json
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
{
|
||||||
|
"realm" : "nutshell",
|
||||||
|
"users" : [ {
|
||||||
|
"id" : "c4fc742a-700f-4c83-96f2-8777c8bb56d1",
|
||||||
|
"username" : "asd@asd.com",
|
||||||
|
"firstName" : "asd",
|
||||||
|
"lastName" : "asd",
|
||||||
|
"email" : "asd@asd.com",
|
||||||
|
"emailVerified" : false,
|
||||||
|
"createdTimestamp" : 1727128876722,
|
||||||
|
"enabled" : true,
|
||||||
|
"totp" : false,
|
||||||
|
"credentials" : [ {
|
||||||
|
"id" : "23ea2b79-9c09-4133-b53b-2708258da890",
|
||||||
|
"type" : "password",
|
||||||
|
"createdDate" : 1727128876754,
|
||||||
|
"secretData" : "{\"value\":\"fDXqE3IjxS5uIYfn9eYgW5GwokWvGsg2wWY0lOgeYyE=\",\"salt\":\"Wlb5f8yPTh4QreuC99b7Zg==\",\"additionalParameters\":{}}",
|
||||||
|
"credentialData" : "{\"hashIterations\":5,\"algorithm\":\"argon2\",\"additionalParameters\":{\"hashLength\":[\"32\"],\"memory\":[\"7168\"],\"type\":[\"id\"],\"version\":[\"1.3\"],\"parallelism\":[\"1\"]}}"
|
||||||
|
} ],
|
||||||
|
"disableableCredentialTypes" : [ ],
|
||||||
|
"requiredActions" : [ ],
|
||||||
|
"realmRoles" : [ "default-roles-nutshell" ],
|
||||||
|
"clientConsents" : [ {
|
||||||
|
"clientId" : "cashu-client",
|
||||||
|
"grantedClientScopes" : [ "email", "roles", "profile" ],
|
||||||
|
"createdDate" : 1732651444894,
|
||||||
|
"lastUpdatedDate" : 1732651444908
|
||||||
|
} ],
|
||||||
|
"notBefore" : 0,
|
||||||
|
"groups" : [ ]
|
||||||
|
}, {
|
||||||
|
"id" : "43a16bd6-f5c5-4dfa-bcd4-6a5540564797",
|
||||||
|
"username" : "callebtc@protonmail.com",
|
||||||
|
"firstName" : "asdasd",
|
||||||
|
"lastName" : "asdasdasdasd",
|
||||||
|
"email" : "callebtc@protonmail.com",
|
||||||
|
"emailVerified" : false,
|
||||||
|
"createdTimestamp" : 1732639511706,
|
||||||
|
"enabled" : true,
|
||||||
|
"totp" : false,
|
||||||
|
"credentials" : [ ],
|
||||||
|
"disableableCredentialTypes" : [ ],
|
||||||
|
"requiredActions" : [ ],
|
||||||
|
"federatedIdentities" : [ {
|
||||||
|
"identityProvider" : "github",
|
||||||
|
"userId" : "93376500",
|
||||||
|
"userName" : "callebtc"
|
||||||
|
} ],
|
||||||
|
"realmRoles" : [ "default-roles-nutshell" ],
|
||||||
|
"notBefore" : 0,
|
||||||
|
"groups" : [ ]
|
||||||
|
} ]
|
||||||
|
}
|
||||||
@@ -183,14 +183,14 @@ async def test_mint(wallet1: Wallet):
|
|||||||
assert wallet1.balance == 64
|
assert wallet1.balance == 64
|
||||||
|
|
||||||
# verify that proofs in proofs_used db have the same mint_id as the invoice in the db
|
# verify that proofs in proofs_used db have the same mint_id as the invoice in the db
|
||||||
mint_quote = await get_bolt11_mint_quote(db=wallet1.db, quote=mint_quote.quote)
|
mint_quote_2 = await get_bolt11_mint_quote(db=wallet1.db, quote=mint_quote.quote)
|
||||||
assert mint_quote
|
assert mint_quote_2
|
||||||
proofs_minted = await get_proofs(
|
proofs_minted = await get_proofs(
|
||||||
db=wallet1.db, mint_id=mint_quote.quote, table="proofs"
|
db=wallet1.db, mint_id=mint_quote_2.quote, table="proofs"
|
||||||
)
|
)
|
||||||
assert len(proofs_minted) == len(expected_proof_amounts)
|
assert len(proofs_minted) == len(expected_proof_amounts)
|
||||||
assert all([p.amount in expected_proof_amounts for p in proofs_minted])
|
assert all([p.amount in expected_proof_amounts for p in proofs_minted])
|
||||||
assert all([p.mint_id == mint_quote.quote for p in proofs_minted])
|
assert all([p.mint_id == mint_quote_2.quote for p in proofs_minted])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -356,7 +356,7 @@ async def test_swap_to_send_more_than_balance(wallet1: Wallet):
|
|||||||
await wallet1.mint(64, quote_id=mint_quote.quote)
|
await wallet1.mint(64, quote_id=mint_quote.quote)
|
||||||
await assert_err(
|
await assert_err(
|
||||||
wallet1.swap_to_send(wallet1.proofs, 128, set_reserved=True),
|
wallet1.swap_to_send(wallet1.proofs, 128, set_reserved=True),
|
||||||
"balance too low.",
|
"Balance too low",
|
||||||
)
|
)
|
||||||
assert wallet1.balance == 64
|
assert wallet1.balance == 64
|
||||||
assert wallet1.available_balance == 64
|
assert wallet1.available_balance == 64
|
||||||
|
|||||||
251
tests/test_wallet_auth.py
Normal file
251
tests/test_wallet_auth.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from cashu.core.base import Unit
|
||||||
|
from cashu.core.crypto.keys import random_hash
|
||||||
|
from cashu.core.crypto.secp import PrivateKey
|
||||||
|
from cashu.core.errors import (
|
||||||
|
BlindAuthFailedError,
|
||||||
|
BlindAuthRateLimitExceededError,
|
||||||
|
ClearAuthFailedError,
|
||||||
|
)
|
||||||
|
from cashu.core.settings import settings
|
||||||
|
from cashu.wallet.auth.auth import WalletAuth
|
||||||
|
from cashu.wallet.wallet import Wallet
|
||||||
|
from tests.conftest import SERVER_ENDPOINT
|
||||||
|
from tests.helpers import assert_err
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def wallet():
|
||||||
|
dirpath = Path("test_data/wallet")
|
||||||
|
if dirpath.exists() and dirpath.is_dir():
|
||||||
|
shutil.rmtree(dirpath)
|
||||||
|
wallet = await Wallet.with_db(
|
||||||
|
url=SERVER_ENDPOINT,
|
||||||
|
db="test_data/wallet",
|
||||||
|
name="wallet",
|
||||||
|
)
|
||||||
|
await wallet.load_mint()
|
||||||
|
yield wallet
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not settings.mint_require_auth,
|
||||||
|
reason="settings.mint_require_auth is False",
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wallet_auth_password(wallet: Wallet):
|
||||||
|
auth_wallet = await WalletAuth.with_db(
|
||||||
|
url=wallet.url,
|
||||||
|
db=wallet.db.db_location,
|
||||||
|
username="asd@asd.com",
|
||||||
|
password="asdasd",
|
||||||
|
)
|
||||||
|
|
||||||
|
requires_auth = await auth_wallet.init_auth_wallet(
|
||||||
|
wallet.mint_info, mint_auth_proofs=False
|
||||||
|
)
|
||||||
|
assert requires_auth
|
||||||
|
|
||||||
|
# expect JWT (CAT) with format ey*.ey*
|
||||||
|
assert auth_wallet.oidc_client.access_token
|
||||||
|
assert auth_wallet.oidc_client.access_token.split(".")[0].startswith("ey")
|
||||||
|
assert auth_wallet.oidc_client.access_token.split(".")[1].startswith("ey")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not settings.mint_require_auth,
|
||||||
|
reason="settings.mint_require_auth is False",
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wallet_auth_wrong_password(wallet: Wallet):
|
||||||
|
auth_wallet = await WalletAuth.with_db(
|
||||||
|
url=wallet.url,
|
||||||
|
db=wallet.db.db_location,
|
||||||
|
username="asd@asd.com",
|
||||||
|
password="wrong_password",
|
||||||
|
)
|
||||||
|
|
||||||
|
await assert_err(auth_wallet.init_auth_wallet(wallet.mint_info), "401 Unauthorized")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not settings.mint_require_auth,
|
||||||
|
reason="settings.mint_require_auth is False",
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wallet_auth_mint(wallet: Wallet):
|
||||||
|
auth_wallet = await WalletAuth.with_db(
|
||||||
|
url=wallet.url,
|
||||||
|
db=wallet.db.db_location,
|
||||||
|
username="asd@asd.com",
|
||||||
|
password="asdasd",
|
||||||
|
)
|
||||||
|
|
||||||
|
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
|
||||||
|
assert requires_auth
|
||||||
|
|
||||||
|
await auth_wallet.load_proofs()
|
||||||
|
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not settings.mint_require_auth,
|
||||||
|
reason="settings.mint_require_auth is False",
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wallet_auth_mint_manually(wallet: Wallet):
|
||||||
|
auth_wallet = await WalletAuth.with_db(
|
||||||
|
url=wallet.url,
|
||||||
|
db=wallet.db.db_location,
|
||||||
|
username="asd@asd.com",
|
||||||
|
password="asdasd",
|
||||||
|
)
|
||||||
|
|
||||||
|
requires_auth = await auth_wallet.init_auth_wallet(
|
||||||
|
wallet.mint_info, mint_auth_proofs=False
|
||||||
|
)
|
||||||
|
assert requires_auth
|
||||||
|
assert len(auth_wallet.proofs) == 0
|
||||||
|
|
||||||
|
await auth_wallet.mint_blind_auth()
|
||||||
|
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not settings.mint_require_auth,
|
||||||
|
reason="settings.mint_require_auth is False",
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wallet_auth_mint_manually_invalid_cat(wallet: Wallet):
|
||||||
|
auth_wallet = await WalletAuth.with_db(
|
||||||
|
url=wallet.url,
|
||||||
|
db=wallet.db.db_location,
|
||||||
|
username="asd@asd.com",
|
||||||
|
password="asdasd",
|
||||||
|
)
|
||||||
|
|
||||||
|
requires_auth = await auth_wallet.init_auth_wallet(
|
||||||
|
wallet.mint_info, mint_auth_proofs=False
|
||||||
|
)
|
||||||
|
assert requires_auth
|
||||||
|
assert len(auth_wallet.proofs) == 0
|
||||||
|
|
||||||
|
# invalidate CAT in the database
|
||||||
|
auth_wallet.oidc_client.access_token = random_hash()
|
||||||
|
|
||||||
|
# this is the code executed in auth_wallet.mint_blind_auth():
|
||||||
|
clear_auth_token = auth_wallet.oidc_client.access_token
|
||||||
|
if not clear_auth_token:
|
||||||
|
raise Exception("No clear auth token available.")
|
||||||
|
|
||||||
|
amounts = auth_wallet.mint_info.bat_max_mint * [1] # 1 AUTH tokens
|
||||||
|
secrets = [hashlib.sha256(os.urandom(32)).hexdigest() for _ in amounts]
|
||||||
|
rs = [PrivateKey(privkey=os.urandom(32), raw=True) for _ in amounts]
|
||||||
|
outputs, rs = auth_wallet._construct_outputs(amounts, secrets, rs)
|
||||||
|
|
||||||
|
# should fail because of invalid CAT
|
||||||
|
await assert_err(
|
||||||
|
auth_wallet.blind_mint_blind_auth(clear_auth_token, outputs),
|
||||||
|
ClearAuthFailedError.detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not settings.mint_require_auth,
|
||||||
|
reason="settings.mint_require_auth is False",
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wallet_auth_invoice(wallet: Wallet):
|
||||||
|
# should fail, wallet error
|
||||||
|
await assert_err(wallet.mint_quote(10, Unit.sat), "Mint requires blind auth")
|
||||||
|
|
||||||
|
auth_wallet = await WalletAuth.with_db(
|
||||||
|
url=wallet.url,
|
||||||
|
db=wallet.db.db_location,
|
||||||
|
username="asd@asd.com",
|
||||||
|
password="asdasd",
|
||||||
|
)
|
||||||
|
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
|
||||||
|
assert requires_auth
|
||||||
|
|
||||||
|
await auth_wallet.load_proofs()
|
||||||
|
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
|
||||||
|
|
||||||
|
wallet.auth_db = auth_wallet.db
|
||||||
|
wallet.auth_keyset_id = auth_wallet.keyset_id
|
||||||
|
|
||||||
|
# should succeed
|
||||||
|
await wallet.mint_quote(10, Unit.sat)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not settings.mint_require_auth,
|
||||||
|
reason="settings.mint_require_auth is False",
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wallet_auth_invoice_invalid_bat(wallet: Wallet):
|
||||||
|
# should fail, wallet error
|
||||||
|
await assert_err(wallet.mint_quote(10, Unit.sat), "Mint requires blind auth")
|
||||||
|
|
||||||
|
auth_wallet = await WalletAuth.with_db(
|
||||||
|
url=wallet.url,
|
||||||
|
db=wallet.db.db_location,
|
||||||
|
username="asd@asd.com",
|
||||||
|
password="asdasd",
|
||||||
|
)
|
||||||
|
requires_auth = await auth_wallet.init_auth_wallet(wallet.mint_info)
|
||||||
|
assert requires_auth
|
||||||
|
|
||||||
|
await auth_wallet.load_proofs()
|
||||||
|
assert len(auth_wallet.proofs) == auth_wallet.mint_info.bat_max_mint
|
||||||
|
|
||||||
|
# invalidate blind auth proofs
|
||||||
|
for p in auth_wallet.proofs:
|
||||||
|
await auth_wallet.db.execute(
|
||||||
|
f"UPDATE proofs SET secret = '{random_hash()}' WHERE secret = '{p.secret}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
wallet.auth_db = auth_wallet.db
|
||||||
|
wallet.auth_keyset_id = auth_wallet.keyset_id
|
||||||
|
|
||||||
|
# blind auth failed
|
||||||
|
await assert_err(wallet.mint_quote(10, Unit.sat), BlindAuthFailedError.detail)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not settings.mint_require_auth,
|
||||||
|
reason="settings.mint_require_auth is False",
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wallet_auth_rate_limit(wallet: Wallet):
|
||||||
|
auth_wallet = await WalletAuth.with_db(
|
||||||
|
url=wallet.url,
|
||||||
|
db=wallet.db.db_location,
|
||||||
|
username="asd@asd.com",
|
||||||
|
password="asdasd",
|
||||||
|
)
|
||||||
|
requires_auth = await auth_wallet.init_auth_wallet(
|
||||||
|
wallet.mint_info, mint_auth_proofs=False
|
||||||
|
)
|
||||||
|
assert requires_auth
|
||||||
|
|
||||||
|
errored = False
|
||||||
|
for _ in range(100):
|
||||||
|
try:
|
||||||
|
await auth_wallet.mint_blind_auth()
|
||||||
|
except Exception as e:
|
||||||
|
assert BlindAuthRateLimitExceededError.detail in str(e)
|
||||||
|
errored = True
|
||||||
|
break
|
||||||
|
|
||||||
|
assert errored
|
||||||
|
|
||||||
|
# should have minted at least twice
|
||||||
|
assert len(auth_wallet.proofs) > auth_wallet.mint_info.bat_max_mint
|
||||||
@@ -54,7 +54,7 @@ async def init_wallet():
|
|||||||
wallet = await Wallet.with_db(
|
wallet = await Wallet.with_db(
|
||||||
url=settings.mint_url,
|
url=settings.mint_url,
|
||||||
db="test_data/test_cli_wallet",
|
db="test_data/test_cli_wallet",
|
||||||
name="wallet",
|
name="test_cli_wallet",
|
||||||
)
|
)
|
||||||
await wallet.load_proofs()
|
await wallet.load_proofs()
|
||||||
return wallet
|
return wallet
|
||||||
@@ -411,7 +411,7 @@ def test_wallets(cli_prefix):
|
|||||||
print("WALLETS")
|
print("WALLETS")
|
||||||
# on github this is empty
|
# on github this is empty
|
||||||
if len(result.output):
|
if len(result.output):
|
||||||
assert "test_cli_wallet" in result.output
|
assert "wallet" in result.output
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ def test_send_too_much(mint, cli_prefix):
|
|||||||
cli,
|
cli,
|
||||||
[*cli_prefix, "send", "100000"],
|
[*cli_prefix, "send", "100000"],
|
||||||
)
|
)
|
||||||
assert "balance too low" in str(result.exception)
|
assert "Balance too low" in str(result.exception)
|
||||||
|
|
||||||
|
|
||||||
def test_receive_tokenv3(mint, cli_prefix):
|
def test_receive_tokenv3(mint, cli_prefix):
|
||||||
|
|||||||
Reference in New Issue
Block a user