refactor db transactions (#571)

This commit is contained in:
callebtc
2024-07-09 15:46:19 +02:00
committed by GitHub
parent 51ae82bee8
commit 539054a7c9
5 changed files with 49 additions and 54 deletions

View File

@@ -57,7 +57,7 @@ These steps help you install Python via pyenv and Poetry. If you already have Po
```bash ```bash
# on ubuntu: # on ubuntu:
sudo apt install -y build-essential pkg-config libffi-dev libpq-dev zlib1g-dev libssl-dev python3-dev libsqlite3-dev ncurses-dev libbz2-dev libreadline-dev lzma-dev sudo apt install -y build-essential pkg-config libffi-dev libpq-dev zlib1g-dev libssl-dev python3-dev libsqlite3-dev ncurses-dev libbz2-dev libreadline-dev lzma-dev liblzma-dev
# install python using pyenv # install python using pyenv
curl https://pyenv.run | bash curl https://pyenv.run | bash

View File

@@ -36,23 +36,15 @@ class LedgerCrud(ABC):
... ...
@abstractmethod @abstractmethod
async def get_spent_proofs( async def get_proofs_used(
self, self,
*, *,
Ys: List[str],
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> List[Proof]: ) -> List[Proof]:
... ...
async def get_proof_used(
self,
*,
Y: str,
db: Database,
conn: Optional[Connection] = None,
) -> Optional[Proof]:
...
@abstractmethod @abstractmethod
async def invalidate_proof( async def invalidate_proof(
self, self,
@@ -157,6 +149,16 @@ class LedgerCrud(ABC):
) -> Optional[BlindedSignature]: ) -> Optional[BlindedSignature]:
... ...
@abstractmethod
async def get_promises(
self,
*,
db: Database,
b_s: List[str],
conn: Optional[Connection] = None,
) -> List[BlindedSignature]:
...
@abstractmethod @abstractmethod
async def store_mint_quote( async def store_mint_quote(
self, self,
@@ -294,18 +296,21 @@ class LedgerCrudSqlite(LedgerCrud):
) )
return BlindedSignature.from_row(row) if row else None return BlindedSignature.from_row(row) if row else None
async def get_spent_proofs( async def get_promises(
self, self,
*, *,
db: Database, db: Database,
b_s: List[str],
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> List[Proof]: ) -> List[BlindedSignature]:
rows = await (conn or db).fetchall( rows = await (conn or db).fetchall(
f""" f"""
SELECT * from {db.table_with_schema('proofs_used')} SELECT * from {db.table_with_schema('promises')}
""" WHERE b_ IN ({','.join([':b_' + str(i) for i in range(len(b_s))])})
""",
{f"b_{i}": b_s[i] for i in range(len(b_s))},
) )
return [Proof(**r) for r in rows] if rows else [] return [BlindedSignature.from_row(r) for r in rows] if rows else []
async def invalidate_proof( async def invalidate_proof(
self, self,
@@ -722,18 +727,17 @@ class LedgerCrudSqlite(LedgerCrud):
) )
return [MintKeyset(**row) for row in rows] return [MintKeyset(**row) for row in rows]
async def get_proof_used( async def get_proofs_used(
self, self,
*, *,
Y: str, Ys: List[str],
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> Optional[Proof]: ) -> List[Proof]:
row = await (conn or db).fetchone( query = f"""
f""" SELECT * from {db.table_with_schema('proofs_used')}
SELECT * from {db.table_with_schema('proofs_used')} WHERE y IN ({','.join([':y_' + str(i) for i in range(len(Ys))])})
WHERE y = :y """
""", values = {f"y_{i}": Ys[i] for i in range(len(Ys))}
{"y": Y}, rows = await (conn or db).fetchall(query, values)
) return [Proof(**r) for r in rows] if rows else []
return Proof(**row) if row else None

View File

@@ -35,10 +35,8 @@ class DbReadHelper:
proofs_spent_dict: Dict[str, Proof] = {} proofs_spent_dict: Dict[str, Proof] = {}
# check used secrets in database # check used secrets in database
async with self.db.get_connection(conn) as conn: async with self.db.get_connection(conn) as conn:
for Y in Ys: spent_proofs = await self.crud.get_proofs_used(db=self.db, Ys=Ys, conn=conn)
spent_proof = await self.crud.get_proof_used(db=self.db, Y=Y, conn=conn) proofs_spent_dict = {p.Y: p for p in spent_proofs}
if spent_proof:
proofs_spent_dict[Y] = spent_proof
return proofs_spent_dict return proofs_spent_dict
async def get_proofs_states( async def get_proofs_states(

View File

@@ -4,6 +4,8 @@ from ..core.base import Method, MintKeyset, Unit
from ..core.db import Database from ..core.db import Database
from ..lightning.base import LightningBackend from ..lightning.base import LightningBackend
from ..mint.crud import LedgerCrud from ..mint.crud import LedgerCrud
from .db.read import DbReadHelper
from .db.write import DbWriteHelper
from .events.events import LedgerEventManager from .events.events import LedgerEventManager
@@ -18,6 +20,8 @@ class SupportsBackends(Protocol):
class SupportsDb(Protocol): class SupportsDb(Protocol):
db: Database db: Database
db_read: DbReadHelper
db_write: DbWriteHelper
crud: LedgerCrud crud: LedgerCrud

View File

@@ -25,6 +25,8 @@ from ..core.settings import settings
from ..lightning.base import LightningBackend from ..lightning.base import LightningBackend
from ..mint.crud import LedgerCrud from ..mint.crud import LedgerCrud
from .conditions import LedgerSpendingConditions from .conditions import LedgerSpendingConditions
from .db.read import DbReadHelper
from .db.write import DbWriteHelper
from .protocols import SupportsBackends, SupportsDb, SupportsKeysets from .protocols import SupportsBackends, SupportsDb, SupportsKeysets
@@ -37,6 +39,8 @@ class LedgerVerification(
keysets: Dict[str, MintKeyset] keysets: Dict[str, MintKeyset]
crud: LedgerCrud crud: LedgerCrud
db: Database db: Database
db_read: DbReadHelper
db_write: DbWriteHelper
lightning: Dict[Unit, LightningBackend] lightning: Dict[Unit, LightningBackend]
async def verify_inputs_and_outputs( async def verify_inputs_and_outputs(
@@ -64,7 +68,10 @@ class LedgerVerification(
if not proofs: if not proofs:
raise TransactionError("no proofs provided.") raise TransactionError("no proofs provided.")
# Verify proofs are spendable # Verify proofs are spendable
if not len(await self._get_proofs_spent([p.Y for p in proofs], conn)) == 0: if (
not len(await self.db_read._get_proofs_spent([p.Y for p in proofs], conn))
== 0
):
raise TokenAlreadySpentError() raise TokenAlreadySpentError()
# Verify amounts of inputs # Verify amounts of inputs
if not all([self._verify_amount(p.amount) for p in proofs]): if not all([self._verify_amount(p.amount) for p in proofs]):
@@ -156,29 +163,11 @@ class LedgerVerification(
Returns: Returns:
result (List[bool]): Whether outputs are already present in the database. result (List[bool]): Whether outputs are already present in the database.
""" """
result = []
async with self.db.get_connection(conn) as conn: async with self.db.get_connection(conn) as conn:
for output in outputs: promises = await self.crud.get_promises(
promise = await self.crud.get_promise( b_s=[output.B_ for output in outputs], db=self.db, conn=conn
b_=output.B_, db=self.db, conn=conn )
) return [True if promise else False for promise in promises]
result.append(False if promise is None else True)
return result
async def _get_proofs_spent(
self, Ys: List[str], conn: Optional[Connection] = None
) -> Dict[str, Proof]:
"""Returns a dictionary of all proofs that are spent.
The key is the Y=h2c(secret) and the value is the proof.
"""
proofs_spent_dict: Dict[str, Proof] = {}
# check used secrets in database
async with self.db.get_connection(conn=conn) as conn:
for Y in Ys:
spent_proof = await self.crud.get_proof_used(db=self.db, Y=Y, conn=conn)
if spent_proof:
proofs_spent_dict[Y] = spent_proof
return proofs_spent_dict
def _verify_secret_criteria(self, proof: Proof) -> Literal[True]: def _verify_secret_criteria(self, proof: Proof) -> Literal[True]:
"""Verifies that a secret is present and is not too long (DOS prevention).""" """Verifies that a secret is present and is not too long (DOS prevention)."""