* wip

* wip

* model

* refactor wallet transactions

* refactor wallet

* sending with fees works and outputs fill up the wallet

* wip work

* ok

* comments

* receive with amount=0

* correctly import postmeltrequest

* fix melt amount

* tests working

* remove mint_loaded decorator in deprecated wallet api

* wallet works with units

* refactor: melt_quote

* fix fees

* add file

* fees for melt inputs

* set default input fee for internal quotes to 0

* fix coinselect

* coin selection working

* yo

* fix all tests

* clean up

* last commit added fees for inputs for melt transactions - this commit adds a blanace too low exception

* fix fee return and melt quote max allowed amount check during creation of melt quote

* clean up code

* add tests for fees

* add melt tests

* update wallet fee information
This commit is contained in:
callebtc
2024-06-15 16:22:41 +02:00
committed by GitHub
parent d80280e35d
commit d30b1a2777
47 changed files with 2446 additions and 1554 deletions

View File

@@ -4,10 +4,10 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from sqlite3 import Row from sqlite3 import Row
from typing import Any, Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from loguru import logger from loguru import logger
from pydantic import BaseModel, Field from pydantic import BaseModel
from .crypto.aes import AESCipher from .crypto.aes import AESCipher
from .crypto.b_dhke import hash_to_curve from .crypto.b_dhke import hash_to_curve
@@ -45,6 +45,21 @@ class DLEQWallet(BaseModel):
# ------- PROOFS ------- # ------- PROOFS -------
class SpentState(Enum):
unspent = "UNSPENT"
spent = "SPENT"
pending = "PENDING"
def __str__(self):
return self.name
class ProofState(BaseModel):
Y: str
state: SpentState
witness: Optional[str] = None
class HTLCWitness(BaseModel): class HTLCWitness(BaseModel):
preimage: Optional[str] = None preimage: Optional[str] = None
signature: Optional[str] = None signature: Optional[str] = None
@@ -85,8 +100,7 @@ class Proof(BaseModel):
Value token Value token
""" """
# NOTE: None for backwards compatibility for old clients that do not include the keyset id < 0.3 id: str = ""
id: Union[None, str] = ""
amount: int = 0 amount: int = 0
secret: str = "" # secret or message to be blinded and signed secret: str = "" # secret or message to be blinded and signed
Y: str = "" # hash_to_curve(secret) Y: str = "" # hash_to_curve(secret)
@@ -199,11 +213,6 @@ class BlindedMessage_Deprecated(BaseModel):
return P2PKWitness.from_witness(self.witness).signatures return P2PKWitness.from_witness(self.witness).signatures
class BlindedMessages(BaseModel):
# NOTE: not used in Pydantic validation
__root__: List[BlindedMessage] = []
class BlindedSignature(BaseModel): class BlindedSignature(BaseModel):
""" """
Blinded signature or "promise" which is the signature on a `BlindedMessage` Blinded signature or "promise" which is the signature on a `BlindedMessage`
@@ -321,274 +330,6 @@ class MintQuote(BaseModel):
) )
# ------- API -------
# ------- API: INFO -------
class MintMeltMethodSetting(BaseModel):
method: str
unit: str
min_amount: Optional[int] = None
max_amount: Optional[int] = None
class GetInfoResponse(BaseModel):
name: Optional[str] = None
pubkey: Optional[str] = None
version: Optional[str] = None
description: Optional[str] = None
description_long: Optional[str] = None
contact: Optional[List[List[str]]] = None
motd: Optional[str] = None
nuts: Optional[Dict[int, Any]] = None
class Nut15MppSupport(BaseModel):
method: str
unit: str
mpp: bool
class GetInfoResponse_deprecated(BaseModel):
name: Optional[str] = None
pubkey: Optional[str] = None
version: Optional[str] = None
description: Optional[str] = None
description_long: Optional[str] = None
contact: Optional[List[List[str]]] = None
nuts: Optional[List[str]] = None
motd: Optional[str] = None
parameter: Optional[dict] = None
# ------- API: KEYS -------
class KeysResponseKeyset(BaseModel):
id: str
unit: str
keys: Dict[int, str]
class KeysResponse(BaseModel):
keysets: List[KeysResponseKeyset]
class KeysetsResponseKeyset(BaseModel):
id: str
unit: str
active: bool
class KeysetsResponse(BaseModel):
keysets: list[KeysetsResponseKeyset]
class KeysResponse_deprecated(BaseModel):
__root__: Dict[str, str]
class KeysetsResponse_deprecated(BaseModel):
keysets: list[str]
# ------- API: MINT QUOTE -------
class PostMintQuoteRequest(BaseModel):
unit: str = Field(..., max_length=settings.mint_max_request_length) # output unit
amount: int = Field(..., gt=0) # output amount
class PostMintQuoteResponse(BaseModel):
quote: str # quote id
request: str # input payment request
paid: bool # whether the request has been paid
expiry: Optional[int] # expiry of the quote
# ------- API: MINT -------
class PostMintRequest(BaseModel):
quote: str = Field(..., max_length=settings.mint_max_request_length) # quote id
outputs: List[BlindedMessage] = Field(
..., max_items=settings.mint_max_request_length
)
class PostMintResponse(BaseModel):
signatures: List[BlindedSignature] = []
class GetMintResponse_deprecated(BaseModel):
pr: str
hash: str
class PostMintRequest_deprecated(BaseModel):
outputs: List[BlindedMessage_Deprecated] = Field(
..., max_items=settings.mint_max_request_length
)
class PostMintResponse_deprecated(BaseModel):
promises: List[BlindedSignature] = []
# ------- API: MELT QUOTE -------
class PostMeltQuoteRequest(BaseModel):
unit: str = Field(..., max_length=settings.mint_max_request_length) # input unit
request: str = Field(
..., max_length=settings.mint_max_request_length
) # output payment request
amount: Optional[int] = Field(default=None, gt=0) # input amount
class PostMeltQuoteResponse(BaseModel):
quote: str # quote id
amount: int # input amount
fee_reserve: int # input fee reserve
paid: bool # whether the request has been paid
expiry: Optional[int] # expiry of the quote
# ------- API: MELT -------
class PostMeltRequest(BaseModel):
quote: str = Field(..., max_length=settings.mint_max_request_length) # quote id
inputs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
outputs: Union[List[BlindedMessage], None] = Field(
None, max_items=settings.mint_max_request_length
)
class PostMeltResponse(BaseModel):
paid: Union[bool, None]
payment_preimage: Union[str, None]
change: Union[List[BlindedSignature], None] = None
class PostMeltRequest_deprecated(BaseModel):
proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
pr: str = Field(..., max_length=settings.mint_max_request_length)
outputs: Union[List[BlindedMessage_Deprecated], None] = Field(
None, max_items=settings.mint_max_request_length
)
class PostMeltResponse_deprecated(BaseModel):
paid: Union[bool, None]
preimage: Union[str, None]
change: Union[List[BlindedSignature], None] = None
# ------- API: SPLIT -------
class PostSplitRequest(BaseModel):
inputs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
outputs: List[BlindedMessage] = Field(
..., max_items=settings.mint_max_request_length
)
class PostSplitResponse(BaseModel):
signatures: List[BlindedSignature]
# deprecated since 0.13.0
class PostSplitRequest_Deprecated(BaseModel):
proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
amount: Optional[int] = None
outputs: List[BlindedMessage_Deprecated] = Field(
..., max_items=settings.mint_max_request_length
)
class PostSplitResponse_Deprecated(BaseModel):
promises: List[BlindedSignature] = []
class PostSplitResponse_Very_Deprecated(BaseModel):
fst: List[BlindedSignature] = []
snd: List[BlindedSignature] = []
deprecated: str = "The amount field is deprecated since 0.13.0"
# ------- API: CHECK -------
class PostCheckStateRequest(BaseModel):
Ys: List[str] = Field(..., max_items=settings.mint_max_request_length)
class SpentState(Enum):
unspent = "UNSPENT"
spent = "SPENT"
pending = "PENDING"
def __str__(self):
return self.name
class ProofState(BaseModel):
Y: str
state: SpentState
witness: Optional[str] = None
class PostCheckStateResponse(BaseModel):
states: List[ProofState] = []
class CheckSpendableRequest_deprecated(BaseModel):
proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
class CheckSpendableResponse_deprecated(BaseModel):
spendable: List[bool]
pending: List[bool]
class CheckFeesRequest_deprecated(BaseModel):
pr: str = Field(..., max_length=settings.mint_max_request_length)
class CheckFeesResponse_deprecated(BaseModel):
fee: Union[int, None]
# ------- API: RESTORE -------
class PostRestoreRequest(BaseModel):
outputs: List[BlindedMessage] = Field(
..., max_items=settings.mint_max_request_length
)
class PostRestoreRequest_Deprecated(BaseModel):
outputs: List[BlindedMessage_Deprecated] = Field(
..., max_items=settings.mint_max_request_length
)
class PostRestoreResponse(BaseModel):
outputs: List[BlindedMessage] = []
signatures: List[BlindedSignature] = []
promises: Optional[List[BlindedSignature]] = [] # deprecated since 0.15.1
# duplicate value of "signatures" for backwards compatibility with old clients < 0.15.1
def __init__(self, **data):
super().__init__(**data)
self.promises = self.signatures
# ------- KEYSETS ------- # ------- KEYSETS -------
@@ -672,6 +413,7 @@ class WalletKeyset:
valid_to: Union[str, None] = None valid_to: Union[str, None] = None
first_seen: Union[str, None] = None first_seen: Union[str, None] = None
active: Union[bool, None] = True active: Union[bool, None] = True
input_fee_ppk: int = 0
def __init__( def __init__(
self, self,
@@ -683,13 +425,14 @@ class WalletKeyset:
valid_to=None, valid_to=None,
first_seen=None, first_seen=None,
active=True, active=True,
use_deprecated_id=False, # BACKWARDS COMPATIBILITY < 0.15.0 input_fee_ppk=0,
): ):
self.valid_from = valid_from self.valid_from = valid_from
self.valid_to = valid_to self.valid_to = valid_to
self.first_seen = first_seen self.first_seen = first_seen
self.active = active self.active = active
self.mint_url = mint_url self.mint_url = mint_url
self.input_fee_ppk = input_fee_ppk
self.public_keys = public_keys self.public_keys = public_keys
# overwrite id by deriving it from the public keys # overwrite id by deriving it from the public keys
@@ -698,19 +441,9 @@ class WalletKeyset:
else: else:
self.id = id self.id = id
# BEGIN BACKWARDS COMPATIBILITY < 0.15.0
if use_deprecated_id:
logger.warning(
"Using deprecated keyset id derivation for backwards compatibility <"
" 0.15.0"
)
self.id = derive_keyset_id_deprecated(self.public_keys)
# END BACKWARDS COMPATIBILITY < 0.15.0
self.unit = Unit[unit] self.unit = Unit[unit]
logger.trace(f"Derived keyset id {self.id} from public keys.") if id and id != self.id:
if id and id != self.id and use_deprecated_id:
logger.warning( logger.warning(
f"WARNING: Keyset id {self.id} does not match the given id {id}." f"WARNING: Keyset id {self.id} does not match the given id {id}."
" Overwriting." " Overwriting."
@@ -743,6 +476,7 @@ class WalletKeyset:
valid_to=row["valid_to"], valid_to=row["valid_to"],
first_seen=row["first_seen"], first_seen=row["first_seen"],
active=row["active"], active=row["active"],
input_fee_ppk=row["input_fee_ppk"],
) )
@@ -756,6 +490,7 @@ class MintKeyset:
active: bool active: bool
unit: Unit unit: Unit
derivation_path: str derivation_path: str
input_fee_ppk: int
seed: Optional[str] = None seed: Optional[str] = None
encrypted_seed: Optional[str] = None encrypted_seed: Optional[str] = None
seed_encryption_method: Optional[str] = None seed_encryption_method: Optional[str] = None
@@ -780,6 +515,7 @@ class MintKeyset:
active: Optional[bool] = None, active: Optional[bool] = None,
unit: Optional[str] = None, unit: Optional[str] = None,
version: Optional[str] = None, version: Optional[str] = None,
input_fee_ppk: Optional[int] = None,
id: str = "", id: str = "",
): ):
self.derivation_path = derivation_path self.derivation_path = derivation_path
@@ -801,6 +537,10 @@ class MintKeyset:
self.first_seen = first_seen self.first_seen = first_seen
self.active = bool(active) if active is not None else False self.active = bool(active) if active is not None else False
self.version = version or settings.version self.version = version or settings.version
self.input_fee_ppk = input_fee_ppk or 0
if self.input_fee_ppk < 0:
raise Exception("Input fee must be non-negative.")
self.version_tuple = tuple( self.version_tuple = tuple(
[int(i) for i in self.version.split(".")] if self.version else [] [int(i) for i in self.version.split(".")] if self.version else []
@@ -930,11 +670,14 @@ class TokenV3(BaseModel):
token: List[TokenV3Token] = [] token: List[TokenV3Token] = []
memo: Optional[str] = None memo: Optional[str] = None
unit: Optional[str] = None
def to_dict(self, include_dleq=False): def to_dict(self, include_dleq=False):
return_dict = dict(token=[t.to_dict(include_dleq) for t in self.token]) return_dict = dict(token=[t.to_dict(include_dleq) for t in self.token])
if self.memo: if self.memo:
return_dict.update(dict(memo=self.memo)) # type: ignore return_dict.update(dict(memo=self.memo)) # type: ignore
if self.unit:
return_dict.update(dict(unit=self.unit)) # type: ignore
return return_dict return return_dict
def get_proofs(self): def get_proofs(self):

View File

@@ -35,12 +35,18 @@ class TokenAlreadySpentError(TransactionError):
super().__init__(self.detail, code=self.code) super().__init__(self.detail, code=self.code)
class TransactionNotBalancedError(TransactionError):
code = 11002
def __init__(self, detail):
super().__init__(detail, code=self.code)
class SecretTooLongError(TransactionError): class SecretTooLongError(TransactionError):
detail = "secret too long"
code = 11003 code = 11003
def __init__(self): def __init__(self, detail="secret too long"):
super().__init__(self.detail, code=self.code) super().__init__(detail, code=self.code)
class NoSecretInProofsError(TransactionError): class NoSecretInProofsError(TransactionError):
@@ -51,6 +57,13 @@ class NoSecretInProofsError(TransactionError):
super().__init__(self.detail, code=self.code) super().__init__(self.detail, code=self.code)
class TransactionUnitError(TransactionError):
code = 11005
def __init__(self, detail):
super().__init__(detail, code=self.code)
class KeysetError(CashuError): class KeysetError(CashuError):
detail = "keyset error" detail = "keyset error"
code = 12000 code = 12000

View File

@@ -3,10 +3,21 @@ import math
from functools import partial, wraps from functools import partial, wraps
from typing import List from typing import List
from ..core.base import BlindedSignature, Proof from ..core.base import Amount, BlindedSignature, Proof, Unit
from ..core.settings import settings from ..core.settings import settings
def amount_summary(proofs: List[Proof], unit: Unit) -> str:
amounts_we_have = [
(amount, len([p for p in proofs if p.amount == amount]))
for amount in set([p.amount for p in proofs])
]
amounts_we_have.sort(key=lambda x: x[0])
return (
f"{', '.join([f'{Amount(unit, a).str()} ({c}x)' for a, c in amounts_we_have])}"
)
def sum_proofs(proofs: List[Proof]): def sum_proofs(proofs: List[Proof]):
return sum([p.amount for p in proofs]) return sum([p.amount for p in proofs])

265
cashu/core/models.py Normal file
View File

@@ -0,0 +1,265 @@
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from .base import (
BlindedMessage,
BlindedMessage_Deprecated,
BlindedSignature,
Proof,
ProofState,
)
from .settings import settings
# ------- API -------
# ------- API: INFO -------
class MintMeltMethodSetting(BaseModel):
method: str
unit: str
min_amount: Optional[int] = None
max_amount: Optional[int] = None
class GetInfoResponse(BaseModel):
name: Optional[str] = None
pubkey: Optional[str] = None
version: Optional[str] = None
description: Optional[str] = None
description_long: Optional[str] = None
contact: Optional[List[List[str]]] = None
motd: Optional[str] = None
nuts: Optional[Dict[int, Any]] = None
class Nut15MppSupport(BaseModel):
method: str
unit: str
mpp: bool
class GetInfoResponse_deprecated(BaseModel):
name: Optional[str] = None
pubkey: Optional[str] = None
version: Optional[str] = None
description: Optional[str] = None
description_long: Optional[str] = None
contact: Optional[List[List[str]]] = None
nuts: Optional[List[str]] = None
motd: Optional[str] = None
parameter: Optional[dict] = None
# ------- API: KEYS -------
class KeysResponseKeyset(BaseModel):
id: str
unit: str
keys: Dict[int, str]
class KeysResponse(BaseModel):
keysets: List[KeysResponseKeyset]
class KeysetsResponseKeyset(BaseModel):
id: str
unit: str
active: bool
input_fee_ppk: Optional[int] = None
class KeysetsResponse(BaseModel):
keysets: list[KeysetsResponseKeyset]
class KeysResponse_deprecated(BaseModel):
__root__: Dict[str, str]
class KeysetsResponse_deprecated(BaseModel):
keysets: list[str]
# ------- API: MINT QUOTE -------
class PostMintQuoteRequest(BaseModel):
unit: str = Field(..., max_length=settings.mint_max_request_length) # output unit
amount: int = Field(..., gt=0) # output amount
class PostMintQuoteResponse(BaseModel):
quote: str # quote id
request: str # input payment request
paid: bool # whether the request has been paid
expiry: Optional[int] # expiry of the quote
# ------- API: MINT -------
class PostMintRequest(BaseModel):
quote: str = Field(..., max_length=settings.mint_max_request_length) # quote id
outputs: List[BlindedMessage] = Field(
..., max_items=settings.mint_max_request_length
)
class PostMintResponse(BaseModel):
signatures: List[BlindedSignature] = []
class GetMintResponse_deprecated(BaseModel):
pr: str
hash: str
class PostMintRequest_deprecated(BaseModel):
outputs: List[BlindedMessage_Deprecated] = Field(
..., max_items=settings.mint_max_request_length
)
class PostMintResponse_deprecated(BaseModel):
promises: List[BlindedSignature] = []
# ------- API: MELT QUOTE -------
class PostMeltQuoteRequest(BaseModel):
unit: str = Field(..., max_length=settings.mint_max_request_length) # input unit
request: str = Field(
..., max_length=settings.mint_max_request_length
) # output payment request
amount: Optional[int] = Field(default=None, gt=0) # input amount
class PostMeltQuoteResponse(BaseModel):
quote: str # quote id
amount: int # input amount
fee_reserve: int # input fee reserve
paid: bool # whether the request has been paid
expiry: Optional[int] # expiry of the quote
# ------- API: MELT -------
class PostMeltRequest(BaseModel):
quote: str = Field(..., max_length=settings.mint_max_request_length) # quote id
inputs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
outputs: Union[List[BlindedMessage], None] = Field(
None, max_items=settings.mint_max_request_length
)
class PostMeltResponse(BaseModel):
paid: Union[bool, None]
payment_preimage: Union[str, None]
change: Union[List[BlindedSignature], None] = None
class PostMeltRequest_deprecated(BaseModel):
proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
pr: str = Field(..., max_length=settings.mint_max_request_length)
outputs: Union[List[BlindedMessage_Deprecated], None] = Field(
None, max_items=settings.mint_max_request_length
)
class PostMeltResponse_deprecated(BaseModel):
paid: Union[bool, None]
preimage: Union[str, None]
change: Union[List[BlindedSignature], None] = None
# ------- API: SPLIT -------
class PostSplitRequest(BaseModel):
inputs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
outputs: List[BlindedMessage] = Field(
..., max_items=settings.mint_max_request_length
)
class PostSplitResponse(BaseModel):
signatures: List[BlindedSignature]
# deprecated since 0.13.0
class PostSplitRequest_Deprecated(BaseModel):
proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
amount: Optional[int] = None
outputs: List[BlindedMessage_Deprecated] = Field(
..., max_items=settings.mint_max_request_length
)
class PostSplitResponse_Deprecated(BaseModel):
promises: List[BlindedSignature] = []
class PostSplitResponse_Very_Deprecated(BaseModel):
fst: List[BlindedSignature] = []
snd: List[BlindedSignature] = []
deprecated: str = "The amount field is deprecated since 0.13.0"
# ------- API: CHECK -------
class PostCheckStateRequest(BaseModel):
Ys: List[str] = Field(..., max_items=settings.mint_max_request_length)
class PostCheckStateResponse(BaseModel):
states: List[ProofState] = []
class CheckSpendableRequest_deprecated(BaseModel):
proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
class CheckSpendableResponse_deprecated(BaseModel):
spendable: List[bool]
pending: List[bool]
class CheckFeesRequest_deprecated(BaseModel):
pr: str = Field(..., max_length=settings.mint_max_request_length)
class CheckFeesResponse_deprecated(BaseModel):
fee: Union[int, None]
# ------- API: RESTORE -------
class PostRestoreRequest(BaseModel):
outputs: List[BlindedMessage] = Field(
..., max_items=settings.mint_max_request_length
)
class PostRestoreRequest_Deprecated(BaseModel):
outputs: List[BlindedMessage_Deprecated] = Field(
..., max_items=settings.mint_max_request_length
)
class PostRestoreResponse(BaseModel):
outputs: List[BlindedMessage] = []
signatures: List[BlindedSignature] = []
promises: Optional[List[BlindedSignature]] = [] # deprecated since 0.15.1
# duplicate value of "signatures" for backwards compatibility with old clients < 0.15.1
def __init__(self, **data):
super().__init__(**data)
self.promises = self.signatures

View File

@@ -58,6 +58,9 @@ class MintSettings(CashuSettings):
mint_database: str = Field(default="data/mint") mint_database: str = Field(default="data/mint")
mint_test_database: str = Field(default="test_data/test_mint") mint_test_database: str = Field(default="test_data/test_mint")
mint_max_secret_length: int = Field(default=512)
mint_input_fee_ppk: int = Field(default=0)
class MintBackends(MintSettings): class MintBackends(MintSettings):
@@ -170,6 +173,8 @@ class WalletSettings(CashuSettings):
locktime_delta_seconds: int = Field(default=86400) # 1 day locktime_delta_seconds: int = Field(default=86400) # 1 day
proofs_batch_size: int = Field(default=1000) proofs_batch_size: int = Field(default=1000)
wallet_target_amount_count: int = Field(default=3)
class LndRestFundingSource(MintSettings): class LndRestFundingSource(MintSettings):
mint_lnd_rest_endpoint: Optional[str] = Field(default=None) mint_lnd_rest_endpoint: Optional[str] = Field(default=None)

View File

@@ -6,9 +6,9 @@ from pydantic import BaseModel
from ..core.base import ( from ..core.base import (
Amount, Amount,
MeltQuote, MeltQuote,
PostMeltQuoteRequest,
Unit, Unit,
) )
from ..core.models import PostMeltQuoteRequest
class StatusResponse(BaseModel): class StatusResponse(BaseModel):

View File

@@ -11,7 +11,8 @@ from bolt11 import (
) )
from loguru import logger from loguru import logger
from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit from ..core.base import Amount, MeltQuote, Unit
from ..core.models import PostMeltQuoteRequest
from ..core.settings import settings from ..core.settings import settings
from .base import ( from .base import (
InvoiceResponse, InvoiceResponse,

View File

@@ -10,8 +10,9 @@ from bolt11 import (
) )
from loguru import logger from loguru import logger
from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit from ..core.base import Amount, MeltQuote, Unit
from ..core.helpers import fee_reserve from ..core.helpers import fee_reserve
from ..core.models import PostMeltQuoteRequest
from ..core.settings import settings from ..core.settings import settings
from .base import ( from .base import (
InvoiceResponse, InvoiceResponse,

View File

@@ -15,8 +15,9 @@ from bolt11 import (
encode, encode,
) )
from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit from ..core.base import Amount, MeltQuote, Unit
from ..core.helpers import fee_reserve from ..core.helpers import fee_reserve
from ..core.models import PostMeltQuoteRequest
from ..core.settings import settings from ..core.settings import settings
from .base import ( from .base import (
InvoiceResponse, InvoiceResponse,

View File

@@ -6,8 +6,9 @@ from bolt11 import (
decode, decode,
) )
from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit from ..core.base import Amount, MeltQuote, Unit
from ..core.helpers import fee_reserve from ..core.helpers import fee_reserve
from ..core.models import PostMeltQuoteRequest
from ..core.settings import settings from ..core.settings import settings
from .base import ( from .base import (
InvoiceResponse, InvoiceResponse,

View File

@@ -12,8 +12,9 @@ from bolt11 import (
) )
from loguru import logger from loguru import logger
from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit from ..core.base import Amount, MeltQuote, Unit
from ..core.helpers import fee_reserve from ..core.helpers import fee_reserve
from ..core.models import PostMeltQuoteRequest
from ..core.settings import settings from ..core.settings import settings
from .base import ( from .base import (
InvoiceResponse, InvoiceResponse,

View File

@@ -4,7 +4,8 @@ from typing import Dict, Optional
import httpx import httpx
from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit from ..core.base import Amount, MeltQuote, Unit
from ..core.models import PostMeltQuoteRequest
from ..core.settings import settings from ..core.settings import settings
from .base import ( from .base import (
InvoiceResponse, InvoiceResponse,

View File

@@ -34,7 +34,8 @@ class LedgerCrud(ABC):
derivation_path: str = "", derivation_path: str = "",
seed: str = "", seed: str = "",
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> List[MintKeyset]: ... ) -> List[MintKeyset]:
...
@abstractmethod @abstractmethod
async def get_spent_proofs( async def get_spent_proofs(
@@ -42,7 +43,8 @@ class LedgerCrud(ABC):
*, *,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> List[Proof]: ... ) -> List[Proof]:
...
async def get_proof_used( async def get_proof_used(
self, self,
@@ -50,7 +52,8 @@ class LedgerCrud(ABC):
Y: str, Y: str,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> Optional[Proof]: ... ) -> Optional[Proof]:
...
@abstractmethod @abstractmethod
async def invalidate_proof( async def invalidate_proof(
@@ -60,7 +63,8 @@ class LedgerCrud(ABC):
proof: Proof, proof: Proof,
quote_id: Optional[str] = None, quote_id: Optional[str] = None,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None:
...
@abstractmethod @abstractmethod
async def get_all_melt_quotes_from_pending_proofs( async def get_all_melt_quotes_from_pending_proofs(
@@ -68,7 +72,8 @@ class LedgerCrud(ABC):
*, *,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> List[MeltQuote]: ... ) -> List[MeltQuote]:
...
@abstractmethod @abstractmethod
async def get_pending_proofs_for_quote( async def get_pending_proofs_for_quote(
@@ -77,7 +82,8 @@ class LedgerCrud(ABC):
quote_id: str, quote_id: str,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> List[Proof]: ... ) -> List[Proof]:
...
@abstractmethod @abstractmethod
async def get_proofs_pending( async def get_proofs_pending(
@@ -86,7 +92,8 @@ class LedgerCrud(ABC):
Ys: List[str], Ys: List[str],
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> List[Proof]: ... ) -> List[Proof]:
...
@abstractmethod @abstractmethod
async def set_proof_pending( async def set_proof_pending(
@@ -96,7 +103,8 @@ class LedgerCrud(ABC):
proof: Proof, proof: Proof,
quote_id: Optional[str] = None, quote_id: Optional[str] = None,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None:
...
@abstractmethod @abstractmethod
async def unset_proof_pending( async def unset_proof_pending(
@@ -105,7 +113,8 @@ class LedgerCrud(ABC):
proof: Proof, proof: Proof,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None:
...
@abstractmethod @abstractmethod
async def store_keyset( async def store_keyset(
@@ -114,14 +123,16 @@ class LedgerCrud(ABC):
db: Database, db: Database,
keyset: MintKeyset, keyset: MintKeyset,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None:
...
@abstractmethod @abstractmethod
async def get_balance( async def get_balance(
self, self,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> int: ... ) -> int:
...
@abstractmethod @abstractmethod
async def store_promise( async def store_promise(
@@ -135,7 +146,8 @@ class LedgerCrud(ABC):
e: str = "", e: str = "",
s: str = "", s: str = "",
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None:
...
@abstractmethod @abstractmethod
async def get_promise( async def get_promise(
@@ -144,7 +156,8 @@ class LedgerCrud(ABC):
db: Database, db: Database,
b_: str, b_: str,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> Optional[BlindedSignature]: ... ) -> Optional[BlindedSignature]:
...
@abstractmethod @abstractmethod
async def store_mint_quote( async def store_mint_quote(
@@ -153,7 +166,8 @@ class LedgerCrud(ABC):
quote: MintQuote, quote: MintQuote,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None:
...
@abstractmethod @abstractmethod
async def get_mint_quote( async def get_mint_quote(
@@ -162,7 +176,8 @@ class LedgerCrud(ABC):
quote_id: str, quote_id: str,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> Optional[MintQuote]: ... ) -> Optional[MintQuote]:
...
@abstractmethod @abstractmethod
async def get_mint_quote_by_request( async def get_mint_quote_by_request(
@@ -171,7 +186,8 @@ class LedgerCrud(ABC):
request: str, request: str,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> Optional[MintQuote]: ... ) -> Optional[MintQuote]:
...
@abstractmethod @abstractmethod
async def update_mint_quote( async def update_mint_quote(
@@ -180,7 +196,8 @@ class LedgerCrud(ABC):
quote: MintQuote, quote: MintQuote,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None:
...
# @abstractmethod # @abstractmethod
# async def update_mint_quote_paid( # async def update_mint_quote_paid(
@@ -199,7 +216,8 @@ class LedgerCrud(ABC):
quote: MeltQuote, quote: MeltQuote,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None:
...
@abstractmethod @abstractmethod
async def get_melt_quote( async def get_melt_quote(
@@ -209,7 +227,8 @@ class LedgerCrud(ABC):
db: Database, db: Database,
checking_id: Optional[str] = None, checking_id: Optional[str] = None,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> Optional[MeltQuote]: ... ) -> Optional[MeltQuote]:
...
@abstractmethod @abstractmethod
async def update_melt_quote( async def update_melt_quote(
@@ -218,7 +237,8 @@ class LedgerCrud(ABC):
quote: MeltQuote, quote: MeltQuote,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None:
...
class LedgerCrudSqlite(LedgerCrud): class LedgerCrudSqlite(LedgerCrud):
@@ -586,8 +606,8 @@ class LedgerCrudSqlite(LedgerCrud):
await (conn or db).execute( # type: ignore await (conn or db).execute( # type: ignore
f""" f"""
INSERT INTO {table_with_schema(db, 'keysets')} INSERT INTO {table_with_schema(db, 'keysets')}
(id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit) (id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
( (
keyset.id, keyset.id,
@@ -601,6 +621,7 @@ class LedgerCrudSqlite(LedgerCrud):
True, True,
keyset.version, keyset.version,
keyset.unit.name, keyset.unit.name,
keyset.input_fee_ppk,
), ),
) )

View File

@@ -14,9 +14,6 @@ from ..core.base import (
Method, Method,
MintKeyset, MintKeyset,
MintQuote, MintQuote,
PostMeltQuoteRequest,
PostMeltQuoteResponse,
PostMintQuoteRequest,
Proof, Proof,
ProofState, ProofState,
SpentState, SpentState,
@@ -40,6 +37,11 @@ from ..core.errors import (
TransactionError, TransactionError,
) )
from ..core.helpers import sum_proofs from ..core.helpers import sum_proofs
from ..core.models import (
PostMeltQuoteRequest,
PostMeltQuoteResponse,
PostMintQuoteRequest,
)
from ..core.settings import settings from ..core.settings import settings
from ..core.split import amount_split from ..core.split import amount_split
from ..lightning.base import ( from ..lightning.base import (
@@ -216,6 +218,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
seed=seed or self.seed, seed=seed or self.seed,
derivation_path=derivation_path, derivation_path=derivation_path,
version=version or settings.version, version=version or settings.version,
input_fee_ppk=settings.mint_input_fee_ppk,
) )
logger.debug(f"Generated new keyset {keyset.id}.") logger.debug(f"Generated new keyset {keyset.id}.")
if autosave: if autosave:
@@ -298,9 +301,8 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
async def _generate_change_promises( async def _generate_change_promises(
self, self,
input_amount: int, fee_provided: int,
output_amount: int, fee_paid: int,
output_fee_paid: int,
outputs: Optional[List[BlindedMessage]], outputs: Optional[List[BlindedMessage]],
keyset: Optional[MintKeyset] = None, keyset: Optional[MintKeyset] = None,
) -> List[BlindedSignature]: ) -> List[BlindedSignature]:
@@ -326,14 +328,16 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
List[BlindedSignature]: Signatures on the outputs. List[BlindedSignature]: Signatures on the outputs.
""" """
# we make sure that the fee is positive # we make sure that the fee is positive
user_fee_paid = input_amount - output_amount overpaid_fee = fee_provided - fee_paid
overpaid_fee = user_fee_paid - output_fee_paid
if overpaid_fee == 0 or outputs is None:
return []
logger.debug( logger.debug(
f"Lightning fee was: {output_fee_paid}. User paid: {user_fee_paid}. " f"Lightning fee was: {fee_paid}. User provided: {fee_provided}. "
f"Returning difference: {overpaid_fee}." f"Returning difference: {overpaid_fee}."
) )
if overpaid_fee > 0 and outputs is not None:
return_amounts = amount_split(overpaid_fee) return_amounts = amount_split(overpaid_fee)
# We return at most as many outputs as were provided or as many as are # We return at most as many outputs as were provided or as many as are
@@ -342,6 +346,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
# we only need as many outputs as we have change to return # we only need as many outputs as we have change to return
outputs = outputs[:n_return_outputs] outputs = outputs[:n_return_outputs]
# we sort the return_amounts in descending order so we only # we sort the return_amounts in descending order so we only
# take the largest values in the next step # take the largest values in the next step
return_amounts_sorted = sorted(return_amounts, reverse=True) return_amounts_sorted = sorted(return_amounts, reverse=True)
@@ -352,8 +357,6 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
raise TransactionError("duplicate promises.") raise TransactionError("duplicate promises.")
return_promises = await self._generate_promises(outputs, keyset) return_promises = await self._generate_promises(outputs, keyset)
return return_promises return return_promises
else:
return []
# ------- TRANSACTIONS ------- # ------- TRANSACTIONS -------
@@ -488,18 +491,14 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
logger.trace("called mint") logger.trace("called mint")
await self._verify_outputs(outputs) await self._verify_outputs(outputs)
sum_amount_outputs = sum([b.amount for b in outputs]) sum_amount_outputs = sum([b.amount for b in outputs])
# we already know from _verify_outputs that all outputs have the same unit because they have the same keyset
output_units = set([k.unit for k in [self.keysets[o.id] for o in outputs]]) output_unit = self.keysets[outputs[0].id].unit
if not len(output_units) == 1:
raise TransactionError("outputs have different units")
output_unit = list(output_units)[0]
self.locks[quote_id] = ( self.locks[quote_id] = (
self.locks.get(quote_id) or asyncio.Lock() self.locks.get(quote_id) or asyncio.Lock()
) # create a new lock if it doesn't exist ) # create a new lock if it doesn't exist
async with self.locks[quote_id]: async with self.locks[quote_id]:
quote = await self.get_mint_quote(quote_id=quote_id) quote = await self.get_mint_quote(quote_id=quote_id)
if not quote.paid: if not quote.paid:
raise QuoteNotPaidError() raise QuoteNotPaidError()
if quote.issued: if quote.issued:
@@ -564,14 +563,17 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
if not mint_quote.checking_id: if not mint_quote.checking_id:
raise TransactionError("mint quote has no checking id") raise TransactionError("mint quote has no checking id")
internal_fee = Amount(unit, 0) # no internal fees
amount = Amount(unit, mint_quote.amount)
payment_quote = PaymentQuoteResponse( payment_quote = PaymentQuoteResponse(
checking_id=mint_quote.checking_id, checking_id=mint_quote.checking_id,
amount=Amount(unit, mint_quote.amount), amount=amount,
fee=Amount(unit, amount=0), fee=internal_fee,
) )
logger.info( logger.info(
f"Issuing internal melt quote: {request} ->" f"Issuing internal melt quote: {request} ->"
f" {mint_quote.quote} ({mint_quote.amount} {mint_quote.unit})" f" {mint_quote.quote} ({amount.str()} + {internal_fee.str()} fees)"
) )
else: else:
# not internal, get payment quote by backend # not internal, get payment quote by backend
@@ -586,6 +588,15 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
if not payment_quote.fee.unit == unit: if not payment_quote.fee.unit == unit:
raise TransactionError("payment quote fee units do not match") raise TransactionError("payment quote fee units do not match")
# verify that the amount of the proofs is not larger than the maximum allowed
if (
settings.mint_max_peg_out
and payment_quote.amount.to(unit).amount > settings.mint_max_peg_out
):
raise NotAllowedError(
f"Maximum melt amount is {settings.mint_max_peg_out} sat."
)
# We assume that the request is a bolt11 invoice, this works since we # We assume that the request is a bolt11 invoice, this works since we
# support only the bol11 method for now. # support only the bol11 method for now.
invoice_obj = bolt11.decode(melt_quote.request) invoice_obj = bolt11.decode(melt_quote.request)
@@ -667,11 +678,16 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
return melt_quote return melt_quote
async def melt_mint_settle_internally(self, melt_quote: MeltQuote) -> MeltQuote: async def melt_mint_settle_internally(
self, melt_quote: MeltQuote, proofs: List[Proof]
) -> MeltQuote:
"""Settles a melt quote internally if there is a mint quote with the same payment request. """Settles a melt quote internally if there is a mint quote with the same payment request.
`proofs` are passed to determine the ecash input transaction fees for this melt quote.
Args: Args:
melt_quote (MeltQuote): Melt quote to settle. melt_quote (MeltQuote): Melt quote to settle.
proofs (List[Proof]): Proofs provided for paying the Lightning invoice.
Raises: Raises:
Exception: Melt quote already paid. Exception: Melt quote already paid.
@@ -687,6 +703,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
) )
if not mint_quote: if not mint_quote:
return melt_quote return melt_quote
# we settle the transaction internally # we settle the transaction internally
if melt_quote.paid: if melt_quote.paid:
raise TransactionError("melt quote already paid") raise TransactionError("melt quote already paid")
@@ -715,15 +732,16 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
f" {mint_quote.quote} ({melt_quote.amount} {melt_quote.unit})" f" {mint_quote.quote} ({melt_quote.amount} {melt_quote.unit})"
) )
# we handle this transaction internally melt_quote.fee_paid = 0 # no internal fees
melt_quote.fee_paid = 0
melt_quote.paid = True melt_quote.paid = True
melt_quote.paid_time = int(time.time()) melt_quote.paid_time = int(time.time())
await self.crud.update_melt_quote(quote=melt_quote, db=self.db)
mint_quote.paid = True mint_quote.paid = True
mint_quote.paid_time = melt_quote.paid_time mint_quote.paid_time = melt_quote.paid_time
await self.crud.update_mint_quote(quote=mint_quote, db=self.db)
async with self.db.connect() as conn:
await self.crud.update_melt_quote(quote=melt_quote, db=self.db, conn=conn)
await self.crud.update_mint_quote(quote=mint_quote, db=self.db, conn=conn)
return melt_quote return melt_quote
@@ -759,6 +777,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
# make sure that the outputs (for fee return) are in the same unit as the quote # make sure that the outputs (for fee return) are in the same unit as the quote
if outputs: if outputs:
# _verify_outputs checks if all outputs have the same unit
await self._verify_outputs(outputs, skip_amount_check=True) await self._verify_outputs(outputs, skip_amount_check=True)
outputs_unit = self.keysets[outputs[0].id].unit outputs_unit = self.keysets[outputs[0].id].unit
if not melt_quote.unit == outputs_unit.name: if not melt_quote.unit == outputs_unit.name:
@@ -768,11 +787,18 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
# verify that the amount of the input proofs is equal to the amount of the quote # verify that the amount of the input proofs is equal to the amount of the quote
total_provided = sum_proofs(proofs) total_provided = sum_proofs(proofs)
total_needed = melt_quote.amount + (melt_quote.fee_reserve or 0) input_fees = self.get_fees_for_proofs(proofs)
if not total_provided >= total_needed: total_needed = melt_quote.amount + melt_quote.fee_reserve + input_fees
# we need the fees specifically for lightning to return the overpaid fees
fee_reserve_provided = total_provided - melt_quote.amount - input_fees
if total_provided < total_needed:
raise TransactionError( raise TransactionError(
f"not enough inputs provided for melt. Provided: {total_provided}, needed: {total_needed}" f"not enough inputs provided for melt. Provided: {total_provided}, needed: {total_needed}"
) )
if fee_reserve_provided < melt_quote.fee_reserve:
raise TransactionError(
f"not enough fee reserve provided for melt. Provided fee reserve: {fee_reserve_provided}, needed: {melt_quote.fee_reserve}"
)
# verify that the amount of the proofs is not larger than the maximum allowed # verify that the amount of the proofs is not larger than the maximum allowed
if settings.mint_max_peg_out and total_provided > settings.mint_max_peg_out: if settings.mint_max_peg_out and total_provided > settings.mint_max_peg_out:
@@ -789,7 +815,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
await self._set_proofs_pending(proofs, quote_id=melt_quote.quote) await self._set_proofs_pending(proofs, quote_id=melt_quote.quote)
try: try:
# settle the transaction internally if there is a mint quote with the same payment request # settle the transaction internally if there is a mint quote with the same payment request
melt_quote = await self.melt_mint_settle_internally(melt_quote) melt_quote = await self.melt_mint_settle_internally(melt_quote, proofs)
# quote not paid yet (not internal), pay it with the backend # quote not paid yet (not internal), pay it with the backend
if not melt_quote.paid: if not melt_quote.paid:
logger.debug(f"Lightning: pay invoice {melt_quote.request}") logger.debug(f"Lightning: pay invoice {melt_quote.request}")
@@ -822,9 +848,8 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
return_promises: List[BlindedSignature] = [] return_promises: List[BlindedSignature] = []
if outputs: if outputs:
return_promises = await self._generate_change_promises( return_promises = await self._generate_change_promises(
input_amount=total_provided, fee_provided=fee_reserve_provided,
output_amount=melt_quote.amount, fee_paid=melt_quote.fee_paid,
output_fee_paid=melt_quote.fee_paid,
outputs=outputs, outputs=outputs,
keyset=self.keysets[outputs[0].id], keyset=self.keysets[outputs[0].id],
) )
@@ -898,12 +923,6 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
b_=output.B_, db=self.db, conn=conn b_=output.B_, db=self.db, conn=conn
) )
if promise is not None: if promise is not None:
# BEGIN backwards compatibility mints pre `m007_proofs_and_promises_store_id`
# add keyset id to promise if not present only if the current keyset
# is the only one ever used
if not promise.id and len(self.keysets) == 1:
promise.id = self.keyset.id
# END backwards compatibility
signatures.append(promise) signatures.append(promise)
return_outputs.append(output) return_outputs.append(output)
logger.trace(f"promise found: {promise}") logger.trace(f"promise found: {promise}")

View File

@@ -763,3 +763,13 @@ async def m018_duplicate_deprecated_keyset_ids(db: Database):
keyset.seed_encryption_method, keyset.seed_encryption_method,
), ),
) )
async def m019_add_fee_to_keysets(db: Database):
async with db.connect() as conn:
await conn.execute(
f"ALTER TABLE {table_with_schema(db, 'keysets')} ADD COLUMN input_fee_ppk INTEGER"
)
await conn.execute(
f"UPDATE {table_with_schema(db, 'keysets')} SET input_fee_ppk = 0"
)

View File

@@ -3,7 +3,8 @@ from typing import Any, Dict, List
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from loguru import logger from loguru import logger
from ..core.base import ( from ..core.errors import KeysetNotFoundError
from ..core.models import (
GetInfoResponse, GetInfoResponse,
KeysetsResponse, KeysetsResponse,
KeysetsResponseKeyset, KeysetsResponseKeyset,
@@ -25,7 +26,6 @@ from ..core.base import (
PostSplitRequest, PostSplitRequest,
PostSplitResponse, PostSplitResponse,
) )
from ..core.errors import KeysetNotFoundError
from ..core.settings import settings from ..core.settings import settings
from ..mint.startup import ledger from ..mint.startup import ledger
from .limit import limiter from .limit import limiter
@@ -182,7 +182,10 @@ async def keysets() -> KeysetsResponse:
for id, keyset in ledger.keysets.items(): for id, keyset in ledger.keysets.items():
keysets.append( keysets.append(
KeysetsResponseKeyset( KeysetsResponseKeyset(
id=id, unit=keyset.unit.name, active=keyset.active or False id=keyset.id,
unit=keyset.unit.name,
active=keyset.active,
input_fee_ppk=keyset.input_fee_ppk,
) )
) )
return KeysetsResponse(keysets=keysets) return KeysetsResponse(keysets=keysets)

View File

@@ -3,9 +3,9 @@ from typing import Dict, List, Optional
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from loguru import logger from loguru import logger
from ..core.base import ( from ..core.base import BlindedMessage, BlindedSignature, SpentState
BlindedMessage, from ..core.errors import CashuError
BlindedSignature, from ..core.models import (
CheckFeesRequest_deprecated, CheckFeesRequest_deprecated,
CheckFeesResponse_deprecated, CheckFeesResponse_deprecated,
CheckSpendableRequest_deprecated, CheckSpendableRequest_deprecated,
@@ -25,9 +25,7 @@ from ..core.base import (
PostSplitRequest_Deprecated, PostSplitRequest_Deprecated,
PostSplitResponse_Deprecated, PostSplitResponse_Deprecated,
PostSplitResponse_Very_Deprecated, PostSplitResponse_Very_Deprecated,
SpentState,
) )
from ..core.errors import CashuError
from ..core.settings import settings from ..core.settings import settings
from .limit import limiter from .limit import limiter
from .startup import ledger from .startup import ledger

View File

@@ -1,3 +1,4 @@
import math
from typing import Dict, List, Literal, Optional, Tuple, Union from typing import Dict, List, Literal, Optional, Tuple, Union
from loguru import logger from loguru import logger
@@ -19,6 +20,7 @@ from ..core.errors import (
SecretTooLongError, SecretTooLongError,
TokenAlreadySpentError, TokenAlreadySpentError,
TransactionError, TransactionError,
TransactionUnitError,
) )
from ..core.settings import settings from ..core.settings import settings
from ..lightning.base import LightningBackend from ..lightning.base import LightningBackend
@@ -75,7 +77,8 @@ class LedgerVerification(
if not all([self._verify_input_spending_conditions(p) for p in proofs]): if not all([self._verify_input_spending_conditions(p) for p in proofs]):
raise TransactionError("validation of input spending conditions failed.") raise TransactionError("validation of input spending conditions failed.")
if not outputs: if outputs is None:
# If no outputs are provided, we are melting
return return
# Verify input and output amounts # Verify input and output amounts
@@ -94,7 +97,6 @@ class LedgerVerification(
[ [
self.keysets[p.id].unit == self.keysets[outputs[0].id].unit self.keysets[p.id].unit == self.keysets[outputs[0].id].unit
for p in proofs for p in proofs
if p.id
] ]
): ):
raise TransactionError("input and output keysets have different units.") raise TransactionError("input and output keysets have different units.")
@@ -108,6 +110,8 @@ class LedgerVerification(
): ):
"""Verify that the outputs are valid.""" """Verify that the outputs are valid."""
logger.trace(f"Verifying {len(outputs)} outputs.") logger.trace(f"Verifying {len(outputs)} outputs.")
if not outputs:
raise TransactionError("no outputs provided.")
# Verify all outputs have the same keyset id # Verify all outputs have the same keyset id
if not all([o.id == outputs[0].id for o in outputs]): if not all([o.id == outputs[0].id for o in outputs]):
raise TransactionError("outputs have different keyset ids.") raise TransactionError("outputs have different keyset ids.")
@@ -182,16 +186,14 @@ class LedgerVerification(
"""Verifies that a secret is present and is not too long (DOS prevention).""" """Verifies that a secret is present and is not too long (DOS prevention)."""
if proof.secret is None or proof.secret == "": if proof.secret is None or proof.secret == "":
raise NoSecretInProofsError() raise NoSecretInProofsError()
if len(proof.secret) > 512: if len(proof.secret) > settings.mint_max_secret_length:
raise SecretTooLongError() raise SecretTooLongError(
f"secret too long. max: {settings.mint_max_secret_length}"
)
return True return True
def _verify_proof_bdhke(self, proof: Proof) -> bool: def _verify_proof_bdhke(self, proof: Proof) -> bool:
"""Verifies that the proof of promise was issued by this ledger.""" """Verifies that the proof of promise was issued by this ledger."""
# if no keyset id is given in proof, assume the current one
if not proof.id:
private_key_amount = self.keyset.private_keys[proof.amount]
else:
assert proof.id in self.keysets, f"keyset {proof.id} unknown" assert proof.id in self.keysets, f"keyset {proof.id} unknown"
logger.trace( logger.trace(
f"Validating proof {proof.secret} with keyset" f"Validating proof {proof.secret} with keyset"
@@ -231,23 +233,53 @@ class LedgerVerification(
def _verify_amount(self, amount: int) -> int: def _verify_amount(self, amount: int) -> int:
"""Any amount used should be positive and not larger than 2^MAX_ORDER.""" """Any amount used should be positive and not larger than 2^MAX_ORDER."""
valid = amount > 0 and amount < 2**settings.max_order valid = amount > 0 and amount < 2**settings.max_order
logger.trace(f"Verifying amount {amount} is valid: {valid}")
if not valid: if not valid:
raise NotAllowedError("invalid amount: " + str(amount)) raise NotAllowedError("invalid amount: " + str(amount))
return amount return amount
def _verify_equation_balanced( def _verify_units_match(
self, self,
proofs: List[Proof], proofs: List[Proof],
outs: Union[List[BlindedSignature], List[BlindedMessage]], outs: Union[List[BlindedSignature], List[BlindedMessage]],
) -> Unit:
"""Verifies that the units of the inputs and outputs match."""
units_proofs = [self.keysets[p.id].unit for p in proofs]
units_outputs = [self.keysets[o.id].unit for o in outs if o.id]
if not len(set(units_proofs)) == 1:
raise TransactionUnitError("inputs have different units.")
if not len(set(units_outputs)) == 1:
raise TransactionUnitError("outputs have different units.")
if not units_proofs[0] == units_outputs[0]:
raise TransactionUnitError("input and output keysets have different units.")
return units_proofs[0]
def get_fees_for_proofs(self, proofs: List[Proof]) -> int:
if not len(set([self.keysets[p.id].unit for p in proofs])) == 1:
raise TransactionUnitError("inputs have different units.")
fee = math.ceil(sum([self.keysets[p.id].input_fee_ppk for p in proofs]) / 1000)
return fee
def _verify_equation_balanced(
self,
proofs: List[Proof],
outs: List[BlindedMessage],
) -> None: ) -> None:
"""Verify that Σinputs - Σoutputs = 0. """Verify that Σinputs - Σoutputs = 0.
Outputs can be BlindedSignature or BlindedMessage. Outputs can be BlindedSignature or BlindedMessage.
""" """
if not proofs:
raise TransactionError("no proofs provided.")
if not outs:
raise TransactionError("no outputs provided.")
_ = self._verify_units_match(proofs, outs)
sum_inputs = sum(self._verify_amount(p.amount) for p in proofs) sum_inputs = sum(self._verify_amount(p.amount) for p in proofs)
fees_inputs = self.get_fees_for_proofs(proofs)
sum_outputs = sum(self._verify_amount(p.amount) for p in outs) sum_outputs = sum(self._verify_amount(p.amount) for p in outs)
if not sum_outputs - sum_inputs == 0: if not sum_outputs + fees_inputs - sum_inputs == 0:
raise TransactionError("inputs do not have same amount as outputs.") raise TransactionError(
f"inputs ({sum_inputs}) - fees ({fees_inputs}) vs outputs ({sum_outputs}) are not balanced."
)
def _verify_and_get_unit_method( def _verify_and_get_unit_method(
self, unit_str: str, method_str: str self, unit_str: str, method_str: str

View File

@@ -189,7 +189,7 @@ async def swap(
# pay invoice from outgoing mint # pay invoice from outgoing mint
await outgoing_wallet.load_proofs(reload=True) await outgoing_wallet.load_proofs(reload=True)
quote = await outgoing_wallet.request_melt(invoice.bolt11) quote = await outgoing_wallet.melt_quote(invoice.bolt11)
total_amount = quote.amount + quote.fee_reserve total_amount = quote.amount + quote.fee_reserve
if outgoing_wallet.available_balance < total_amount: if outgoing_wallet.available_balance < total_amount:
raise Exception("balance too low") raise Exception("balance too low")
@@ -237,16 +237,14 @@ async def send_command(
default=None, default=None,
description="Mint URL to send from (None for default mint)", description="Mint URL to send from (None for default mint)",
), ),
nosplit: bool = Query( offline: bool = Query(default=False, description="Force offline send."),
default=False, description="Do not split tokens before sending."
),
): ):
global wallet global wallet
if mint: if mint:
wallet = await mint_wallet(mint) wallet = await mint_wallet(mint)
if not nostr: if not nostr:
balance, token = await send( balance, token = await send(
wallet, amount=amount, lock=lock, legacy=False, split=not nosplit wallet, amount=amount, lock=lock, legacy=False, offline=offline
) )
return SendResponse(balance=balance, token=token) return SendResponse(balance=balance, token=token)
else: else:

View File

@@ -138,7 +138,8 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool):
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj["HOST"] = host or settings.mint_url ctx.obj["HOST"] = host or settings.mint_url
ctx.obj["UNIT"] = unit ctx.obj["UNIT"] = unit or settings.wallet_unit
unit = ctx.obj["UNIT"]
ctx.obj["WALLET_NAME"] = walletname ctx.obj["WALLET_NAME"] = walletname
settings.wallet_name = walletname settings.wallet_name = walletname
@@ -147,16 +148,18 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool):
# otherwise it will create a mnemonic and store it in the database # otherwise it will create a mnemonic and store it in the database
if ctx.invoked_subcommand == "restore": if ctx.invoked_subcommand == "restore":
wallet = await Wallet.with_db( wallet = await Wallet.with_db(
ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True, unit=unit
) )
else: else:
# # we need to run the migrations before we load the wallet for the first time # # we need to run the migrations before we load the wallet for the first time
# # otherwise the wallet will not be able to generate a new private key and store it # # otherwise the wallet will not be able to generate a new private key and store it
wallet = await Wallet.with_db( wallet = await Wallet.with_db(
ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True, unit=unit
) )
# now with the migrations done, we can load the wallet and generate a new mnemonic if needed # now with the migrations done, we can load the wallet and generate a new mnemonic if needed
wallet = await Wallet.with_db(ctx.obj["HOST"], db_path, name=walletname) wallet = await Wallet.with_db(
ctx.obj["HOST"], db_path, name=walletname, unit=unit
)
assert wallet, "Wallet not found." assert wallet, "Wallet not found."
ctx.obj["WALLET"] = wallet ctx.obj["WALLET"] = wallet
@@ -193,7 +196,7 @@ async def pay(
wallet: Wallet = ctx.obj["WALLET"] wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_mint() await wallet.load_mint()
await print_balance(ctx) await print_balance(ctx)
quote = await wallet.request_melt(invoice, amount) quote = await wallet.melt_quote(invoice, amount)
logger.debug(f"Quote: {quote}") logger.debug(f"Quote: {quote}")
total_amount = quote.amount + quote.fee_reserve total_amount = quote.amount + quote.fee_reserve
if not yes: if not yes:
@@ -214,7 +217,9 @@ async def pay(
if wallet.available_balance < total_amount: if wallet.available_balance < total_amount:
print(" Error: Balance too low.") print(" Error: Balance too low.")
return return
_, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount) send_proofs, fees = await wallet.select_to_send(
wallet.proofs, total_amount, include_fees=True
)
try: try:
melt_response = await wallet.melt( melt_response = await wallet.melt(
send_proofs, invoice, quote.fee_reserve, quote.quote send_proofs, invoice, quote.fee_reserve, quote.quote
@@ -341,11 +346,11 @@ async def swap(ctx: Context):
invoice = await incoming_wallet.request_mint(amount) invoice = await incoming_wallet.request_mint(amount)
# pay invoice from outgoing mint # pay invoice from outgoing mint
quote = await outgoing_wallet.request_melt(invoice.bolt11) quote = await outgoing_wallet.melt_quote(invoice.bolt11)
total_amount = quote.amount + quote.fee_reserve total_amount = quote.amount + quote.fee_reserve
if outgoing_wallet.available_balance < total_amount: if outgoing_wallet.available_balance < total_amount:
raise Exception("balance too low") raise Exception("balance too low")
_, send_proofs = await outgoing_wallet.split_to_send( send_proofs, fees = await outgoing_wallet.select_to_send(
outgoing_wallet.proofs, total_amount, set_reserved=True outgoing_wallet.proofs, total_amount, set_reserved=True
) )
await outgoing_wallet.melt( await outgoing_wallet.melt(
@@ -372,8 +377,9 @@ async def swap(ctx: Context):
@coro @coro
async def balance(ctx: Context, verbose): async def balance(ctx: Context, verbose):
wallet: Wallet = ctx.obj["WALLET"] wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_proofs(unit=False)
unit_balances = wallet.balance_per_unit() unit_balances = wallet.balance_per_unit()
await wallet.load_proofs(reload=True)
if len(unit_balances) > 1 and not ctx.obj["UNIT"]: if len(unit_balances) > 1 and not ctx.obj["UNIT"]:
print(f"You have balances in {len(unit_balances)} units:") print(f"You have balances in {len(unit_balances)} units:")
print("") print("")
@@ -397,7 +403,6 @@ async def balance(ctx: Context, verbose):
await print_mint_balances(wallet) await print_mint_balances(wallet)
await wallet.load_proofs(reload=True)
if verbose: if verbose:
print( print(
f"Balance: {wallet.unit.str(wallet.available_balance)} (pending:" f"Balance: {wallet.unit.str(wallet.available_balance)} (pending:"
@@ -447,11 +452,19 @@ async def balance(ctx: Context, verbose):
"--yes", "-y", default=False, is_flag=True, help="Skip confirmation.", type=bool "--yes", "-y", default=False, is_flag=True, help="Skip confirmation.", type=bool
) )
@click.option( @click.option(
"--nosplit", "--offline",
"-s", "-o",
default=False, default=False,
is_flag=True, is_flag=True,
help="Do not split tokens before sending.", help="Force offline send.",
type=bool,
)
@click.option(
"--include-fees",
"-f",
default=False,
is_flag=True,
help="Include fees for receiving token.",
type=bool, type=bool,
) )
@click.pass_context @click.pass_context
@@ -466,7 +479,8 @@ async def send_command(
legacy: bool, legacy: bool,
verbose: bool, verbose: bool,
yes: bool, yes: bool,
nosplit: bool, offline: bool,
include_fees: bool,
): ):
wallet: Wallet = ctx.obj["WALLET"] wallet: Wallet = ctx.obj["WALLET"]
amount = int(amount * 100) if wallet.unit == Unit.usd else int(amount) amount = int(amount * 100) if wallet.unit == Unit.usd else int(amount)
@@ -476,8 +490,9 @@ async def send_command(
amount=amount, amount=amount,
lock=lock, lock=lock,
legacy=legacy, legacy=legacy,
split=not nosplit, offline=offline,
include_dleq=dleq, include_dleq=dleq,
include_fees=include_fees,
) )
else: else:
await send_nostr( await send_nostr(
@@ -514,7 +529,9 @@ async def receive_cli(
# ask the user if they want to trust the new mints # ask the user if they want to trust the new mints
for mint_url in set([t.mint for t in tokenObj.token if t.mint]): for mint_url in set([t.mint for t in tokenObj.token if t.mint]):
mint_wallet = Wallet( mint_wallet = Wallet(
mint_url, os.path.join(settings.cashu_dir, wallet.name) mint_url,
os.path.join(settings.cashu_dir, wallet.name),
unit=tokenObj.unit or wallet.unit.name,
) )
await verify_mint(mint_wallet, mint_url) await verify_mint(mint_wallet, mint_url)
receive_wallet = await receive(wallet, tokenObj) receive_wallet = await receive(wallet, tokenObj)
@@ -853,6 +870,8 @@ async def wallets(ctx):
@coro @coro
async def info(ctx: Context, mint: bool, mnemonic: bool): async def info(ctx: Context, mint: bool, mnemonic: bool):
wallet: Wallet = ctx.obj["WALLET"] wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_keysets_from_db(unit=None)
print(f"Version: {settings.version}") print(f"Version: {settings.version}")
print(f"Wallet: {ctx.obj['WALLET_NAME']}") print(f"Wallet: {ctx.obj['WALLET_NAME']}")
if settings.debug: if settings.debug:
@@ -861,30 +880,38 @@ async def info(ctx: Context, mint: bool, mnemonic: bool):
mint_list = await list_mints(wallet) mint_list = await list_mints(wallet)
print("Mints:") print("Mints:")
for mint_url in mint_list: for mint_url in mint_list:
print(f" - {mint_url}") print(f" - URL: {mint_url}")
keysets_strs = [
f"ID: {k.id} unit: {k.unit.name} active: {str(bool(k.active)) + ' ' if k.active else str(bool(k.active))} fee (ppk): {k.input_fee_ppk}"
for k in wallet.keysets.values()
]
if keysets_strs:
print(" - Keysets:")
for k in keysets_strs:
print(f" - {k}")
if mint: if mint:
wallet.url = mint_url wallet.url = mint_url
try: try:
mint_info: dict = (await wallet._load_mint_info()).dict() mint_info: dict = (await wallet.load_mint_info()).dict()
print("")
print("---- Mint information ----")
print("")
print(f"Mint URL: {mint_url}")
if mint_info: if mint_info:
print(f"Mint name: {mint_info['name']}") print(f" - Mint name: {mint_info['name']}")
if mint_info.get("description"): if mint_info.get("description"):
print(f"Description: {mint_info['description']}") print(f" - Description: {mint_info['description']}")
if mint_info.get("description_long"): if mint_info.get("description_long"):
print(f"Long description: {mint_info['description_long']}") print(
if mint_info.get("contact"): f" - Long description: {mint_info['description_long']}"
print(f"Contact: {mint_info['contact']}") )
if mint_info.get("contact") and mint_info.get("contact") != [
["", ""]
]:
print(f" - Contact: {mint_info['contact']}")
if mint_info.get("version"): if mint_info.get("version"):
print(f"Version: {mint_info['version']}") print(f" - Version: {mint_info['version']}")
if mint_info.get("motd"): if mint_info.get("motd"):
print(f"Message of the day: {mint_info['motd']}") print(f" - Message of the day: {mint_info['motd']}")
if mint_info.get("nuts"): if mint_info.get("nuts"):
print( print(
"Supported NUTS:" " - Supported NUTS:"
f" {', '.join(['NUT-'+str(k) for k in mint_info['nuts'].keys()])}" f" {', '.join(['NUT-'+str(k) for k in mint_info['nuts'].keys()])}"
) )
print("") print("")
@@ -896,14 +923,16 @@ async def info(ctx: Context, mint: bool, mnemonic: bool):
assert wallet.mnemonic assert wallet.mnemonic
print(f"Mnemonic:\n - {wallet.mnemonic}") print(f"Mnemonic:\n - {wallet.mnemonic}")
if settings.env_file: if settings.env_file:
print(f"Settings: {settings.env_file}") print("Settings:")
print(f" - File: {settings.env_file}")
if settings.tor: if settings.tor:
print(f"Tor enabled: {settings.tor}") print(f"Tor enabled: {settings.tor}")
if settings.nostr_private_key: if settings.nostr_private_key:
try: try:
client = NostrClient(private_key=settings.nostr_private_key, connect=False) client = NostrClient(private_key=settings.nostr_private_key, connect=False)
print(f"Nostr public key: {client.public_key.bech32()}") print("Nostr:")
print(f"Nostr relays: {', '.join(settings.nostr_relays)}") print(f" - Public key: {client.public_key.bech32()}")
print(f" - Relays: {', '.join(settings.nostr_relays)}")
except Exception: except Exception:
print("Nostr: Error. Invalid key.") print("Nostr: Error. Invalid key.")
if settings.socks_proxy: if settings.socks_proxy:
@@ -972,7 +1001,9 @@ async def selfpay(ctx: Context, all: bool = False):
mint_balance_dict = await wallet.balance_per_minturl() mint_balance_dict = await wallet.balance_per_minturl()
mint_balance = int(mint_balance_dict[wallet.url]["available"]) mint_balance = int(mint_balance_dict[wallet.url]["available"])
# send balance once to mark as reserved # send balance once to mark as reserved
await wallet.split_to_send(wallet.proofs, mint_balance, None, set_reserved=True) await wallet.select_to_send(
wallet.proofs, mint_balance, set_reserved=True, include_fees=False
)
# load all reserved proofs (including the one we just sent) # load all reserved proofs (including the one we just sent)
reserved_proofs = await get_reserved_proofs(wallet.db) reserved_proofs = await get_reserved_proofs(wallet.db)
if not len(reserved_proofs): if not len(reserved_proofs):

View File

@@ -12,7 +12,7 @@ from ...wallet.wallet import Wallet as Wallet
async def print_balance(ctx: Context): async def print_balance(ctx: Context):
wallet: Wallet = ctx.obj["WALLET"] wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_proofs(reload=True, unit=wallet.unit) await wallet.load_proofs(reload=True)
print(f"Balance: {wallet.unit.str(wallet.available_balance)}") print(f"Balance: {wallet.unit.str(wallet.available_balance)}")
@@ -24,11 +24,11 @@ async def get_unit_wallet(ctx: Context, force_select: bool = False):
force_select (bool, optional): Force the user to select a unit. Defaults to False. force_select (bool, optional): Force the user to select a unit. Defaults to False.
""" """
wallet: Wallet = ctx.obj["WALLET"] wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_proofs(reload=True, unit=False) await wallet.load_proofs(reload=False)
# show balances per unit # show balances per unit
unit_balances = wallet.balance_per_unit() unit_balances = wallet.balance_per_unit()
if ctx.obj["UNIT"] in [u.name for u in unit_balances] and not force_select: if wallet.unit in [unit_balances.keys()] and not force_select:
wallet.unit = Unit[ctx.obj["UNIT"]] return wallet
elif len(unit_balances) > 1 and not ctx.obj["UNIT"]: elif len(unit_balances) > 1 and not ctx.obj["UNIT"]:
print(f"You have balances in {len(unit_balances)} units:") print(f"You have balances in {len(unit_balances)} units:")
print("") print("")
@@ -68,7 +68,7 @@ async def get_mint_wallet(ctx: Context, force_select: bool = False):
""" """
# we load a dummy wallet so we can check the balance per mint # we load a dummy wallet so we can check the balance per mint
wallet: Wallet = ctx.obj["WALLET"] wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_proofs(reload=True) await wallet.load_proofs(reload=False)
mint_balances = await wallet.balance_per_minturl() mint_balances = await wallet.balance_per_minturl()
if ctx.obj["HOST"] not in mint_balances and not force_select: if ctx.obj["HOST"] not in mint_balances and not force_select:
@@ -102,6 +102,7 @@ async def get_mint_wallet(ctx: Context, force_select: bool = False):
mint_url, mint_url,
os.path.join(settings.cashu_dir, ctx.obj["WALLET_NAME"]), os.path.join(settings.cashu_dir, ctx.obj["WALLET_NAME"]),
name=wallet.name, name=wallet.name,
unit=wallet.unit.name,
) )
await mint_wallet.load_proofs(reload=True) await mint_wallet.load_proofs(reload=True)

View File

@@ -34,6 +34,7 @@ async def store_proof(
async def get_proofs( async def get_proofs(
*, *,
db: Database, db: Database,
id: Optional[str] = "",
melt_id: str = "", melt_id: str = "",
mint_id: str = "", mint_id: str = "",
table: str = "proofs", table: str = "proofs",
@@ -42,6 +43,9 @@ async def get_proofs(
clauses = [] clauses = []
values: List[Any] = [] values: List[Any] = []
if id:
clauses.append("id = ?")
values.append(id)
if melt_id: if melt_id:
clauses.append("melt_id = ?") clauses.append("melt_id = ?")
values.append(melt_id) values.append(melt_id)
@@ -169,8 +173,8 @@ async def store_keyset(
await (conn or db).execute( # type: ignore await (conn or db).execute( # type: ignore
""" """
INSERT INTO keysets INSERT INTO keysets
(id, mint_url, valid_from, valid_to, first_seen, active, public_keys, unit) (id, mint_url, valid_from, valid_to, first_seen, active, public_keys, unit, input_fee_ppk)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
( (
keyset.id, keyset.id,
@@ -181,26 +185,29 @@ async def store_keyset(
keyset.active, keyset.active,
keyset.serialize(), keyset.serialize(),
keyset.unit.name, keyset.unit.name,
keyset.input_fee_ppk,
), ),
) )
async def get_keysets( async def get_keysets(
id: str = "", id: str = "",
mint_url: str = "", mint_url: Optional[str] = None,
unit: Optional[str] = None,
db: Optional[Database] = None, db: Optional[Database] = None,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> List[WalletKeyset]: ) -> List[WalletKeyset]:
clauses = [] clauses = []
values: List[Any] = [] values: List[Any] = []
clauses.append("active = ?")
values.append(True)
if id: if id:
clauses.append("id = ?") clauses.append("id = ?")
values.append(id) values.append(id)
if mint_url: if mint_url:
clauses.append("mint_url = ?") clauses.append("mint_url = ?")
values.append(mint_url) values.append(mint_url)
if unit:
clauses.append("unit = ?")
values.append(unit)
where = "" where = ""
if clauses: if clauses:
where = f"WHERE {' AND '.join(clauses)}" where = f"WHERE {' AND '.join(clauses)}"
@@ -219,6 +226,24 @@ async def get_keysets(
return ret return ret
async def update_keyset(
keyset: WalletKeyset,
db: Database,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
"""
UPDATE keysets
SET active = ?
WHERE id = ?
""",
(
keyset.active,
keyset.id,
),
)
async def store_lightning_invoice( async def store_lightning_invoice(
db: Database, db: Database,
invoice: Invoice, invoice: Invoice,

View File

@@ -40,23 +40,26 @@ async def redeem_TokenV3_multimint(wallet: Wallet, token: TokenV3) -> Wallet:
Helper function to iterate thruogh a token with multiple mints and redeem them from Helper function to iterate thruogh a token with multiple mints and redeem them from
these mints one keyset at a time. these mints one keyset at a time.
""" """
if not token.unit:
# load unit from wallet keyset db
keysets = await get_keysets(id=token.token[0].proofs[0].id, db=wallet.db)
if keysets:
token.unit = keysets[0].unit.name
for t in token.token: for t in token.token:
assert t.mint, Exception( assert t.mint, Exception(
"redeem_TokenV3_multimint: multimint redeem without URL" "redeem_TokenV3_multimint: multimint redeem without URL"
) )
mint_wallet = await Wallet.with_db( mint_wallet = await Wallet.with_db(
t.mint, os.path.join(settings.cashu_dir, wallet.name) t.mint,
os.path.join(settings.cashu_dir, wallet.name),
unit=token.unit or wallet.unit.name,
) )
keyset_ids = mint_wallet._get_proofs_keysets(t.proofs) keyset_ids = mint_wallet._get_proofs_keysets(t.proofs)
logger.trace(f"Keysets in tokens: {keyset_ids}") logger.trace(f"Keysets in tokens: {' '.join(set(keyset_ids))}")
# loop over all keysets await mint_wallet.load_mint()
for keyset_id in set(keyset_ids): proofs_to_keep, _ = await mint_wallet.redeem(t.proofs)
await mint_wallet.load_mint(keyset_id) print(f"Received {mint_wallet.unit.str(sum_proofs(proofs_to_keep))}")
mint_wallet.unit = mint_wallet.keysets[keyset_id].unit
# redeem proofs of this keyset
redeem_proofs = [p for p in t.proofs if p.id == keyset_id]
_, _ = await mint_wallet.redeem(redeem_proofs)
print(f"Received {mint_wallet.unit.str(sum_proofs(redeem_proofs))}")
# return the last mint_wallet # return the last mint_wallet
return mint_wallet return mint_wallet
@@ -137,19 +140,19 @@ async def receive(
) )
else: else:
# this is very legacy code, virtually any token should have mint information # this is very legacy code, virtually any token should have mint information
# no mint information present, we extract the proofs and use wallet's default mint # no mint information present, we extract the proofs find the mint and unit from the db
# first we load the mint URL from the DB
keyset_in_token = proofs[0].id keyset_in_token = proofs[0].id
assert keyset_in_token assert keyset_in_token
# we get the keyset from the db # we get the keyset from the db
mint_keysets = await get_keysets(id=keyset_in_token, db=wallet.db) mint_keysets = await get_keysets(id=keyset_in_token, db=wallet.db)
assert mint_keysets, Exception(f"we don't know this keyset: {keyset_in_token}") assert mint_keysets, Exception(f"we don't know this keyset: {keyset_in_token}")
mint_keyset = mint_keysets[0] mint_keyset = [k for k in mint_keysets if k.id == keyset_in_token][0]
assert mint_keyset.mint_url, Exception("we don't know this mint's URL") assert mint_keyset.mint_url, Exception("we don't know this mint's URL")
# now we have the URL # now we have the URL
mint_wallet = await Wallet.with_db( mint_wallet = await Wallet.with_db(
mint_keyset.mint_url, mint_keyset.mint_url,
os.path.join(settings.cashu_dir, wallet.name), os.path.join(settings.cashu_dir, wallet.name),
unit=mint_keyset.unit.name or wallet.unit.name,
) )
await mint_wallet.load_mint(keyset_in_token) await mint_wallet.load_mint(keyset_in_token)
_, _ = await mint_wallet.redeem(proofs) _, _ = await mint_wallet.redeem(proofs)
@@ -166,8 +169,9 @@ async def send(
amount: int, amount: int,
lock: str, lock: str,
legacy: bool, legacy: bool,
split: bool = True, offline: bool = False,
include_dleq: bool = False, include_dleq: bool = False,
include_fees: bool = False,
): ):
""" """
Prints token to send to stdout. Prints token to send to stdout.
@@ -191,23 +195,18 @@ async def send(
sig_all=True, sig_all=True,
n_sigs=1, n_sigs=1,
) )
print(f"Secret lock: {secret_lock}")
await wallet.load_proofs() await wallet.load_proofs()
if split:
await wallet.load_mint() await wallet.load_mint()
_, send_proofs = await wallet.split_to_send(
wallet.proofs, amount, secret_lock, set_reserved=True
)
else:
# get a proof with specific amount # get a proof with specific amount
send_proofs = [] send_proofs, fees = await wallet.select_to_send(
for p in wallet.proofs: wallet.proofs,
if not p.reserved and p.amount == amount: amount,
send_proofs = [p] set_reserved=False,
break offline=offline,
assert send_proofs, Exception( include_fees=include_fees,
"No proof with this amount found. Available amounts:"
f" {set([p.amount for p in wallet.proofs])}"
) )
token = await wallet.serialize_proofs( token = await wallet.serialize_proofs(

View File

@@ -55,7 +55,7 @@ class LightningWallet(Wallet):
Returns: Returns:
bool: True if successful bool: True if successful
""" """
quote = await self.request_melt(pr) quote = await self.melt_quote(pr)
total_amount = quote.amount + quote.fee_reserve total_amount = quote.amount + quote.fee_reserve
assert total_amount > 0, "amount is not positive" assert total_amount > 0, "amount is not positive"
if self.available_balance < total_amount: if self.available_balance < total_amount:

View File

@@ -236,3 +236,10 @@ async def m011_keysets_add_unit(db: Database):
# add column for storing the unit of a keyset # add column for storing the unit of a keyset
await conn.execute("ALTER TABLE keysets ADD COLUMN unit TEXT") await conn.execute("ALTER TABLE keysets ADD COLUMN unit TEXT")
await conn.execute("UPDATE keysets SET unit = 'sat'") await conn.execute("UPDATE keysets SET unit = 'sat'")
async def m012_add_fee_to_keysets(db: Database):
async with db.connect() as conn:
# add column for storing the fee of a keyset
await conn.execute("ALTER TABLE keysets ADD COLUMN input_fee_ppk INTEGER")
await conn.execute("UPDATE keysets SET input_fee_ppk = 0")

View File

@@ -2,7 +2,8 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from ..core.base import Nut15MppSupport, Unit from ..core.base import Unit
from ..core.models import Nut15MppSupport
class MintInfo(BaseModel): class MintInfo(BaseModel):

View File

@@ -63,7 +63,7 @@ async def send_nostr(
await wallet.load_mint() await wallet.load_mint()
await wallet.load_proofs() await wallet.load_proofs()
_, send_proofs = await wallet.split_to_send( _, send_proofs = await wallet.split_to_send(
wallet.proofs, amount, set_reserved=True wallet.proofs, amount, set_reserved=True, include_fees=False
) )
token = await wallet.serialize_proofs(send_proofs, include_dleq=include_dleq) token = await wallet.serialize_proofs(send_proofs, include_dleq=include_dleq)

208
cashu/wallet/proofs.py Normal file
View File

@@ -0,0 +1,208 @@
import base64
import json
from itertools import groupby
from typing import Dict, List, Optional
from loguru import logger
from ..core.base import (
Proof,
TokenV2,
TokenV2Mint,
TokenV3,
TokenV3Token,
Unit,
WalletKeyset,
)
from ..core.db import Database
from ..wallet.crud import (
get_keysets,
)
from .protocols import SupportsDb, SupportsKeysets
class WalletProofs(SupportsDb, SupportsKeysets):
keyset_id: str
db: Database
@staticmethod
def _get_proofs_per_keyset(proofs: List[Proof]):
return {
key: list(group) for key, group in groupby(proofs, lambda p: p.id) if key
}
async def _get_proofs_per_minturl(
self, proofs: List[Proof], unit: Optional[Unit] = None
) -> Dict[str, List[Proof]]:
ret: Dict[str, List[Proof]] = {}
keyset_ids = set([p.id for p in proofs])
for id in keyset_ids:
if id is None:
continue
keysets_crud = await get_keysets(id=id, db=self.db)
assert keysets_crud, f"keyset {id} not found"
keyset: WalletKeyset = keysets_crud[0]
if unit and keyset.unit != unit:
continue
assert keyset.mint_url
if keyset.mint_url not in ret:
ret[keyset.mint_url] = [p for p in proofs if p.id == id]
else:
ret[keyset.mint_url].extend([p for p in proofs if p.id == id])
return ret
def _get_proofs_per_unit(self, proofs: List[Proof]) -> Dict[Unit, List[Proof]]:
ret: Dict[Unit, List[Proof]] = {}
for proof in proofs:
if proof.id not in self.keysets:
logger.error(f"Keyset {proof.id} not found in wallet.")
continue
unit = self.keysets[proof.id].unit
if unit not in ret:
ret[unit] = [proof]
else:
ret[unit].append(proof)
return ret
def _get_proofs_keysets(self, proofs: List[Proof]) -> List[str]:
"""Extracts all keyset ids from a list of proofs.
Args:
proofs (List[Proof]): List of proofs to get the keyset id's of
"""
keysets: List[str] = [proof.id for proof in proofs]
return keysets
async def _get_keyset_urls(self, keysets: List[str]) -> Dict[str, List[str]]:
"""Retrieves the mint URLs for a list of keyset id's from the wallet's database.
Returns a dictionary from URL to keyset ID
Args:
keysets (List[str]): List of keysets.
"""
mint_urls: Dict[str, List[str]] = {}
for ks in set(keysets):
keysets_db = await get_keysets(id=ks, db=self.db)
keyset_db = keysets_db[0] if keysets_db else None
if keyset_db and keyset_db.mint_url:
mint_urls[keyset_db.mint_url] = (
mint_urls[keyset_db.mint_url] + [ks]
if mint_urls.get(keyset_db.mint_url)
else [ks]
)
return mint_urls
async def _make_token(
self, proofs: List[Proof], include_mints=True, include_unit=True
) -> TokenV3:
"""
Takes list of proofs and produces a TokenV3 by looking up
the mint URLs by the keyset id from the database.
Args:
proofs (List[Proof]): List of proofs to be included in the token
include_mints (bool, optional): Whether to include the mint URLs in the token. Defaults to True.
Returns:
TokenV3: TokenV3 object
"""
token = TokenV3()
if include_unit:
token.unit = self.unit.name
if include_mints:
# we create a map from mint url to keyset id and then group
# all proofs with their mint url to build a tokenv3
# extract all keysets from proofs
keysets = self._get_proofs_keysets(proofs)
# get all mint URLs for all unique keysets from db
mint_urls = await self._get_keyset_urls(keysets)
# append all url-grouped proofs to token
for url, ids in mint_urls.items():
mint_proofs = [p for p in proofs if p.id in ids]
token.token.append(TokenV3Token(mint=url, proofs=mint_proofs))
else:
token_proofs = TokenV3Token(proofs=proofs)
token.token.append(token_proofs)
return token
async def serialize_proofs(
self, proofs: List[Proof], include_mints=True, include_dleq=False, legacy=False
) -> str:
"""Produces sharable token with proofs and mint information.
Args:
proofs (List[Proof]): List of proofs to be included in the token
include_mints (bool, optional): Whether to include the mint URLs in the token. Defaults to True.
legacy (bool, optional): Whether to produce a legacy V2 token. Defaults to False.
Returns:
str: Serialized Cashu token
"""
if legacy:
# V2 tokens
token_v2 = await self._make_token_v2(proofs, include_mints)
return await self._serialize_token_base64_tokenv2(token_v2)
# # deprecated code for V1 tokens
# proofs_serialized = [p.to_dict() for p in proofs]
# return base64.urlsafe_b64encode(
# json.dumps(proofs_serialized).encode()
# ).decode()
# V3 tokens
token = await self._make_token(proofs, include_mints)
return token.serialize(include_dleq)
async def _make_token_v2(self, proofs: List[Proof], include_mints=True) -> TokenV2:
"""
Takes list of proofs and produces a TokenV2 by looking up
the keyset id and mint URLs from the database.
"""
# build token
token = TokenV2(proofs=proofs)
# add mint information to the token, if requested
if include_mints:
# dummy object to hold information about the mint
mints: Dict[str, TokenV2Mint] = {}
# dummy object to hold all keyset id's we need to fetch from the db later
keysets: List[str] = [proof.id for proof in proofs if proof.id]
# iterate through unique keyset ids
for id in set(keysets):
# load the keyset from the db
keysets_db = await get_keysets(id=id, db=self.db)
keyset_db = keysets_db[0] if keysets_db else None
if keyset_db and keyset_db.mint_url and keyset_db.id:
# we group all mints according to URL
if keyset_db.mint_url not in mints:
mints[keyset_db.mint_url] = TokenV2Mint(
url=keyset_db.mint_url,
ids=[keyset_db.id],
)
else:
# if a mint URL has multiple keysets, append to the already existing list
mints[keyset_db.mint_url].ids.append(keyset_db.id)
if len(mints) > 0:
# add mints grouped by url to the token
token.mints = list(mints.values())
return token
async def _serialize_token_base64_tokenv2(self, token: TokenV2) -> str:
"""
Takes a TokenV2 and serializes it in urlsafe_base64.
Args:
token (TokenV2): TokenV2 object to be serialized
Returns:
str: Serialized token
"""
# encode the token as a base64 string
token_base64 = base64.urlsafe_b64encode(
json.dumps(token.to_dict()).encode()
).decode()
return token_base64

View File

@@ -1,7 +1,8 @@
from typing import Protocol from typing import Dict, Protocol
import httpx import httpx
from ..core.base import Unit, WalletKeyset
from ..core.crypto.secp import PrivateKey from ..core.crypto.secp import PrivateKey
from ..core.db import Database from ..core.db import Database
@@ -15,7 +16,9 @@ class SupportsDb(Protocol):
class SupportsKeysets(Protocol): class SupportsKeysets(Protocol):
keysets: Dict[str, WalletKeyset] # holds keysets
keyset_id: str keyset_id: str
unit: Unit
class SupportsHttpxClient(Protocol): class SupportsHttpxClient(Protocol):

View File

@@ -9,6 +9,7 @@ from mnemonic import Mnemonic
from ..core.crypto.secp import PrivateKey from ..core.crypto.secp import PrivateKey
from ..core.db import Database from ..core.db import Database
from ..core.secret import Secret
from ..core.settings import settings from ..core.settings import settings
from ..wallet.crud import ( from ..wallet.crud import (
bump_secret_derivation, bump_secret_derivation,
@@ -93,19 +94,13 @@ class WalletSecrets(SupportsDb, SupportsKeysets):
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
async def _generate_secret(self) -> str: async def _generate_random_secret(self) -> str:
"""Returns base64 encoded deterministic random string. """Returns base64 encoded deterministic random string.
NOTE: This method should probably retire after `deterministic_secrets`. We are NOTE: This method should probably retire after `deterministic_secrets`. We are
deriving secrets from a counter but don't store the respective blinding factor. deriving secrets from a counter but don't store the respective blinding factor.
We won't be able to restore any ecash generated with these secrets. We won't be able to restore any ecash generated with these secrets.
""" """
# secret_counter = await bump_secret_derivation(db=self.db, keyset_id=keyset_id)
# logger.trace(f"secret_counter: {secret_counter}")
# s, _, _ = await self.generate_determinstic_secret(secret_counter, keyset_id)
# # return s.decode("utf-8")
# return hashlib.sha256(s).hexdigest()
# return random 32 byte hex string # return random 32 byte hex string
return hashlib.sha256(os.urandom(32)).hexdigest() return hashlib.sha256(os.urandom(32)).hexdigest()
@@ -209,3 +204,29 @@ class WalletSecrets(SupportsDb, SupportsKeysets):
rs = [PrivateKey(privkey=s[1], raw=True) for s in secrets_rs_derivationpaths] rs = [PrivateKey(privkey=s[1], raw=True) for s in secrets_rs_derivationpaths]
derivation_paths = [s[2] for s in secrets_rs_derivationpaths] derivation_paths = [s[2] for s in secrets_rs_derivationpaths]
return secrets, rs, derivation_paths return secrets, rs, derivation_paths
async def generate_locked_secrets(
self, send_outputs: List[int], keep_outputs: List[int], secret_lock: Secret
) -> Tuple[List[str], List[PrivateKey], List[str]]:
"""Generates secrets and blinding factors for a transaction with `send_outputs` and `keep_outputs`.
Args:
send_outputs (List[int]): List of amounts to send
keep_outputs (List[int]): List of amounts to keep
Returns:
Tuple[List[str], List[PrivateKey], List[str]]: Secrets, blinding factors, derivation paths
"""
rs: List[PrivateKey] = []
# generate secrets for receiver
secret_locks = [secret_lock.serialize() for i in range(len(send_outputs))]
logger.debug(f"Creating proofs with custom secrets: {secret_locks}")
# append predefined secrets (to send) to random secrets (to keep)
# generate secrets to keep
secrets = [
await self._generate_random_secret() for s in range(len(keep_outputs))
] + secret_locks
# TODO: derive derivation paths from secrets
derivation_paths = ["custom"] * len(secrets)
return secrets, rs, derivation_paths

View File

@@ -0,0 +1,212 @@
import math
import uuid
from typing import Dict, List, Tuple, Union
from loguru import logger
from ..core.base import (
Proof,
Unit,
WalletKeyset,
)
from ..core.db import Database
from ..core.helpers import amount_summary, sum_proofs
from ..wallet.crud import (
update_proof,
)
from .protocols import SupportsDb, SupportsKeysets
class WalletTransactions(SupportsDb, SupportsKeysets):
keysets: Dict[str, WalletKeyset] # holds keysets
keyset_id: str
db: Database
unit: Unit
def get_fees_for_keyset(self, amounts: List[int], keyset: WalletKeyset) -> int:
fees = max(math.ceil(sum([keyset.input_fee_ppk for a in amounts]) / 1000), 0)
return fees
def get_fees_for_proofs(self, proofs: List[Proof]) -> int:
# for each proof, find the keyset with the same id and sum the fees
fees = max(
math.ceil(sum([self.keysets[p.id].input_fee_ppk for p in proofs]) / 1000), 0
)
return fees
def get_fees_for_proofs_ppk(self, proofs: List[Proof]) -> int:
return sum([self.keysets[p.id].input_fee_ppk for p in proofs])
async def _select_proofs_to_send_(
self, proofs: List[Proof], amount_to_send: int, tolerance: int = 0
) -> List[Proof]:
send_proofs: List[Proof] = []
NO_SELECTION: List[Proof] = []
logger.trace(f"proofs: {[p.amount for p in proofs]}")
# sort proofs by amount (descending)
sorted_proofs = sorted(proofs, key=lambda p: p.amount, reverse=True)
# only consider proofs smaller than the amount we want to send (+ tolerance) for coin selection
fee_for_single_proof = self.get_fees_for_proofs([sorted_proofs[0]])
sorted_proofs = [
p
for p in sorted_proofs
if p.amount <= amount_to_send + tolerance + fee_for_single_proof
]
if not sorted_proofs:
logger.info(
f"no small-enough proofs to send. Have: {[p.amount for p in proofs]}"
)
return NO_SELECTION
target_amount = amount_to_send
# compose the target amount from the remaining_proofs
logger.debug(f"sorted_proofs: {[p.amount for p in sorted_proofs]}")
for p in sorted_proofs:
# logger.debug(f"send_proofs: {[p.amount for p in send_proofs]}")
# logger.debug(f"target_amount: {target_amount}")
# logger.debug(f"p.amount: {p.amount}")
if sum_proofs(send_proofs) + p.amount <= target_amount + tolerance:
send_proofs.append(p)
target_amount = amount_to_send + self.get_fees_for_proofs(send_proofs)
if sum_proofs(send_proofs) < amount_to_send:
logger.info("could not select proofs to reach target amount (too little).")
return NO_SELECTION
fees = self.get_fees_for_proofs(send_proofs)
logger.debug(f"Selected sum of proofs: {sum_proofs(send_proofs)}, fees: {fees}")
return send_proofs
async def _select_proofs_to_send(
self,
proofs: List[Proof],
amount_to_send: Union[int, float],
*,
include_fees: bool = True,
) -> List[Proof]:
# check that enough spendable proofs exist
if sum_proofs(proofs) < amount_to_send:
return []
logger.trace(
f"_select_proofs_to_send amount_to_send: {amount_to_send}  amounts we have: {amount_summary(proofs, self.unit)} (sum: {sum_proofs(proofs)})"
)
sorted_proofs = sorted(proofs, key=lambda p: p.amount)
next_bigger = next(
(p for p in sorted_proofs if p.amount > amount_to_send), None
)
smaller_proofs = [p for p in sorted_proofs if p.amount <= amount_to_send]
smaller_proofs = sorted(smaller_proofs, key=lambda p: p.amount, reverse=True)
if not smaller_proofs and next_bigger:
logger.trace(
"> no proofs smaller than amount_to_send, adding next bigger proof"
)
return [next_bigger]
if not smaller_proofs and not next_bigger:
logger.trace("> no proofs to select from")
return []
remainder = amount_to_send
selected_proofs = [smaller_proofs[0]]
fee_ppk = self.get_fees_for_proofs_ppk(selected_proofs) if include_fees else 0
logger.debug(f"adding proof: {smaller_proofs[0].amount} fee: {fee_ppk} ppk")
remainder -= smaller_proofs[0].amount - fee_ppk / 1000
logger.debug(f"remainder: {remainder}")
if remainder > 0:
logger.trace(
f"> selecting more proofs from {amount_summary(smaller_proofs[1:], self.unit)} sum: {sum_proofs(smaller_proofs[1:])} to reach {remainder}"
)
selected_proofs += await self._select_proofs_to_send(
smaller_proofs[1:], remainder, include_fees=include_fees
)
sum_selected_proofs = sum_proofs(selected_proofs)
if sum_selected_proofs < amount_to_send and next_bigger:
logger.trace("> adding next bigger proof")
return [next_bigger]
logger.trace(
f"_select_proofs_to_send - selected proof amounts: {amount_summary(selected_proofs, self.unit)} (sum: {sum_proofs(selected_proofs)})"
)
return selected_proofs
async def _select_proofs_to_split(
self, proofs: List[Proof], amount_to_send: int
) -> Tuple[List[Proof], int]:
"""
Selects proofs that can be used with the current mint. Implements a simple coin selection algorithm.
The algorithm has two objectives: Get rid of all tokens from old epochs and include additional proofs from
the current epoch starting from the proofs with the largest amount.
Rules:
1) Proofs that are not marked as reserved
2) Proofs that have a different keyset than the activated keyset_id of the mint
3) Include all proofs that have an older keyset than the current keyset of the mint (to get rid of old epochs).
4) If the target amount is not reached, add proofs of the current keyset until it is.
Args:
proofs (List[Proof]): List of proofs to select from
amount_to_send (int): Amount to select proofs for
Returns:
List[Proof]: List of proofs to send (including fees)
int: Fees for the transaction
Raises:
Exception: If the balance is too low to send the amount
"""
logger.debug(
f"_select_proofs_to_split - amounts we have: {amount_summary(proofs, self.unit)}"
)
send_proofs: List[Proof] = []
# check that enough spendable proofs exist
if sum_proofs(proofs) < amount_to_send:
raise Exception("balance too low.")
# add all proofs that have an older keyset than the current keyset of the mint
proofs_old_epochs = [
p for p in proofs if p.id != self.keysets[self.keyset_id].id
]
send_proofs += proofs_old_epochs
# coinselect based on amount only from the current keyset
# start with the proofs with the largest amount and add them until the target amount is reached
proofs_current_epoch = [
p for p in proofs if p.id == self.keysets[self.keyset_id].id
]
sorted_proofs_of_current_keyset = sorted(
proofs_current_epoch, key=lambda p: p.amount
)
while sum_proofs(send_proofs) < amount_to_send + self.get_fees_for_proofs(
send_proofs
):
proof_to_add = sorted_proofs_of_current_keyset.pop()
send_proofs.append(proof_to_add)
logger.trace(
f"_select_proofs_to_split  selected proof amounts: {[p.amount for p in send_proofs]}"
)
fees = self.get_fees_for_proofs(send_proofs)
return send_proofs, fees
async def set_reserved(self, proofs: List[Proof], reserved: bool) -> None:
"""Mark a proof as reserved or reset it in the wallet db to avoid reuse when it is sent.
Args:
proofs (List[Proof]): List of proofs to mark as reserved
reserved (bool): Whether to mark the proofs as reserved or not
"""
uuid_str = str(uuid.uuid1())
for proof in proofs:
proof.reserved = True
await update_proof(proof, reserved=reserved, send_id=uuid_str, db=self.db)

539
cashu/wallet/v1_api.py Normal file
View File

@@ -0,0 +1,539 @@
import json
import uuid
from posixpath import join
from typing import List, Optional, Tuple, Union
import bolt11
import httpx
from httpx import Response
from loguru import logger
from ..core.base import (
BlindedMessage,
BlindedSignature,
Proof,
ProofState,
SpentState,
Unit,
WalletKeyset,
)
from ..core.crypto.secp import PublicKey
from ..core.db import Database
from ..core.models import (
CheckFeesResponse_deprecated,
GetInfoResponse,
KeysetsResponse,
KeysetsResponseKeyset,
KeysResponse,
PostCheckStateRequest,
PostCheckStateResponse,
PostMeltQuoteRequest,
PostMeltQuoteResponse,
PostMeltRequest,
PostMeltResponse,
PostMeltResponse_deprecated,
PostMintQuoteRequest,
PostMintQuoteResponse,
PostMintRequest,
PostMintResponse,
PostRestoreResponse,
PostSplitRequest,
PostSplitResponse,
)
from ..core.settings import settings
from ..tor.tor import TorProxy
from .crud import (
get_lightning_invoice,
)
from .wallet_deprecated import LedgerAPIDeprecated
def async_set_httpx_client(func):
"""
Decorator that wraps around any async class method of LedgerAPI that makes
API calls. Sets some HTTP headers and starts a Tor instance if none is
already running and and sets local proxy to use it.
"""
async def wrapper(self, *args, **kwargs):
# set proxy
proxies_dict = {}
proxy_url: Union[str, None] = None
if settings.tor and TorProxy().check_platform():
self.tor = TorProxy(timeout=True)
self.tor.run_daemon(verbose=True)
proxy_url = "socks5://localhost:9050"
elif settings.socks_proxy:
proxy_url = f"socks5://{settings.socks_proxy}"
elif settings.http_proxy:
proxy_url = settings.http_proxy
if proxy_url:
proxies_dict.update({"all://": proxy_url})
headers_dict = {"Client-version": settings.version}
self.httpx = httpx.AsyncClient(
verify=not settings.debug,
proxies=proxies_dict, # type: ignore
headers=headers_dict,
base_url=self.url,
timeout=None if settings.debug else 60,
)
return await func(self, *args, **kwargs)
return wrapper
def async_ensure_mint_loaded(func):
"""Decorator that ensures that the mint is loaded before calling the wrapped
function. If the mint is not loaded, it will be loaded first.
"""
async def wrapper(self, *args, **kwargs):
if not self.keysets:
await self.load_mint()
return await func(self, *args, **kwargs)
return wrapper
class LedgerAPI(LedgerAPIDeprecated, object):
tor: TorProxy
db: Database # we need the db for melt_deprecated
httpx: httpx.AsyncClient
def __init__(self, url: str, db: Database):
self.url = url
self.db = db
@async_set_httpx_client
async def _init_s(self):
"""Dummy function that can be called from outside to use LedgerAPI.s"""
return
@staticmethod
def raise_on_error_request(
resp: Response,
) -> None:
"""Raises an exception if the response from the mint contains an error.
Args:
resp_dict (Response): Response dict (previously JSON) from mint
Raises:
Exception: if the response contains an error
"""
try:
resp_dict = resp.json()
except json.JSONDecodeError:
# if we can't decode the response, raise for status
resp.raise_for_status()
return
if "detail" in resp_dict:
logger.trace(f"Error from mint: {resp_dict}")
error_message = f"Mint Error: {resp_dict['detail']}"
if "code" in resp_dict:
error_message += f" (Code: {resp_dict['code']})"
raise Exception(error_message)
# raise for status if no error
resp.raise_for_status()
"""
ENDPOINTS
"""
@async_set_httpx_client
async def _get_keys(self) -> List[WalletKeyset]:
"""API that gets the current keys of the mint
Args:
url (str): Mint URL
Returns:
WalletKeyset: Current mint keyset
Raises:
Exception: If no keys are received from the mint
"""
resp = await self.httpx.get(
join(self.url, "/v1/keys"),
)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
ret = await self._get_keys_deprecated(self.url)
return [ret]
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
keys_dict: dict = resp.json()
assert len(keys_dict), Exception("did not receive any keys")
keys = KeysResponse.parse_obj(keys_dict)
logger.debug(
f"Received {len(keys.keysets)} keysets from mint:"
f" {' '.join([k.id + f' ({k.unit})' for k in keys.keysets])}."
)
ret = [
WalletKeyset(
id=keyset.id,
unit=keyset.unit,
public_keys={
int(amt): PublicKey(bytes.fromhex(val), raw=True)
for amt, val in keyset.keys.items()
},
mint_url=self.url,
)
for keyset in keys.keysets
]
return ret
@async_set_httpx_client
async def _get_keyset(self, keyset_id: str) -> WalletKeyset:
"""API that gets the keys of a specific keyset from the mint.
Args:
keyset_id (str): base64 keyset ID, needs to be urlsafe-encoded before sending to mint (done in this method)
Returns:
WalletKeyset: Keyset with ID keyset_id
Raises:
Exception: If no keys are received from the mint
"""
keyset_id_urlsafe = keyset_id.replace("+", "-").replace("/", "_")
resp = await self.httpx.get(
join(self.url, f"/v1/keys/{keyset_id_urlsafe}"),
)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
ret = await self._get_keyset_deprecated(self.url, keyset_id)
return ret
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
keys_dict = resp.json()
assert len(keys_dict), Exception("did not receive any keys")
keys = KeysResponse.parse_obj(keys_dict)
this_keyset = keys.keysets[0]
keyset_keys = {
int(amt): PublicKey(bytes.fromhex(val), raw=True)
for amt, val in this_keyset.keys.items()
}
keyset = WalletKeyset(
id=keyset_id,
unit=this_keyset.unit,
public_keys=keyset_keys,
mint_url=self.url,
)
return keyset
@async_set_httpx_client
async def _get_keysets(self) -> List[KeysetsResponseKeyset]:
"""API that gets a list of all active keysets of the mint.
Returns:
KeysetsResponse (List[str]): List of all active keyset IDs of the mint
Raises:
Exception: If no keysets are received from the mint
"""
resp = await self.httpx.get(
join(self.url, "/v1/keysets"),
)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
ret = await self._get_keysets_deprecated(self.url)
return ret
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
keysets_dict = resp.json()
keysets = KeysetsResponse.parse_obj(keysets_dict).keysets
if not keysets:
raise Exception("did not receive any keysets")
return keysets
@async_set_httpx_client
async def _get_info(self) -> GetInfoResponse:
"""API that gets the mint info.
Returns:
GetInfoResponse: Current mint info
Raises:
Exception: If the mint info request fails
"""
resp = await self.httpx.get(
join(self.url, "/v1/info"),
)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
ret = await self._get_info_deprecated()
return ret
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
data: dict = resp.json()
mint_info: GetInfoResponse = GetInfoResponse.parse_obj(data)
return mint_info
@async_set_httpx_client
@async_ensure_mint_loaded
async def mint_quote(self, amount: int, unit: Unit) -> PostMintQuoteResponse:
"""Requests a mint quote from the server and returns a payment request.
Args:
amount (int): Amount of tokens to mint
Returns:
PostMintQuoteResponse: Mint Quote Response
Raises:
Exception: If the mint request fails
"""
logger.trace("Requesting mint: GET /v1/mint/bolt11")
payload = PostMintQuoteRequest(unit=unit.name, amount=amount)
resp = await self.httpx.post(
join(self.url, "/v1/mint/quote/bolt11"), json=payload.dict()
)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
ret = await self.request_mint_deprecated(amount)
return ret
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
return_dict = resp.json()
return PostMintQuoteResponse.parse_obj(return_dict)
@async_set_httpx_client
@async_ensure_mint_loaded
async def mint(
self, outputs: List[BlindedMessage], quote: str
) -> List[BlindedSignature]:
"""Mints new coins and returns a proof of promise.
Args:
outputs (List[BlindedMessage]): Outputs to mint new tokens with
quote (str): Quote ID.
Returns:
list[Proof]: List of proofs.
Raises:
Exception: If the minting fails
"""
outputs_payload = PostMintRequest(outputs=outputs, quote=quote)
logger.trace("Checking Lightning invoice. POST /v1/mint/bolt11")
def _mintrequest_include_fields(outputs: List[BlindedMessage]):
"""strips away fields from the model that aren't necessary for the /mint"""
outputs_include = {"id", "amount", "B_"}
return {
"quote": ...,
"outputs": {i: outputs_include for i in range(len(outputs))},
}
payload = outputs_payload.dict(include=_mintrequest_include_fields(outputs)) # type: ignore
resp = await self.httpx.post(
join(self.url, "/v1/mint/bolt11"),
json=payload, # type: ignore
)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
ret = await self.mint_deprecated(outputs, quote)
return ret
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
response_dict = resp.json()
logger.trace("Lightning invoice checked. POST /v1/mint/bolt11")
promises = PostMintResponse.parse_obj(response_dict).signatures
return promises
@async_set_httpx_client
@async_ensure_mint_loaded
async def melt_quote(
self, payment_request: str, unit: Unit, amount: Optional[int] = None
) -> PostMeltQuoteResponse:
"""Checks whether the Lightning payment is internal."""
invoice_obj = bolt11.decode(payment_request)
assert invoice_obj.amount_msat, "invoice must have amount"
payload = PostMeltQuoteRequest(
unit=unit.name, request=payment_request, amount=amount
)
resp = await self.httpx.post(
join(self.url, "/v1/melt/quote/bolt11"),
json=payload.dict(),
)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
ret: CheckFeesResponse_deprecated = await self.check_fees_deprecated(
payment_request
)
quote_id = "deprecated_" + str(uuid.uuid4())
return PostMeltQuoteResponse(
quote=quote_id,
amount=amount or invoice_obj.amount_msat // 1000,
fee_reserve=ret.fee or 0,
paid=False,
expiry=invoice_obj.expiry,
)
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
return_dict = resp.json()
return PostMeltQuoteResponse.parse_obj(return_dict)
@async_set_httpx_client
@async_ensure_mint_loaded
async def melt(
self,
quote: str,
proofs: List[Proof],
outputs: Optional[List[BlindedMessage]],
) -> PostMeltResponse:
"""
Accepts proofs and a lightning invoice to pay in exchange.
"""
payload = PostMeltRequest(quote=quote, inputs=proofs, outputs=outputs)
def _meltrequest_include_fields(
proofs: List[Proof], outputs: List[BlindedMessage]
):
"""strips away fields from the model that aren't necessary for the /melt"""
proofs_include = {"id", "amount", "secret", "C", "witness"}
outputs_include = {"id", "amount", "B_"}
return {
"quote": ...,
"inputs": {i: proofs_include for i in range(len(proofs))},
"outputs": {i: outputs_include for i in range(len(outputs))},
}
resp = await self.httpx.post(
join(self.url, "/v1/melt/bolt11"),
json=payload.dict(include=_meltrequest_include_fields(proofs, outputs)), # type: ignore
timeout=None,
)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
invoice = await get_lightning_invoice(id=quote, db=self.db)
assert invoice, f"no invoice found for id {quote}"
ret: PostMeltResponse_deprecated = await self.melt_deprecated(
proofs=proofs, outputs=outputs, invoice=invoice.bolt11
)
return PostMeltResponse(
paid=ret.paid, payment_preimage=ret.preimage, change=ret.change
)
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
return_dict = resp.json()
return PostMeltResponse.parse_obj(return_dict)
@async_set_httpx_client
@async_ensure_mint_loaded
async def split(
self,
proofs: List[Proof],
outputs: List[BlindedMessage],
) -> List[BlindedSignature]:
"""Consume proofs and create new promises based on amount split."""
logger.debug("Calling split. POST /v1/swap")
split_payload = PostSplitRequest(inputs=proofs, outputs=outputs)
# construct payload
def _splitrequest_include_fields(proofs: List[Proof]):
"""strips away fields from the model that aren't necessary for /v1/swap"""
proofs_include = {
"id",
"amount",
"secret",
"C",
"witness",
}
return {
"outputs": ...,
"inputs": {i: proofs_include for i in range(len(proofs))},
}
resp = await self.httpx.post(
join(self.url, "/v1/swap"),
json=split_payload.dict(include=_splitrequest_include_fields(proofs)), # type: ignore
)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
ret = await self.split_deprecated(proofs, outputs)
return ret
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
promises_dict = resp.json()
mint_response = PostSplitResponse.parse_obj(promises_dict)
promises = [BlindedSignature(**p.dict()) for p in mint_response.signatures]
if len(promises) == 0:
raise Exception("received no splits.")
return promises
@async_set_httpx_client
@async_ensure_mint_loaded
async def check_proof_state(self, proofs: List[Proof]) -> PostCheckStateResponse:
"""
Checks whether the secrets in proofs are already spent or not and returns a list of booleans.
"""
payload = PostCheckStateRequest(Ys=[p.Y for p in proofs])
resp = await self.httpx.post(
join(self.url, "/v1/checkstate"),
json=payload.dict(),
)
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
ret = await self.check_proof_state_deprecated(proofs)
# convert CheckSpendableResponse_deprecated to CheckSpendableResponse
states: List[ProofState] = []
for spendable, pending, p in zip(ret.spendable, ret.pending, proofs):
if spendable and not pending:
states.append(ProofState(Y=p.Y, state=SpentState.unspent))
elif spendable and pending:
states.append(ProofState(Y=p.Y, state=SpentState.pending))
else:
states.append(ProofState(Y=p.Y, state=SpentState.spent))
ret = PostCheckStateResponse(states=states)
return ret
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
return PostCheckStateResponse.parse_obj(resp.json())
@async_set_httpx_client
@async_ensure_mint_loaded
async def restore_promises(
self, outputs: List[BlindedMessage]
) -> Tuple[List[BlindedMessage], List[BlindedSignature]]:
"""
Asks the mint to restore promises corresponding to outputs.
"""
payload = PostMintRequest(quote="restore", outputs=outputs)
resp = await self.httpx.post(join(self.url, "/v1/restore"), json=payload.dict())
# BEGIN backwards compatibility < 0.15.0
# assume the mint has not upgraded yet if we get a 404
if resp.status_code == 404:
ret = await self.restore_promises_deprecated(outputs)
return ret
# END backwards compatibility < 0.15.0
self.raise_on_error_request(resp)
response_dict = resp.json()
returnObj = PostRestoreResponse.parse_obj(response_dict)
# BEGIN backwards compatibility < 0.15.1
# if the mint returns promises, duplicate into signatures
if returnObj.promises:
returnObj.signatures = returnObj.promises
# END backwards compatibility < 0.15.1
return returnObj.outputs, returnObj.signatures

File diff suppressed because it is too large Load Diff

View File

@@ -10,6 +10,11 @@ from ..core.base import (
BlindedMessage, BlindedMessage,
BlindedMessage_Deprecated, BlindedMessage_Deprecated,
BlindedSignature, BlindedSignature,
Proof,
WalletKeyset,
)
from ..core.crypto.secp import PublicKey
from ..core.models import (
CheckFeesRequest_deprecated, CheckFeesRequest_deprecated,
CheckFeesResponse_deprecated, CheckFeesResponse_deprecated,
CheckSpendableRequest_deprecated, CheckSpendableRequest_deprecated,
@@ -18,6 +23,7 @@ from ..core.base import (
GetInfoResponse_deprecated, GetInfoResponse_deprecated,
GetMintResponse_deprecated, GetMintResponse_deprecated,
KeysetsResponse_deprecated, KeysetsResponse_deprecated,
KeysetsResponseKeyset,
PostMeltRequest_deprecated, PostMeltRequest_deprecated,
PostMeltResponse_deprecated, PostMeltResponse_deprecated,
PostMintQuoteResponse, PostMintQuoteResponse,
@@ -26,10 +32,7 @@ from ..core.base import (
PostRestoreResponse, PostRestoreResponse,
PostSplitRequest_Deprecated, PostSplitRequest_Deprecated,
PostSplitResponse_Deprecated, PostSplitResponse_Deprecated,
Proof,
WalletKeyset,
) )
from ..core.crypto.secp import PublicKey
from ..core.settings import settings from ..core.settings import settings
from ..tor.tor import TorProxy from ..tor.tor import TorProxy
from .protocols import SupportsHttpxClient, SupportsMintURL from .protocols import SupportsHttpxClient, SupportsMintURL
@@ -78,7 +81,7 @@ def async_ensure_mint_loaded_deprecated(func):
async def wrapper(self, *args, **kwargs): async def wrapper(self, *args, **kwargs):
if not self.keysets: if not self.keysets:
await self._load_mint() await self.load_mint()
return await func(self, *args, **kwargs) return await func(self, *args, **kwargs)
return wrapper return wrapper
@@ -164,9 +167,7 @@ class LedgerAPIDeprecated(SupportsHttpxClient, SupportsMintURL):
return keyset return keyset
@async_set_httpx_client @async_set_httpx_client
async def _get_keys_of_keyset_deprecated( async def _get_keyset_deprecated(self, url: str, keyset_id: str) -> WalletKeyset:
self, url: str, keyset_id: str
) -> WalletKeyset:
"""API that gets the keys of a specific keyset from the mint. """API that gets the keys of a specific keyset from the mint.
@@ -201,8 +202,7 @@ class LedgerAPIDeprecated(SupportsHttpxClient, SupportsMintURL):
return keyset return keyset
@async_set_httpx_client @async_set_httpx_client
@async_ensure_mint_loaded_deprecated async def _get_keysets_deprecated(self, url: str) -> List[KeysetsResponseKeyset]:
async def _get_keyset_ids_deprecated(self, url: str) -> List[str]:
"""API that gets a list of all active keysets of the mint. """API that gets a list of all active keysets of the mint.
Args: Args:
@@ -222,7 +222,11 @@ class LedgerAPIDeprecated(SupportsHttpxClient, SupportsMintURL):
keysets_dict = resp.json() keysets_dict = resp.json()
keysets = KeysetsResponse_deprecated.parse_obj(keysets_dict) keysets = KeysetsResponse_deprecated.parse_obj(keysets_dict)
assert len(keysets.keysets), Exception("did not receive any keysets") assert len(keysets.keysets), Exception("did not receive any keysets")
return keysets.keysets keysets_new = [
KeysetsResponseKeyset(id=id, unit="sat", active=True)
for id in keysets.keysets
]
return keysets_new
@async_set_httpx_client @async_set_httpx_client
@async_ensure_mint_loaded_deprecated @async_ensure_mint_loaded_deprecated

View File

@@ -45,6 +45,7 @@ settings.mint_private_key = "TEST_PRIVATE_KEY"
settings.mint_seed_decryption_key = "" settings.mint_seed_decryption_key = ""
settings.mint_max_balance = 0 settings.mint_max_balance = 0
settings.mint_lnd_enable_mpp = True settings.mint_lnd_enable_mpp = True
settings.mint_input_fee_ppk = 0
assert "test" in settings.cashu_dir assert "test" in settings.cashu_dir
shutil.rmtree(settings.cashu_dir, ignore_errors=True) shutil.rmtree(settings.cashu_dir, ignore_errors=True)

View File

@@ -2,9 +2,10 @@ from typing import List
import pytest import pytest
from cashu.core.base import BlindedMessage, PostMintQuoteRequest, Proof from cashu.core.base import BlindedMessage, Proof
from cashu.core.crypto.b_dhke import step1_alice from cashu.core.crypto.b_dhke import step1_alice
from cashu.core.helpers import calculate_number_of_blank_outputs from cashu.core.helpers import calculate_number_of_blank_outputs
from cashu.core.models import PostMintQuoteRequest
from cashu.core.settings import settings from cashu.core.settings import settings
from cashu.mint.ledger import Ledger from cashu.mint.ledger import Ledger
from tests.helpers import pay_if_regtest from tests.helpers import pay_if_regtest
@@ -129,9 +130,9 @@ async def test_generate_promises(ledger: Ledger):
async def test_generate_change_promises(ledger: Ledger): async def test_generate_change_promises(ledger: Ledger):
# Example slightly adapted from NUT-08 because we want to ensure the dynamic change # Example slightly adapted from NUT-08 because we want to ensure the dynamic change
# token amount works: `n_blank_outputs != n_returned_promises != 4`. # token amount works: `n_blank_outputs != n_returned_promises != 4`.
invoice_amount = 100_000 # invoice_amount = 100_000
fee_reserve = 2_000 fee_reserve = 2_000
total_provided = invoice_amount + fee_reserve # total_provided = invoice_amount + fee_reserve
actual_fee = 100 actual_fee = 100
expected_returned_promises = 7 # Amounts = [4, 8, 32, 64, 256, 512, 1024] expected_returned_promises = 7 # Amounts = [4, 8, 32, 64, 256, 512, 1024]
@@ -149,7 +150,7 @@ async def test_generate_change_promises(ledger: Ledger):
] ]
promises = await ledger._generate_change_promises( promises = await ledger._generate_change_promises(
total_provided, invoice_amount, actual_fee, outputs fee_provided=fee_reserve, fee_paid=actual_fee, outputs=outputs
) )
assert len(promises) == expected_returned_promises assert len(promises) == expected_returned_promises
@@ -160,9 +161,9 @@ async def test_generate_change_promises(ledger: Ledger):
async def test_generate_change_promises_legacy_wallet(ledger: Ledger): async def test_generate_change_promises_legacy_wallet(ledger: Ledger):
# Check if mint handles a legacy wallet implementation (always sends 4 blank # Check if mint handles a legacy wallet implementation (always sends 4 blank
# outputs) as well. # outputs) as well.
invoice_amount = 100_000 # invoice_amount = 100_000
fee_reserve = 2_000 fee_reserve = 2_000
total_provided = invoice_amount + fee_reserve # total_provided = invoice_amount + fee_reserve
actual_fee = 100 actual_fee = 100
expected_returned_promises = 4 # Amounts = [64, 256, 512, 1024] expected_returned_promises = 4 # Amounts = [64, 256, 512, 1024]
@@ -179,9 +180,7 @@ async def test_generate_change_promises_legacy_wallet(ledger: Ledger):
for b, _ in blinded_msgs for b, _ in blinded_msgs
] ]
promises = await ledger._generate_change_promises( promises = await ledger._generate_change_promises(fee_reserve, actual_fee, outputs)
total_provided, invoice_amount, actual_fee, outputs
)
assert len(promises) == expected_returned_promises assert len(promises) == expected_returned_promises
assert sum([promise.amount for promise in promises]) == expected_returned_fees assert sum([promise.amount for promise in promises]) == expected_returned_fees
@@ -189,14 +188,14 @@ async def test_generate_change_promises_legacy_wallet(ledger: Ledger):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledger): async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledger):
invoice_amount = 100_000 # invoice_amount = 100_000
fee_reserve = 1_000 fee_reserve = 1_000
total_provided = invoice_amount + fee_reserve # total_provided = invoice_amount + fee_reserve
actual_fee_msat = 100_000 actual_fee_msat = 100_000
outputs = None outputs = None
promises = await ledger._generate_change_promises( promises = await ledger._generate_change_promises(
total_provided, invoice_amount, actual_fee_msat, outputs fee_reserve, actual_fee_msat, outputs
) )
assert len(promises) == 0 assert len(promises) == 0

View File

@@ -3,14 +3,14 @@ import httpx
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from cashu.core.base import ( from cashu.core.base import SpentState
from cashu.core.models import (
GetInfoResponse, GetInfoResponse,
MintMeltMethodSetting, MintMeltMethodSetting,
PostCheckStateRequest, PostCheckStateRequest,
PostCheckStateResponse, PostCheckStateResponse,
PostRestoreRequest, PostRestoreRequest,
PostRestoreResponse, PostRestoreResponse,
SpentState,
) )
from cashu.core.settings import settings from cashu.core.settings import settings
from cashu.mint.ledger import Ledger from cashu.mint.ledger import Ledger
@@ -89,6 +89,7 @@ async def test_api_keysets(ledger: Ledger):
"id": "009a1f293253e41e", "id": "009a1f293253e41e",
"unit": "sat", "unit": "sat",
"active": True, "active": True,
"input_fee_ppk": 0,
}, },
] ]
} }

View File

@@ -2,12 +2,12 @@ import httpx
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from cashu.core.base import ( from cashu.core.base import Proof
from cashu.core.models import (
CheckSpendableRequest_deprecated, CheckSpendableRequest_deprecated,
CheckSpendableResponse_deprecated, CheckSpendableResponse_deprecated,
PostRestoreRequest, PostRestoreRequest,
PostRestoreResponse, PostRestoreResponse,
Proof,
) )
from cashu.mint.ledger import Ledger from cashu.mint.ledger import Ledger
from cashu.wallet.crud import bump_secret_derivation from cashu.wallet.crud import bump_secret_derivation

View File

@@ -1,7 +1,7 @@
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from cashu.core.base import PostMeltQuoteRequest from cashu.core.models import PostMeltQuoteRequest
from cashu.mint.ledger import Ledger from cashu.mint.ledger import Ledger
from cashu.wallet.wallet import Wallet from cashu.wallet.wallet import Wallet
from cashu.wallet.wallet import Wallet as Wallet1 from cashu.wallet.wallet import Wallet as Wallet1

241
tests/test_mint_fees.py Normal file
View File

@@ -0,0 +1,241 @@
from typing import Optional
import pytest
import pytest_asyncio
from cashu.core.helpers import sum_proofs
from cashu.core.models import PostMeltQuoteRequest
from cashu.core.split import amount_split
from cashu.mint.ledger import Ledger
from cashu.wallet.wallet import Wallet
from cashu.wallet.wallet import Wallet as Wallet1
from tests.conftest import SERVER_ENDPOINT
from tests.helpers import get_real_invoice, is_fake, is_regtest, pay_if_regtest
async def assert_err(f, msg):
"""Compute f() and expect an error message 'msg'."""
try:
await f
except Exception as exc:
if msg not in str(exc.args[0]):
raise Exception(f"Expected error: {msg}, got: {exc.args[0]}")
return
raise Exception(f"Expected error: {msg}, got no error")
@pytest_asyncio.fixture(scope="function")
async def wallet1(ledger: Ledger):
wallet1 = await Wallet1.with_db(
url=SERVER_ENDPOINT,
db="test_data/wallet1",
name="wallet1",
)
await wallet1.load_mint()
yield wallet1
def set_ledger_keyset_fees(
fee_ppk: int, ledger: Ledger, wallet: Optional[Wallet] = None
):
for keyset in ledger.keysets.values():
keyset.input_fee_ppk = fee_ppk
if wallet:
for wallet_keyset in wallet.keysets.values():
wallet_keyset.input_fee_ppk = fee_ppk
@pytest.mark.asyncio
async def test_get_fees_for_proofs(wallet1: Wallet, ledger: Ledger):
invoice = await wallet1.request_mint(64)
pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, split=[1] * 64, id=invoice.id)
# two proofs
set_ledger_keyset_fees(100, ledger)
proofs = [wallet1.proofs[0], wallet1.proofs[1]]
fees = ledger.get_fees_for_proofs(proofs)
assert fees == 1
set_ledger_keyset_fees(1234, ledger)
fees = ledger.get_fees_for_proofs(proofs)
assert fees == 3
set_ledger_keyset_fees(0, ledger)
fees = ledger.get_fees_for_proofs(proofs)
assert fees == 0
set_ledger_keyset_fees(1, ledger)
fees = ledger.get_fees_for_proofs(proofs)
assert fees == 1
# ten proofs
ten_proofs = wallet1.proofs[:10]
set_ledger_keyset_fees(100, ledger)
fees = ledger.get_fees_for_proofs(ten_proofs)
assert fees == 1
set_ledger_keyset_fees(101, ledger)
fees = ledger.get_fees_for_proofs(ten_proofs)
assert fees == 2
# three proofs
three_proofs = wallet1.proofs[:3]
set_ledger_keyset_fees(333, ledger)
fees = ledger.get_fees_for_proofs(three_proofs)
assert fees == 1
set_ledger_keyset_fees(334, ledger)
fees = ledger.get_fees_for_proofs(three_proofs)
assert fees == 2
@pytest.mark.asyncio
@pytest.mark.skipif_with_fees(is_regtest, reason="only works with FakeWallet")
async def test_wallet_fee(wallet1: Wallet, ledger: Ledger):
# THIS TEST IS A FAKE, WE SET THE WALLET FEES MANUALLY IN set_ledger_keyset_fees
# It would be better to test if the wallet can get the fees from the mint itself
# but the ledger instance does not update the responses from the `mint` that is running in the background
# so we just pretend here and test really nothing...
# set fees to 100 ppk
set_ledger_keyset_fees(100, ledger, wallet1)
# check if all wallet keysets have the correct fees
for keyset in wallet1.keysets.values():
assert keyset.input_fee_ppk == 100
@pytest.mark.asyncio
async def test_split_with_fees(wallet1: Wallet, ledger: Ledger):
# set fees to 100 ppk
set_ledger_keyset_fees(100, ledger)
invoice = await wallet1.request_mint(64)
pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)
send_proofs, _ = await wallet1.select_to_send(wallet1.proofs, 10)
fees = ledger.get_fees_for_proofs(send_proofs)
assert fees == 1
outputs = await wallet1.construct_outputs(amount_split(9))
promises = await ledger.split(proofs=send_proofs, outputs=outputs)
assert len(promises) == len(outputs)
assert [p.amount for p in promises] == [p.amount for p in outputs]
@pytest.mark.asyncio
async def test_split_with_high_fees(wallet1: Wallet, ledger: Ledger):
# set fees to 100 ppk
set_ledger_keyset_fees(1234, ledger)
invoice = await wallet1.request_mint(64)
pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)
send_proofs, _ = await wallet1.select_to_send(wallet1.proofs, 10)
fees = ledger.get_fees_for_proofs(send_proofs)
assert fees == 3
outputs = await wallet1.construct_outputs(amount_split(7))
promises = await ledger.split(proofs=send_proofs, outputs=outputs)
assert len(promises) == len(outputs)
assert [p.amount for p in promises] == [p.amount for p in outputs]
@pytest.mark.asyncio
async def test_split_not_enough_fees(wallet1: Wallet, ledger: Ledger):
# set fees to 100 ppk
set_ledger_keyset_fees(100, ledger)
invoice = await wallet1.request_mint(64)
pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)
send_proofs, _ = await wallet1.select_to_send(wallet1.proofs, 10)
fees = ledger.get_fees_for_proofs(send_proofs)
assert fees == 1
# with 10 sat input, we request 10 sat outputs but fees are 1 sat so the swap will fail
outputs = await wallet1.construct_outputs(amount_split(10))
await assert_err(
ledger.split(proofs=send_proofs, outputs=outputs), "are not balanced"
)
@pytest.mark.asyncio
@pytest.mark.skipif(is_regtest, reason="only works with FakeWallet")
async def test_melt_internal(wallet1: Wallet, ledger: Ledger):
# set fees to 100 ppk
set_ledger_keyset_fees(100, ledger, wallet1)
# mint twice so we have enough to pay the second invoice back
invoice = await wallet1.request_mint(128)
await wallet1.mint(128, id=invoice.id)
assert wallet1.balance == 128
# create a mint quote so that we can melt to it internally
invoice_to_pay = await wallet1.request_mint(64)
invoice_payment_request = invoice_to_pay.bolt11
melt_quote = await ledger.melt_quote(
PostMeltQuoteRequest(request=invoice_payment_request, unit="sat")
)
assert not melt_quote.paid
assert melt_quote.amount == 64
assert melt_quote.fee_reserve == 0
melt_quote_pre_payment = await ledger.get_melt_quote(melt_quote.quote)
assert not melt_quote_pre_payment.paid, "melt quote should not be paid"
# let's first try to melt without enough funds
send_proofs, fees = await wallet1.select_to_send(wallet1.proofs, 63)
# this should fail because we need 64 + 1 sat fees
assert sum_proofs(send_proofs) == 64
await assert_err(
ledger.melt(proofs=send_proofs, quote=melt_quote.quote),
"not enough inputs provided for melt",
)
# the wallet respects the fees for coin selection
send_proofs, fees = await wallet1.select_to_send(wallet1.proofs, 64)
# includes 1 sat fees
assert sum_proofs(send_proofs) == 65
await ledger.melt(proofs=send_proofs, quote=melt_quote.quote)
melt_quote_post_payment = await ledger.get_melt_quote(melt_quote.quote)
assert melt_quote_post_payment.paid, "melt quote should be paid"
@pytest.mark.asyncio
@pytest.mark.skipif(is_fake, reason="only works with Regtest")
async def test_melt_external_with_fees(wallet1: Wallet, ledger: Ledger):
# set fees to 100 ppk
set_ledger_keyset_fees(100, ledger, wallet1)
# mint twice so we have enough to pay the second invoice back
invoice = await wallet1.request_mint(128)
pay_if_regtest(invoice.bolt11)
await wallet1.mint(128, id=invoice.id)
assert wallet1.balance == 128
invoice_dict = get_real_invoice(64)
invoice_payment_request = invoice_dict["payment_request"]
mint_quote = await wallet1.melt_quote(invoice_payment_request)
total_amount = mint_quote.amount + mint_quote.fee_reserve
send_proofs, fee = await wallet1.select_to_send(wallet1.proofs, total_amount)
melt_quote = await ledger.melt_quote(
PostMeltQuoteRequest(request=invoice_payment_request, unit="sat")
)
melt_quote_pre_payment = await ledger.get_melt_quote(melt_quote.quote)
assert not melt_quote_pre_payment.paid, "melt quote should not be paid"
assert not melt_quote.paid, "melt quote should not be paid"
await ledger.melt(proofs=send_proofs, quote=melt_quote.quote)
melt_quote_post_payment = await ledger.get_melt_quote(melt_quote.quote)
assert melt_quote_post_payment.paid, "melt quote should be paid"

View File

@@ -2,7 +2,8 @@ import pytest
import respx import respx
from httpx import Response from httpx import Response
from cashu.core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit from cashu.core.base import Amount, MeltQuote, Unit
from cashu.core.models import PostMeltQuoteRequest
from cashu.core.settings import settings from cashu.core.settings import settings
from cashu.lightning.blink import MINIMUM_FEE_MSAT, BlinkWallet # type: ignore from cashu.lightning.blink import MINIMUM_FEE_MSAT, BlinkWallet # type: ignore

View File

@@ -1,8 +1,8 @@
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from cashu.core.base import PostMeltQuoteRequest, PostMintQuoteRequest
from cashu.core.helpers import sum_proofs from cashu.core.helpers import sum_proofs
from cashu.core.models import PostMeltQuoteRequest, PostMintQuoteRequest
from cashu.mint.ledger import Ledger from cashu.mint.ledger import Ledger
from cashu.wallet.wallet import Wallet from cashu.wallet.wallet import Wallet
from cashu.wallet.wallet import Wallet as Wallet1 from cashu.wallet.wallet import Wallet as Wallet1
@@ -155,6 +155,18 @@ async def test_split(wallet1: Wallet, ledger: Ledger):
assert [p.amount for p in promises] == [p.amount for p in outputs] assert [p.amount for p in promises] == [p.amount for p in outputs]
@pytest.mark.asyncio
async def test_split_with_no_outputs(wallet1: Wallet, ledger: Ledger):
invoice = await wallet1.request_mint(64)
pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)
_, send_proofs = await wallet1.split_to_send(wallet1.proofs, 10, set_reserved=False)
await assert_err(
ledger.split(proofs=send_proofs, outputs=[]),
"no outputs provided",
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_split_with_input_less_than_outputs(wallet1: Wallet, ledger: Ledger): async def test_split_with_input_less_than_outputs(wallet1: Wallet, ledger: Ledger):
invoice = await wallet1.request_mint(64) invoice = await wallet1.request_mint(64)
@@ -165,19 +177,19 @@ async def test_split_with_input_less_than_outputs(wallet1: Wallet, ledger: Ledge
wallet1.proofs, 10, set_reserved=False wallet1.proofs, 10, set_reserved=False
) )
all_send_proofs = send_proofs + keep_proofs too_many_proofs = send_proofs + send_proofs
# generate outputs for all proofs, not only the sent ones # generate more outputs than inputs
secrets, rs, derivation_paths = await wallet1.generate_n_secrets( secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
len(all_send_proofs) len(too_many_proofs)
) )
outputs, rs = wallet1._construct_outputs( outputs, rs = wallet1._construct_outputs(
[p.amount for p in all_send_proofs], secrets, rs [p.amount for p in too_many_proofs], secrets, rs
) )
await assert_err( await assert_err(
ledger.split(proofs=send_proofs, outputs=outputs), ledger.split(proofs=send_proofs, outputs=outputs),
"inputs do not have same amount as outputs.", "are not balanced",
) )
# make sure we can still spend our tokens # make sure we can still spend our tokens
@@ -201,7 +213,7 @@ async def test_split_with_input_more_than_outputs(wallet1: Wallet, ledger: Ledge
await assert_err( await assert_err(
ledger.split(proofs=inputs, outputs=outputs), ledger.split(proofs=inputs, outputs=outputs),
"inputs do not have same amount as outputs", "are not balanced",
) )
# make sure we can still spend our tokens # make sure we can still spend our tokens
@@ -216,6 +228,9 @@ async def test_split_twice_with_same_outputs(wallet1: Wallet, ledger: Ledger):
inputs1 = wallet1.proofs[:1] inputs1 = wallet1.proofs[:1]
inputs2 = wallet1.proofs[1:] inputs2 = wallet1.proofs[1:]
assert inputs1[0].amount == 64
assert inputs2[0].amount == 64
output_amounts = [64] output_amounts = [64]
secrets, rs, derivation_paths = await wallet1.generate_n_secrets( secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
len(output_amounts) len(output_amounts)

View File

@@ -42,14 +42,14 @@ async def assert_err(f, msg: Union[str, CashuError]):
def assert_amt(proofs: List[Proof], expected: int): def assert_amt(proofs: List[Proof], expected: int):
"""Assert amounts the proofs contain.""" """Assert amounts the proofs contain."""
assert [p.amount for p in proofs] == expected assert sum([p.amount for p in proofs]) == expected
async def reset_wallet_db(wallet: Wallet): async def reset_wallet_db(wallet: Wallet):
await wallet.db.execute("DELETE FROM proofs") await wallet.db.execute("DELETE FROM proofs")
await wallet.db.execute("DELETE FROM proofs_used") await wallet.db.execute("DELETE FROM proofs_used")
await wallet.db.execute("DELETE FROM keysets") await wallet.db.execute("DELETE FROM keysets")
await wallet._load_mint() await wallet.load_mint()
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
@@ -97,7 +97,7 @@ async def test_get_keyset(wallet1: Wallet):
# gets the keys of a specific keyset # gets the keys of a specific keyset
assert keyset.id is not None assert keyset.id is not None
assert keyset.public_keys is not None assert keyset.public_keys is not None
keys2 = await wallet1._get_keys_of_keyset(keyset.id) keys2 = await wallet1._get_keyset(keyset.id)
assert keys2.public_keys is not None assert keys2.public_keys is not None
assert len(keyset.public_keys) == len(keys2.public_keys) assert len(keyset.public_keys) == len(keys2.public_keys)
@@ -105,12 +105,12 @@ async def test_get_keyset(wallet1: Wallet):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_keyset_from_db(wallet1: Wallet): async def test_get_keyset_from_db(wallet1: Wallet):
# first load it from the mint # first load it from the mint
# await wallet1._load_mint_keys() # await wallet1.activate_keyset()
# NOTE: conftest already called wallet.load_mint() which got the keys from the mint # NOTE: conftest already called wallet.load_mint() which got the keys from the mint
keyset1 = copy.copy(wallet1.keysets[wallet1.keyset_id]) keyset1 = copy.copy(wallet1.keysets[wallet1.keyset_id])
# then load it from the db # then load it from the db
await wallet1._load_mint_keys() await wallet1.activate_keyset()
keyset2 = copy.copy(wallet1.keysets[wallet1.keyset_id]) keyset2 = copy.copy(wallet1.keysets[wallet1.keyset_id])
assert keyset1.public_keys == keyset2.public_keys assert keyset1.public_keys == keyset2.public_keys
@@ -133,17 +133,17 @@ async def test_get_info(wallet1: Wallet):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_nonexistent_keyset(wallet1: Wallet): async def test_get_nonexistent_keyset(wallet1: Wallet):
await assert_err( await assert_err(
wallet1._get_keys_of_keyset("nonexistent"), wallet1._get_keyset("nonexistent"),
KeysetNotFoundError(), KeysetNotFoundError(),
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_keyset_ids(wallet1: Wallet): async def test_get_keysets(wallet1: Wallet):
keysets = await wallet1._get_keyset_ids() keysets = await wallet1._get_keysets()
assert isinstance(keysets, list) assert isinstance(keysets, list)
assert len(keysets) > 0 assert len(keysets) > 0
assert wallet1.keyset_id in keysets assert wallet1.keyset_id in [k.id for k in keysets]
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -156,6 +156,7 @@ async def test_request_mint(wallet1: Wallet):
async def test_mint(wallet1: Wallet): async def test_mint(wallet1: Wallet):
invoice = await wallet1.request_mint(64) invoice = await wallet1.request_mint(64)
pay_if_regtest(invoice.bolt11) pay_if_regtest(invoice.bolt11)
expected_proof_amounts = wallet1.split_wallet_state(64)
await wallet1.mint(64, id=invoice.id) await wallet1.mint(64, id=invoice.id)
assert wallet1.balance == 64 assert wallet1.balance == 64
@@ -168,7 +169,8 @@ async def test_mint(wallet1: Wallet):
proofs_minted = await get_proofs( proofs_minted = await get_proofs(
db=wallet1.db, mint_id=invoice_db.id, table="proofs" db=wallet1.db, mint_id=invoice_db.id, table="proofs"
) )
assert len(proofs_minted) == 1 assert len(proofs_minted) == len(expected_proof_amounts)
assert all([p.amount in expected_proof_amounts for p in proofs_minted])
assert all([p.mint_id == invoice.id for p in proofs_minted]) assert all([p.mint_id == invoice.id for p in proofs_minted])
@@ -212,11 +214,15 @@ async def test_split(wallet1: Wallet):
pay_if_regtest(invoice.bolt11) pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id) await wallet1.mint(64, id=invoice.id)
assert wallet1.balance == 64 assert wallet1.balance == 64
# the outputs we keep that we expect after the split
expected_proof_amounts = wallet1.split_wallet_state(44)
p1, p2 = await wallet1.split(wallet1.proofs, 20) p1, p2 = await wallet1.split(wallet1.proofs, 20)
assert wallet1.balance == 64 assert wallet1.balance == 64
assert sum_proofs(p1) == 44 assert sum_proofs(p1) == 44
assert [p.amount for p in p1] == [4, 8, 32] # what we keep should have the expected amounts
assert [p.amount for p in p1] == expected_proof_amounts
assert sum_proofs(p2) == 20 assert sum_proofs(p2) == 20
# what we send should be the optimal split
assert [p.amount for p in p2] == [4, 16] assert [p.amount for p in p2] == [4, 16]
assert all([p.id == wallet1.keyset_id for p in p1]) assert all([p.id == wallet1.keyset_id for p in p1])
assert all([p.id == wallet1.keyset_id for p in p2]) assert all([p.id == wallet1.keyset_id for p in p2])
@@ -227,13 +233,19 @@ async def test_split_to_send(wallet1: Wallet):
invoice = await wallet1.request_mint(64) invoice = await wallet1.request_mint(64)
pay_if_regtest(invoice.bolt11) pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id) await wallet1.mint(64, id=invoice.id)
keep_proofs, spendable_proofs = await wallet1.split_to_send( assert wallet1.balance == 64
# this will select 32 sats and them (nothing to keep)
keep_proofs, send_proofs = await wallet1.split_to_send(
wallet1.proofs, 32, set_reserved=True wallet1.proofs, 32, set_reserved=True
) )
get_spendable = await wallet1._select_proofs_to_send(wallet1.proofs, 32) assert_amt(send_proofs, 32)
assert keep_proofs == get_spendable assert_amt(keep_proofs, 0)
spendable_proofs = await wallet1._select_proofs_to_send(wallet1.proofs, 32)
assert sum_proofs(spendable_proofs) == 32 assert sum_proofs(spendable_proofs) == 32
assert sum_proofs(send_proofs) == 32
assert wallet1.balance == 64 assert wallet1.balance == 64
assert wallet1.available_balance == 32 assert wallet1.available_balance == 32
@@ -271,7 +283,7 @@ async def test_melt(wallet1: Wallet):
invoice_payment_hash = str(invoice.payment_hash) invoice_payment_hash = str(invoice.payment_hash)
invoice_payment_request = invoice.bolt11 invoice_payment_request = invoice.bolt11
quote = await wallet1.request_melt(invoice_payment_request) quote = await wallet1.melt_quote(invoice_payment_request)
total_amount = quote.amount + quote.fee_reserve total_amount = quote.amount + quote.fee_reserve
if is_regtest: if is_regtest:
@@ -421,7 +433,7 @@ async def test_split_invalid_amount(wallet1: Wallet):
await wallet1.mint(64, id=invoice.id) await wallet1.mint(64, id=invoice.id)
await assert_err( await assert_err(
wallet1.split(wallet1.proofs, -1), wallet1.split(wallet1.proofs, -1),
"amount must be positive.", "amount can't be negative",
) )
@@ -436,13 +448,13 @@ async def test_token_state(wallet1: Wallet):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_mint_keys_specific_keyset(wallet1: Wallet): async def testactivate_keyset_specific_keyset(wallet1: Wallet):
await wallet1._load_mint_keys() await wallet1.activate_keyset()
assert list(wallet1.keysets.keys()) == ["009a1f293253e41e"] assert list(wallet1.keysets.keys()) == ["009a1f293253e41e"]
await wallet1._load_mint_keys(keyset_id=wallet1.keyset_id) await wallet1.activate_keyset(keyset_id=wallet1.keyset_id)
await wallet1._load_mint_keys(keyset_id="009a1f293253e41e") await wallet1.activate_keyset(keyset_id="009a1f293253e41e")
# expect deprecated keyset id to be present # expect deprecated keyset id to be present
await assert_err( await assert_err(
wallet1._load_mint_keys(keyset_id="nonexistent"), wallet1.activate_keyset(keyset_id="nonexistent"),
KeysetNotFoundError(), KeysetNotFoundError("nonexistent"),
) )

View File

@@ -65,16 +65,16 @@ async def test_send(wallet: Wallet):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_without_split(wallet: Wallet): async def test_send_without_split(wallet: Wallet):
with TestClient(app) as client: with TestClient(app) as client:
response = client.post("/send?amount=2&nosplit=true") response = client.post("/send?amount=2&offline=true")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["balance"] assert response.json()["balance"]
@pytest.mark.skipif(is_regtest, reason="regtest") @pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_without_split_but_wrong_amount(wallet: Wallet): async def test_send_too_much(wallet: Wallet):
with TestClient(app) as client: with TestClient(app) as client:
response = client.post("/send?amount=10&nosplit=true") response = client.post("/send?amount=110000")
assert response.status_code == 400 assert response.status_code == 400

View File

@@ -175,6 +175,7 @@ def test_invoice_with_split(mint, cli_prefix):
wallet = asyncio.run(init_wallet()) wallet = asyncio.run(init_wallet())
assert wallet.proof_amounts.count(1) >= 10 assert wallet.proof_amounts.count(1) >= 10
@pytest.mark.skipif(not is_fake, reason="only on fakewallet") @pytest.mark.skipif(not is_fake, reason="only on fakewallet")
def test_invoices_with_minting(cli_prefix): def test_invoices_with_minting(cli_prefix):
# arrange # arrange
@@ -223,6 +224,7 @@ def test_invoices_without_minting(cli_prefix):
assert get_invoice_from_invoices_command(result.output)["ID"] == invoice.id assert get_invoice_from_invoices_command(result.output)["ID"] == invoice.id
assert get_invoice_from_invoices_command(result.output)["Paid"] == str(invoice.paid) assert get_invoice_from_invoices_command(result.output)["Paid"] == str(invoice.paid)
@pytest.mark.skipif(not is_fake, reason="only on fakewallet") @pytest.mark.skipif(not is_fake, reason="only on fakewallet")
def test_invoices_with_onlypaid_option(cli_prefix): def test_invoices_with_onlypaid_option(cli_prefix):
# arrange # arrange
@@ -263,6 +265,7 @@ def test_invoices_with_onlypaid_option_without_minting(cli_prefix):
assert result.exit_code == 0 assert result.exit_code == 0
assert "No invoices found." in result.output assert "No invoices found." in result.output
@pytest.mark.skipif(not is_fake, reason="only on fakewallet") @pytest.mark.skipif(not is_fake, reason="only on fakewallet")
def test_invoices_with_onlyunpaid_option(cli_prefix): def test_invoices_with_onlyunpaid_option(cli_prefix):
# arrange # arrange
@@ -322,6 +325,7 @@ def test_invoices_with_both_onlypaid_and_onlyunpaid_options(cli_prefix):
in result.output in result.output
) )
@pytest.mark.skipif(not is_fake, reason="only on fakewallet") @pytest.mark.skipif(not is_fake, reason="only on fakewallet")
def test_invoices_with_pending_option(cli_prefix): def test_invoices_with_pending_option(cli_prefix):
# arrange # arrange
@@ -422,11 +426,11 @@ def test_send_legacy(mint, cli_prefix):
assert token_str.startswith("eyJwcm9v"), "output is not as expected" assert token_str.startswith("eyJwcm9v"), "output is not as expected"
def test_send_without_split(mint, cli_prefix): def test_send_offline(mint, cli_prefix):
runner = CliRunner() runner = CliRunner()
result = runner.invoke( result = runner.invoke(
cli, cli,
[*cli_prefix, "send", "2", "--nosplit"], [*cli_prefix, "send", "2", "--offline"],
) )
assert result.exception is None assert result.exception is None
print("SEND") print("SEND")
@@ -434,13 +438,13 @@ def test_send_without_split(mint, cli_prefix):
assert "cashuA" in result.output, "output does not have a token" assert "cashuA" in result.output, "output does not have a token"
def test_send_without_split_but_wrong_amount(mint, cli_prefix): def test_send_too_much(mint, cli_prefix):
runner = CliRunner() runner = CliRunner()
result = runner.invoke( result = runner.invoke(
cli, cli,
[*cli_prefix, "send", "10", "--nosplit"], [*cli_prefix, "send", "100000"],
) )
assert "No proof with this amount found" in str(result.exception) assert "balance too low" in str(result.exception)
def test_receive_tokenv3(mint, cli_prefix): def test_receive_tokenv3(mint, cli_prefix):

View File

@@ -37,7 +37,7 @@ async def reset_wallet_db(wallet: LightningWallet):
await wallet.db.execute("DELETE FROM proofs") await wallet.db.execute("DELETE FROM proofs")
await wallet.db.execute("DELETE FROM proofs_used") await wallet.db.execute("DELETE FROM proofs_used")
await wallet.db.execute("DELETE FROM keysets") await wallet.db.execute("DELETE FROM keysets")
await wallet._load_mint() await wallet.load_mint()
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")

View File

@@ -42,7 +42,7 @@ async def reset_wallet_db(wallet: Wallet):
await wallet.db.execute("DELETE FROM proofs") await wallet.db.execute("DELETE FROM proofs")
await wallet.db.execute("DELETE FROM proofs_used") await wallet.db.execute("DELETE FROM proofs_used")
await wallet.db.execute("DELETE FROM keysets") await wallet.db.execute("DELETE FROM keysets")
await wallet._load_mint() await wallet.load_mint()
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
@@ -206,7 +206,7 @@ async def test_restore_wallet_after_split_to_send(wallet3: Wallet):
wallet3.proofs = [] wallet3.proofs = []
assert wallet3.balance == 0 assert wallet3.balance == 0
await wallet3.restore_promises_from_to(0, 100) await wallet3.restore_promises_from_to(0, 100)
assert wallet3.balance == 64 * 2 assert wallet3.balance == 96
await wallet3.invalidate(wallet3.proofs, check_spendable=True) await wallet3.invalidate(wallet3.proofs, check_spendable=True)
assert wallet3.balance == 64 assert wallet3.balance == 64
@@ -233,7 +233,7 @@ async def test_restore_wallet_after_send_and_receive(wallet3: Wallet, wallet2: W
assert wallet3.proofs == [] assert wallet3.proofs == []
assert wallet3.balance == 0 assert wallet3.balance == 0
await wallet3.restore_promises_from_to(0, 100) await wallet3.restore_promises_from_to(0, 100)
assert wallet3.balance == 64 + 2 * 32 assert wallet3.balance == 96
await wallet3.invalidate(wallet3.proofs, check_spendable=True) await wallet3.invalidate(wallet3.proofs, check_spendable=True)
assert wallet3.balance == 32 assert wallet3.balance == 32
@@ -276,7 +276,7 @@ async def test_restore_wallet_after_send_and_self_receive(wallet3: Wallet):
assert wallet3.proofs == [] assert wallet3.proofs == []
assert wallet3.balance == 0 assert wallet3.balance == 0
await wallet3.restore_promises_from_to(0, 100) await wallet3.restore_promises_from_to(0, 100)
assert wallet3.balance == 64 + 2 * 32 + 32 assert wallet3.balance == 128
await wallet3.invalidate(wallet3.proofs, check_spendable=True) await wallet3.invalidate(wallet3.proofs, check_spendable=True)
assert wallet3.balance == 64 assert wallet3.balance == 64
@@ -311,7 +311,7 @@ async def test_restore_wallet_after_send_twice(
assert wallet3.balance == 0 assert wallet3.balance == 0
await wallet3.restore_promises_from_to(0, 10) await wallet3.restore_promises_from_to(0, 10)
box.add(wallet3.proofs) box.add(wallet3.proofs)
assert wallet3.balance == 5 assert wallet3.balance == 4
await wallet3.invalidate(wallet3.proofs, check_spendable=True) await wallet3.invalidate(wallet3.proofs, check_spendable=True)
assert wallet3.balance == 2 assert wallet3.balance == 2
@@ -333,7 +333,7 @@ async def test_restore_wallet_after_send_twice(
assert wallet3.balance == 0 assert wallet3.balance == 0
await wallet3.restore_promises_from_to(0, 15) await wallet3.restore_promises_from_to(0, 15)
box.add(wallet3.proofs) box.add(wallet3.proofs)
assert wallet3.balance == 7 assert wallet3.balance == 6
await wallet3.invalidate(wallet3.proofs, check_spendable=True) await wallet3.invalidate(wallet3.proofs, check_spendable=True)
assert wallet3.balance == 2 assert wallet3.balance == 2
@@ -370,7 +370,7 @@ async def test_restore_wallet_after_send_and_self_receive_nonquadratic_value(
assert wallet3.balance == 0 assert wallet3.balance == 0
await wallet3.restore_promises_from_to(0, 20) await wallet3.restore_promises_from_to(0, 20)
box.add(wallet3.proofs) box.add(wallet3.proofs)
assert wallet3.balance == 138 assert wallet3.balance == 84
await wallet3.invalidate(wallet3.proofs, check_spendable=True) await wallet3.invalidate(wallet3.proofs, check_spendable=True)
assert wallet3.balance == 64 assert wallet3.balance == 64
@@ -389,6 +389,6 @@ async def test_restore_wallet_after_send_and_self_receive_nonquadratic_value(
assert wallet3.proofs == [] assert wallet3.proofs == []
assert wallet3.balance == 0 assert wallet3.balance == 0
await wallet3.restore_promises_from_to(0, 50) await wallet3.restore_promises_from_to(0, 50)
assert wallet3.balance == 182 assert wallet3.balance == 108
await wallet3.invalidate(wallet3.proofs, check_spendable=True) await wallet3.invalidate(wallet3.proofs, check_spendable=True)
assert wallet3.balance == 64 assert wallet3.balance == 64