diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index a040248..94b837b 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -473,8 +473,32 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning): try: # verify spending inputs, outputs, and spending conditions await self.verify_inputs_and_outputs(proofs, outputs) - # Mark proofs as used and prepare new promises - await self._invalidate_proofs(proofs) + + # BEGIN backwards compatibility < 0.13.0 + if amount is not None: + logger.debug( + "Split: Client provided `amount` - backwards compatibility response" + " pre 0.13.0" + ) + # split outputs according to amount + total = sum_proofs(proofs) + if amount > total: + raise Exception("split amount is higher than the total sum.") + outs_fst = amount_split(total - amount) + B_fst = [od for od in outputs[: len(outs_fst)]] + B_snd = [od for od in outputs[len(outs_fst) :]] + + # Mark proofs as used and prepare new promises + await self._invalidate_proofs(proofs) + prom_fst = await self._generate_promises(B_fst, keyset) + prom_snd = await self._generate_promises(B_snd, keyset) + promises = prom_fst + prom_snd + # END backwards compatibility < 0.13.0 + else: + # Mark proofs as used and prepare new promises + await self._invalidate_proofs(proofs) + promises = await self._generate_promises(outputs, keyset) + except Exception as e: logger.trace(f"split failed: {e}") raise e @@ -482,31 +506,6 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning): # delete proofs from pending list await self._unset_proofs_pending(proofs) - # BEGIN backwards compatibility < 0.13.0 - if amount is not None: - logger.debug( - "Split: Client provided `amount` - backwards compatibility response pre" - " 0.13.0" - ) - # split outputs according to amount - total = sum_proofs(proofs) - if amount > total: - raise Exception("split amount is higher than the total sum.") - outs_fst = amount_split(total - amount) - B_fst = [od for od in outputs[: len(outs_fst)]] - B_snd = [od for od in outputs[len(outs_fst) :]] - - # generate promises - prom_fst = await self._generate_promises(B_fst, keyset) - prom_snd = await self._generate_promises(B_snd, keyset) - promises = prom_fst + prom_snd - # END backwards compatibility < 0.13.0 - else: - promises = await self._generate_promises(outputs, keyset) - - # verify amounts in produced promises - self._verify_equation_balanced(proofs, promises) - logger.trace("split successful") return promises diff --git a/cashu/mint/verification.py b/cashu/mint/verification.py index 1cbd76b..188b433 100644 --- a/cashu/mint/verification.py +++ b/cashu/mint/verification.py @@ -74,6 +74,9 @@ class LedgerVerification(LedgerSpendingConditions, SupportsKeysets, SupportsDb): if not outputs: return + # Verify input and output amounts + self._verify_equation_balanced(proofs, outputs) + # Verify outputs self._verify_outputs(outputs) @@ -176,6 +179,6 @@ class LedgerVerification(LedgerSpendingConditions, SupportsKeysets, SupportsDb): """ sum_inputs = sum(self._verify_amount(p.amount) for p in proofs) sum_outputs = sum(self._verify_amount(p.amount) for p in outs) - assert ( - sum_outputs - sum_inputs == 0 - ), "inputs do not have same amount as outputs" + assert sum_outputs - sum_inputs == 0, TransactionError( + "inputs do not have same amount as outputs." + ) diff --git a/tests/conftest.py b/tests/conftest.py index 9056213..b2c286a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,7 +60,7 @@ async def ledger(): await ledger.load_used_proofs() await ledger.init_keysets() - database_name = "test" + database_name = "mint" if not settings.mint_database.startswith("postgres"): # clear sqlite database diff --git a/tests/test_mint_operations.py b/tests/test_mint_operations.py index e4e9e73..7bcf2ff 100644 --- a/tests/test_mint_operations.py +++ b/tests/test_mint_operations.py @@ -8,6 +8,17 @@ from tests.conftest import SERVER_ENDPOINT from tests.helpers import pay_if_regtest +async def assert_err(f, msg): + """Compute f() and expect an error message 'msg'.""" + try: + await f + except Exception as exc: + if msg not in str(exc.args[0]): + raise Exception(f"Expected error: {msg}, got: {exc.args[0]}") + return + raise Exception(f"Expected error: {msg}, got no error") + + @pytest_asyncio.fixture(scope="function") async def wallet1(mint): wallet1 = await Wallet1.with_db( @@ -58,6 +69,60 @@ async def test_split(wallet1: Wallet, ledger: Ledger): assert [p.amount for p in promises] == [p.amount for p in outputs] +@pytest.mark.asyncio +async def test_split_with_input_less_than_outputs(wallet1: Wallet, ledger: Ledger): + invoice = await wallet1.request_mint(64) + pay_if_regtest(invoice.bolt11) + await wallet1.mint(64, id=invoice.id) + + keep_proofs, send_proofs = await wallet1.split_to_send( + wallet1.proofs, 10, set_reserved=False + ) + + all_send_proofs = send_proofs + keep_proofs + + # generate outputs for all proofs, not only the sent ones + secrets, rs, derivation_paths = await wallet1.generate_n_secrets( + len(all_send_proofs) + ) + outputs, rs = wallet1._construct_outputs( + [p.amount for p in all_send_proofs], secrets, rs + ) + + await assert_err( + ledger.split(proofs=send_proofs, outputs=outputs), + "inputs do not have same amount as outputs.", + ) + + # make sure we can still spend our tokens + keep_proofs, send_proofs = await wallet1.split(wallet1.proofs, 10) + + +@pytest.mark.asyncio +async def test_split_with_input_more_than_outputs(wallet1: Wallet, ledger: Ledger): + invoice = await wallet1.request_mint(128) + pay_if_regtest(invoice.bolt11) + await wallet1.mint(128, id=invoice.id) + + inputs = wallet1.proofs + + # less outputs than inputs + output_amounts = [8] + secrets, rs, derivation_paths = await wallet1.generate_n_secrets( + len(output_amounts) + ) + outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs) + + await assert_err( + ledger.split(proofs=inputs, outputs=outputs), + "inputs do not have same amount as outputs", + ) + + # make sure we can still spend our tokens + keep_proofs, send_proofs = await wallet1.split(inputs, 10) + print(keep_proofs, send_proofs) + + @pytest.mark.asyncio async def test_check_proof_state(wallet1: Wallet, ledger: Ledger): invoice = await wallet1.request_mint(64) diff --git a/tests/test_wallet_htlc.py b/tests/test_wallet_htlc.py index b64557f..0f6cbff 100644 --- a/tests/test_wallet_htlc.py +++ b/tests/test_wallet_htlc.py @@ -191,7 +191,7 @@ async def test_htlc_redeem_hashlock_wrong_signature_timelock_correct_signature( secret = await wallet1.create_htlc_lock( preimage=preimage, hashlock_pubkey=pubkey_wallet2, - locktime_seconds=5, + locktime_seconds=2, locktime_pubkey=pubkey_wallet1, ) _, send_proofs = await wallet1.split_to_send(wallet1.proofs, 8, secret_lock=secret) @@ -206,7 +206,7 @@ async def test_htlc_redeem_hashlock_wrong_signature_timelock_correct_signature( "Mint Error: HTLC hash lock signatures did not match.", ) - await asyncio.sleep(5) + await asyncio.sleep(2) # should succeed since lock time has passed and we provided wallet1 signature for timelock await wallet1.redeem(send_proofs) @@ -225,7 +225,7 @@ async def test_htlc_redeem_hashlock_wrong_signature_timelock_wrong_signature( secret = await wallet1.create_htlc_lock( preimage=preimage, hashlock_pubkey=pubkey_wallet2, - locktime_seconds=5, + locktime_seconds=2, locktime_pubkey=pubkey_wallet1, ) _, send_proofs = await wallet1.split_to_send(wallet1.proofs, 8, secret_lock=secret) @@ -242,7 +242,7 @@ async def test_htlc_redeem_hashlock_wrong_signature_timelock_wrong_signature( "Mint Error: HTLC hash lock signatures did not match.", ) - await asyncio.sleep(5) + await asyncio.sleep(2) # should fail since lock time has passed and we provided a wrong signature for timelock await assert_err( wallet1.redeem(send_proofs),