diff --git a/README.md b/README.md index 4291515..c87d2ff 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,23 @@ MINT_REDIS_CACHE_URL=redis://localhost:6379 ### NUT-21 Authentication with Keycloak Cashu supports clear and blind authentication as defined in [NUT-21](https://github.com/cashubtc/nuts/blob/main/21.md) and [NUT-22](https://github.com/cashubtc/nuts/blob/main/22.md) to limit the use of a mint to a registered set of users. Clear authentication is supported via a OICD provider such as Keycloak. You can set up and run Keycloak instance using the docker compose file `docker/keycloak/docker-compose.yml` in this repository. +### Migrate SQLite mint DB to Postgres +Use the standalone tool at `cashu/mint/sqlite_to_postgres.py` to migrate a mint database from SQLite to Postgres. + +```bash +# 1) optionally reset the target Postgres database (DROPS ALL DATA) +psql -U -h -p -d -c "DROP SCHEMA public CASCADE; CREATE SCHEMA public; GRANT ALL PRIVILEGES ON SCHEMA public TO ;" + +# 2) run migration (inside poetry env) +poetry run python cashu/mint/sqlite_to_postgres.py \ + --sqlite data/mint/mint.sqlite3 \ + --postgres postgres://:@:/ \ + --batch-size 2000 +``` + +- The tool aborts if the Postgres DB appears populated and prints the exact reset command with your connection details. +- After copying, it verifies row counts and compares the `balance` view across both databases. + # Running tests To run the tests in this repository, first install the dev dependencies with ```bash diff --git a/cashu/core/base.py b/cashu/core/base.py index 870e589..f712359 100644 --- a/cashu/core/base.py +++ b/cashu/core/base.py @@ -231,6 +231,10 @@ class BlindedMessage(BaseModel): id: str # Keyset id B_: str # Hex-encoded blinded message + @classmethod + def from_row(cls, row: RowMapping): + return cls(amount=row["amount"], B_=row["b_"], id=row["id"]) + class BlindedMessage_Deprecated(BaseModel): """ @@ -300,7 +304,7 @@ class MeltQuote(LedgerEvent): mint: Optional[str] = None @classmethod - def from_row(cls, row: Row): + def from_row(cls, row: Row, change: Optional[List[BlindedSignature]] = None): try: created_time = int(row["created_time"]) if row["created_time"] else None paid_time = int(row["paid_time"]) if row["paid_time"] else None @@ -314,11 +318,6 @@ class MeltQuote(LedgerEvent): payment_preimage = row.get("payment_preimage") or row.get("proof") # type: ignore - # parse change from row as json - change = None - if "change" in row.keys() and row["change"]: - change = json.loads(row["change"]) - outputs = None if "outputs" in row.keys() and row["outputs"]: outputs = json.loads(row["outputs"]) diff --git a/cashu/core/db.py b/cashu/core/db.py index 47469a2..93ef0ff 100644 --- a/cashu/core/db.py +++ b/cashu/core/db.py @@ -7,7 +7,7 @@ from contextlib import asynccontextmanager from typing import Optional, Union from loguru import logger -from sqlalchemy import text +from sqlalchemy import event, text from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool @@ -142,6 +142,19 @@ class Database(Compat): kwargs["max_overflow"] = 100 # type: ignore[assignment] self.engine = create_async_engine(database_uri, **kwargs) + + # Ensure SQLite enforces foreign keys on every connection + if self.type == SQLITE: + + @event.listens_for(self.engine.sync_engine, "connect") + def _set_sqlite_pragma(dbapi_connection, connection_record): + try: + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON;") + cursor.close() + except Exception as e: + logger.warning(f"Could not enable SQLite PRAGMA foreign_keys: {e}") + self.async_session = sessionmaker( self.engine, # type: ignore expire_on_commit=False, diff --git a/cashu/mint/auth/server.py b/cashu/mint/auth/server.py index 0158c10..cdd964a 100644 --- a/cashu/mint/auth/server.py +++ b/cashu/mint/auth/server.py @@ -201,7 +201,8 @@ class AuthLedger(Ledger): ) await self._verify_outputs(outputs) - promises = await self._generate_promises(outputs) + await self._store_blinded_messages(outputs) + promises = await self._sign_blinded_messages(outputs) # update last_access timestamp of the user await self.auth_crud.update_user(user_id=user.id, db=self.db) diff --git a/cashu/mint/crud.py b/cashu/mint/crud.py index 28f2356..56e122b 100644 --- a/cashu/mint/crud.py +++ b/cashu/mint/crud.py @@ -6,6 +6,7 @@ from loguru import logger from ..core.base import ( Amount, + BlindedMessage, BlindedSignature, MeltQuote, MintBalanceLogEntry, @@ -151,19 +152,59 @@ class LedgerCrud(ABC): ) -> Tuple[Amount, Amount]: ... @abstractmethod - async def store_promise( + async def store_blinded_message( + self, + *, + db: Database, + amount: int, + b_: str, + id: str, + mint_id: Optional[str] = None, + melt_id: Optional[str] = None, + swap_id: Optional[str] = None, + conn: Optional[Connection] = None, + ) -> None: ... + + @abstractmethod + async def delete_blinded_messages_melt_id( + self, + *, + db: Database, + melt_id: str, + conn: Optional[Connection] = None, + ) -> None: ... + + @abstractmethod + async def update_blinded_message_signature( self, *, db: Database, amount: int, b_: str, c_: str, - id: str, e: str = "", s: str = "", conn: Optional[Connection] = None, ) -> None: ... + @abstractmethod + async def get_blinded_messages_melt_id( + self, + *, + db: Database, + melt_id: str, + conn: Optional[Connection] = None, + ) -> List[BlindedMessage]: ... + + @abstractmethod + async def get_blind_signatures_melt_id( + self, + *, + db: Database, + melt_id: str, + conn: Optional[Connection] = None, + ) -> List[BlindedSignature]: ... + @abstractmethod async def get_promise( self, @@ -285,32 +326,119 @@ class LedgerCrudSqlite(LedgerCrud): LedgerCrud (ABC): Abstract base class for LedgerCrud. """ - async def store_promise( + async def store_blinded_message( + self, + *, + db: Database, + amount: int, + b_: str, + id: str, + mint_id: Optional[str] = None, + melt_id: Optional[str] = None, + swap_id: Optional[str] = None, + conn: Optional[Connection] = None, + ) -> None: + await (conn or db).execute( + f""" + INSERT INTO {db.table_with_schema('promises')} + (amount, b_, id, created, mint_quote, melt_quote, swap_id) + VALUES (:amount, :b_, :id, :created, :mint_quote, :melt_quote, :swap_id) + """, + { + "amount": amount, + "b_": b_, + "id": id, + "created": db.to_timestamp(db.timestamp_now_str()), + "mint_quote": mint_id, + "melt_quote": melt_id, + "swap_id": swap_id, + }, + ) + + async def get_blinded_messages_melt_id( + self, + *, + db: Database, + melt_id: str, + conn: Optional[Connection] = None, + ) -> List[BlindedMessage]: + rows = await (conn or db).fetchall( + f""" + SELECT * from {db.table_with_schema('promises')} + WHERE melt_quote = :melt_id AND c_ IS NULL + """, + {"melt_id": melt_id}, + ) + return [BlindedMessage.from_row(r) for r in rows] if rows else [] + + async def get_blind_signatures_melt_id( + self, + *, + db: Database, + melt_id: str, + conn: Optional[Connection] = None, + ) -> List[BlindedSignature]: + rows = await (conn or db).fetchall( + f""" + SELECT * from {db.table_with_schema('promises')} + WHERE melt_quote = :melt_id AND c_ IS NOT NULL + """, + {"melt_id": melt_id}, + ) + return [BlindedSignature.from_row(r) for r in rows] if rows else [] # type: ignore + + async def delete_blinded_messages_melt_id( + self, + *, + db: Database, + melt_id: str, + conn: Optional[Connection] = None, + ) -> None: + """Deletes a blinded message (promise) that has not been signed yet (c_ is NULL) with the given quote ID.""" + await (conn or db).execute( + f""" + DELETE FROM {db.table_with_schema('promises')} + WHERE melt_quote = :melt_id AND c_ IS NULL + """, + { + "melt_id": melt_id, + }, + ) + + async def update_blinded_message_signature( self, *, db: Database, amount: int, b_: str, c_: str, - id: str, e: str = "", s: str = "", conn: Optional[Connection] = None, ) -> None: + existing = await (conn or db).fetchone( + f""" + SELECT * from {db.table_with_schema('promises')} + WHERE b_ = :b_ + """, + {"b_": str(b_)}, + ) + if existing is None: + raise ValueError("blinded message does not exist") + await (conn or db).execute( f""" - INSERT INTO {db.table_with_schema('promises')} - (amount, b_, c_, dleq_e, dleq_s, id, created) - VALUES (:amount, :b_, :c_, :dleq_e, :dleq_s, :id, :created) + UPDATE {db.table_with_schema('promises')} + SET amount = :amount, c_ = :c_, dleq_e = :dleq_e, dleq_s = :dleq_s, signed_at = :signed_at + WHERE b_ = :b_; """, { - "amount": amount, "b_": b_, + "amount": amount, "c_": c_, "dleq_e": e, "dleq_s": s, - "id": id, - "created": db.to_timestamp(db.timestamp_now_str()), + "signed_at": db.to_timestamp(db.timestamp_now_str()), }, ) @@ -324,7 +452,7 @@ class LedgerCrudSqlite(LedgerCrud): row = await (conn or db).fetchone( f""" SELECT * from {db.table_with_schema('promises')} - WHERE b_ = :b_ + WHERE b_ = :b_ AND c_ IS NOT NULL """, {"b_": str(b_)}, ) @@ -340,7 +468,7 @@ class LedgerCrudSqlite(LedgerCrud): rows = await (conn or db).fetchall( f""" SELECT * from {db.table_with_schema('promises')} - WHERE b_ IN ({','.join([f":b_{i}" for i in range(len(b_s))])}) + WHERE b_ IN ({','.join([f":b_{i}" for i in range(len(b_s))])}) AND c_ IS NOT NULL """, {f"b_{i}": b_s[i] for i in range(len(b_s))}, ) @@ -568,8 +696,8 @@ class LedgerCrudSqlite(LedgerCrud): await (conn or db).execute( f""" INSERT INTO {db.table_with_schema('melt_quotes')} - (quote, method, request, checking_id, unit, amount, fee_reserve, state, paid, created_time, paid_time, fee_paid, proof, outputs, change, expiry) - VALUES (:quote, :method, :request, :checking_id, :unit, :amount, :fee_reserve, :state, :paid, :created_time, :paid_time, :fee_paid, :proof, :outputs, :change, :expiry) + (quote, method, request, checking_id, unit, amount, fee_reserve, state, paid, created_time, paid_time, fee_paid, proof, expiry) + VALUES (:quote, :method, :request, :checking_id, :unit, :amount, :fee_reserve, :state, :paid, :created_time, :paid_time, :fee_paid, :proof, :expiry) """, { "quote": quote.quote, @@ -589,8 +717,6 @@ class LedgerCrudSqlite(LedgerCrud): ), "fee_paid": quote.fee_paid, "proof": quote.payment_preimage, - "outputs": json.dumps(quote.outputs) if quote.outputs else None, - "change": json.dumps(quote.change) if quote.change else None, "expiry": db.to_timestamp( db.timestamp_from_seconds(quote.expiry) or "" ), @@ -627,7 +753,14 @@ class LedgerCrudSqlite(LedgerCrud): """, values, ) - return MeltQuote.from_row(row) if row else None # type: ignore + + change = None + if row: + change = await self.get_blind_signatures_melt_id( + db=db, melt_id=row["quote"], conn=conn + ) + + return MeltQuote.from_row(row, change) if row else None # type: ignore async def get_melt_quote_by_request( self, @@ -654,7 +787,7 @@ class LedgerCrudSqlite(LedgerCrud): ) -> None: await (conn or db).execute( f""" - UPDATE {db.table_with_schema('melt_quotes')} SET state = :state, fee_paid = :fee_paid, paid_time = :paid_time, proof = :proof, outputs = :outputs, change = :change, checking_id = :checking_id WHERE quote = :quote + UPDATE {db.table_with_schema('melt_quotes')} SET state = :state, fee_paid = :fee_paid, paid_time = :paid_time, proof = :proof, checking_id = :checking_id WHERE quote = :quote """, { "state": quote.state.value, @@ -663,16 +796,6 @@ class LedgerCrudSqlite(LedgerCrud): db.timestamp_from_seconds(quote.paid_time) or "" ), "proof": quote.payment_preimage, - "outputs": ( - json.dumps([s.dict() for s in quote.outputs]) - if quote.outputs - else None - ), - "change": ( - json.dumps([s.dict() for s in quote.change]) - if quote.change - else None - ), "quote": quote.quote, "checking_id": quote.checking_id, }, diff --git a/cashu/mint/db/write.py b/cashu/mint/db/write.py index eeb38fe..66d6e30 100644 --- a/cashu/mint/db/write.py +++ b/cashu/mint/db/write.py @@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Union from loguru import logger from ...core.base import ( - BlindedMessage, MeltQuote, MeltQuoteState, MintKeyset, @@ -198,9 +197,7 @@ class DbWriteHelper: await self.events.submit(quote) return quote - async def _set_melt_quote_pending( - self, quote: MeltQuote, outputs: Optional[List[BlindedMessage]] = None - ) -> MeltQuote: + async def _set_melt_quote_pending(self, quote: MeltQuote) -> MeltQuote: """Sets the melt quote as pending. Args: @@ -221,8 +218,6 @@ class DbWriteHelper: raise TransactionError("Melt quote already pending.") # set the quote as pending quote_copy.state = MeltQuoteState.pending - if outputs: - quote_copy.outputs = outputs await self.crud.update_melt_quote(quote=quote_copy, db=self.db, conn=conn) await self.events.submit(quote_copy) @@ -257,21 +252,25 @@ class DbWriteHelper: await self.events.submit(quote_copy) return quote_copy - async def _update_mint_quote_state( - self, quote_id: str, state: MintQuoteState - ): + async def _update_mint_quote_state(self, quote_id: str, state: MintQuoteState): async with self.db.get_connection(lock_table="mint_quotes") as conn: - mint_quote = await self.crud.get_mint_quote(quote_id=quote_id, db=self.db, conn=conn) + mint_quote = await self.crud.get_mint_quote( + quote_id=quote_id, db=self.db, conn=conn + ) if not mint_quote: raise TransactionError("Mint quote not found.") mint_quote.state = state await self.crud.update_mint_quote(quote=mint_quote, db=self.db, conn=conn) - + async def _update_melt_quote_state( - self, quote_id: str, state: MeltQuoteState, + self, + quote_id: str, + state: MeltQuoteState, ): async with self.db.get_connection(lock_table="melt_quotes") as conn: - melt_quote = await self.crud.get_melt_quote(quote_id=quote_id, db=self.db, conn=conn) + melt_quote = await self.crud.get_melt_quote( + quote_id=quote_id, db=self.db, conn=conn + ) if not melt_quote: raise TransactionError("Melt quote not found.") melt_quote.state = state diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index b04e036..c13e529 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -250,6 +250,7 @@ class Ledger( fee_provided: int, fee_paid: int, outputs: Optional[List[BlindedMessage]], + melt_id: Optional[str] = None, keyset: Optional[MintKeyset] = None, ) -> List[BlindedSignature]: """Generates a set of new promises (blinded signatures) from a set of blank outputs @@ -305,7 +306,10 @@ class Ledger( outputs[i].amount = return_amounts_sorted[i] # type: ignore if not self._verify_no_duplicate_outputs(outputs): raise TransactionError("duplicate promises.") - return_promises = await self._generate_promises(outputs, keyset) + return_promises = await self._sign_blinded_messages(outputs) + # delete remaining unsigned blank outputs from db + if melt_id: + await self.crud.delete_blinded_messages_melt_id(melt_id=melt_id, db=self.db) return return_promises # ------- TRANSACTIONS ------- @@ -491,8 +495,8 @@ class Ledger( raise TransactionError("quote expired") if not self._verify_mint_quote_witness(quote, outputs, signature): raise QuoteSignatureInvalidError() - - promises = await self._generate_promises(outputs) + await self._store_blinded_messages(outputs, mint_id=quote_id) + promises = await self._sign_blinded_messages(outputs) except Exception as e: await self.db_write._unset_mint_quote_pending( quote_id=quote_id, state=previous_state @@ -726,7 +730,10 @@ class Ledger( pending_proofs, keysets=self.keysets, conn=conn ) # change to compensate wallet for overpaid fees - if melt_quote.outputs: + melt_outputs = await self.crud.get_blinded_messages_melt_id( + melt_id=quote_id, db=self.db + ) + if melt_outputs: total_provided = sum_proofs(pending_proofs) input_fees = self.get_fees_for_proofs(pending_proofs) fee_reserve_provided = ( @@ -735,8 +742,9 @@ class Ledger( return_promises = await self._generate_change_promises( fee_provided=fee_reserve_provided, fee_paid=melt_quote.fee_paid, - outputs=melt_quote.outputs, - keyset=self.keysets[melt_quote.outputs[0].id], + outputs=melt_outputs, + melt_id=quote_id, + keyset=self.keysets[melt_outputs[0].id], ) melt_quote.change = return_promises await self.crud.update_melt_quote(quote=melt_quote, db=self.db) @@ -752,6 +760,9 @@ class Ledger( await self.db_write._unset_proofs_pending( pending_proofs, keysets=self.keysets ) + await self.crud.delete_blinded_messages_melt_id( + melt_id=quote_id, db=self.db + ) return melt_quote @@ -873,8 +884,6 @@ class Ledger( raise TransactionError( f"output unit {outputs_unit.name} does not match quote unit {melt_quote.unit}" ) - # we don't need to set it here, _set_melt_quote_pending will set it in the db - melt_quote.outputs = outputs # verify SIG_ALL signatures message_to_sign = ( @@ -907,7 +916,9 @@ class Ledger( proofs, keysets=self.keysets, quote_id=melt_quote.quote ) previous_state = melt_quote.state - melt_quote = await self.db_write._set_melt_quote_pending(melt_quote, outputs) + melt_quote = await self.db_write._set_melt_quote_pending(melt_quote) + if outputs: + await self._store_blinded_messages(outputs, melt_id=melt_quote.quote) # if the melt corresponds to an internal mint, mark both as paid melt_quote = await self.melt_mint_settle_internally(melt_quote, proofs) @@ -966,6 +977,9 @@ class Ledger( await self.db_write._unset_melt_quote_pending( quote=melt_quote, state=previous_state ) + await self.crud.delete_blinded_messages_melt_id( + melt_id=melt_quote.quote, db=self.db + ) if status.error_message: logger.error( f"Status check error: {status.error_message}" @@ -1011,6 +1025,7 @@ class Ledger( fee_provided=fee_reserve_provided, fee_paid=melt_quote.fee_paid, outputs=outputs, + melt_id=melt_quote.quote, keyset=self.keysets[outputs[0].id], ) @@ -1050,8 +1065,9 @@ class Ledger( ) try: async with self.db.get_connection(lock_table="proofs_pending") as conn: + await self._store_blinded_messages(outputs, keyset=keyset, conn=conn) await self._invalidate_proofs(proofs=proofs, conn=conn) - promises = await self._generate_promises(outputs, keyset, conn) + promises = await self._sign_blinded_messages(outputs, conn) except Exception as e: logger.trace(f"swap failed: {e}") raise e @@ -1081,10 +1097,47 @@ class Ledger( # ------- BLIND SIGNATURES ------- - async def _generate_promises( + async def _store_blinded_messages( self, outputs: List[BlindedMessage], keyset: Optional[MintKeyset] = None, + mint_id: Optional[str] = None, + melt_id: Optional[str] = None, + swap_id: Optional[str] = None, + conn: Optional[Connection] = None, + ) -> None: + """Stores a blinded message in the database. + + Args: + outputs (List[BlindedMessage]): Blinded messages to store. + keyset (Optional[MintKeyset], optional): Keyset to use. Uses default keyset if not given. Defaults to None. + conn: (Optional[Connection], optional): Database connection to reuse. Will create a new one if not given. Defaults to None. + """ + async with self.db.get_connection(conn) as conn: + for output in outputs: + keyset = keyset or self.keysets[output.id] + if output.id not in self.keysets: + raise TransactionError(f"keyset {output.id} not found") + if output.id != keyset.id: + raise TransactionError("keyset id does not match output id") + if not keyset.active: + raise TransactionError("keyset is not active") + logger.trace(f"Storing blinded message with keyset {keyset.id}.") + await self.crud.store_blinded_message( + id=keyset.id, + amount=output.amount, + b_=output.B_, + mint_id=mint_id, + melt_id=melt_id, + swap_id=swap_id, + db=self.db, + conn=conn, + ) + logger.trace(f"Stored blinded message for {output.amount}") + + async def _sign_blinded_messages( + self, + outputs: List[BlindedMessage], conn: Optional[Connection] = None, ) -> list[BlindedSignature]: """Generates a promises (Blind signatures) for given amount and returns a pair (amount, C'). @@ -1107,9 +1160,9 @@ class Ledger( ] = [] for output in outputs: B_ = PublicKey(bytes.fromhex(output.B_), raw=True) - keyset = keyset or self.keysets[output.id] if output.id not in self.keysets: raise TransactionError(f"keyset {output.id} not found") + keyset = self.keysets[output.id] if output.id != keyset.id: raise TransactionError("keyset id does not match output id") if not keyset.active: @@ -1127,9 +1180,8 @@ class Ledger( for promise in promises: keyset_id, B_, amount, C_, e, s = promise logger.trace(f"crud: _generate_promise storing promise for {amount}") - await self.crud.store_promise( + await self.crud.update_blinded_message_signature( amount=amount, - id=keyset_id, b_=B_.serialize().hex(), c_=C_.serialize().hex(), e=e.serialize(), diff --git a/cashu/mint/migrations.py b/cashu/mint/migrations.py index 8824e0e..0dee5b4 100644 --- a/cashu/mint/migrations.py +++ b/cashu/mint/migrations.py @@ -1,4 +1,5 @@ import copy +import json from typing import List from sqlalchemy import RowMapping @@ -26,10 +27,10 @@ async def m001_initial(db: Database): f""" CREATE TABLE IF NOT EXISTS {db.table_with_schema('promises')} ( amount {db.big_int} NOT NULL, - b_b TEXT NOT NULL, - c_b TEXT NOT NULL, + b_ TEXT NOT NULL, + c_ TEXT NOT NULL, - UNIQUE (b_b) + UNIQUE (b_) ); """ @@ -52,11 +53,11 @@ async def m001_initial(db: Database): f""" CREATE TABLE IF NOT EXISTS {db.table_with_schema('invoices')} ( amount {db.big_int} NOT NULL, - pr TEXT NOT NULL, - hash TEXT NOT NULL, + bolt11 TEXT NOT NULL, + id TEXT NOT NULL, issued BOOL NOT NULL, - UNIQUE (hash) + UNIQUE (id) ); """ @@ -78,7 +79,7 @@ async def create_balance_views(db: Database, conn: Connection): SELECT id AS keyset, COALESCE(s, 0) AS balance FROM ( SELECT id, SUM(amount) AS s FROM {db.table_with_schema('promises')} - WHERE amount > 0 + WHERE amount > 0 AND c_ IS NOT NULL GROUP BY id ) AS balance_issued; """ @@ -191,7 +192,7 @@ async def m006_invoices_add_payment_hash(db: Database): " TEXT" ) await conn.execute( - f"UPDATE {db.table_with_schema('invoices')} SET payment_hash = hash" + f"UPDATE {db.table_with_schema('invoices')} SET payment_hash = id" ) @@ -230,16 +231,6 @@ async def m008_promises_dleq(db: Database): async def m009_add_out_to_invoices(db: Database): # column in invoices for marking whether the invoice is incoming (out=False) or outgoing (out=True) async with db.connect() as conn: - # rename column pr to bolt11 - await conn.execute( - f"ALTER TABLE {db.table_with_schema('invoices')} RENAME COLUMN pr TO" - " bolt11" - ) - # rename column hash to payment_hash - await conn.execute( - f"ALTER TABLE {db.table_with_schema('invoices')} RENAME COLUMN hash TO id" - ) - await conn.execute( f"ALTER TABLE {db.table_with_schema('invoices')} ADD COLUMN out BOOL" ) @@ -720,7 +711,7 @@ async def m017_foreign_keys_proof_tables(db: Database): ) await conn.execute( - f"INSERT INTO {db.table_with_schema('promises_new')} (amount, id, b_, c_, dleq_e, dleq_s, created) SELECT amount, id, b_b, c_b, e, s, created FROM {db.table_with_schema('promises')}" + f"INSERT INTO {db.table_with_schema('promises_new')} (amount, id, b_, c_, dleq_e, dleq_s, created) SELECT amount, id, b_, c_, e, s, created FROM {db.table_with_schema('promises')}" ) await conn.execute(f"DROP TABLE {db.table_with_schema('promises')}") await conn.execute( @@ -968,3 +959,170 @@ async def m027_add_balance_to_keysets_and_log_table(db: Database): ); """ ) + + +async def m028_promises_c_allow_null_add_melt_quote(db: Database): + """ + Allow column that stores the c_ to be NULL and add melt_quote to promises. + Insert all change promises from melt_quotes into the promises table. + Drop the change and the outputs columns from melt_quotes. + """ + + # migrate stored melt outputs for pending quotes into promises + async def migrate_stored_melt_outputs_for_pending_quotes( + db: Database, conn: Connection + ): + rows = await conn.fetchall( + f""" + SELECT quote, outputs FROM {db.table_with_schema('melt_quotes')} + WHERE state = :state AND outputs IS NOT NULL + """, + {"state": MeltQuoteState.pending.value}, + ) + for row in rows: + try: + outputs = json.loads(row["outputs"]) if row["outputs"] else [] + except Exception: + outputs = [] + + for o in outputs: + amount = o.get("amount") if isinstance(o, dict) else None + keyset_id = o.get("id") if isinstance(o, dict) else None + b_hex = o.get("B_") if isinstance(o, dict) else None + if amount is None or keyset_id is None or b_hex is None: + continue + await conn.execute( + f""" + INSERT INTO {db.table_with_schema('promises')} + (amount, id, b_, created, mint_quote, melt_quote, swap_id) + VALUES (:amount, :id, :b_, :created, :mint_quote, :melt_quote, :swap_id) + """, + { + "amount": int(amount), + "id": keyset_id, + "b_": b_hex, + "created": db.to_timestamp(db.timestamp_now_str()), + "mint_quote": None, + "melt_quote": row["quote"], + "swap_id": None, + }, + ) + + # remove obsolete columns outputs and change from melt_quotes + async def remove_obsolete_columns_from_melt_quotes(db: Database, conn: Connection): + if conn.type == "SQLITE": + # For SQLite, recreate table without the columns + await conn.execute("PRAGMA foreign_keys=OFF;") + await conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {db.table_with_schema('melt_quotes_new')} ( + quote TEXT NOT NULL, + method TEXT NOT NULL, + request TEXT NOT NULL, + checking_id TEXT NOT NULL, + unit TEXT NOT NULL, + amount {db.big_int} NOT NULL, + fee_reserve {db.big_int}, + paid BOOL NOT NULL, + created_time TIMESTAMP, + paid_time TIMESTAMP, + fee_paid {db.big_int}, + proof TEXT, + state TEXT, + expiry TIMESTAMP, + + UNIQUE (quote) + ); + """ + ) + await conn.execute( + f""" + INSERT INTO {db.table_with_schema('melt_quotes_new')} ( + quote, method, request, checking_id, unit, amount, fee_reserve, paid, created_time, paid_time, fee_paid, proof, state, expiry + ) + SELECT quote, method, request, checking_id, unit, amount, fee_reserve, paid, created_time, paid_time, fee_paid, proof, state, expiry + FROM {db.table_with_schema('melt_quotes')}; + """ + ) + await conn.execute(f"DROP TABLE {db.table_with_schema('melt_quotes')}") + await conn.execute( + f"ALTER TABLE {db.table_with_schema('melt_quotes_new')} RENAME TO {db.table_with_schema('melt_quotes')}" + ) + await conn.execute("PRAGMA foreign_keys=ON;") + else: + # For Postgres/Cockroach, drop the columns directly if they exist + await conn.execute( + f"ALTER TABLE {db.table_with_schema('melt_quotes')} DROP COLUMN IF EXISTS outputs" + ) + await conn.execute( + f"ALTER TABLE {db.table_with_schema('melt_quotes')} DROP COLUMN IF EXISTS change" + ) + + # recreate promises table with columns mint_quote, melt_quote, swap_id and with c_ nullable + async def recreate_promises_table(db: Database, conn: Connection): + if conn.type == "SQLITE": + await conn.execute("PRAGMA foreign_keys=OFF;") + await conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {db.table_with_schema('promises_new')} ( + amount {db.big_int} NOT NULL, + id TEXT, + b_ TEXT NOT NULL, + c_ TEXT, + dleq_e TEXT, + dleq_s TEXT, + created TIMESTAMP, + signed_at TIMESTAMP, + mint_quote TEXT, + melt_quote TEXT, + swap_id TEXT, + + FOREIGN KEY (mint_quote) REFERENCES {db.table_with_schema('mint_quotes')}(quote), + FOREIGN KEY (melt_quote) REFERENCES {db.table_with_schema('melt_quotes')}(quote), + + UNIQUE (b_) + ); + """ + ) + + await conn.execute( + f"INSERT INTO {db.table_with_schema('promises_new')} (amount, id, b_, c_, dleq_e, dleq_s, created, mint_quote, swap_id) " + f"SELECT amount, id, b_, c_, dleq_e, dleq_s, created, mint_quote, swap_id FROM {db.table_with_schema('promises')}" + ) + + await conn.execute(f"DROP TABLE {db.table_with_schema('promises')}") + await conn.execute( + f"ALTER TABLE {db.table_with_schema('promises_new')} RENAME TO {db.table_with_schema('promises')}" + ) + await conn.execute("PRAGMA foreign_keys=ON;") + else: + # add columns melt_quote, signed_at and make column c_ nullable + await conn.execute( + f"ALTER TABLE {db.table_with_schema('promises')} ADD COLUMN melt_quote TEXT" + ) + await conn.execute( + f"ALTER TABLE {db.table_with_schema('promises')} ADD COLUMN signed_at TIMESTAMP" + ) + await conn.execute( + f"ALTER TABLE {db.table_with_schema('promises')} ALTER COLUMN c_ DROP NOT NULL" + ) + # add foreign key constraint to melt_quote + await conn.execute( + f"ALTER TABLE {db.table_with_schema('promises')} ADD CONSTRAINT fk_promises_melt_quote FOREIGN KEY (melt_quote) REFERENCES {db.table_with_schema('melt_quotes')}(quote)" + ) + + async with db.connect() as conn: + # drop the balance views first + await drop_balance_views(db, conn) + + # recreate promises table + await recreate_promises_table(db, conn) + + # migrate stored melt outputs for pending quotes into promises + await migrate_stored_melt_outputs_for_pending_quotes(db, conn) + + # remove obsolete columns from melt_quotes table + await remove_obsolete_columns_from_melt_quotes(db, conn) + + # recreate the balance views + await create_balance_views(db, conn) diff --git a/cashu/mint/sqlite_to_postgres.py b/cashu/mint/sqlite_to_postgres.py new file mode 100644 index 0000000..a760e19 --- /dev/null +++ b/cashu/mint/sqlite_to_postgres.py @@ -0,0 +1,338 @@ +import argparse +import asyncio +import datetime +import os +import re +import sqlite3 +from typing import Any, Dict, Iterable, List, Optional, Tuple +from urllib.parse import urlparse + +# Reuse project DB and migrations to create target schema +from cashu.core.db import Database +from cashu.core.migrations import migrate_databases +from cashu.mint import migrations as mint_migrations + +DEFAULT_BATCH_SIZE = 1000 + + +def _is_int_string(value: str) -> bool: + return bool(re.fullmatch(r"\d+", value)) + + +def _convert_value(value: Any, decl_type: Optional[str]) -> Any: + if value is None: + return None + if not decl_type: + return value + dtype = decl_type.upper() + + if "TIMESTAMP" in dtype: + # SQLite stores timestamps as INT seconds or formatted strings + if isinstance(value, (int, float)): + return datetime.datetime.fromtimestamp(int(value)) + if isinstance(value, str): + if _is_int_string(value): + return datetime.datetime.fromtimestamp(int(value)) + # try parse common format; fallback to raw string + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M:%S.%f"): + try: + return datetime.datetime.strptime(value, fmt) + except Exception: + pass + return value + return value + + if dtype in {"BOOL", "BOOLEAN"}: + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str) and value.lower() in {"0", "1", "true", "false"}: + return value.lower() in {"1", "true"} + return bool(value) + + # BIGINT/INT: leave as-is; asyncpg will coerce ints + return value + + +def _get_sqlite_tables(conn: sqlite3.Connection) -> List[Tuple[str, str]]: + cur = conn.cursor() + # exclude sqlite internal tables + rows = cur.execute( + "SELECT name, type FROM sqlite_master WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' ORDER BY type, name" + ).fetchall() + return [(r[0], r[1]) for r in rows] + + +def _get_table_columns( + conn: sqlite3.Connection, table: str +) -> List[Tuple[str, Optional[str]]]: + cur = conn.cursor() + rows = cur.execute(f"PRAGMA table_info({table})").fetchall() + # rows: cid, name, type, notnull, dflt_value, pk + return [(r[1], r[2]) for r in rows] + + +def _iter_sqlite_rows( + conn: sqlite3.Connection, table: str, batch_size: int +) -> Iterable[List[sqlite3.Row]]: + cur = conn.cursor() + cur.execute(f"SELECT * FROM {table}") + while True: + rows = cur.fetchmany(batch_size) + if not rows: + break + yield rows + + +def _prepare_insert_sql(table: str, columns: List[str]) -> str: + cols = ", ".join(columns) + params = ", ".join(f":{c}" for c in columns) + # Use ON CONFLICT DO NOTHING to make script idempotent on empty DBs + return f"INSERT INTO {table} ({cols}) VALUES ({params}) ON CONFLICT DO NOTHING" + + +async def _ensure_target_schema(pg_url: str) -> Database: + db = Database("mint", pg_url) + await migrate_databases(db, mint_migrations) + return db + + +async def _pg_table_row_count(db: Database, table: str) -> int: + try: + async with db.connect() as conn: + r = await conn.fetchone(f"SELECT COUNT(*) AS c FROM {table}") + return int(r["c"]) if r else 0 + except Exception: + return 0 + + +def _sqlite_table_row_count(conn: sqlite3.Connection, table: str) -> int: + try: + cur = conn.cursor() + return int(cur.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]) + except Exception: + return 0 + + +async def _precheck_postgres_populated( + pg_url: str, candidate_tables: List[str] +) -> Optional[str]: + db = Database("mint", pg_url) + populated: List[Tuple[str, int]] = [] + for t in candidate_tables: + cnt = await _pg_table_row_count(db, t) + if cnt > 0: + populated.append((t, cnt)) + + if populated: + url = urlparse(pg_url.replace("postgresql+asyncpg://", "postgres://")) + user = url.username or "" + host = url.hostname or "localhost" + port = url.port or 5432 + dbname = (url.path or "/").lstrip("/") or "" + details = ", ".join(f"{t}={c}" for t, c in populated) + info = ( + "Target Postgres database appears to be populated; aborting migration to avoid corruption.\n" + f"Detected rows: {details}.\n" + "To reset the database, connect as the proper user and run:\n" + f'psql -U {user} -h {host} -p {port} -d {dbname} -c "DROP SCHEMA public CASCADE; CREATE SCHEMA public; GRANT ALL PRIVILEGES ON SCHEMA public TO {user};"' + ) + return info + return None + + +async def _compare_balance_views( + sqlite_conn: sqlite3.Connection, pg_db: Database +) -> Tuple[bool, str]: + # Read SQLite balance view + try: + s_rows = sqlite_conn.execute("SELECT keyset, balance FROM balance").fetchall() + sqlite_map = {str(r[0]): int(r[1]) for r in s_rows} + except Exception as e: + return False, f"Failed reading SQLite balance view: {e}" + + # Read Postgres balance view + try: + async with pg_db.connect() as conn: + p_rows = await conn.fetchall("SELECT keyset, balance FROM balance") + pg_map = {str(r["keyset"]): int(r["balance"]) for r in p_rows} + except Exception as e: + return False, f"Failed reading Postgres balance view: {e}" + + if sqlite_map == pg_map: + return True, "Balance views match" + + # Summarize differences + diffs = [] + keys = set(sqlite_map) | set(pg_map) + for k in sorted(keys): + sv = sqlite_map.get(k) + pv = pg_map.get(k) + if sv != pv: + diffs.append(f"{k}: sqlite={sv} postgres={pv}") + if len(diffs) >= 10: + diffs.append("…") + break + return False, "Balance view differs: " + "; ".join(diffs) + + +async def _copy_table( + sqlite_conn: sqlite3.Connection, + pg_db: Database, + table: str, + batch_size: int, +) -> int: + # views are skipped; ensure table exists on target + columns_with_types = _get_table_columns(sqlite_conn, table) + if not columns_with_types: + return 0 + columns = [name for name, _ in columns_with_types] + insert_sql = _prepare_insert_sql(table, columns) + + total = 0 + total_rows = _sqlite_table_row_count(sqlite_conn, table) + printed_done = False + # commit per batch to avoid gigantic transactions + for batch in _iter_sqlite_rows(sqlite_conn, table, batch_size): + payload: List[Dict[str, Any]] = [] + for row in batch: + row_dict = {columns[i]: row[i] for i in range(len(columns))} + normalized: Dict[str, Any] = {} + for col, decl_type in columns_with_types: + normalized[col] = _convert_value(row_dict.get(col), decl_type) + payload.append(normalized) + + if not payload: + continue + async with pg_db.connect() as conn: # new txn per batch + await conn.execute(insert_sql, payload) + total += len(payload) + if total_rows: + pct = int(total * 100 / total_rows) + print(f"[{table}] {total}/{total_rows} ({pct}%)", end="\r", flush=True) + printed_done = True + if printed_done: + print("") + return total + + +def _ordered_tables(existing: Dict[str, str]) -> List[str]: + desired_order = [ + "keysets", + "mint_pubkeys", + "mint_quotes", + "melt_quotes", + "promises", + "proofs_used", + "proofs_pending", + "balance_log", + ] + # Filter desired order by presence + present_ordered = [ + t for t in desired_order if t in existing and existing[t] == "table" + ] + # Append any other base tables not covered yet + rest = [ + t + for t, typ in existing.items() + if typ == "table" and t not in present_ordered and t not in {"dbversions"} + ] + return present_ordered + rest + + +async def migrate_sqlite_to_postgres( + sqlite_path: str, pg_url: str, batch_size: int +) -> None: + if not os.path.exists(sqlite_path): + raise FileNotFoundError(f"SQLite file not found: {sqlite_path}") + + # 1) open sqlite + sqlite_conn = sqlite3.connect(sqlite_path) + sqlite_conn.row_factory = sqlite3.Row + + # decide which tables to check/copy + all_tables = _get_sqlite_tables(sqlite_conn) + table_map = {name: typ for name, typ in all_tables} + skip = {"dbversions", "balance", "balance_issued", "balance_redeemed"} + candidate_tables = [ + t for t, typ in table_map.items() if typ == "table" and t not in skip + ] + + # 2) precheck Postgres not populated + info = await _precheck_postgres_populated(pg_url, candidate_tables) + if info: + print(info) + sqlite_conn.close() + return + + # 3) ensure target schema on postgres + pg_db = await _ensure_target_schema(pg_url) + + # 4) inspect sqlite schema + ordered = _ordered_tables(table_map) + ordered = [t for t in ordered if t not in skip] + + # 5) copy data + for tbl in ordered: + print(f"Copying table: {tbl}") + count = await _copy_table(sqlite_conn, pg_db, tbl, batch_size) + print(f"Copied {count} rows from {tbl}") + + # 6) verification: compare table row counts and balance view + print("Verifying data integrity …") + mismatches: List[str] = [] + for tbl in ordered: + s_cnt = _sqlite_table_row_count(sqlite_conn, tbl) + p_cnt = await _pg_table_row_count(pg_db, tbl) + if s_cnt != p_cnt: + mismatches.append(f"{tbl}: sqlite={s_cnt} postgres={p_cnt}") + ok_balance, balance_msg = await _compare_balance_views(sqlite_conn, pg_db) + + # 7) finalize + await pg_db.engine.dispose() # close connections cleanly + sqlite_conn.close() + + if mismatches: + print("WARNING: Row count mismatches detected:") + for m in mismatches: + print(f" - {m}") + if not ok_balance: + print(f"WARNING: {balance_msg}") + + if not mismatches and ok_balance: + total_rows_copied = sum( + _sqlite_table_row_count(sqlite3.connect(sqlite_path), t) for t in ordered + ) + print( + "Migration successful: all row counts match and balance view is identical.\n" + f"Tables migrated: {len(ordered)}, total rows: {total_rows_copied}." + ) + else: + print("Migration completed with warnings. Review the messages above.") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Migrate Cashu mint SQLite DB to Postgres" + ) + parser.add_argument("--sqlite", required=True, help="Path to mint.sqlite3 file") + parser.add_argument( + "--postgres", + required=True, + help="Postgres connection string, e.g. postgres://user:pass@host:5432/dbname", + ) + parser.add_argument( + "--batch-size", + type=int, + default=DEFAULT_BATCH_SIZE, + help=f"Batch size for inserts (default {DEFAULT_BATCH_SIZE})", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + asyncio.run(migrate_sqlite_to_postgres(args.sqlite, args.postgres, args.batch_size)) + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py index ed2ecb5..bd29228 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,6 +37,8 @@ settings.fakewallet_brr = True settings.fakewallet_delay_outgoing_payment = 0 settings.fakewallet_delay_incoming_payment = 1 settings.fakewallet_stochastic_invoice = False +settings.lightning_fee_percent = 2.0 +settings.lightning_reserve_fee_min = 2000 # msat assert ( settings.mint_test_database != settings.mint_database ), "Test database is the same as the main database" diff --git a/tests/mint/test_mint.py b/tests/mint/test_mint.py index 80d923b..bbfd2c6 100644 --- a/tests/mint/test_mint.py +++ b/tests/mint/test_mint.py @@ -1,235 +1,365 @@ -from typing import List - -import pytest - -from cashu.core.base import BlindedMessage, Proof, Unit -from cashu.core.crypto.b_dhke import step1_alice -from cashu.core.helpers import calculate_number_of_blank_outputs -from cashu.core.models import PostMintQuoteRequest -from cashu.core.settings import settings -from cashu.mint.ledger import Ledger -from tests.helpers import pay_if_regtest - - -async def assert_err(f, msg): - """Compute f() and expect an error message 'msg'.""" - try: - await f - except Exception as exc: - assert exc.args[0] == msg, Exception( - f"Expected error: {msg}, got: {exc.args[0]}" - ) - - -def assert_amt(proofs: List[Proof], expected: int): - """Assert amounts the proofs contain.""" - assert [p.amount for p in proofs] == expected - - -@pytest.mark.asyncio -async def test_pubkeys(ledger: Ledger): - assert ledger.keyset.public_keys - assert ( - ledger.keyset.public_keys[1].serialize().hex() - == "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" - ) - assert ( - ledger.keyset.public_keys[2 ** (settings.max_order - 1)].serialize().hex() - == "023c84c0895cc0e827b348ea0a62951ca489a5e436f3ea7545f3c1d5f1bea1c866" - ) - - -@pytest.mark.asyncio -async def test_privatekeys(ledger: Ledger): - assert ledger.keyset.private_keys - assert ( - ledger.keyset.private_keys[1].serialize() - == "8300050453f08e6ead1296bb864e905bd46761beed22b81110fae0751d84604d" - ) - assert ( - ledger.keyset.private_keys[2 ** (settings.max_order - 1)].serialize() - == "b0477644cb3d82ffcc170bc0a76e0409727232e87c5ae51d64a259936228c7be" - ) - - -@pytest.mark.asyncio -async def test_keysets(ledger: Ledger): - assert len(ledger.keysets) - assert len(list(ledger.keysets.keys())) - assert ledger.keyset.id == "009a1f293253e41e" - - -@pytest.mark.asyncio -async def test_get_keyset(ledger: Ledger): - keyset = ledger.get_keyset() - assert isinstance(keyset, dict) - assert len(keyset) == settings.max_order - - -@pytest.mark.asyncio -async def test_mint(ledger: Ledger): - quote = await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat")) - await pay_if_regtest(quote.request) - blinded_messages_mock = [ - BlindedMessage( - amount=8, - B_="02634a2c2b34bec9e8a4aba4361f6bf202d7fa2365379b0840afe249a7a9d71239", - id="009a1f293253e41e", - ) - ] - promises = await ledger.mint(outputs=blinded_messages_mock, quote_id=quote.quote) - assert len(promises) - assert promises[0].amount == 8 - assert ( - promises[0].C_ - == "031422eeffb25319e519c68de000effb294cb362ef713a7cf4832cea7b0452ba6e" - ) - - -@pytest.mark.asyncio -async def test_mint_invalid_quote(ledger: Ledger): - await assert_err( - ledger.get_mint_quote(quote_id="invalid_quote_id"), - "quote not found", - ) - - -@pytest.mark.asyncio -async def test_melt_invalid_quote(ledger: Ledger): - await assert_err( - ledger.get_melt_quote(quote_id="invalid_quote_id"), - "quote not found", - ) - - -@pytest.mark.asyncio -async def test_mint_invalid_blinded_message(ledger: Ledger): - quote = await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat")) - await pay_if_regtest(quote.request) - blinded_messages_mock_invalid_key = [ - BlindedMessage( - amount=8, - B_="02634a2c2b34bec9e8a4aba4361f6bff02d7fa2365379b0840afe249a7a9d71237", - id="009a1f293253e41e", - ) - ] - await assert_err( - ledger.mint(outputs=blinded_messages_mock_invalid_key, quote_id=quote.quote), - "invalid public key", - ) - - -@pytest.mark.asyncio -async def test_generate_promises(ledger: Ledger): - blinded_messages_mock = [ - BlindedMessage( - amount=8, - B_="02634a2c2b34bec9e8a4aba4361f6bf202d7fa2365379b0840afe249a7a9d71239", - id="009a1f293253e41e", - ) - ] - promises = await ledger._generate_promises(blinded_messages_mock) - assert ( - promises[0].C_ - == "031422eeffb25319e519c68de000effb294cb362ef713a7cf4832cea7b0452ba6e" - ) - assert promises[0].amount == 8 - assert promises[0].id == "009a1f293253e41e" - - # DLEQ proof present - assert promises[0].dleq - assert promises[0].dleq.s - assert promises[0].dleq.e - - -@pytest.mark.asyncio -async def test_generate_change_promises(ledger: Ledger): - # Example slightly adapted from NUT-08 because we want to ensure the dynamic change - # token amount works: `n_blank_outputs != n_returned_promises != 4`. - # invoice_amount = 100_000 - fee_reserve = 2_000 - # total_provided = invoice_amount + fee_reserve - actual_fee = 100 - - expected_returned_promises = 7 # Amounts = [4, 8, 32, 64, 256, 512, 1024] - expected_returned_fees = 1900 - - n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve) - blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)] - outputs = [ - BlindedMessage( - amount=1, - B_=b.serialize().hex(), - id="009a1f293253e41e", - ) - for b, _ in blinded_msgs - ] - - promises = await ledger._generate_change_promises( - fee_provided=fee_reserve, fee_paid=actual_fee, outputs=outputs - ) - - assert len(promises) == expected_returned_promises - assert sum([promise.amount for promise in promises]) == expected_returned_fees - - -@pytest.mark.asyncio -async def test_generate_change_promises_legacy_wallet(ledger: Ledger): - # Check if mint handles a legacy wallet implementation (always sends 4 blank - # outputs) as well. - # invoice_amount = 100_000 - fee_reserve = 2_000 - # total_provided = invoice_amount + fee_reserve - actual_fee = 100 - - expected_returned_promises = 4 # Amounts = [64, 256, 512, 1024] - expected_returned_fees = 1856 - - n_blank_outputs = 4 - blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)] - outputs = [ - BlindedMessage( - amount=1, - B_=b.serialize().hex(), - id="009a1f293253e41e", - ) - for b, _ in blinded_msgs - ] - - promises = await ledger._generate_change_promises(fee_reserve, actual_fee, outputs) - - assert len(promises) == expected_returned_promises - assert sum([promise.amount for promise in promises]) == expected_returned_fees - - -@pytest.mark.asyncio -async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledger): - # invoice_amount = 100_000 - fee_reserve = 1_000 - # total_provided = invoice_amount + fee_reserve - actual_fee_msat = 100_000 - outputs = None - - promises = await ledger._generate_change_promises( - fee_reserve, actual_fee_msat, outputs - ) - assert len(promises) == 0 - - -@pytest.mark.asyncio -async def test_get_balance(ledger: Ledger): - unit = Unit["sat"] - balance, fees_paid = await ledger.get_balance(unit) - assert balance == 0 - assert fees_paid == 0 - - -@pytest.mark.asyncio -async def test_maximum_balance(ledger: Ledger): - settings.mint_max_balance = 1000 - await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat")) - await assert_err( - ledger.mint_quote(PostMintQuoteRequest(amount=8000, unit="sat")), - "Mint has reached maximum balance.", - ) - settings.mint_max_balance = 0 +from typing import List + +import pytest + +from cashu.core.base import BlindedMessage, Proof, Unit +from cashu.core.crypto.b_dhke import step1_alice +from cashu.core.helpers import calculate_number_of_blank_outputs +from cashu.core.models import PostMeltQuoteRequest, PostMintQuoteRequest +from cashu.core.settings import settings +from cashu.mint.ledger import Ledger +from tests.helpers import pay_if_regtest + + +async def assert_err(f, msg): + """Compute f() and expect an error message 'msg'.""" + try: + await f + except Exception as exc: + assert exc.args[0] == msg, Exception( + f"Expected error: {msg}, got: {exc.args[0]}" + ) + + +def assert_amt(proofs: List[Proof], expected: int): + """Assert amounts the proofs contain.""" + assert [p.amount for p in proofs] == expected + + +@pytest.mark.asyncio +async def test_pubkeys(ledger: Ledger): + assert ledger.keyset.public_keys + assert ( + ledger.keyset.public_keys[1].serialize().hex() + == "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" + ) + assert ( + ledger.keyset.public_keys[2 ** (settings.max_order - 1)].serialize().hex() + == "023c84c0895cc0e827b348ea0a62951ca489a5e436f3ea7545f3c1d5f1bea1c866" + ) + + +@pytest.mark.asyncio +async def test_privatekeys(ledger: Ledger): + assert ledger.keyset.private_keys + assert ( + ledger.keyset.private_keys[1].serialize() + == "8300050453f08e6ead1296bb864e905bd46761beed22b81110fae0751d84604d" + ) + assert ( + ledger.keyset.private_keys[2 ** (settings.max_order - 1)].serialize() + == "b0477644cb3d82ffcc170bc0a76e0409727232e87c5ae51d64a259936228c7be" + ) + + +@pytest.mark.asyncio +async def test_keysets(ledger: Ledger): + assert len(ledger.keysets) + assert len(list(ledger.keysets.keys())) + assert ledger.keyset.id == "009a1f293253e41e" + + +@pytest.mark.asyncio +async def test_get_keyset(ledger: Ledger): + keyset = ledger.get_keyset() + assert isinstance(keyset, dict) + assert len(keyset) == settings.max_order + + +@pytest.mark.asyncio +async def test_mint(ledger: Ledger): + quote = await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat")) + await pay_if_regtest(quote.request) + blinded_messages_mock = [ + BlindedMessage( + amount=8, + B_="02634a2c2b34bec9e8a4aba4361f6bf202d7fa2365379b0840afe249a7a9d71239", + id="009a1f293253e41e", + ) + ] + promises = await ledger.mint(outputs=blinded_messages_mock, quote_id=quote.quote) + assert len(promises) + assert promises[0].amount == 8 + assert ( + promises[0].C_ + == "031422eeffb25319e519c68de000effb294cb362ef713a7cf4832cea7b0452ba6e" + ) + + +@pytest.mark.asyncio +async def test_mint_invalid_quote(ledger: Ledger): + await assert_err( + ledger.get_mint_quote(quote_id="invalid_quote_id"), + "quote not found", + ) + + +@pytest.mark.asyncio +async def test_melt_invalid_quote(ledger: Ledger): + await assert_err( + ledger.get_melt_quote(quote_id="invalid_quote_id"), + "quote not found", + ) + + +@pytest.mark.asyncio +async def test_mint_invalid_blinded_message(ledger: Ledger): + quote = await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat")) + await pay_if_regtest(quote.request) + blinded_messages_mock_invalid_key = [ + BlindedMessage( + amount=8, + B_="02634a2c2b34bec9e8a4aba4361f6bff02d7fa2365379b0840afe249a7a9d71237", + id="009a1f293253e41e", + ) + ] + await assert_err( + ledger.mint(outputs=blinded_messages_mock_invalid_key, quote_id=quote.quote), + "invalid public key", + ) + + +@pytest.mark.asyncio +async def test_generate_promises(ledger: Ledger): + blinded_messages_mock = [ + BlindedMessage( + amount=8, + B_="02634a2c2b34bec9e8a4aba4361f6bf202d7fa2365379b0840afe249a7a9d71239", + id="009a1f293253e41e", + ) + ] + await ledger._store_blinded_messages(blinded_messages_mock) + promises = await ledger._sign_blinded_messages(blinded_messages_mock) + assert ( + promises[0].C_ + == "031422eeffb25319e519c68de000effb294cb362ef713a7cf4832cea7b0452ba6e" + ) + assert promises[0].amount == 8 + assert promises[0].id == "009a1f293253e41e" + + # DLEQ proof present + assert promises[0].dleq + assert promises[0].dleq.s + assert promises[0].dleq.e + + +@pytest.mark.asyncio +async def test_generate_change_promises(ledger: Ledger): + # Example slightly adapted from NUT-08 because we want to ensure the dynamic change + # token amount works: `n_blank_outputs != n_returned_promises != 4`. + # invoice_amount = 100_000 + fee_reserve = 2_000 + # total_provided = invoice_amount + fee_reserve + actual_fee = 100 + + expected_returned_promises = 7 # Amounts = [4, 8, 32, 64, 256, 512, 1024] + expected_returned_fees = 1900 + + n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve) + blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)] + outputs = [ + BlindedMessage( + amount=1, + B_=b.serialize().hex(), + id="009a1f293253e41e", + ) + for b, _ in blinded_msgs + ] + await ledger._store_blinded_messages(outputs) + promises = await ledger._generate_change_promises( + fee_provided=fee_reserve, fee_paid=actual_fee, outputs=outputs + ) + + assert len(promises) == expected_returned_promises + assert sum([promise.amount for promise in promises]) == expected_returned_fees + + +@pytest.mark.asyncio +async def test_generate_change_promises_legacy_wallet(ledger: Ledger): + # Check if mint handles a legacy wallet implementation (always sends 4 blank + # outputs) as well. + # invoice_amount = 100_000 + fee_reserve = 2_000 + # total_provided = invoice_amount + fee_reserve + actual_fee = 100 + + expected_returned_promises = 4 # Amounts = [64, 256, 512, 1024] + expected_returned_fees = 1856 + + n_blank_outputs = 4 + blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)] + outputs = [ + BlindedMessage( + amount=1, + B_=b.serialize().hex(), + id="009a1f293253e41e", + ) + for b, _ in blinded_msgs + ] + + await ledger._store_blinded_messages(outputs) + promises = await ledger._generate_change_promises(fee_reserve, actual_fee, outputs) + + assert len(promises) == expected_returned_promises + assert sum([promise.amount for promise in promises]) == expected_returned_fees + + +@pytest.mark.asyncio +async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledger): + # invoice_amount = 100_000 + fee_reserve = 1_000 + # total_provided = invoice_amount + fee_reserve + actual_fee_msat = 100_000 + outputs = None + + promises = await ledger._generate_change_promises( + fee_reserve, actual_fee_msat, outputs + ) + assert len(promises) == 0 + + +@pytest.mark.asyncio +async def test_get_balance(ledger: Ledger): + unit = Unit["sat"] + balance, fees_paid = await ledger.get_balance(unit) + assert balance == 0 + assert fees_paid == 0 + + +@pytest.mark.asyncio +async def test_maximum_balance(ledger: Ledger): + settings.mint_max_balance = 1000 + await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat")) + await assert_err( + ledger.mint_quote(PostMintQuoteRequest(amount=8000, unit="sat")), + "Mint has reached maximum balance.", + ) + settings.mint_max_balance = 0 + + +@pytest.mark.asyncio +async def test_generate_change_promises_signs_subset_and_deletes_rest(ledger: Ledger): + from cashu.core.base import BlindedMessage + from cashu.core.crypto.b_dhke import step1_alice + from cashu.core.split import amount_split + + # Create a real melt quote to satisfy FK on promises.melt_quote + mint_quote_resp = await ledger.mint_quote( + PostMintQuoteRequest(amount=64, unit="sat") + ) + melt_quote_resp = await ledger.melt_quote( + PostMeltQuoteRequest(request=mint_quote_resp.request, unit="sat") + ) + melt_id = melt_quote_resp.quote + fee_provided = 2_000 + fee_paid = 100 + overpaid_fee = fee_provided - fee_paid + return_amounts = amount_split(overpaid_fee) + + # Store more blank outputs than needed for the change. + extra_blanks = 3 + n_blank = len(return_amounts) + extra_blanks + blank_outputs = [ + BlindedMessage( + amount=1, + B_=step1_alice(f"change_blank_{i}")[0].serialize().hex(), + id=ledger.keyset.id, + ) + for i in range(n_blank) + ] + await ledger._store_blinded_messages(blank_outputs, melt_id=melt_id) + + # Fetch the stored unsigned blanks (same as melt flow) and run change generation. + stored_outputs = await ledger.crud.get_blinded_messages_melt_id( + db=ledger.db, melt_id=melt_id + ) + assert len(stored_outputs) == n_blank + + promises = await ledger._generate_change_promises( + fee_provided=fee_provided, + fee_paid=fee_paid, + outputs=stored_outputs, + melt_id=melt_id, + keyset=ledger.keyset, + ) + + assert len(promises) == len(return_amounts) + assert sorted(p.amount for p in promises) == sorted(return_amounts) + + # All unsigned blanks should be deleted after signing the subset. + remaining_unsigned = await ledger.crud.get_blinded_messages_melt_id( + db=ledger.db, melt_id=melt_id + ) + assert remaining_unsigned == [] + + # The signed promises should remain in the DB with c_ set. + async with ledger.db.connect() as conn: + rows = await conn.fetchall( + f""" + SELECT amount, c_ FROM {ledger.db.table_with_schema('promises')} + WHERE melt_quote = :melt_id + """, + {"melt_id": melt_id}, + ) + assert len(rows) == len(return_amounts) + assert all(row["c_"] for row in rows) + assert sorted(int(row["amount"]) for row in rows) == sorted(return_amounts) + + +@pytest.mark.asyncio +async def test_generate_change_promises_zero_fee_deletes_all_blanks(ledger: Ledger): + from cashu.core.base import BlindedMessage + from cashu.core.crypto.b_dhke import step1_alice + + # Create a real melt quote to satisfy FK on promises.melt_quote + mint_quote_resp = await ledger.mint_quote( + PostMintQuoteRequest(amount=64, unit="sat") + ) + melt_quote_resp = await ledger.melt_quote( + PostMeltQuoteRequest(request=mint_quote_resp.request, unit="sat") + ) + melt_id = melt_quote_resp.quote + fee_provided = 1_000 + fee_paid = 1_000 # no overpaid fee + n_blank = 4 + blank_outputs = [ + BlindedMessage( + amount=1, + B_=step1_alice(f"no_fee_blank_{i}")[0].serialize().hex(), + id=ledger.keyset.id, + ) + for i in range(n_blank) + ] + await ledger._store_blinded_messages(blank_outputs, melt_id=melt_id) + + stored_outputs = await ledger.crud.get_blinded_messages_melt_id( + db=ledger.db, melt_id=melt_id + ) + assert len(stored_outputs) == n_blank + + promises = await ledger._generate_change_promises( + fee_provided=fee_provided, + fee_paid=fee_paid, + outputs=stored_outputs, + melt_id=melt_id, + keyset=ledger.keyset, + ) + + assert promises == [] + + remaining_unsigned = await ledger.crud.get_blinded_messages_melt_id( + db=ledger.db, melt_id=melt_id + ) + # With zero fee nothing is signed or deleted; blanks stay pending. + assert len(remaining_unsigned) == n_blank + + async with ledger.db.connect() as conn: + rows = await conn.fetchall( + f""" + SELECT amount, c_ FROM {ledger.db.table_with_schema('promises')} + WHERE melt_quote = :melt_id + """, + {"melt_id": melt_id}, + ) + assert len(rows) == n_blank + assert all(row["c_"] is None for row in rows) diff --git a/tests/mint/test_mint_db.py b/tests/mint/test_mint_db.py index 5db7a6a..9004ac5 100644 --- a/tests/mint/test_mint_db.py +++ b/tests/mint/test_mint_db.py @@ -23,9 +23,7 @@ from tests.helpers import ( pay_if_regtest, ) -payment_request = ( - "lnbc1u1p5qeft3sp5jn5cqclnxvucfqtjm8qnlar2vhevcuudpccv7tsuglruj3qm579spp5ygdhy0t7xu53myke8z3z024xhz4kzgk9fcqk64sp0fyeqzhmaswqdqqcqpjrzjq0euzzxv65mts5ngg8c2t3vzz2aeuevy5845jvyqulqucd8c9kkhzrtp55qq63qqqqqqqqqqqqqzwyqqyg9qxpqysgqscprcpnk8whs3askqhgu6z5a4hupyn8du2aahdcf00s5pxrs4g94sv9f95xdn4tu0wec7kfyzj439wu9z27k6m6e3q4ysjquf5agx7gp0eeye4" -) +payment_request = "lnbc1u1p5qeft3sp5jn5cqclnxvucfqtjm8qnlar2vhevcuudpccv7tsuglruj3qm579spp5ygdhy0t7xu53myke8z3z024xhz4kzgk9fcqk64sp0fyeqzhmaswqdqqcqpjrzjq0euzzxv65mts5ngg8c2t3vzz2aeuevy5845jvyqulqucd8c9kkhzrtp55qq63qqqqqqqqqqqqqzwyqqyg9qxpqysgqscprcpnk8whs3askqhgu6z5a4hupyn8du2aahdcf00s5pxrs4g94sv9f95xdn4tu0wec7kfyzj439wu9z27k6m6e3q4ysjquf5agx7gp0eeye4" @pytest_asyncio.fixture(scope="function") @@ -295,30 +293,51 @@ async def test_db_events_add_client(wallet: Wallet, ledger: Ledger): # remove subscription client.remove_subscription("subId") + @pytest.mark.asyncio async def test_db_update_mint_quote_state(wallet: Wallet, ledger: Ledger): mint_quote = await wallet.request_mint(128) - await ledger.db_write._update_mint_quote_state(mint_quote.quote, MintQuoteState.paid) - - mint_quote_db = await ledger.crud.get_mint_quote(quote_id=mint_quote.quote, db=ledger.db) + await ledger.db_write._update_mint_quote_state( + mint_quote.quote, MintQuoteState.paid + ) + + mint_quote_db = await ledger.crud.get_mint_quote( + quote_id=mint_quote.quote, db=ledger.db + ) + assert mint_quote_db assert mint_quote_db.state == MintQuoteState.paid # Update it to issued - await ledger.db_write._update_mint_quote_state(mint_quote_db.quote, MintQuoteState.issued) + await ledger.db_write._update_mint_quote_state( + mint_quote_db.quote, MintQuoteState.issued + ) # Try and revert it back to unpaid - await assert_err(ledger.db_write._update_mint_quote_state(mint_quote_db.quote, MintQuoteState.unpaid), "Cannot change state of an issued mint quote.") + await assert_err( + ledger.db_write._update_mint_quote_state( + mint_quote_db.quote, MintQuoteState.unpaid + ), + "Cannot change state of an issued mint quote.", + ) + @pytest.mark.asyncio -@pytest.mark.skipif( - is_deprecated_api_only, - reason=("Deprecated API") -) +@pytest.mark.skipif(is_deprecated_api_only, reason=("Deprecated API")) async def test_db_update_melt_quote_state(wallet: Wallet, ledger: Ledger): melt_quote = await wallet.melt_quote(payment_request) - await ledger.db_write._update_melt_quote_state(melt_quote.quote, MeltQuoteState.paid) + await ledger.db_write._update_melt_quote_state( + melt_quote.quote, MeltQuoteState.paid + ) - melt_quote_db = await ledger.crud.get_melt_quote(quote_id=melt_quote.quote, db=ledger.db) + melt_quote_db = await ledger.crud.get_melt_quote( + quote_id=melt_quote.quote, db=ledger.db + ) + assert melt_quote_db assert melt_quote_db.state == MeltQuoteState.paid - await assert_err(ledger.db_write._update_melt_quote_state(melt_quote.quote, MeltQuoteState.unpaid), "Cannot change state of a paid melt quote.") \ No newline at end of file + await assert_err( + ledger.db_write._update_melt_quote_state( + melt_quote.quote, MeltQuoteState.unpaid + ), + "Cannot change state of a paid melt quote.", + ) diff --git a/tests/mint/test_mint_db_operations.py b/tests/mint/test_mint_db_operations.py index c6e58af..d81a3e5 100644 --- a/tests/mint/test_mint_db_operations.py +++ b/tests/mint/test_mint_db_operations.py @@ -10,6 +10,7 @@ import pytest_asyncio from cashu.core import db from cashu.core.db import Connection from cashu.core.migrations import backup_database +from cashu.core.models import PostMeltQuoteRequest from cashu.core.settings import settings from cashu.mint.ledger import Ledger from cashu.wallet.wallet import Wallet @@ -353,3 +354,402 @@ async def test_db_lock_table(wallet: Wallet, ledger: Ledger): ), "failed to acquire database lock", ) + + +@pytest.mark.asyncio +async def test_store_and_sign_blinded_message(ledger: Ledger): + # Localized imports to avoid polluting module scope + from cashu.core.crypto.b_dhke import step1_alice, step2_bob + from cashu.core.crypto.secp import PublicKey + + # Arrange: prepare a blinded message tied to current active keyset + amount = 8 + keyset_id = ledger.keyset.id + B_pubkey, _ = step1_alice("test_store_and_sign_blinded_message") + B_hex = B_pubkey.serialize().hex() + + # Act: store the blinded message (unsinged promise row) + await ledger.crud.store_blinded_message( + db=ledger.db, + amount=amount, + b_=B_hex, + id=keyset_id, + ) + + # Act: compute a valid blind signature for the stored row and persist it + private_key_amount = ledger.keyset.private_keys[amount] + B_point = PublicKey(bytes.fromhex(B_hex), raw=True) + C_point, e, s = step2_bob(B_point, private_key_amount) + + await ledger.crud.update_blinded_message_signature( + db=ledger.db, + amount=amount, + b_=B_hex, + c_=C_point.serialize().hex(), + e=e.serialize(), + s=s.serialize(), + ) + + # Assert: row is now a full promise and can be read back via get_promise + promise = await ledger.crud.get_promise(db=ledger.db, b_=B_hex) + assert promise is not None + assert promise.amount == amount + assert promise.C_ == C_point.serialize().hex() + assert promise.id == keyset_id + + +@pytest.mark.asyncio +async def test_get_blinded_messages_by_melt_id(wallet: Wallet, ledger: Ledger): + # Arrange + from cashu.core.crypto.b_dhke import step1_alice + + amount = 8 + keyset_id = ledger.keyset.id + # Create a real melt quote to satisfy FK on promises.melt_quote + mint_quote = await wallet.request_mint(64) + melt_quote = await ledger.melt_quote( + PostMeltQuoteRequest(request=mint_quote.request, unit="sat") + ) + melt_id = melt_quote.quote + + # Create two blinded messages + B1, _ = step1_alice("get_by_melt_id_1") + B2, _ = step1_alice("get_by_melt_id_2") + b1_hex = B1.serialize().hex() + b2_hex = B2.serialize().hex() + + # Persist as unsigned messages with proper melt_id FK + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b1_hex, id=keyset_id, melt_id=melt_id + ) + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b2_hex, id=keyset_id, melt_id=melt_id + ) + + # Act + rows = await ledger.crud.get_blinded_messages_melt_id(db=ledger.db, melt_id=melt_id) + + # Assert + assert len(rows) == 2 + assert {r.B_ for r in rows} == {b1_hex, b2_hex} + assert all(r.id == keyset_id for r in rows) + + +@pytest.mark.asyncio +async def test_delete_blinded_messages_by_melt_id(wallet: Wallet, ledger: Ledger): + from cashu.core.crypto.b_dhke import step1_alice + + amount = 4 + keyset_id = ledger.keyset.id + # Create a real melt quote to satisfy FK on promises.melt_quote + mint_quote = await wallet.request_mint(64) + melt_quote = await ledger.melt_quote( + PostMeltQuoteRequest(request=mint_quote.request, unit="sat") + ) + melt_id = melt_quote.quote + + # Create two blinded messages + B1, _ = step1_alice("delete_by_melt_id_1") + B2, _ = step1_alice("delete_by_melt_id_2") + b1_hex = B1.serialize().hex() + b2_hex = B2.serialize().hex() + + # Persist as unsigned messages + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b1_hex, id=keyset_id, melt_id=melt_id + ) + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b2_hex, id=keyset_id, melt_id=melt_id + ) + + rows_before = await ledger.crud.get_blinded_messages_melt_id( + db=ledger.db, melt_id=melt_id + ) + assert len(rows_before) == 2 + + # Act: delete all unsigned messages for this melt_id + await ledger.crud.delete_blinded_messages_melt_id(db=ledger.db, melt_id=melt_id) + + # Assert: now none left for that melt_id + rows_after = await ledger.crud.get_blinded_messages_melt_id( + db=ledger.db, melt_id=melt_id + ) + assert rows_after == [] + + +@pytest.mark.asyncio +async def test_get_blinded_messages_by_melt_id_filters_signed( + wallet: Wallet, ledger: Ledger +): + from cashu.core.crypto.b_dhke import step1_alice, step2_bob + from cashu.core.crypto.secp import PublicKey + + amount = 2 + keyset_id = ledger.keyset.id + # Create a real melt quote to satisfy FK on promises.melt_quote + mint_quote = await wallet.request_mint(64) + melt_quote = await ledger.melt_quote( + PostMeltQuoteRequest(request=mint_quote.request, unit="sat") + ) + melt_id = melt_quote.quote + + B1, _ = step1_alice("filter_by_melt_id_1") + B2, _ = step1_alice("filter_by_melt_id_2") + b1_hex = B1.serialize().hex() + b2_hex = B2.serialize().hex() + + # Persist two unsigned messages + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b1_hex, id=keyset_id, melt_id=melt_id + ) + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b2_hex, id=keyset_id, melt_id=melt_id + ) + + # Sign one of them (it should no longer be returned by get_blinded_messages_melt_id which filters c_ IS NULL) + priv = ledger.keyset.private_keys[amount] + C_point, e, s = step2_bob(PublicKey(bytes.fromhex(b1_hex), raw=True), priv) + await ledger.crud.update_blinded_message_signature( + db=ledger.db, + amount=amount, + b_=b1_hex, + c_=C_point.serialize().hex(), + e=e.serialize(), + s=s.serialize(), + ) + + # Act + rows = await ledger.crud.get_blinded_messages_melt_id(db=ledger.db, melt_id=melt_id) + + # Assert: only the unsigned one remains (b2_hex) + assert len(rows) == 1 + assert rows[0].B_ == b2_hex + assert rows[0].id == keyset_id + + +@pytest.mark.asyncio +async def test_store_blinded_message(ledger: Ledger): + from cashu.core.crypto.b_dhke import step1_alice + + amount = 8 + keyset_id = ledger.keyset.id + B_pub, _ = step1_alice("test_store_blinded_message") + b_hex = B_pub.serialize().hex() + + # Act: store unsigned blinded message + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b_hex, id=keyset_id + ) + + # Assert: row exists and is unsigned (c_ IS NULL) + async with ledger.db.connect() as conn: + row = await conn.fetchone( + f"SELECT amount, id, b_, c_, created FROM {ledger.db.table_with_schema('promises')} WHERE b_ = :b_", + {"b_": b_hex}, + ) + assert row is not None + assert int(row["amount"]) == amount + assert row["id"] == keyset_id + assert row["b_"] == b_hex + assert row["c_"] is None + assert row["created"] is not None + + +@pytest.mark.asyncio +async def test_update_blinded_message_signature_before_store_blinded_message_errors( + ledger: Ledger, +): + from cashu.core.crypto.b_dhke import step1_alice, step2_bob + from cashu.core.crypto.secp import PublicKey + + amount = 8 + # Generate a blinded message that we will NOT store + B_pub, _ = step1_alice("test_sign_before_store_blinded_message") + b_hex = B_pub.serialize().hex() + + # Create a valid signature tuple for that blinded message + priv = ledger.keyset.private_keys[amount] + C_point, e, s = step2_bob(PublicKey(bytes.fromhex(b_hex), raw=True), priv) + + # Expect a DB-level error; on SQLite/Postgres this is typically a no-op update, so this test is xfail. + await assert_err( + ledger.crud.update_blinded_message_signature( + db=ledger.db, + amount=amount, + b_=b_hex, + c_=C_point.serialize().hex(), + e=e.serialize(), + s=s.serialize(), + ), + "blinded message does not exist", + ) + + +@pytest.mark.asyncio +async def test_store_blinded_message_duplicate_b_(ledger: Ledger): + from cashu.core.crypto.b_dhke import step1_alice + + amount = 2 + keyset_id = ledger.keyset.id + B_pub, _ = step1_alice("test_duplicate_b_") + b_hex = B_pub.serialize().hex() + + # First insert should succeed + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b_hex, id=keyset_id + ) + + +@pytest.mark.asyncio +async def test_get_blind_signatures_by_melt_id_returns_signed( + wallet: Wallet, ledger: Ledger +): + from cashu.core.crypto.b_dhke import step1_alice, step2_bob + from cashu.core.crypto.secp import PublicKey + + amount = 4 + keyset_id = ledger.keyset.id + # Create a real melt quote to satisfy FK on promises.melt_quote + mint_quote = await wallet.request_mint(64) + melt_quote = await ledger.melt_quote( + PostMeltQuoteRequest(request=mint_quote.request, unit="sat") + ) + melt_id = melt_quote.quote + + # Prepare two blinded messages under the same melt_id + B1, _ = step1_alice("signed_promises_by_melt_id_1") + B2, _ = step1_alice("signed_promises_by_melt_id_2") + b1_hex = B1.serialize().hex() + b2_hex = B2.serialize().hex() + + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b1_hex, id=keyset_id, melt_id=melt_id + ) + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b2_hex, id=keyset_id, melt_id=melt_id + ) + + # Sign only one of them -> should be returned by get_blind_signatures_melt_id + priv = ledger.keyset.private_keys[amount] + C_point, e, s = step2_bob(PublicKey(bytes.fromhex(b1_hex), raw=True), priv) + await ledger.crud.update_blinded_message_signature( + db=ledger.db, + amount=amount, + b_=b1_hex, + c_=C_point.serialize().hex(), + e=e.serialize(), + s=s.serialize(), + ) + + # Act + signed = await ledger.crud.get_blind_signatures_melt_id( + db=ledger.db, melt_id=melt_id + ) + + # Assert: only the signed one is returned + assert len(signed) == 1 + assert signed[0].amount == amount + assert signed[0].id == keyset_id + + +@pytest.mark.asyncio +async def test_get_melt_quote_includes_change_signatures( + wallet: Wallet, ledger: Ledger +): + from cashu.core.crypto.b_dhke import step1_alice, step2_bob + from cashu.core.crypto.secp import PublicKey + + amount = 8 + keyset_id = ledger.keyset.id + + # Create melt quote and attach outputs/promises under its melt_id + mint_quote = await wallet.request_mint(64) + melt_quote = await ledger.melt_quote( + PostMeltQuoteRequest(request=mint_quote.request, unit="sat") + ) + + melt_id = melt_quote.quote + + # Create two blinded messages, sign one -> becomes change + B1, _ = step1_alice("melt_quote_change_1") + B2, _ = step1_alice("melt_quote_change_2") + b1_hex = B1.serialize().hex() + b2_hex = B2.serialize().hex() + + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b1_hex, id=keyset_id, melt_id=melt_id + ) + await ledger.crud.store_blinded_message( + db=ledger.db, amount=amount, b_=b2_hex, id=keyset_id, melt_id=melt_id + ) + + # Sign one -> should appear in change loaded by get_melt_quote + priv = ledger.keyset.private_keys[amount] + C_point, e, s = step2_bob(PublicKey(bytes.fromhex(b1_hex), raw=True), priv) + await ledger.crud.update_blinded_message_signature( + db=ledger.db, + amount=amount, + b_=b1_hex, + c_=C_point.serialize().hex(), + e=e.serialize(), + s=s.serialize(), + ) + + # Act + quote_db = await ledger.crud.get_melt_quote(quote_id=melt_id, db=ledger.db) + + # Assert: change contains the signed promise(s) + assert quote_db is not None + assert quote_db.quote == melt_id + assert quote_db.change is not None + assert len(quote_db.change) == 1 + assert quote_db.change[0].amount == amount + assert quote_db.change[0].id == keyset_id + + +@pytest.mark.asyncio +async def test_promises_fk_constraints_enforced(ledger: Ledger): + from cashu.core.crypto.b_dhke import step1_alice + + keyset_id = ledger.keyset.id + B1, _ = step1_alice("fk_check_melt") + B2, _ = step1_alice("fk_check_mint") + b1_hex = B1.serialize().hex() + b2_hex = B2.serialize().hex() + + # Use a single connection and enable FK enforcement on SQLite + async with ledger.db.connect() as conn: + # Fake melt_id should violate FK on promises.melt_quote + await assert_err_multiple( + ledger.crud.store_blinded_message( + db=ledger.db, + amount=1, + b_=b1_hex, + id=keyset_id, + melt_id="nonexistent-melt-id", + conn=conn, + ), + [ + "FOREIGN KEY", # SQLite + "violates foreign key constraint", # Postgres + ], + ) + + async with ledger.db.connect() as conn: + # Fake mint_id should violate FK on promises.mint_quote + await assert_err_multiple( + ledger.crud.store_blinded_message( + db=ledger.db, + amount=1, + b_=b2_hex, + id=keyset_id, + mint_id="nonexistent-mint-id", + conn=conn, + ), + [ + "FOREIGN KEY", # SQLite + "violates foreign key constraint", # Postgres + ], + ) + + # Done. This test only checks FK enforcement paths. diff --git a/tests/wallet/test_wallet_subscription.py b/tests/wallet/test_wallet_subscription.py index 286ce55..4d6a776 100644 --- a/tests/wallet/test_wallet_subscription.py +++ b/tests/wallet/test_wallet_subscription.py @@ -51,13 +51,13 @@ async def test_wallet_subscription_mint(wallet: Wallet): await asyncio.sleep(wait + 2) assert triggered - assert len(msg_stack) == 3 + assert len(msg_stack) >= 3 assert msg_stack[0].payload["state"] == MintQuoteState.unpaid.value assert msg_stack[1].payload["state"] == MintQuoteState.paid.value - assert msg_stack[2].payload["state"] == MintQuoteState.issued.value + assert msg_stack[-1].payload["state"] == MintQuoteState.issued.value @pytest.mark.asyncio @@ -133,7 +133,9 @@ async def test_wallet_subscription_multiple_listeners_receive_updates(wallet: Wa from cashu.wallet.subscriptions import SubscriptionManager subs = SubscriptionManager(wallet.url) - threading.Thread(target=subs.connect, name="SubscriptionManager", daemon=True).start() + threading.Thread( + target=subs.connect, name="SubscriptionManager", daemon=True + ).start() stack1: list[JSONRPCNotficationParams] = [] stack2: list[JSONRPCNotficationParams] = []