Mint: invalidate and generate promises in single db transaction for split (#374)

* test for spending output again

* first gernerate (which can fail) then invalidate (db and memory)

* use external get_db_connection function to be compatible with existing Database class in LNbits
This commit is contained in:
callebtc
2023-12-02 22:54:28 +01:00
committed by GitHub
parent 0ec3af9bf1
commit 34a2e7e5da
4 changed files with 75 additions and 36 deletions

View File

@@ -202,3 +202,26 @@ def lock_table(db: Database, table: str) -> str:
elif db.type == SQLITE: elif db.type == SQLITE:
return "BEGIN EXCLUSIVE TRANSACTION;" return "BEGIN EXCLUSIVE TRANSACTION;"
return "<nothing>" return "<nothing>"
@asynccontextmanager
async def get_db_connection(db: Database, conn: Optional[Connection] = None):
"""Either yield the existing database connection or create a new one.
Note: This should be implemented as Database.get_db_connection(self, conn) but
since we want to use it in LNbits, we can't change the Database class their.
Args:
db (Database): Database object.
conn (Optional[Connection], optional): Connection object. Defaults to None.
Yields:
Connection: Connection object.
"""
if conn is not None:
# Yield the existing connection
yield conn
else:
# Create and yield a new connection
async with db.connect() as new_conn:
yield new_conn

View File

@@ -17,7 +17,7 @@ from ..core.base import (
from ..core.crypto import b_dhke from ..core.crypto import b_dhke
from ..core.crypto.keys import derive_pubkey, random_hash from ..core.crypto.keys import derive_pubkey, random_hash
from ..core.crypto.secp import PublicKey from ..core.crypto.secp import PublicKey
from ..core.db import Connection, Database from ..core.db import Connection, Database, get_db_connection
from ..core.errors import ( from ..core.errors import (
KeysetError, KeysetError,
KeysetNotFoundError, KeysetNotFoundError,
@@ -151,17 +151,17 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
# ------- ECASH ------- # ------- ECASH -------
async def _invalidate_proofs(self, proofs: List[Proof]) -> None: async def _invalidate_proofs(
self, proofs: List[Proof], conn: Optional[Connection] = None
) -> None:
"""Adds secrets of proofs to the list of known secrets and stores them in the db. """Adds secrets of proofs to the list of known secrets and stores them in the db.
Removes proofs from pending table. This is executed if the ecash has been redeemed.
Args: Args:
proofs (List[Proof]): Proofs to add to known secret table. proofs (List[Proof]): Proofs to add to known secret table.
""" """
# Mark proofs as used and prepare new promises
secrets = set([p.secret for p in proofs]) secrets = set([p.secret for p in proofs])
self.secrets_used |= secrets self.secrets_used |= secrets
async with self.db.connect() as conn: async with get_db_connection(self.db, conn) as conn:
# store in db # store in db
for p in proofs: for p in proofs:
await self.crud.invalidate_proof(proof=p, db=self.db, conn=conn) await self.crud.invalidate_proof(proof=p, db=self.db, conn=conn)
@@ -450,14 +450,12 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
proofs: List[Proof], proofs: List[Proof],
outputs: List[BlindedMessage], outputs: List[BlindedMessage],
keyset: Optional[MintKeyset] = None, keyset: Optional[MintKeyset] = None,
amount: Optional[int] = None, # backwards compatibility < 0.13.0
): ):
"""Consumes proofs and prepares new promises based on the amount split. Used for splitting tokens """Consumes proofs and prepares new promises based on the amount split. Used for splitting tokens
Before sending or for redeeming tokens for new ones that have been received by another wallet. Before sending or for redeeming tokens for new ones that have been received by another wallet.
Args: Args:
proofs (List[Proof]): Proofs to be invalidated for the split. proofs (List[Proof]): Proofs to be invalidated for the split.
amount (int): Amount at which the split should happen.
outputs (List[BlindedMessage]): New outputs that should be signed in return. outputs (List[BlindedMessage]): New outputs that should be signed in return.
keyset (Optional[MintKeyset], optional): Keyset to use. Uses default keyset if not given. Defaults to None. keyset (Optional[MintKeyset], optional): Keyset to use. Uses default keyset if not given. Defaults to None.
@@ -471,33 +469,18 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
await self._set_proofs_pending(proofs) await self._set_proofs_pending(proofs)
try: try:
# explicitly check that amount of inputs is equal to amount of outputs
# note: we check this again in verify_inputs_and_outputs but only if any
# outputs are provided at all. To make sure of that before calling
# verify_inputs_and_outputs, we check it here.
self._verify_equation_balanced(proofs, outputs)
# verify spending inputs, outputs, and spending conditions # verify spending inputs, outputs, and spending conditions
await self.verify_inputs_and_outputs(proofs, outputs) await self.verify_inputs_and_outputs(proofs, outputs)
# BEGIN backwards compatibility < 0.13.0 # Mark proofs as used and prepare new promises
if amount is not None: async with get_db_connection(self.db) as conn:
logger.debug( promises = await self._generate_promises(outputs, keyset, conn)
"Split: Client provided `amount` - backwards compatibility response" await self._invalidate_proofs(proofs, conn)
" pre 0.13.0"
)
# split outputs according to amount
total = sum_proofs(proofs)
if amount > total:
raise Exception("split amount is higher than the total sum.")
outs_fst = amount_split(total - amount)
B_fst = [od for od in outputs[: len(outs_fst)]]
B_snd = [od for od in outputs[len(outs_fst) :]]
# Mark proofs as used and prepare new promises
await self._invalidate_proofs(proofs)
prom_fst = await self._generate_promises(B_fst, keyset)
prom_snd = await self._generate_promises(B_snd, keyset)
promises = prom_fst + prom_snd
# END backwards compatibility < 0.13.0
else:
# Mark proofs as used and prepare new promises
await self._invalidate_proofs(proofs)
promises = await self._generate_promises(outputs, keyset)
except Exception as e: except Exception as e:
logger.trace(f"split failed: {e}") logger.trace(f"split failed: {e}")
@@ -535,7 +518,10 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
# ------- BLIND SIGNATURES ------- # ------- BLIND SIGNATURES -------
async def _generate_promises( async def _generate_promises(
self, B_s: List[BlindedMessage], keyset: Optional[MintKeyset] = None self,
B_s: List[BlindedMessage],
keyset: Optional[MintKeyset] = None,
conn: Optional[Connection] = None,
) -> list[BlindedSignature]: ) -> list[BlindedSignature]:
"""Generates a promises (Blind signatures) for given amount and returns a pair (amount, C'). """Generates a promises (Blind signatures) for given amount and returns a pair (amount, C').
@@ -557,7 +543,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
promises.append((B_, amount, C_, e, s)) promises.append((B_, amount, C_, e, s))
signatures = [] signatures = []
async with self.db.connect() as conn: async with get_db_connection(self.db, conn) as conn:
for promise in promises: for promise in promises:
B_, amount, C_, e, s = promise B_, amount, C_, e, s = promise
logger.trace(f"crud: _generate_promise storing promise for {amount}") logger.trace(f"crud: _generate_promise storing promise for {amount}")

View File

@@ -258,9 +258,7 @@ async def split(
logger.trace(f"> POST /split: {payload}") logger.trace(f"> POST /split: {payload}")
assert payload.outputs, Exception("no outputs provided.") assert payload.outputs, Exception("no outputs provided.")
promises = await ledger.split( promises = await ledger.split(proofs=payload.proofs, outputs=payload.outputs)
proofs=payload.proofs, outputs=payload.outputs, amount=payload.amount
)
if payload.amount: if payload.amount:
# BEGIN backwards compatibility < 0.13 # BEGIN backwards compatibility < 0.13

View File

@@ -123,6 +123,38 @@ async def test_split_with_input_more_than_outputs(wallet1: Wallet, ledger: Ledge
print(keep_proofs, send_proofs) print(keep_proofs, send_proofs)
@pytest.mark.asyncio
async def test_split_twice_with_same_outputs(wallet1: Wallet, ledger: Ledger):
invoice = await wallet1.request_mint(128)
pay_if_regtest(invoice.bolt11)
await wallet1.mint(128, [64, 64], id=invoice.id)
inputs1 = wallet1.proofs[:1]
inputs2 = wallet1.proofs[1:]
output_amounts = [64]
secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
len(output_amounts)
)
outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs)
await ledger.split(proofs=inputs1, outputs=outputs)
# try to spend other proofs with the same outputs again
await assert_err(
ledger.split(proofs=inputs2, outputs=outputs),
"UNIQUE constraint failed: promises.B_b",
)
# try to spend inputs2 again with new outputs
output_amounts = [64]
secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
len(output_amounts)
)
outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs)
await ledger.split(proofs=inputs2, outputs=outputs)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_proof_state(wallet1: Wallet, ledger: Ledger): async def test_check_proof_state(wallet1: Wallet, ledger: Ledger):
invoice = await wallet1.request_mint(64) invoice = await wallet1.request_mint(64)