diff --git a/cashu/core/base.py b/cashu/core/base.py index e530065..c54a610 100644 --- a/cashu/core/base.py +++ b/cashu/core/base.py @@ -340,9 +340,7 @@ class MeltQuote(LedgerEvent): ) @classmethod - def from_resp_wallet( - cls, melt_quote_resp, mint: str, amount: int, unit: str, request: str - ): + def from_resp_wallet(cls, melt_quote_resp, mint: str, unit: str, request: str): # BEGIN: BACKWARDS COMPATIBILITY < 0.16.0: "paid" field to "state" if melt_quote_resp.state is None: if melt_quote_resp.paid is True: @@ -353,10 +351,12 @@ class MeltQuote(LedgerEvent): return cls( quote=melt_quote_resp.quote, method="bolt11", - request=request, + request=melt_quote_resp.request + or request, # BACKWARDS COMPATIBILITY mint response < 0.16.6 checking_id="", - unit=unit, - amount=amount, + unit=melt_quote_resp.unit + or unit, # BACKWARDS COMPATIBILITY mint response < 0.16.6 + amount=melt_quote_resp.amount, fee_reserve=melt_quote_resp.fee_reserve, state=MeltQuoteState(melt_quote_resp.state), mint=mint, @@ -468,8 +468,10 @@ class MintQuote(LedgerEvent): method="bolt11", request=mint_quote_resp.request, checking_id="", - unit=unit, - amount=amount, + unit=mint_quote_resp.unit + or unit, # BACKWARDS COMPATIBILITY mint response < 0.16.6 + amount=mint_quote_resp.amount + or amount, # BACKWARDS COMPATIBILITY mint response < 0.16.6 state=MintQuoteState(mint_quote_resp.state), mint=mint, expiry=mint_quote_resp.expiry, diff --git a/cashu/core/models.py b/cashu/core/models.py index bc6636b..70ca8d2 100644 --- a/cashu/core/models.py +++ b/cashu/core/models.py @@ -142,8 +142,12 @@ class PostMintQuoteRequest(BaseModel): class PostMintQuoteResponse(BaseModel): quote: str # quote id request: str # input payment request - amount: Optional[int] # output amount (optional for backwards compat pre 0.16.6) - unit: Optional[str] # output unit (optional for backwards compat pre 0.16.6) + amount: Optional[ + int + ] # output amount (optional for BACKWARDS COMPAT mint response < 0.16.6) + unit: Optional[ + str + ] # output unit (optional for BACKWARDS COMPAT mint response < 0.16.6) state: Optional[str] # state of the quote (optional for backwards compat) expiry: Optional[int] # expiry of the quote pubkey: Optional[str] = None # NUT-20 quote lock pubkey @@ -225,10 +229,12 @@ class PostMeltQuoteRequest(BaseModel): class PostMeltQuoteResponse(BaseModel): quote: str # quote id amount: int # input amount - unit: Optional[str] # input unit (optional for backwards compat pre 0.16.6) + unit: Optional[ + str + ] # input unit (optional for BACKWARDS COMPAT mint response < 0.16.6) request: Optional[ str - ] # output payment request (optional for backwards compat pre 0.16.6) + ] # output payment request (optional for BACKWARDS COMPAT mint response < 0.16.6) fee_reserve: int # input fee reserve paid: Optional[bool] = ( None # whether the request has been paid # DEPRECATED as per NUT PR #136 diff --git a/cashu/core/settings.py b/cashu/core/settings.py index ece06e1..311780c 100644 --- a/cashu/core/settings.py +++ b/cashu/core/settings.py @@ -199,7 +199,7 @@ class WalletSettings(CashuSettings): ) locktime_delta_seconds: int = Field(default=86400) # 1 day - proofs_batch_size: int = Field(default=1000) + proofs_batch_size: int = Field(default=200) wallet_target_amount_count: int = Field(default=3) diff --git a/cashu/wallet/cli/cli.py b/cashu/wallet/cli/cli.py index 4b81fbe..a75ed6c 100644 --- a/cashu/wallet/cli/cli.py +++ b/cashu/wallet/cli/cli.py @@ -456,8 +456,8 @@ async def invoice( while time.time() < check_until and not paid: await asyncio.sleep(5) try: - mint_quote_resp = await wallet.get_mint_quote(mint_quote.quote) - if mint_quote_resp.state == MintQuoteState.paid.value: + mint_quote = await wallet.get_mint_quote(mint_quote.quote) + if mint_quote.state == MintQuoteState.paid: await wallet.mint( amount, split=optional_split, @@ -527,8 +527,8 @@ async def swap(ctx: Context): mint_quote = await incoming_wallet.request_mint(amount) # pay invoice from outgoing mint - melt_quote_resp = await outgoing_wallet.melt_quote(mint_quote.request) - total_amount = melt_quote_resp.amount + melt_quote_resp.fee_reserve + melt_quote = await outgoing_wallet.melt_quote(mint_quote.request) + total_amount = melt_quote.amount + melt_quote.fee_reserve if outgoing_wallet.available_balance < total_amount: raise Exception("balance too low") send_proofs, fees = await outgoing_wallet.select_to_send( @@ -537,8 +537,8 @@ async def swap(ctx: Context): await outgoing_wallet.melt( send_proofs, mint_quote.request, - melt_quote_resp.fee_reserve, - melt_quote_resp.quote, + melt_quote.fee_reserve, + melt_quote.quote, ) # mint token in incoming mint diff --git a/cashu/wallet/crud.py b/cashu/wallet/crud.py index e52941f..bc34644 100644 --- a/cashu/wallet/crud.py +++ b/cashu/wallet/crud.py @@ -14,6 +14,13 @@ from ..core.base import ( from ..core.db import Connection, Database +class _UnsetType: + pass + + +_UNSET = _UnsetType() + + async def store_proof( proof: Proof, db: Database, @@ -121,31 +128,31 @@ async def invalidate_proof( async def update_proof( proof: Proof, *, - reserved: Optional[bool] = None, - send_id: Optional[str] = None, - mint_id: Optional[str] = None, - melt_id: Optional[str] = None, + reserved: bool | _UnsetType = _UNSET, + send_id: str | None | _UnsetType = _UNSET, + mint_id: str | None | _UnsetType = _UNSET, + melt_id: str | None | _UnsetType = _UNSET, db: Optional[Database] = None, conn: Optional[Connection] = None, ) -> None: clauses = [] values: Dict[str, Any] = {} - if reserved is not None: + if reserved is not _UNSET: clauses.append("reserved = :reserved") values["reserved"] = reserved clauses.append("time_reserved = :time_reserved") values["time_reserved"] = int(time.time()) - if send_id is not None: + if send_id is not _UNSET: clauses.append("send_id = :send_id") values["send_id"] = send_id - if mint_id is not None: + if mint_id is not _UNSET: clauses.append("mint_id = :mint_id") values["mint_id"] = mint_id - if melt_id is not None: + if melt_id is not _UNSET: clauses.append("melt_id = :melt_id") values["melt_id"] = melt_id diff --git a/cashu/wallet/helpers.py b/cashu/wallet/helpers.py index a574b91..6b7f1b9 100644 --- a/cashu/wallet/helpers.py +++ b/cashu/wallet/helpers.py @@ -172,7 +172,7 @@ async def send( print(token) - await wallet.set_reserved(send_proofs, reserved=True) + await wallet.set_reserved_for_send(send_proofs, reserved=True) return wallet.available_balance, token diff --git a/cashu/wallet/transactions.py b/cashu/wallet/transactions.py index 3cde2c4..350ef6b 100644 --- a/cashu/wallet/transactions.py +++ b/cashu/wallet/transactions.py @@ -223,7 +223,23 @@ class WalletTransactions(SupportsDb, SupportsKeysets): return keep_amounts, send_amounts - async def set_reserved(self, proofs: List[Proof], reserved: bool) -> None: + async def set_reserved_for_melt( + self, proofs: List[Proof], reserved: bool, quote_id: str | None + ): + """Sets the proofs as pending for a melt operation (reserved=True) or as not pending (reserved=False) anymore. + + Args: + proofs (List[Proof]): _description_ + reserved (bool): _description_ + quote_id (str | None): _description_ + """ + async with self.db.connect() as conn: + for p in proofs: + p.melt_id = quote_id + p.reserved = reserved + await update_proof(p, reserved=reserved, melt_id=quote_id, conn=conn) + + async def set_reserved_for_send(self, proofs: List[Proof], reserved: bool) -> None: """Mark a proof as reserved or reset it in the wallet db to avoid reuse when it is sent. Args: diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index df96156..d3bcfff 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -44,6 +44,7 @@ from . import migrations from .compat import WalletCompat from .crud import ( bump_secret_derivation, + get_bolt11_melt_quote, get_bolt11_mint_quote, get_keysets, get_mint_by_url, @@ -519,6 +520,42 @@ class Wallet( await store_bolt11_mint_quote(db=self.db, quote=quote) return quote + async def get_mint_quote( + self, + quote_id: str, + ) -> MintQuote: + """Get a mint quote from mint. + + Args: + quote_id (str): Id of the mint quote. + + Returns: + MintQuote: Mint quote. + """ + mint_quote_response = await super().get_mint_quote(quote_id) + mint_quote_local = await get_bolt11_mint_quote(db=self.db, quote=quote_id) + mint_quote = MintQuote.from_resp_wallet( + mint_quote_response, + mint=self.url, + amount=( + mint_quote_response.amount or mint_quote_local.amount + if mint_quote_local + else 0 # BACKWARD COMPATIBILITY mint response < 0.16.6 + ), + unit=( + mint_quote_response.unit or mint_quote_local.unit + if mint_quote_local + else self.unit.name # BACKWARD COMPATIBILITY mint response < 0.16.6 + ), + ) + if mint_quote_local and mint_quote_local.privkey: + mint_quote.privkey = mint_quote_local.privkey + + if not mint_quote_local: + await store_bolt11_mint_quote(db=self.db, quote=mint_quote) + + return mint_quote + async def mint( self, amount: int, @@ -701,7 +738,7 @@ class Wallet( async def melt_quote( self, invoice: str, amount_msat: Optional[int] = None - ) -> PostMeltQuoteResponse: + ) -> MeltQuote: """ Fetches a melt quote from the mint and either uses the amount in the invoice or the amount provided. """ @@ -714,12 +751,77 @@ class Wallet( melt_quote = MeltQuote.from_resp_wallet( melt_quote_resp, self.url, - amount=melt_quote_resp.amount, unit=self.unit.name, request=invoice, ) await store_bolt11_melt_quote(db=self.db, quote=melt_quote) - return melt_quote_resp + melt_quote = MeltQuote.from_resp_wallet( + melt_quote_resp, + self.url, + unit=melt_quote_resp.unit + or self.unit.name, # BACKWARD COMPATIBILITY mint response < 0.16.6 + request=melt_quote_resp.request + or invoice, # BACKWARD COMPATIBILITY mint response < 0.16.6 + ) + return melt_quote + + async def get_melt_quote(self, quote: str) -> Optional[MeltQuote]: + """Fetches a melt quote from the mint and updates proofs in the database. + + Args: + quote (str): Quote ID to fetch. + + Returns: + Optional[MeltQuote]: MeltQuote object. + """ + melt_quote_resp = await super().get_melt_quote(quote) + melt_quote_local = await get_bolt11_melt_quote(db=self.db, quote=quote) + melt_quote = MeltQuote.from_resp_wallet( + melt_quote_resp, + self.url, + unit=( + melt_quote_resp.unit or melt_quote_local.unit + if melt_quote_local + else self.unit.name # BACKWARD COMPATIBILITY mint response < 0.16.6 + ), + request=( + melt_quote_resp.request or melt_quote_local.request + if (melt_quote_local and melt_quote_local.request) + else "None" # BACKWARD COMPATIBILITY mint response < 0.16.6 + ), + ) + + # update database + if not melt_quote_local: + await store_bolt11_melt_quote(db=self.db, quote=melt_quote) + else: + proofs = await get_proofs(db=self.db, melt_id=quote) + if ( + melt_quote.state == MeltQuoteState.paid + and melt_quote_local.state != MeltQuoteState.paid + ): + logger.debug("Updating paid status of melt quote.") + await update_bolt11_melt_quote( + db=self.db, + quote=quote, + state=melt_quote.state, + paid_time=int(time.time()), + payment_preimage=melt_quote.payment_preimage or "", + fee_paid=melt_quote.fee_paid, + ) + # invalidate proofs + if sum_proofs(proofs) == melt_quote.amount + melt_quote.fee_reserve: + await self.invalidate(proofs) + + if melt_quote.change: + logger.warning( + "Melt quote contains change but change is not supported yet." + ) + + if melt_quote.state == MeltQuoteState.unpaid: + logger.debug("Updating unpaid status of melt quote.") + await self.set_reserved_for_melt(proofs, reserved=False, quote_id=None) + return melt_quote async def melt( self, proofs: List[Proof], invoice: str, fee_reserve_sat: int, quote_id: str @@ -732,9 +834,9 @@ class Wallet( fee_reserve_sat (int): Amount of fees to be reserved for the payment. """ + # Make sure we're operating on an independent copy of proofs proofs = copy.copy(proofs) - amount = sum_proofs(proofs) # Generate a number of blank outputs for any overpaid fees. As described in # NUT-08, the mint will imprint these outputs with a value depending on the @@ -749,29 +851,26 @@ class Wallet( n_change_outputs * [1], change_secrets, change_rs ) + await self.set_reserved_for_melt(proofs, reserved=True, quote_id=quote_id) proofs = self.sign_proofs_inplace_melt(proofs, change_outputs, quote_id) + try: + melt_quote_resp = await super().melt(quote_id, proofs, change_outputs) + except Exception as e: + logger.debug(f"Mint error: {e}") + # remove the melt_id in proofs and set reserved to False + await self.set_reserved_for_melt(proofs, reserved=False, quote_id=None) + raise Exception(f"could not pay invoice: {e}") - # store the melt_id in proofs db - async with self.db.connect() as conn: - for p in proofs: - p.melt_id = quote_id - await update_proof(p, melt_id=quote_id, conn=conn) - - melt_quote_resp = await super().melt(quote_id, proofs, change_outputs) melt_quote = MeltQuote.from_resp_wallet( melt_quote_resp, self.url, - amount=amount, unit=self.unit.name, request=invoice, ) # if payment fails if melt_quote.state == MeltQuoteState.unpaid: # remove the melt_id in proofs and set reserved to False - for p in proofs: - p.melt_id = None - p.reserved = False - await update_proof(p, melt_id="", db=self.db) + await self.set_reserved_for_melt(proofs, reserved=False, quote_id=None) raise Exception("could not pay invoice.") elif melt_quote.state == MeltQuoteState.pending: # payment is still pending @@ -996,6 +1095,27 @@ class Wallet( logger.error(proofs) raise e + async def get_spent_proofs_check_states_batched( + self, proofs: List[Proof] + ) -> List[Proof]: + """Checks the state of proofs in batches. + + Args: + proofs (List[Proof]): List of proofs to check. + + Returns: + List[Proof]: List of proofs that are spent. + """ + batch_size = settings.proofs_batch_size + spent_proofs = [] + for i in range(0, len(proofs), batch_size): + batch = proofs[i : i + batch_size] + proof_states = await self.check_proof_state(batch) + for j, state in enumerate(proof_states.states): + if state.spent: + spent_proofs.append(batch[j]) + return spent_proofs + async def invalidate( self, proofs: List[Proof], check_spendable=False ) -> List[Proof]: @@ -1010,15 +1130,9 @@ class Wallet( """ invalidated_proofs: List[Proof] = [] if check_spendable: - # checks proofs in batches - for _proofs in [ - proofs[i : i + settings.proofs_batch_size] - for i in range(0, len(proofs), settings.proofs_batch_size) - ]: - proof_states = await self.check_proof_state(proofs) - for i, state in enumerate(proof_states.states): - if state.spent: - invalidated_proofs.append(proofs[i]) + invalidated_proofs = await self.get_spent_proofs_check_states_batched( + proofs + ) else: invalidated_proofs = proofs @@ -1100,7 +1214,7 @@ class Wallet( + amount_summary(proofs, self.unit) ) if set_reserved: - await self.set_reserved(send_proofs, reserved=True) + await self.set_reserved_for_send(send_proofs, reserved=True) return send_proofs, fees async def swap_to_send( @@ -1153,7 +1267,7 @@ class Wallet( swap_proofs, amount, secret_lock, include_fees=include_fees ) if set_reserved: - await self.set_reserved(send_proofs, reserved=True) + await self.set_reserved_for_send(send_proofs, reserved=True) return keep_proofs, send_proofs # ---------- BALANCE CHECKS ---------- diff --git a/tests/test_mint_operations.py b/tests/test_mint_operations.py index ac353ae..4312a0c 100644 --- a/tests/test_mint_operations.py +++ b/tests/test_mint_operations.py @@ -1,7 +1,7 @@ import pytest import pytest_asyncio -from cashu.core.base import MeltQuoteState +from cashu.core.base import MeltQuoteState, MintQuoteState from cashu.core.helpers import sum_proofs from cashu.core.models import PostMeltQuoteRequest, PostMintQuoteRequest from cashu.core.nuts import nut20 @@ -59,8 +59,9 @@ async def test_melt_internal(wallet1: Wallet, ledger: Ledger): if not settings.debug_mint_only_deprecated: melt_quote_response_pre_payment = await wallet1.get_melt_quote(melt_quote.quote) + assert melt_quote_response_pre_payment assert ( - not melt_quote_response_pre_payment.state == MeltQuoteState.paid.value + not melt_quote_response_pre_payment.state == MeltQuoteState.paid ), "melt quote should not be paid" assert melt_quote_response_pre_payment.amount == 64 @@ -88,20 +89,21 @@ async def test_melt_external(wallet1: Wallet, ledger: Ledger): invoice_dict = get_real_invoice(64) invoice_payment_request = invoice_dict["payment_request"] - mint_quote = await wallet1.melt_quote(invoice_payment_request) - assert not mint_quote.paid, "mint quote should not be paid" - assert mint_quote.state == MeltQuoteState.unpaid.value + melt_quote = await wallet1.melt_quote(invoice_payment_request) + assert not melt_quote.paid, "mint quote should not be paid" + assert melt_quote.state == MeltQuoteState.unpaid - total_amount = mint_quote.amount + mint_quote.fee_reserve - keep_proofs, send_proofs = await wallet1.swap_to_send(wallet1.proofs, total_amount) + total_amount = melt_quote.amount + melt_quote.fee_reserve + _, send_proofs = await wallet1.swap_to_send(wallet1.proofs, total_amount) melt_quote = await ledger.melt_quote( PostMeltQuoteRequest(request=invoice_payment_request, unit="sat") ) if not settings.debug_mint_only_deprecated: melt_quote_response_pre_payment = await wallet1.get_melt_quote(melt_quote.quote) + assert melt_quote_response_pre_payment assert ( - melt_quote_response_pre_payment.state == MeltQuoteState.unpaid.value + melt_quote_response_pre_payment.state == MeltQuoteState.unpaid ), "melt quote should not be paid" assert melt_quote_response_pre_payment.amount == melt_quote.amount @@ -127,10 +129,8 @@ async def test_mint_internal(wallet1: Wallet, ledger: Ledger): assert mint_quote.paid, "mint quote should be paid" if not settings.debug_mint_only_deprecated: - mint_quote_resp = await wallet1.get_mint_quote(mint_quote.quote) - assert ( - mint_quote_resp.state == MeltQuoteState.paid.value - ), "mint quote should be paid" + mint_quote = await wallet1.get_mint_quote(mint_quote.quote) + assert mint_quote.state == MintQuoteState.paid, "mint quote should be paid" output_amounts = [128] secrets, rs, derivation_paths = await wallet1.generate_n_secrets( @@ -163,8 +163,8 @@ async def test_mint_external(wallet1: Wallet, ledger: Ledger): assert mint_quote.unpaid if not settings.debug_mint_only_deprecated: - mint_quote_resp = await wallet1.get_mint_quote(quote.quote) - assert not mint_quote_resp.paid, "mint quote should not be paid" + mint_quote = await wallet1.get_mint_quote(quote.quote) + assert not mint_quote.paid, "mint quote should not be paid" await assert_err( wallet1.mint(128, quote_id=quote.quote), diff --git a/tests/test_wallet.py b/tests/test_wallet.py index 63f84de..eb2aa9e 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -4,7 +4,7 @@ from typing import List, Union import pytest import pytest_asyncio -from cashu.core.base import MintQuoteState, Proof +from cashu.core.base import MeltQuote, MeltQuoteState, MintQuoteState, Proof from cashu.core.errors import CashuError, KeysetNotFoundError from cashu.core.helpers import sum_proofs from cashu.core.settings import settings @@ -20,6 +20,7 @@ from cashu.wallet.wallet import Wallet as Wallet2 from tests.conftest import SERVER_ENDPOINT from tests.helpers import ( get_real_invoice, + is_deprecated_api_only, is_fake, is_github_actions, is_regtest, @@ -174,9 +175,9 @@ async def test_mint(wallet1: Wallet): mint_quote = await wallet1.request_mint(64) await pay_if_regtest(mint_quote.request) if not settings.debug_mint_only_deprecated: - quote_resp = await wallet1.get_mint_quote(mint_quote.quote) - assert quote_resp.request == mint_quote.request - assert quote_resp.state == MintQuoteState.paid.value + mint_quote = await wallet1.get_mint_quote(mint_quote.quote) + assert mint_quote.request == mint_quote.request + assert mint_quote.state == MintQuoteState.paid expected_proof_amounts = wallet1.split_wallet_state(64) await wallet1.mint(64, quote_id=mint_quote.quote) @@ -314,6 +315,7 @@ async def test_melt(wallet1: Wallet): if not settings.debug_mint_only_deprecated: quote_resp = await wallet1.get_melt_quote(quote.quote) + assert quote_resp assert quote_resp.amount == quote.amount _, send_proofs = await wallet1.swap_to_send(wallet1.proofs, total_amount) @@ -337,7 +339,19 @@ async def test_melt(wallet1: Wallet): melt_quote_db = await get_bolt11_melt_quote( db=wallet1.db, request=invoice_payment_request ) - assert melt_quote_db, "No invoice in db" + assert melt_quote_db, "No melt quote in db" + + # compare melt quote from API against db + if not settings.debug_mint_only_deprecated: + melt_quote_api_resp = await wallet1.get_melt_quote(melt_quote_db.quote) + assert melt_quote_api_resp, "No melt quote from API" + assert melt_quote_api_resp.quote == melt_quote_db.quote, "Wrong quote ID" + assert melt_quote_api_resp.amount == melt_quote_db.amount, "Wrong amount" + assert melt_quote_api_resp.fee_reserve == melt_quote_db.fee_reserve, "Wrong fee" + assert melt_quote_api_resp.request == melt_quote_db.request, "Wrong request" + assert melt_quote_api_resp.state == melt_quote_db.state, "Wrong state" + assert melt_quote_api_resp.unit == melt_quote_db.unit, "Wrong unit" + proofs_used = await get_proofs( db=wallet1.db, melt_id=melt_quote_db.quote, table="proofs_used" ) @@ -350,6 +364,49 @@ async def test_melt(wallet1: Wallet): assert wallet1.balance == 64, "Wrong balance" +@pytest.mark.asyncio +@pytest.mark.skipif(is_deprecated_api_only, reason="Deprecated API only") +async def test_get_melt_quote_state(wallet1: Wallet): + mint_quote = await wallet1.request_mint(128) + await pay_if_regtest(mint_quote.request) + await wallet1.mint(128, quote_id=mint_quote.quote) + invoice_payment_request = "" + if is_regtest: + invoice_dict = get_real_invoice(64) + invoice_payment_request = invoice_dict["payment_request"] + + if is_fake: + mint_quote = await wallet1.request_mint(64) + invoice_payment_request = mint_quote.request + quote = await wallet1.melt_quote(invoice_payment_request) + assert quote.state == MeltQuoteState.unpaid + assert quote.request == invoice_payment_request + total_amount = quote.amount + quote.fee_reserve + _, send_proofs = await wallet1.swap_to_send(wallet1.proofs, total_amount) + melt_response = await wallet1.melt( + proofs=send_proofs, + invoice=invoice_payment_request, + fee_reserve_sat=quote.fee_reserve, + quote_id=quote.quote, + ) + melt_quote_wallet = MeltQuote.from_resp_wallet( + melt_response, + mint="test", + unit=quote.unit or "sat", + request=quote.request or invoice_payment_request, + ) + + # compare melt quote from API against db + melt_quote_api_resp = await wallet1.get_melt_quote(melt_quote_wallet.quote) + assert melt_quote_api_resp, "No melt quote from API" + assert melt_quote_api_resp.quote == melt_quote_wallet.quote, "Wrong quote ID" + assert melt_quote_api_resp.amount == melt_quote_wallet.amount, "Wrong amount" + assert melt_quote_api_resp.fee_reserve == melt_quote_wallet.fee_reserve, "Wrong fee" + assert melt_quote_api_resp.request == melt_quote_wallet.request, "Wrong request" + assert melt_quote_api_resp.state == melt_quote_wallet.state, "Wrong state" + assert melt_quote_api_resp.unit == melt_quote_wallet.unit, "Wrong unit" + + @pytest.mark.asyncio async def test_swap_to_send_more_than_balance(wallet1: Wallet): mint_quote = await wallet1.request_mint(64) @@ -446,6 +503,24 @@ async def test_invalidate_unspent_proofs_with_checking(wallet1: Wallet): assert wallet1.balance == 64 +@pytest.mark.asyncio +async def test_invalidate_batch_many_proofs(wallet1: Wallet): + """Try to invalidate proofs that have not been spent yet but force no check.""" + amount_to_mint = 500 # nutshell default value is 1000 + mint_quote = await wallet1.request_mint(amount_to_mint) + await pay_if_regtest(mint_quote.request) + proofs = await wallet1.mint( + amount_to_mint, quote_id=mint_quote.quote, split=[1] * amount_to_mint + ) + assert len(proofs) == amount_to_mint + + states = await wallet1.check_proof_state(proofs) + assert all([s.unspent for s in states.states]) + spent_proofs = await wallet1.get_spent_proofs_check_states_batched(proofs) + assert len(spent_proofs) == 0 + assert wallet1.balance == amount_to_mint + + @pytest.mark.asyncio async def test_split_invalid_amount(wallet1: Wallet): mint_quote = await wallet1.request_mint(64) diff --git a/tests/test_wallet_regtest.py b/tests/test_wallet_regtest.py index e4cfc6e..e950dba 100644 --- a/tests/test_wallet_regtest.py +++ b/tests/test_wallet_regtest.py @@ -5,6 +5,7 @@ import pytest import pytest_asyncio from cashu.mint.ledger import Ledger +from cashu.wallet.crud import get_proofs from cashu.wallet.wallet import Wallet from tests.conftest import SERVER_ENDPOINT from tests.helpers import ( @@ -104,3 +105,189 @@ async def test_regtest_failed_quote(wallet: Wallet, ledger: Ledger): states = await wallet.check_proof_state(send_proofs) assert all([s.unspent for s in states.states]) + + +@pytest.mark.asyncio +@pytest.mark.skipif(is_fake, reason="only regtest") +async def test_regtest_get_melt_quote_melt_fail_restore_pending_batch_check( + wallet: Wallet, ledger: Ledger +): + # simulates a payment that fails on the mint and whether the wallet is able to + # restore the state of all proofs (set unreserved) + mint_quote = await wallet.request_mint(64) + await pay_if_regtest(mint_quote.request) + await wallet.mint(64, quote_id=mint_quote.quote) + assert wallet.balance == 64 + + # create hodl invoice + preimage, invoice_dict = get_hold_invoice(16) + invoice_payment_request = str(invoice_dict["payment_request"]) + invoice_obj = bolt11.decode(invoice_payment_request) + preimage_hash = invoice_obj.payment_hash + + # wallet pays the invoice + quote = await wallet.melt_quote(invoice_payment_request) + total_amount = quote.amount + quote.fee_reserve + _, send_proofs = await wallet.swap_to_send( + wallet.proofs, total_amount, set_reserved=True + ) + + # verify that the proofs are reserved + proofs_db = await get_proofs(db=wallet.db, melt_id=quote.quote) + assert all([p.reserved for p in proofs_db]) + + asyncio.create_task( + wallet.melt( + proofs=send_proofs, + invoice=invoice_payment_request, + fee_reserve_sat=quote.fee_reserve, + quote_id=quote.quote, + ) + ) + await asyncio.sleep(SLEEP_TIME) + + states = await wallet.check_proof_state(send_proofs) + assert all([s.pending for s in states.states]) + + # fail the payment, melt will unset the proofs as reserved + cancel_invoice(preimage_hash=preimage_hash) + + await asyncio.sleep(SLEEP_TIME) + + # test get_spent_proofs_check_states_batched: verify that no proofs are spent + spent_proofs = await wallet.get_spent_proofs_check_states_batched(send_proofs) + assert len(spent_proofs) == 0 + + proofs_db_later = await get_proofs(db=wallet.db, melt_id=quote.quote) + assert all([p.reserved is False for p in proofs_db_later]) + + +@pytest.mark.asyncio +@pytest.mark.skipif(is_fake, reason="only regtest") +async def test_regtest_get_melt_quote_wallet_crash_melt_fail_restore_pending_batch_check( + wallet: Wallet, ledger: Ledger +): + # simulates a payment failure but the wallet crashed, we confirm that wallet.get_melt_quote() will correctly + # recover the state of the proofs and set them as unreserved + mint_quote = await wallet.request_mint(64) + await pay_if_regtest(mint_quote.request) + await wallet.mint(64, quote_id=mint_quote.quote) + assert wallet.balance == 64 + + # create hodl invoice + preimage, invoice_dict = get_hold_invoice(16) + invoice_payment_request = str(invoice_dict["payment_request"]) + invoice_obj = bolt11.decode(invoice_payment_request) + preimage_hash = invoice_obj.payment_hash + + # wallet pays the invoice + quote = await wallet.melt_quote(invoice_payment_request) + total_amount = quote.amount + quote.fee_reserve + _, send_proofs = await wallet.swap_to_send( + wallet.proofs, total_amount, set_reserved=True + ) + assert len(send_proofs) == 2 + + task = asyncio.create_task( + wallet.melt( + proofs=send_proofs, + invoice=invoice_payment_request, + fee_reserve_sat=quote.fee_reserve, + quote_id=quote.quote, + ) + ) + await asyncio.sleep(SLEEP_TIME) + + # verify that the proofs are reserved + proofs_db = await get_proofs(db=wallet.db, melt_id=quote.quote) + assert len(proofs_db) == 2 + assert all([p.reserved for p in proofs_db]) + + # simulate a and kill the task + task.cancel() + + await asyncio.sleep(SLEEP_TIME) + + states = await wallet.check_proof_state(send_proofs) + assert all([s.pending for s in states.states]) + + # fail the payment, melt will unset the proofs as reserved + cancel_invoice(preimage_hash=preimage_hash) + + await asyncio.sleep(SLEEP_TIME) + + # get the melt quote, this should restore the state of the proofs + melt_quote = await wallet.get_melt_quote(quote.quote) + assert melt_quote + assert melt_quote.unpaid + + # verify that get_melt_quote unset all proofs as not pending anymore + proofs_db_later = await get_proofs(db=wallet.db, melt_id=quote.quote) + assert len(proofs_db_later) == 0 + + +@pytest.mark.asyncio +@pytest.mark.skipif(is_fake, reason="only regtest") +async def test_regtest_wallet_crash_melt_succeed_restore_pending_batch_check( + wallet: Wallet, ledger: Ledger +): + # simulates a payment that succeeds but the wallet crashes in the mean time + # we then call get_spent_proofs_check_states_batched to check the proof states + # and the wallet should then invalidate the reserved proofs + + mint_quote = await wallet.request_mint(64) + await pay_if_regtest(mint_quote.request) + await wallet.mint(64, quote_id=mint_quote.quote) + assert wallet.balance == 64 + + # create hodl invoice + preimage, invoice_dict = get_hold_invoice(16) + invoice_payment_request = str(invoice_dict["payment_request"]) + # invoice_obj = bolt11.decode(invoice_payment_request) + # preimage_hash = invoice_obj.payment_hash + + # wallet pays the invoice + quote = await wallet.melt_quote(invoice_payment_request) + total_amount = quote.amount + quote.fee_reserve + _, send_proofs = await wallet.swap_to_send( + wallet.proofs, total_amount, set_reserved=True + ) + + # verify that the proofs are reserved + proofs_db = await get_proofs(db=wallet.db, melt_id=quote.quote) + assert all([p.reserved for p in proofs_db]) + + task = asyncio.create_task( + wallet.melt( + proofs=send_proofs, + invoice=invoice_payment_request, + fee_reserve_sat=quote.fee_reserve, + quote_id=quote.quote, + ) + ) + await asyncio.sleep(SLEEP_TIME) + + # simulate a and kill the task + task.cancel() + await asyncio.sleep(SLEEP_TIME) + # verify that the proofs are still reserved + proofs_db = await get_proofs(db=wallet.db, melt_id=quote.quote) + assert all([p.reserved for p in proofs_db]) + + # verify that the proofs are still pending + states = await wallet.check_proof_state(send_proofs) + assert all([s.pending for s in states.states]) + + # succeed the payment + settle_invoice(preimage=preimage) + + await asyncio.sleep(SLEEP_TIME) + + # get the melt quote + melt_quote = await wallet.get_melt_quote(quote.quote) + assert melt_quote + assert melt_quote.paid + + # verify that get_melt_quote unset all proofs as not pending anymore + proofs_db_later = await get_proofs(db=wallet.db, melt_id=quote.quote) + assert len(proofs_db_later) == 0