Mint: invalidate and generate promises in single db transaction for split (#374)

* test for spending output again

* first gernerate (which can fail) then invalidate (db and memory)

* use external get_db_connection function to be compatible with existing Database class in LNbits
This commit is contained in:
callebtc
2023-12-02 22:54:28 +01:00
committed by GitHub
parent 0ec3af9bf1
commit 34a2e7e5da
4 changed files with 75 additions and 36 deletions

View File

@@ -202,3 +202,26 @@ def lock_table(db: Database, table: str) -> str:
elif db.type == SQLITE:
return "BEGIN EXCLUSIVE TRANSACTION;"
return "<nothing>"
@asynccontextmanager
async def get_db_connection(db: Database, conn: Optional[Connection] = None):
"""Either yield the existing database connection or create a new one.
Note: This should be implemented as Database.get_db_connection(self, conn) but
since we want to use it in LNbits, we can't change the Database class their.
Args:
db (Database): Database object.
conn (Optional[Connection], optional): Connection object. Defaults to None.
Yields:
Connection: Connection object.
"""
if conn is not None:
# Yield the existing connection
yield conn
else:
# Create and yield a new connection
async with db.connect() as new_conn:
yield new_conn

View File

@@ -17,7 +17,7 @@ from ..core.base import (
from ..core.crypto import b_dhke
from ..core.crypto.keys import derive_pubkey, random_hash
from ..core.crypto.secp import PublicKey
from ..core.db import Connection, Database
from ..core.db import Connection, Database, get_db_connection
from ..core.errors import (
KeysetError,
KeysetNotFoundError,
@@ -151,17 +151,17 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
# ------- ECASH -------
async def _invalidate_proofs(self, proofs: List[Proof]) -> None:
async def _invalidate_proofs(
self, proofs: List[Proof], conn: Optional[Connection] = None
) -> None:
"""Adds secrets of proofs to the list of known secrets and stores them in the db.
Removes proofs from pending table. This is executed if the ecash has been redeemed.
Args:
proofs (List[Proof]): Proofs to add to known secret table.
"""
# Mark proofs as used and prepare new promises
secrets = set([p.secret for p in proofs])
self.secrets_used |= secrets
async with self.db.connect() as conn:
async with get_db_connection(self.db, conn) as conn:
# store in db
for p in proofs:
await self.crud.invalidate_proof(proof=p, db=self.db, conn=conn)
@@ -450,14 +450,12 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
proofs: List[Proof],
outputs: List[BlindedMessage],
keyset: Optional[MintKeyset] = None,
amount: Optional[int] = None, # backwards compatibility < 0.13.0
):
"""Consumes proofs and prepares new promises based on the amount split. Used for splitting tokens
Before sending or for redeeming tokens for new ones that have been received by another wallet.
Args:
proofs (List[Proof]): Proofs to be invalidated for the split.
amount (int): Amount at which the split should happen.
outputs (List[BlindedMessage]): New outputs that should be signed in return.
keyset (Optional[MintKeyset], optional): Keyset to use. Uses default keyset if not given. Defaults to None.
@@ -471,33 +469,18 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
await self._set_proofs_pending(proofs)
try:
# explicitly check that amount of inputs is equal to amount of outputs
# note: we check this again in verify_inputs_and_outputs but only if any
# outputs are provided at all. To make sure of that before calling
# verify_inputs_and_outputs, we check it here.
self._verify_equation_balanced(proofs, outputs)
# verify spending inputs, outputs, and spending conditions
await self.verify_inputs_and_outputs(proofs, outputs)
# BEGIN backwards compatibility < 0.13.0
if amount is not None:
logger.debug(
"Split: Client provided `amount` - backwards compatibility response"
" pre 0.13.0"
)
# split outputs according to amount
total = sum_proofs(proofs)
if amount > total:
raise Exception("split amount is higher than the total sum.")
outs_fst = amount_split(total - amount)
B_fst = [od for od in outputs[: len(outs_fst)]]
B_snd = [od for od in outputs[len(outs_fst) :]]
# Mark proofs as used and prepare new promises
await self._invalidate_proofs(proofs)
prom_fst = await self._generate_promises(B_fst, keyset)
prom_snd = await self._generate_promises(B_snd, keyset)
promises = prom_fst + prom_snd
# END backwards compatibility < 0.13.0
else:
# Mark proofs as used and prepare new promises
await self._invalidate_proofs(proofs)
promises = await self._generate_promises(outputs, keyset)
async with get_db_connection(self.db) as conn:
promises = await self._generate_promises(outputs, keyset, conn)
await self._invalidate_proofs(proofs, conn)
except Exception as e:
logger.trace(f"split failed: {e}")
@@ -535,7 +518,10 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
# ------- BLIND SIGNATURES -------
async def _generate_promises(
self, B_s: List[BlindedMessage], keyset: Optional[MintKeyset] = None
self,
B_s: List[BlindedMessage],
keyset: Optional[MintKeyset] = None,
conn: Optional[Connection] = None,
) -> list[BlindedSignature]:
"""Generates a promises (Blind signatures) for given amount and returns a pair (amount, C').
@@ -557,7 +543,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
promises.append((B_, amount, C_, e, s))
signatures = []
async with self.db.connect() as conn:
async with get_db_connection(self.db, conn) as conn:
for promise in promises:
B_, amount, C_, e, s = promise
logger.trace(f"crud: _generate_promise storing promise for {amount}")

View File

@@ -258,9 +258,7 @@ async def split(
logger.trace(f"> POST /split: {payload}")
assert payload.outputs, Exception("no outputs provided.")
promises = await ledger.split(
proofs=payload.proofs, outputs=payload.outputs, amount=payload.amount
)
promises = await ledger.split(proofs=payload.proofs, outputs=payload.outputs)
if payload.amount:
# BEGIN backwards compatibility < 0.13

View File

@@ -123,6 +123,38 @@ async def test_split_with_input_more_than_outputs(wallet1: Wallet, ledger: Ledge
print(keep_proofs, send_proofs)
@pytest.mark.asyncio
async def test_split_twice_with_same_outputs(wallet1: Wallet, ledger: Ledger):
invoice = await wallet1.request_mint(128)
pay_if_regtest(invoice.bolt11)
await wallet1.mint(128, [64, 64], id=invoice.id)
inputs1 = wallet1.proofs[:1]
inputs2 = wallet1.proofs[1:]
output_amounts = [64]
secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
len(output_amounts)
)
outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs)
await ledger.split(proofs=inputs1, outputs=outputs)
# try to spend other proofs with the same outputs again
await assert_err(
ledger.split(proofs=inputs2, outputs=outputs),
"UNIQUE constraint failed: promises.B_b",
)
# try to spend inputs2 again with new outputs
output_amounts = [64]
secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
len(output_amounts)
)
outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs)
await ledger.split(proofs=inputs2, outputs=outputs)
@pytest.mark.asyncio
async def test_check_proof_state(wallet1: Wallet, ledger: Ledger):
invoice = await wallet1.request_mint(64)