mirror of
https://github.com/aljazceru/nutshell.git
synced 2025-12-21 19:14:19 +01:00
Mint: invalidate and generate promises in single db transaction for split (#374)
* test for spending output again * first gernerate (which can fail) then invalidate (db and memory) * use external get_db_connection function to be compatible with existing Database class in LNbits
This commit is contained in:
@@ -202,3 +202,26 @@ def lock_table(db: Database, table: str) -> str:
|
||||
elif db.type == SQLITE:
|
||||
return "BEGIN EXCLUSIVE TRANSACTION;"
|
||||
return "<nothing>"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_connection(db: Database, conn: Optional[Connection] = None):
|
||||
"""Either yield the existing database connection or create a new one.
|
||||
|
||||
Note: This should be implemented as Database.get_db_connection(self, conn) but
|
||||
since we want to use it in LNbits, we can't change the Database class their.
|
||||
|
||||
Args:
|
||||
db (Database): Database object.
|
||||
conn (Optional[Connection], optional): Connection object. Defaults to None.
|
||||
|
||||
Yields:
|
||||
Connection: Connection object.
|
||||
"""
|
||||
if conn is not None:
|
||||
# Yield the existing connection
|
||||
yield conn
|
||||
else:
|
||||
# Create and yield a new connection
|
||||
async with db.connect() as new_conn:
|
||||
yield new_conn
|
||||
|
||||
@@ -17,7 +17,7 @@ from ..core.base import (
|
||||
from ..core.crypto import b_dhke
|
||||
from ..core.crypto.keys import derive_pubkey, random_hash
|
||||
from ..core.crypto.secp import PublicKey
|
||||
from ..core.db import Connection, Database
|
||||
from ..core.db import Connection, Database, get_db_connection
|
||||
from ..core.errors import (
|
||||
KeysetError,
|
||||
KeysetNotFoundError,
|
||||
@@ -151,17 +151,17 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
|
||||
|
||||
# ------- ECASH -------
|
||||
|
||||
async def _invalidate_proofs(self, proofs: List[Proof]) -> None:
|
||||
async def _invalidate_proofs(
|
||||
self, proofs: List[Proof], conn: Optional[Connection] = None
|
||||
) -> None:
|
||||
"""Adds secrets of proofs to the list of known secrets and stores them in the db.
|
||||
Removes proofs from pending table. This is executed if the ecash has been redeemed.
|
||||
|
||||
Args:
|
||||
proofs (List[Proof]): Proofs to add to known secret table.
|
||||
"""
|
||||
# Mark proofs as used and prepare new promises
|
||||
secrets = set([p.secret for p in proofs])
|
||||
self.secrets_used |= secrets
|
||||
async with self.db.connect() as conn:
|
||||
async with get_db_connection(self.db, conn) as conn:
|
||||
# store in db
|
||||
for p in proofs:
|
||||
await self.crud.invalidate_proof(proof=p, db=self.db, conn=conn)
|
||||
@@ -450,14 +450,12 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
|
||||
proofs: List[Proof],
|
||||
outputs: List[BlindedMessage],
|
||||
keyset: Optional[MintKeyset] = None,
|
||||
amount: Optional[int] = None, # backwards compatibility < 0.13.0
|
||||
):
|
||||
"""Consumes proofs and prepares new promises based on the amount split. Used for splitting tokens
|
||||
Before sending or for redeeming tokens for new ones that have been received by another wallet.
|
||||
|
||||
Args:
|
||||
proofs (List[Proof]): Proofs to be invalidated for the split.
|
||||
amount (int): Amount at which the split should happen.
|
||||
outputs (List[BlindedMessage]): New outputs that should be signed in return.
|
||||
keyset (Optional[MintKeyset], optional): Keyset to use. Uses default keyset if not given. Defaults to None.
|
||||
|
||||
@@ -471,33 +469,18 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
|
||||
|
||||
await self._set_proofs_pending(proofs)
|
||||
try:
|
||||
# explicitly check that amount of inputs is equal to amount of outputs
|
||||
# note: we check this again in verify_inputs_and_outputs but only if any
|
||||
# outputs are provided at all. To make sure of that before calling
|
||||
# verify_inputs_and_outputs, we check it here.
|
||||
self._verify_equation_balanced(proofs, outputs)
|
||||
# verify spending inputs, outputs, and spending conditions
|
||||
await self.verify_inputs_and_outputs(proofs, outputs)
|
||||
|
||||
# BEGIN backwards compatibility < 0.13.0
|
||||
if amount is not None:
|
||||
logger.debug(
|
||||
"Split: Client provided `amount` - backwards compatibility response"
|
||||
" pre 0.13.0"
|
||||
)
|
||||
# split outputs according to amount
|
||||
total = sum_proofs(proofs)
|
||||
if amount > total:
|
||||
raise Exception("split amount is higher than the total sum.")
|
||||
outs_fst = amount_split(total - amount)
|
||||
B_fst = [od for od in outputs[: len(outs_fst)]]
|
||||
B_snd = [od for od in outputs[len(outs_fst) :]]
|
||||
|
||||
# Mark proofs as used and prepare new promises
|
||||
await self._invalidate_proofs(proofs)
|
||||
prom_fst = await self._generate_promises(B_fst, keyset)
|
||||
prom_snd = await self._generate_promises(B_snd, keyset)
|
||||
promises = prom_fst + prom_snd
|
||||
# END backwards compatibility < 0.13.0
|
||||
else:
|
||||
# Mark proofs as used and prepare new promises
|
||||
await self._invalidate_proofs(proofs)
|
||||
promises = await self._generate_promises(outputs, keyset)
|
||||
# Mark proofs as used and prepare new promises
|
||||
async with get_db_connection(self.db) as conn:
|
||||
promises = await self._generate_promises(outputs, keyset, conn)
|
||||
await self._invalidate_proofs(proofs, conn)
|
||||
|
||||
except Exception as e:
|
||||
logger.trace(f"split failed: {e}")
|
||||
@@ -535,7 +518,10 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
|
||||
# ------- BLIND SIGNATURES -------
|
||||
|
||||
async def _generate_promises(
|
||||
self, B_s: List[BlindedMessage], keyset: Optional[MintKeyset] = None
|
||||
self,
|
||||
B_s: List[BlindedMessage],
|
||||
keyset: Optional[MintKeyset] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> list[BlindedSignature]:
|
||||
"""Generates a promises (Blind signatures) for given amount and returns a pair (amount, C').
|
||||
|
||||
@@ -557,7 +543,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerLightning):
|
||||
promises.append((B_, amount, C_, e, s))
|
||||
|
||||
signatures = []
|
||||
async with self.db.connect() as conn:
|
||||
async with get_db_connection(self.db, conn) as conn:
|
||||
for promise in promises:
|
||||
B_, amount, C_, e, s = promise
|
||||
logger.trace(f"crud: _generate_promise storing promise for {amount}")
|
||||
|
||||
@@ -258,9 +258,7 @@ async def split(
|
||||
logger.trace(f"> POST /split: {payload}")
|
||||
assert payload.outputs, Exception("no outputs provided.")
|
||||
|
||||
promises = await ledger.split(
|
||||
proofs=payload.proofs, outputs=payload.outputs, amount=payload.amount
|
||||
)
|
||||
promises = await ledger.split(proofs=payload.proofs, outputs=payload.outputs)
|
||||
|
||||
if payload.amount:
|
||||
# BEGIN backwards compatibility < 0.13
|
||||
|
||||
@@ -123,6 +123,38 @@ async def test_split_with_input_more_than_outputs(wallet1: Wallet, ledger: Ledge
|
||||
print(keep_proofs, send_proofs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_split_twice_with_same_outputs(wallet1: Wallet, ledger: Ledger):
|
||||
invoice = await wallet1.request_mint(128)
|
||||
pay_if_regtest(invoice.bolt11)
|
||||
await wallet1.mint(128, [64, 64], id=invoice.id)
|
||||
inputs1 = wallet1.proofs[:1]
|
||||
inputs2 = wallet1.proofs[1:]
|
||||
|
||||
output_amounts = [64]
|
||||
secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
|
||||
len(output_amounts)
|
||||
)
|
||||
outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs)
|
||||
|
||||
await ledger.split(proofs=inputs1, outputs=outputs)
|
||||
|
||||
# try to spend other proofs with the same outputs again
|
||||
await assert_err(
|
||||
ledger.split(proofs=inputs2, outputs=outputs),
|
||||
"UNIQUE constraint failed: promises.B_b",
|
||||
)
|
||||
|
||||
# try to spend inputs2 again with new outputs
|
||||
output_amounts = [64]
|
||||
secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
|
||||
len(output_amounts)
|
||||
)
|
||||
outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs)
|
||||
|
||||
await ledger.split(proofs=inputs2, outputs=outputs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_proof_state(wallet1: Wallet, ledger: Ledger):
|
||||
invoice = await wallet1.request_mint(64)
|
||||
|
||||
Reference in New Issue
Block a user