From 34a2e7e5daaa1e14cbe2129cf8dcd4f1704b2982 Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Sat, 2 Dec 2023 22:54:28 +0100 Subject: [PATCH] Mint: invalidate and generate promises in single db transaction for split (#374) * test for spending output again * first gernerate (which can fail) then invalidate (db and memory) * use external get_db_connection function to be compatible with existing Database class in LNbits --- cashu/core/db.py | 23 ++++++++++++++++ cashu/mint/ledger.py | 52 +++++++++++++---------------------- cashu/mint/router.py | 4 +-- tests/test_mint_operations.py | 32 +++++++++++++++++++++ 4 files changed, 75 insertions(+), 36 deletions(-) diff --git a/cashu/core/db.py b/cashu/core/db.py index b4337cc..1f96910 100644 --- a/cashu/core/db.py +++ b/cashu/core/db.py @@ -202,3 +202,26 @@ def lock_table(db: Database, table: str) -> str: elif db.type == SQLITE: return "BEGIN EXCLUSIVE TRANSACTION;" return "" + + +@asynccontextmanager +async def get_db_connection(db: Database, conn: Optional[Connection] = None): + """Either yield the existing database connection or create a new one. + + Note: This should be implemented as Database.get_db_connection(self, conn) but + since we want to use it in LNbits, we can't change the Database class their. + + Args: + db (Database): Database object. + conn (Optional[Connection], optional): Connection object. Defaults to None. + + Yields: + Connection: Connection object. + """ + if conn is not None: + # Yield the existing connection + yield conn + else: + # Create and yield a new connection + async with db.connect() as new_conn: + yield new_conn diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index 94b837b..56531a6 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -17,7 +17,7 @@ from ..core.base import ( from ..core.crypto import b_dhke from ..core.crypto.keys import derive_pubkey, random_hash from ..core.crypto.secp import PublicKey -from ..core.db import Connection, Database +from ..core.db import Connection, Database, get_db_connection from ..core.errors import ( KeysetError, KeysetNotFoundError, @@ -151,17 +151,17 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning): # ------- ECASH ------- - async def _invalidate_proofs(self, proofs: List[Proof]) -> None: + async def _invalidate_proofs( + self, proofs: List[Proof], conn: Optional[Connection] = None + ) -> None: """Adds secrets of proofs to the list of known secrets and stores them in the db. - Removes proofs from pending table. This is executed if the ecash has been redeemed. Args: proofs (List[Proof]): Proofs to add to known secret table. """ - # Mark proofs as used and prepare new promises secrets = set([p.secret for p in proofs]) self.secrets_used |= secrets - async with self.db.connect() as conn: + async with get_db_connection(self.db, conn) as conn: # store in db for p in proofs: await self.crud.invalidate_proof(proof=p, db=self.db, conn=conn) @@ -450,14 +450,12 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning): proofs: List[Proof], outputs: List[BlindedMessage], keyset: Optional[MintKeyset] = None, - amount: Optional[int] = None, # backwards compatibility < 0.13.0 ): """Consumes proofs and prepares new promises based on the amount split. Used for splitting tokens Before sending or for redeeming tokens for new ones that have been received by another wallet. Args: proofs (List[Proof]): Proofs to be invalidated for the split. - amount (int): Amount at which the split should happen. outputs (List[BlindedMessage]): New outputs that should be signed in return. keyset (Optional[MintKeyset], optional): Keyset to use. Uses default keyset if not given. Defaults to None. @@ -471,33 +469,18 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning): await self._set_proofs_pending(proofs) try: + # explicitly check that amount of inputs is equal to amount of outputs + # note: we check this again in verify_inputs_and_outputs but only if any + # outputs are provided at all. To make sure of that before calling + # verify_inputs_and_outputs, we check it here. + self._verify_equation_balanced(proofs, outputs) # verify spending inputs, outputs, and spending conditions await self.verify_inputs_and_outputs(proofs, outputs) - # BEGIN backwards compatibility < 0.13.0 - if amount is not None: - logger.debug( - "Split: Client provided `amount` - backwards compatibility response" - " pre 0.13.0" - ) - # split outputs according to amount - total = sum_proofs(proofs) - if amount > total: - raise Exception("split amount is higher than the total sum.") - outs_fst = amount_split(total - amount) - B_fst = [od for od in outputs[: len(outs_fst)]] - B_snd = [od for od in outputs[len(outs_fst) :]] - - # Mark proofs as used and prepare new promises - await self._invalidate_proofs(proofs) - prom_fst = await self._generate_promises(B_fst, keyset) - prom_snd = await self._generate_promises(B_snd, keyset) - promises = prom_fst + prom_snd - # END backwards compatibility < 0.13.0 - else: - # Mark proofs as used and prepare new promises - await self._invalidate_proofs(proofs) - promises = await self._generate_promises(outputs, keyset) + # Mark proofs as used and prepare new promises + async with get_db_connection(self.db) as conn: + promises = await self._generate_promises(outputs, keyset, conn) + await self._invalidate_proofs(proofs, conn) except Exception as e: logger.trace(f"split failed: {e}") @@ -535,7 +518,10 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning): # ------- BLIND SIGNATURES ------- async def _generate_promises( - self, B_s: List[BlindedMessage], keyset: Optional[MintKeyset] = None + self, + B_s: List[BlindedMessage], + keyset: Optional[MintKeyset] = None, + conn: Optional[Connection] = None, ) -> list[BlindedSignature]: """Generates a promises (Blind signatures) for given amount and returns a pair (amount, C'). @@ -557,7 +543,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning): promises.append((B_, amount, C_, e, s)) signatures = [] - async with self.db.connect() as conn: + async with get_db_connection(self.db, conn) as conn: for promise in promises: B_, amount, C_, e, s = promise logger.trace(f"crud: _generate_promise storing promise for {amount}") diff --git a/cashu/mint/router.py b/cashu/mint/router.py index b556954..90fc648 100644 --- a/cashu/mint/router.py +++ b/cashu/mint/router.py @@ -258,9 +258,7 @@ async def split( logger.trace(f"> POST /split: {payload}") assert payload.outputs, Exception("no outputs provided.") - promises = await ledger.split( - proofs=payload.proofs, outputs=payload.outputs, amount=payload.amount - ) + promises = await ledger.split(proofs=payload.proofs, outputs=payload.outputs) if payload.amount: # BEGIN backwards compatibility < 0.13 diff --git a/tests/test_mint_operations.py b/tests/test_mint_operations.py index 7bcf2ff..464e252 100644 --- a/tests/test_mint_operations.py +++ b/tests/test_mint_operations.py @@ -123,6 +123,38 @@ async def test_split_with_input_more_than_outputs(wallet1: Wallet, ledger: Ledge print(keep_proofs, send_proofs) +@pytest.mark.asyncio +async def test_split_twice_with_same_outputs(wallet1: Wallet, ledger: Ledger): + invoice = await wallet1.request_mint(128) + pay_if_regtest(invoice.bolt11) + await wallet1.mint(128, [64, 64], id=invoice.id) + inputs1 = wallet1.proofs[:1] + inputs2 = wallet1.proofs[1:] + + output_amounts = [64] + secrets, rs, derivation_paths = await wallet1.generate_n_secrets( + len(output_amounts) + ) + outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs) + + await ledger.split(proofs=inputs1, outputs=outputs) + + # try to spend other proofs with the same outputs again + await assert_err( + ledger.split(proofs=inputs2, outputs=outputs), + "UNIQUE constraint failed: promises.B_b", + ) + + # try to spend inputs2 again with new outputs + output_amounts = [64] + secrets, rs, derivation_paths = await wallet1.generate_n_secrets( + len(output_amounts) + ) + outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs) + + await ledger.split(proofs=inputs2, outputs=outputs) + + @pytest.mark.asyncio async def test_check_proof_state(wallet1: Wallet, ledger: Ledger): invoice = await wallet1.request_mint(64)