Fix duplicate blank outputs during melt (#795)

* wip blank outputs

* wip: working

* store ids for promises correctly

* tests

* fix migraiton

* revert

* fix tests

* fix auth server

* fix last tests

* retroactively change migration, initial and m017_foreign_keys_proof_tables, remove c_b and replace with c_ (same for b_)

* fix constraint

* oops

* msg stack fix

* fix test foreign key constraint

* fix postgres tests

* foreign key constraint test

* should fix psql error

* foreign key constraint sqlite

* rename to update_blinded_message_signature

* drop outputs and change columns from melt_quotes table

* switch migration order

* reorder migrations again

* fix migration

* add tests

* fix postgres migration too

* create signed_at column postgres

* foreign key constraingt promises table

* migration tool

* readme
This commit is contained in:
callebtc
2025-10-19 15:50:47 +02:00
committed by GitHub
parent a5f950a8f8
commit 9fed0f0f07
14 changed files with 1588 additions and 335 deletions

View File

@@ -199,6 +199,23 @@ MINT_REDIS_CACHE_URL=redis://localhost:6379
### NUT-21 Authentication with Keycloak ### NUT-21 Authentication with Keycloak
Cashu supports clear and blind authentication as defined in [NUT-21](https://github.com/cashubtc/nuts/blob/main/21.md) and [NUT-22](https://github.com/cashubtc/nuts/blob/main/22.md) to limit the use of a mint to a registered set of users. Clear authentication is supported via a OICD provider such as Keycloak. You can set up and run Keycloak instance using the docker compose file `docker/keycloak/docker-compose.yml` in this repository. Cashu supports clear and blind authentication as defined in [NUT-21](https://github.com/cashubtc/nuts/blob/main/21.md) and [NUT-22](https://github.com/cashubtc/nuts/blob/main/22.md) to limit the use of a mint to a registered set of users. Clear authentication is supported via a OICD provider such as Keycloak. You can set up and run Keycloak instance using the docker compose file `docker/keycloak/docker-compose.yml` in this repository.
### Migrate SQLite mint DB to Postgres
Use the standalone tool at `cashu/mint/sqlite_to_postgres.py` to migrate a mint database from SQLite to Postgres.
```bash
# 1) optionally reset the target Postgres database (DROPS ALL DATA)
psql -U <user> -h <host> -p <port> -d <database> -c "DROP SCHEMA public CASCADE; CREATE SCHEMA public; GRANT ALL PRIVILEGES ON SCHEMA public TO <user>;"
# 2) run migration (inside poetry env)
poetry run python cashu/mint/sqlite_to_postgres.py \
--sqlite data/mint/mint.sqlite3 \
--postgres postgres://<user>:<pass>@<host>:<port>/<database> \
--batch-size 2000
```
- The tool aborts if the Postgres DB appears populated and prints the exact reset command with your connection details.
- After copying, it verifies row counts and compares the `balance` view across both databases.
# Running tests # Running tests
To run the tests in this repository, first install the dev dependencies with To run the tests in this repository, first install the dev dependencies with
```bash ```bash

View File

@@ -231,6 +231,10 @@ class BlindedMessage(BaseModel):
id: str # Keyset id id: str # Keyset id
B_: str # Hex-encoded blinded message B_: str # Hex-encoded blinded message
@classmethod
def from_row(cls, row: RowMapping):
return cls(amount=row["amount"], B_=row["b_"], id=row["id"])
class BlindedMessage_Deprecated(BaseModel): class BlindedMessage_Deprecated(BaseModel):
""" """
@@ -300,7 +304,7 @@ class MeltQuote(LedgerEvent):
mint: Optional[str] = None mint: Optional[str] = None
@classmethod @classmethod
def from_row(cls, row: Row): def from_row(cls, row: Row, change: Optional[List[BlindedSignature]] = None):
try: try:
created_time = int(row["created_time"]) if row["created_time"] else None created_time = int(row["created_time"]) if row["created_time"] else None
paid_time = int(row["paid_time"]) if row["paid_time"] else None paid_time = int(row["paid_time"]) if row["paid_time"] else None
@@ -314,11 +318,6 @@ class MeltQuote(LedgerEvent):
payment_preimage = row.get("payment_preimage") or row.get("proof") # type: ignore payment_preimage = row.get("payment_preimage") or row.get("proof") # type: ignore
# parse change from row as json
change = None
if "change" in row.keys() and row["change"]:
change = json.loads(row["change"])
outputs = None outputs = None
if "outputs" in row.keys() and row["outputs"]: if "outputs" in row.keys() and row["outputs"]:
outputs = json.loads(row["outputs"]) outputs = json.loads(row["outputs"])

View File

@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
from typing import Optional, Union from typing import Optional, Union
from loguru import logger from loguru import logger
from sqlalchemy import text from sqlalchemy import event, text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
@@ -142,6 +142,19 @@ class Database(Compat):
kwargs["max_overflow"] = 100 # type: ignore[assignment] kwargs["max_overflow"] = 100 # type: ignore[assignment]
self.engine = create_async_engine(database_uri, **kwargs) self.engine = create_async_engine(database_uri, **kwargs)
# Ensure SQLite enforces foreign keys on every connection
if self.type == SQLITE:
@event.listens_for(self.engine.sync_engine, "connect")
def _set_sqlite_pragma(dbapi_connection, connection_record):
try:
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON;")
cursor.close()
except Exception as e:
logger.warning(f"Could not enable SQLite PRAGMA foreign_keys: {e}")
self.async_session = sessionmaker( self.async_session = sessionmaker(
self.engine, # type: ignore self.engine, # type: ignore
expire_on_commit=False, expire_on_commit=False,

View File

@@ -201,7 +201,8 @@ class AuthLedger(Ledger):
) )
await self._verify_outputs(outputs) await self._verify_outputs(outputs)
promises = await self._generate_promises(outputs) await self._store_blinded_messages(outputs)
promises = await self._sign_blinded_messages(outputs)
# update last_access timestamp of the user # update last_access timestamp of the user
await self.auth_crud.update_user(user_id=user.id, db=self.db) await self.auth_crud.update_user(user_id=user.id, db=self.db)

View File

@@ -6,6 +6,7 @@ from loguru import logger
from ..core.base import ( from ..core.base import (
Amount, Amount,
BlindedMessage,
BlindedSignature, BlindedSignature,
MeltQuote, MeltQuote,
MintBalanceLogEntry, MintBalanceLogEntry,
@@ -151,19 +152,59 @@ class LedgerCrud(ABC):
) -> Tuple[Amount, Amount]: ... ) -> Tuple[Amount, Amount]: ...
@abstractmethod @abstractmethod
async def store_promise( async def store_blinded_message(
self,
*,
db: Database,
amount: int,
b_: str,
id: str,
mint_id: Optional[str] = None,
melt_id: Optional[str] = None,
swap_id: Optional[str] = None,
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod
async def delete_blinded_messages_melt_id(
self,
*,
db: Database,
melt_id: str,
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod
async def update_blinded_message_signature(
self, self,
*, *,
db: Database, db: Database,
amount: int, amount: int,
b_: str, b_: str,
c_: str, c_: str,
id: str,
e: str = "", e: str = "",
s: str = "", s: str = "",
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None: ...
@abstractmethod
async def get_blinded_messages_melt_id(
self,
*,
db: Database,
melt_id: str,
conn: Optional[Connection] = None,
) -> List[BlindedMessage]: ...
@abstractmethod
async def get_blind_signatures_melt_id(
self,
*,
db: Database,
melt_id: str,
conn: Optional[Connection] = None,
) -> List[BlindedSignature]: ...
@abstractmethod @abstractmethod
async def get_promise( async def get_promise(
self, self,
@@ -285,32 +326,119 @@ class LedgerCrudSqlite(LedgerCrud):
LedgerCrud (ABC): Abstract base class for LedgerCrud. LedgerCrud (ABC): Abstract base class for LedgerCrud.
""" """
async def store_promise( async def store_blinded_message(
self,
*,
db: Database,
amount: int,
b_: str,
id: str,
mint_id: Optional[str] = None,
melt_id: Optional[str] = None,
swap_id: Optional[str] = None,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
INSERT INTO {db.table_with_schema('promises')}
(amount, b_, id, created, mint_quote, melt_quote, swap_id)
VALUES (:amount, :b_, :id, :created, :mint_quote, :melt_quote, :swap_id)
""",
{
"amount": amount,
"b_": b_,
"id": id,
"created": db.to_timestamp(db.timestamp_now_str()),
"mint_quote": mint_id,
"melt_quote": melt_id,
"swap_id": swap_id,
},
)
async def get_blinded_messages_melt_id(
self,
*,
db: Database,
melt_id: str,
conn: Optional[Connection] = None,
) -> List[BlindedMessage]:
rows = await (conn or db).fetchall(
f"""
SELECT * from {db.table_with_schema('promises')}
WHERE melt_quote = :melt_id AND c_ IS NULL
""",
{"melt_id": melt_id},
)
return [BlindedMessage.from_row(r) for r in rows] if rows else []
async def get_blind_signatures_melt_id(
self,
*,
db: Database,
melt_id: str,
conn: Optional[Connection] = None,
) -> List[BlindedSignature]:
rows = await (conn or db).fetchall(
f"""
SELECT * from {db.table_with_schema('promises')}
WHERE melt_quote = :melt_id AND c_ IS NOT NULL
""",
{"melt_id": melt_id},
)
return [BlindedSignature.from_row(r) for r in rows] if rows else [] # type: ignore
async def delete_blinded_messages_melt_id(
self,
*,
db: Database,
melt_id: str,
conn: Optional[Connection] = None,
) -> None:
"""Deletes a blinded message (promise) that has not been signed yet (c_ is NULL) with the given quote ID."""
await (conn or db).execute(
f"""
DELETE FROM {db.table_with_schema('promises')}
WHERE melt_quote = :melt_id AND c_ IS NULL
""",
{
"melt_id": melt_id,
},
)
async def update_blinded_message_signature(
self, self,
*, *,
db: Database, db: Database,
amount: int, amount: int,
b_: str, b_: str,
c_: str, c_: str,
id: str,
e: str = "", e: str = "",
s: str = "", s: str = "",
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ) -> None:
existing = await (conn or db).fetchone(
f"""
SELECT * from {db.table_with_schema('promises')}
WHERE b_ = :b_
""",
{"b_": str(b_)},
)
if existing is None:
raise ValueError("blinded message does not exist")
await (conn or db).execute( await (conn or db).execute(
f""" f"""
INSERT INTO {db.table_with_schema('promises')} UPDATE {db.table_with_schema('promises')}
(amount, b_, c_, dleq_e, dleq_s, id, created) SET amount = :amount, c_ = :c_, dleq_e = :dleq_e, dleq_s = :dleq_s, signed_at = :signed_at
VALUES (:amount, :b_, :c_, :dleq_e, :dleq_s, :id, :created) WHERE b_ = :b_;
""", """,
{ {
"amount": amount,
"b_": b_, "b_": b_,
"amount": amount,
"c_": c_, "c_": c_,
"dleq_e": e, "dleq_e": e,
"dleq_s": s, "dleq_s": s,
"id": id, "signed_at": db.to_timestamp(db.timestamp_now_str()),
"created": db.to_timestamp(db.timestamp_now_str()),
}, },
) )
@@ -324,7 +452,7 @@ class LedgerCrudSqlite(LedgerCrud):
row = await (conn or db).fetchone( row = await (conn or db).fetchone(
f""" f"""
SELECT * from {db.table_with_schema('promises')} SELECT * from {db.table_with_schema('promises')}
WHERE b_ = :b_ WHERE b_ = :b_ AND c_ IS NOT NULL
""", """,
{"b_": str(b_)}, {"b_": str(b_)},
) )
@@ -340,7 +468,7 @@ class LedgerCrudSqlite(LedgerCrud):
rows = await (conn or db).fetchall( rows = await (conn or db).fetchall(
f""" f"""
SELECT * from {db.table_with_schema('promises')} SELECT * from {db.table_with_schema('promises')}
WHERE b_ IN ({','.join([f":b_{i}" for i in range(len(b_s))])}) WHERE b_ IN ({','.join([f":b_{i}" for i in range(len(b_s))])}) AND c_ IS NOT NULL
""", """,
{f"b_{i}": b_s[i] for i in range(len(b_s))}, {f"b_{i}": b_s[i] for i in range(len(b_s))},
) )
@@ -568,8 +696,8 @@ class LedgerCrudSqlite(LedgerCrud):
await (conn or db).execute( await (conn or db).execute(
f""" f"""
INSERT INTO {db.table_with_schema('melt_quotes')} INSERT INTO {db.table_with_schema('melt_quotes')}
(quote, method, request, checking_id, unit, amount, fee_reserve, state, paid, created_time, paid_time, fee_paid, proof, outputs, change, expiry) (quote, method, request, checking_id, unit, amount, fee_reserve, state, paid, created_time, paid_time, fee_paid, proof, expiry)
VALUES (:quote, :method, :request, :checking_id, :unit, :amount, :fee_reserve, :state, :paid, :created_time, :paid_time, :fee_paid, :proof, :outputs, :change, :expiry) VALUES (:quote, :method, :request, :checking_id, :unit, :amount, :fee_reserve, :state, :paid, :created_time, :paid_time, :fee_paid, :proof, :expiry)
""", """,
{ {
"quote": quote.quote, "quote": quote.quote,
@@ -589,8 +717,6 @@ class LedgerCrudSqlite(LedgerCrud):
), ),
"fee_paid": quote.fee_paid, "fee_paid": quote.fee_paid,
"proof": quote.payment_preimage, "proof": quote.payment_preimage,
"outputs": json.dumps(quote.outputs) if quote.outputs else None,
"change": json.dumps(quote.change) if quote.change else None,
"expiry": db.to_timestamp( "expiry": db.to_timestamp(
db.timestamp_from_seconds(quote.expiry) or "" db.timestamp_from_seconds(quote.expiry) or ""
), ),
@@ -627,7 +753,14 @@ class LedgerCrudSqlite(LedgerCrud):
""", """,
values, values,
) )
return MeltQuote.from_row(row) if row else None # type: ignore
change = None
if row:
change = await self.get_blind_signatures_melt_id(
db=db, melt_id=row["quote"], conn=conn
)
return MeltQuote.from_row(row, change) if row else None # type: ignore
async def get_melt_quote_by_request( async def get_melt_quote_by_request(
self, self,
@@ -654,7 +787,7 @@ class LedgerCrudSqlite(LedgerCrud):
) -> None: ) -> None:
await (conn or db).execute( await (conn or db).execute(
f""" f"""
UPDATE {db.table_with_schema('melt_quotes')} SET state = :state, fee_paid = :fee_paid, paid_time = :paid_time, proof = :proof, outputs = :outputs, change = :change, checking_id = :checking_id WHERE quote = :quote UPDATE {db.table_with_schema('melt_quotes')} SET state = :state, fee_paid = :fee_paid, paid_time = :paid_time, proof = :proof, checking_id = :checking_id WHERE quote = :quote
""", """,
{ {
"state": quote.state.value, "state": quote.state.value,
@@ -663,16 +796,6 @@ class LedgerCrudSqlite(LedgerCrud):
db.timestamp_from_seconds(quote.paid_time) or "" db.timestamp_from_seconds(quote.paid_time) or ""
), ),
"proof": quote.payment_preimage, "proof": quote.payment_preimage,
"outputs": (
json.dumps([s.dict() for s in quote.outputs])
if quote.outputs
else None
),
"change": (
json.dumps([s.dict() for s in quote.change])
if quote.change
else None
),
"quote": quote.quote, "quote": quote.quote,
"checking_id": quote.checking_id, "checking_id": quote.checking_id,
}, },

View File

@@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Union
from loguru import logger from loguru import logger
from ...core.base import ( from ...core.base import (
BlindedMessage,
MeltQuote, MeltQuote,
MeltQuoteState, MeltQuoteState,
MintKeyset, MintKeyset,
@@ -198,9 +197,7 @@ class DbWriteHelper:
await self.events.submit(quote) await self.events.submit(quote)
return quote return quote
async def _set_melt_quote_pending( async def _set_melt_quote_pending(self, quote: MeltQuote) -> MeltQuote:
self, quote: MeltQuote, outputs: Optional[List[BlindedMessage]] = None
) -> MeltQuote:
"""Sets the melt quote as pending. """Sets the melt quote as pending.
Args: Args:
@@ -221,8 +218,6 @@ class DbWriteHelper:
raise TransactionError("Melt quote already pending.") raise TransactionError("Melt quote already pending.")
# set the quote as pending # set the quote as pending
quote_copy.state = MeltQuoteState.pending quote_copy.state = MeltQuoteState.pending
if outputs:
quote_copy.outputs = outputs
await self.crud.update_melt_quote(quote=quote_copy, db=self.db, conn=conn) await self.crud.update_melt_quote(quote=quote_copy, db=self.db, conn=conn)
await self.events.submit(quote_copy) await self.events.submit(quote_copy)
@@ -257,21 +252,25 @@ class DbWriteHelper:
await self.events.submit(quote_copy) await self.events.submit(quote_copy)
return quote_copy return quote_copy
async def _update_mint_quote_state( async def _update_mint_quote_state(self, quote_id: str, state: MintQuoteState):
self, quote_id: str, state: MintQuoteState
):
async with self.db.get_connection(lock_table="mint_quotes") as conn: async with self.db.get_connection(lock_table="mint_quotes") as conn:
mint_quote = await self.crud.get_mint_quote(quote_id=quote_id, db=self.db, conn=conn) mint_quote = await self.crud.get_mint_quote(
quote_id=quote_id, db=self.db, conn=conn
)
if not mint_quote: if not mint_quote:
raise TransactionError("Mint quote not found.") raise TransactionError("Mint quote not found.")
mint_quote.state = state mint_quote.state = state
await self.crud.update_mint_quote(quote=mint_quote, db=self.db, conn=conn) await self.crud.update_mint_quote(quote=mint_quote, db=self.db, conn=conn)
async def _update_melt_quote_state( async def _update_melt_quote_state(
self, quote_id: str, state: MeltQuoteState, self,
quote_id: str,
state: MeltQuoteState,
): ):
async with self.db.get_connection(lock_table="melt_quotes") as conn: async with self.db.get_connection(lock_table="melt_quotes") as conn:
melt_quote = await self.crud.get_melt_quote(quote_id=quote_id, db=self.db, conn=conn) melt_quote = await self.crud.get_melt_quote(
quote_id=quote_id, db=self.db, conn=conn
)
if not melt_quote: if not melt_quote:
raise TransactionError("Melt quote not found.") raise TransactionError("Melt quote not found.")
melt_quote.state = state melt_quote.state = state

View File

@@ -250,6 +250,7 @@ class Ledger(
fee_provided: int, fee_provided: int,
fee_paid: int, fee_paid: int,
outputs: Optional[List[BlindedMessage]], outputs: Optional[List[BlindedMessage]],
melt_id: Optional[str] = None,
keyset: Optional[MintKeyset] = None, keyset: Optional[MintKeyset] = None,
) -> List[BlindedSignature]: ) -> List[BlindedSignature]:
"""Generates a set of new promises (blinded signatures) from a set of blank outputs """Generates a set of new promises (blinded signatures) from a set of blank outputs
@@ -305,7 +306,10 @@ class Ledger(
outputs[i].amount = return_amounts_sorted[i] # type: ignore outputs[i].amount = return_amounts_sorted[i] # type: ignore
if not self._verify_no_duplicate_outputs(outputs): if not self._verify_no_duplicate_outputs(outputs):
raise TransactionError("duplicate promises.") raise TransactionError("duplicate promises.")
return_promises = await self._generate_promises(outputs, keyset) return_promises = await self._sign_blinded_messages(outputs)
# delete remaining unsigned blank outputs from db
if melt_id:
await self.crud.delete_blinded_messages_melt_id(melt_id=melt_id, db=self.db)
return return_promises return return_promises
# ------- TRANSACTIONS ------- # ------- TRANSACTIONS -------
@@ -491,8 +495,8 @@ class Ledger(
raise TransactionError("quote expired") raise TransactionError("quote expired")
if not self._verify_mint_quote_witness(quote, outputs, signature): if not self._verify_mint_quote_witness(quote, outputs, signature):
raise QuoteSignatureInvalidError() raise QuoteSignatureInvalidError()
await self._store_blinded_messages(outputs, mint_id=quote_id)
promises = await self._generate_promises(outputs) promises = await self._sign_blinded_messages(outputs)
except Exception as e: except Exception as e:
await self.db_write._unset_mint_quote_pending( await self.db_write._unset_mint_quote_pending(
quote_id=quote_id, state=previous_state quote_id=quote_id, state=previous_state
@@ -726,7 +730,10 @@ class Ledger(
pending_proofs, keysets=self.keysets, conn=conn pending_proofs, keysets=self.keysets, conn=conn
) )
# change to compensate wallet for overpaid fees # change to compensate wallet for overpaid fees
if melt_quote.outputs: melt_outputs = await self.crud.get_blinded_messages_melt_id(
melt_id=quote_id, db=self.db
)
if melt_outputs:
total_provided = sum_proofs(pending_proofs) total_provided = sum_proofs(pending_proofs)
input_fees = self.get_fees_for_proofs(pending_proofs) input_fees = self.get_fees_for_proofs(pending_proofs)
fee_reserve_provided = ( fee_reserve_provided = (
@@ -735,8 +742,9 @@ class Ledger(
return_promises = await self._generate_change_promises( return_promises = await self._generate_change_promises(
fee_provided=fee_reserve_provided, fee_provided=fee_reserve_provided,
fee_paid=melt_quote.fee_paid, fee_paid=melt_quote.fee_paid,
outputs=melt_quote.outputs, outputs=melt_outputs,
keyset=self.keysets[melt_quote.outputs[0].id], melt_id=quote_id,
keyset=self.keysets[melt_outputs[0].id],
) )
melt_quote.change = return_promises melt_quote.change = return_promises
await self.crud.update_melt_quote(quote=melt_quote, db=self.db) await self.crud.update_melt_quote(quote=melt_quote, db=self.db)
@@ -752,6 +760,9 @@ class Ledger(
await self.db_write._unset_proofs_pending( await self.db_write._unset_proofs_pending(
pending_proofs, keysets=self.keysets pending_proofs, keysets=self.keysets
) )
await self.crud.delete_blinded_messages_melt_id(
melt_id=quote_id, db=self.db
)
return melt_quote return melt_quote
@@ -873,8 +884,6 @@ class Ledger(
raise TransactionError( raise TransactionError(
f"output unit {outputs_unit.name} does not match quote unit {melt_quote.unit}" f"output unit {outputs_unit.name} does not match quote unit {melt_quote.unit}"
) )
# we don't need to set it here, _set_melt_quote_pending will set it in the db
melt_quote.outputs = outputs
# verify SIG_ALL signatures # verify SIG_ALL signatures
message_to_sign = ( message_to_sign = (
@@ -907,7 +916,9 @@ class Ledger(
proofs, keysets=self.keysets, quote_id=melt_quote.quote proofs, keysets=self.keysets, quote_id=melt_quote.quote
) )
previous_state = melt_quote.state previous_state = melt_quote.state
melt_quote = await self.db_write._set_melt_quote_pending(melt_quote, outputs) melt_quote = await self.db_write._set_melt_quote_pending(melt_quote)
if outputs:
await self._store_blinded_messages(outputs, melt_id=melt_quote.quote)
# if the melt corresponds to an internal mint, mark both as paid # if the melt corresponds to an internal mint, mark both as paid
melt_quote = await self.melt_mint_settle_internally(melt_quote, proofs) melt_quote = await self.melt_mint_settle_internally(melt_quote, proofs)
@@ -966,6 +977,9 @@ class Ledger(
await self.db_write._unset_melt_quote_pending( await self.db_write._unset_melt_quote_pending(
quote=melt_quote, state=previous_state quote=melt_quote, state=previous_state
) )
await self.crud.delete_blinded_messages_melt_id(
melt_id=melt_quote.quote, db=self.db
)
if status.error_message: if status.error_message:
logger.error( logger.error(
f"Status check error: {status.error_message}" f"Status check error: {status.error_message}"
@@ -1011,6 +1025,7 @@ class Ledger(
fee_provided=fee_reserve_provided, fee_provided=fee_reserve_provided,
fee_paid=melt_quote.fee_paid, fee_paid=melt_quote.fee_paid,
outputs=outputs, outputs=outputs,
melt_id=melt_quote.quote,
keyset=self.keysets[outputs[0].id], keyset=self.keysets[outputs[0].id],
) )
@@ -1050,8 +1065,9 @@ class Ledger(
) )
try: try:
async with self.db.get_connection(lock_table="proofs_pending") as conn: async with self.db.get_connection(lock_table="proofs_pending") as conn:
await self._store_blinded_messages(outputs, keyset=keyset, conn=conn)
await self._invalidate_proofs(proofs=proofs, conn=conn) await self._invalidate_proofs(proofs=proofs, conn=conn)
promises = await self._generate_promises(outputs, keyset, conn) promises = await self._sign_blinded_messages(outputs, conn)
except Exception as e: except Exception as e:
logger.trace(f"swap failed: {e}") logger.trace(f"swap failed: {e}")
raise e raise e
@@ -1081,10 +1097,47 @@ class Ledger(
# ------- BLIND SIGNATURES ------- # ------- BLIND SIGNATURES -------
async def _generate_promises( async def _store_blinded_messages(
self, self,
outputs: List[BlindedMessage], outputs: List[BlindedMessage],
keyset: Optional[MintKeyset] = None, keyset: Optional[MintKeyset] = None,
mint_id: Optional[str] = None,
melt_id: Optional[str] = None,
swap_id: Optional[str] = None,
conn: Optional[Connection] = None,
) -> None:
"""Stores a blinded message in the database.
Args:
outputs (List[BlindedMessage]): Blinded messages to store.
keyset (Optional[MintKeyset], optional): Keyset to use. Uses default keyset if not given. Defaults to None.
conn: (Optional[Connection], optional): Database connection to reuse. Will create a new one if not given. Defaults to None.
"""
async with self.db.get_connection(conn) as conn:
for output in outputs:
keyset = keyset or self.keysets[output.id]
if output.id not in self.keysets:
raise TransactionError(f"keyset {output.id} not found")
if output.id != keyset.id:
raise TransactionError("keyset id does not match output id")
if not keyset.active:
raise TransactionError("keyset is not active")
logger.trace(f"Storing blinded message with keyset {keyset.id}.")
await self.crud.store_blinded_message(
id=keyset.id,
amount=output.amount,
b_=output.B_,
mint_id=mint_id,
melt_id=melt_id,
swap_id=swap_id,
db=self.db,
conn=conn,
)
logger.trace(f"Stored blinded message for {output.amount}")
async def _sign_blinded_messages(
self,
outputs: List[BlindedMessage],
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> list[BlindedSignature]: ) -> list[BlindedSignature]:
"""Generates a promises (Blind signatures) for given amount and returns a pair (amount, C'). """Generates a promises (Blind signatures) for given amount and returns a pair (amount, C').
@@ -1107,9 +1160,9 @@ class Ledger(
] = [] ] = []
for output in outputs: for output in outputs:
B_ = PublicKey(bytes.fromhex(output.B_), raw=True) B_ = PublicKey(bytes.fromhex(output.B_), raw=True)
keyset = keyset or self.keysets[output.id]
if output.id not in self.keysets: if output.id not in self.keysets:
raise TransactionError(f"keyset {output.id} not found") raise TransactionError(f"keyset {output.id} not found")
keyset = self.keysets[output.id]
if output.id != keyset.id: if output.id != keyset.id:
raise TransactionError("keyset id does not match output id") raise TransactionError("keyset id does not match output id")
if not keyset.active: if not keyset.active:
@@ -1127,9 +1180,8 @@ class Ledger(
for promise in promises: for promise in promises:
keyset_id, B_, amount, C_, e, s = promise keyset_id, B_, amount, C_, e, s = promise
logger.trace(f"crud: _generate_promise storing promise for {amount}") logger.trace(f"crud: _generate_promise storing promise for {amount}")
await self.crud.store_promise( await self.crud.update_blinded_message_signature(
amount=amount, amount=amount,
id=keyset_id,
b_=B_.serialize().hex(), b_=B_.serialize().hex(),
c_=C_.serialize().hex(), c_=C_.serialize().hex(),
e=e.serialize(), e=e.serialize(),

View File

@@ -1,4 +1,5 @@
import copy import copy
import json
from typing import List from typing import List
from sqlalchemy import RowMapping from sqlalchemy import RowMapping
@@ -26,10 +27,10 @@ async def m001_initial(db: Database):
f""" f"""
CREATE TABLE IF NOT EXISTS {db.table_with_schema('promises')} ( CREATE TABLE IF NOT EXISTS {db.table_with_schema('promises')} (
amount {db.big_int} NOT NULL, amount {db.big_int} NOT NULL,
b_b TEXT NOT NULL, b_ TEXT NOT NULL,
c_b TEXT NOT NULL, c_ TEXT NOT NULL,
UNIQUE (b_b) UNIQUE (b_)
); );
""" """
@@ -52,11 +53,11 @@ async def m001_initial(db: Database):
f""" f"""
CREATE TABLE IF NOT EXISTS {db.table_with_schema('invoices')} ( CREATE TABLE IF NOT EXISTS {db.table_with_schema('invoices')} (
amount {db.big_int} NOT NULL, amount {db.big_int} NOT NULL,
pr TEXT NOT NULL, bolt11 TEXT NOT NULL,
hash TEXT NOT NULL, id TEXT NOT NULL,
issued BOOL NOT NULL, issued BOOL NOT NULL,
UNIQUE (hash) UNIQUE (id)
); );
""" """
@@ -78,7 +79,7 @@ async def create_balance_views(db: Database, conn: Connection):
SELECT id AS keyset, COALESCE(s, 0) AS balance FROM ( SELECT id AS keyset, COALESCE(s, 0) AS balance FROM (
SELECT id, SUM(amount) AS s SELECT id, SUM(amount) AS s
FROM {db.table_with_schema('promises')} FROM {db.table_with_schema('promises')}
WHERE amount > 0 WHERE amount > 0 AND c_ IS NOT NULL
GROUP BY id GROUP BY id
) AS balance_issued; ) AS balance_issued;
""" """
@@ -191,7 +192,7 @@ async def m006_invoices_add_payment_hash(db: Database):
" TEXT" " TEXT"
) )
await conn.execute( await conn.execute(
f"UPDATE {db.table_with_schema('invoices')} SET payment_hash = hash" f"UPDATE {db.table_with_schema('invoices')} SET payment_hash = id"
) )
@@ -230,16 +231,6 @@ async def m008_promises_dleq(db: Database):
async def m009_add_out_to_invoices(db: Database): async def m009_add_out_to_invoices(db: Database):
# column in invoices for marking whether the invoice is incoming (out=False) or outgoing (out=True) # column in invoices for marking whether the invoice is incoming (out=False) or outgoing (out=True)
async with db.connect() as conn: async with db.connect() as conn:
# rename column pr to bolt11
await conn.execute(
f"ALTER TABLE {db.table_with_schema('invoices')} RENAME COLUMN pr TO"
" bolt11"
)
# rename column hash to payment_hash
await conn.execute(
f"ALTER TABLE {db.table_with_schema('invoices')} RENAME COLUMN hash TO id"
)
await conn.execute( await conn.execute(
f"ALTER TABLE {db.table_with_schema('invoices')} ADD COLUMN out BOOL" f"ALTER TABLE {db.table_with_schema('invoices')} ADD COLUMN out BOOL"
) )
@@ -720,7 +711,7 @@ async def m017_foreign_keys_proof_tables(db: Database):
) )
await conn.execute( await conn.execute(
f"INSERT INTO {db.table_with_schema('promises_new')} (amount, id, b_, c_, dleq_e, dleq_s, created) SELECT amount, id, b_b, c_b, e, s, created FROM {db.table_with_schema('promises')}" f"INSERT INTO {db.table_with_schema('promises_new')} (amount, id, b_, c_, dleq_e, dleq_s, created) SELECT amount, id, b_, c_, e, s, created FROM {db.table_with_schema('promises')}"
) )
await conn.execute(f"DROP TABLE {db.table_with_schema('promises')}") await conn.execute(f"DROP TABLE {db.table_with_schema('promises')}")
await conn.execute( await conn.execute(
@@ -968,3 +959,170 @@ async def m027_add_balance_to_keysets_and_log_table(db: Database):
); );
""" """
) )
async def m028_promises_c_allow_null_add_melt_quote(db: Database):
"""
Allow column that stores the c_ to be NULL and add melt_quote to promises.
Insert all change promises from melt_quotes into the promises table.
Drop the change and the outputs columns from melt_quotes.
"""
# migrate stored melt outputs for pending quotes into promises
async def migrate_stored_melt_outputs_for_pending_quotes(
db: Database, conn: Connection
):
rows = await conn.fetchall(
f"""
SELECT quote, outputs FROM {db.table_with_schema('melt_quotes')}
WHERE state = :state AND outputs IS NOT NULL
""",
{"state": MeltQuoteState.pending.value},
)
for row in rows:
try:
outputs = json.loads(row["outputs"]) if row["outputs"] else []
except Exception:
outputs = []
for o in outputs:
amount = o.get("amount") if isinstance(o, dict) else None
keyset_id = o.get("id") if isinstance(o, dict) else None
b_hex = o.get("B_") if isinstance(o, dict) else None
if amount is None or keyset_id is None or b_hex is None:
continue
await conn.execute(
f"""
INSERT INTO {db.table_with_schema('promises')}
(amount, id, b_, created, mint_quote, melt_quote, swap_id)
VALUES (:amount, :id, :b_, :created, :mint_quote, :melt_quote, :swap_id)
""",
{
"amount": int(amount),
"id": keyset_id,
"b_": b_hex,
"created": db.to_timestamp(db.timestamp_now_str()),
"mint_quote": None,
"melt_quote": row["quote"],
"swap_id": None,
},
)
# remove obsolete columns outputs and change from melt_quotes
async def remove_obsolete_columns_from_melt_quotes(db: Database, conn: Connection):
if conn.type == "SQLITE":
# For SQLite, recreate table without the columns
await conn.execute("PRAGMA foreign_keys=OFF;")
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {db.table_with_schema('melt_quotes_new')} (
quote TEXT NOT NULL,
method TEXT NOT NULL,
request TEXT NOT NULL,
checking_id TEXT NOT NULL,
unit TEXT NOT NULL,
amount {db.big_int} NOT NULL,
fee_reserve {db.big_int},
paid BOOL NOT NULL,
created_time TIMESTAMP,
paid_time TIMESTAMP,
fee_paid {db.big_int},
proof TEXT,
state TEXT,
expiry TIMESTAMP,
UNIQUE (quote)
);
"""
)
await conn.execute(
f"""
INSERT INTO {db.table_with_schema('melt_quotes_new')} (
quote, method, request, checking_id, unit, amount, fee_reserve, paid, created_time, paid_time, fee_paid, proof, state, expiry
)
SELECT quote, method, request, checking_id, unit, amount, fee_reserve, paid, created_time, paid_time, fee_paid, proof, state, expiry
FROM {db.table_with_schema('melt_quotes')};
"""
)
await conn.execute(f"DROP TABLE {db.table_with_schema('melt_quotes')}")
await conn.execute(
f"ALTER TABLE {db.table_with_schema('melt_quotes_new')} RENAME TO {db.table_with_schema('melt_quotes')}"
)
await conn.execute("PRAGMA foreign_keys=ON;")
else:
# For Postgres/Cockroach, drop the columns directly if they exist
await conn.execute(
f"ALTER TABLE {db.table_with_schema('melt_quotes')} DROP COLUMN IF EXISTS outputs"
)
await conn.execute(
f"ALTER TABLE {db.table_with_schema('melt_quotes')} DROP COLUMN IF EXISTS change"
)
# recreate promises table with columns mint_quote, melt_quote, swap_id and with c_ nullable
async def recreate_promises_table(db: Database, conn: Connection):
if conn.type == "SQLITE":
await conn.execute("PRAGMA foreign_keys=OFF;")
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {db.table_with_schema('promises_new')} (
amount {db.big_int} NOT NULL,
id TEXT,
b_ TEXT NOT NULL,
c_ TEXT,
dleq_e TEXT,
dleq_s TEXT,
created TIMESTAMP,
signed_at TIMESTAMP,
mint_quote TEXT,
melt_quote TEXT,
swap_id TEXT,
FOREIGN KEY (mint_quote) REFERENCES {db.table_with_schema('mint_quotes')}(quote),
FOREIGN KEY (melt_quote) REFERENCES {db.table_with_schema('melt_quotes')}(quote),
UNIQUE (b_)
);
"""
)
await conn.execute(
f"INSERT INTO {db.table_with_schema('promises_new')} (amount, id, b_, c_, dleq_e, dleq_s, created, mint_quote, swap_id) "
f"SELECT amount, id, b_, c_, dleq_e, dleq_s, created, mint_quote, swap_id FROM {db.table_with_schema('promises')}"
)
await conn.execute(f"DROP TABLE {db.table_with_schema('promises')}")
await conn.execute(
f"ALTER TABLE {db.table_with_schema('promises_new')} RENAME TO {db.table_with_schema('promises')}"
)
await conn.execute("PRAGMA foreign_keys=ON;")
else:
# add columns melt_quote, signed_at and make column c_ nullable
await conn.execute(
f"ALTER TABLE {db.table_with_schema('promises')} ADD COLUMN melt_quote TEXT"
)
await conn.execute(
f"ALTER TABLE {db.table_with_schema('promises')} ADD COLUMN signed_at TIMESTAMP"
)
await conn.execute(
f"ALTER TABLE {db.table_with_schema('promises')} ALTER COLUMN c_ DROP NOT NULL"
)
# add foreign key constraint to melt_quote
await conn.execute(
f"ALTER TABLE {db.table_with_schema('promises')} ADD CONSTRAINT fk_promises_melt_quote FOREIGN KEY (melt_quote) REFERENCES {db.table_with_schema('melt_quotes')}(quote)"
)
async with db.connect() as conn:
# drop the balance views first
await drop_balance_views(db, conn)
# recreate promises table
await recreate_promises_table(db, conn)
# migrate stored melt outputs for pending quotes into promises
await migrate_stored_melt_outputs_for_pending_quotes(db, conn)
# remove obsolete columns from melt_quotes table
await remove_obsolete_columns_from_melt_quotes(db, conn)
# recreate the balance views
await create_balance_views(db, conn)

View File

@@ -0,0 +1,338 @@
import argparse
import asyncio
import datetime
import os
import re
import sqlite3
from typing import Any, Dict, Iterable, List, Optional, Tuple
from urllib.parse import urlparse
# Reuse project DB and migrations to create target schema
from cashu.core.db import Database
from cashu.core.migrations import migrate_databases
from cashu.mint import migrations as mint_migrations
DEFAULT_BATCH_SIZE = 1000
def _is_int_string(value: str) -> bool:
return bool(re.fullmatch(r"\d+", value))
def _convert_value(value: Any, decl_type: Optional[str]) -> Any:
if value is None:
return None
if not decl_type:
return value
dtype = decl_type.upper()
if "TIMESTAMP" in dtype:
# SQLite stores timestamps as INT seconds or formatted strings
if isinstance(value, (int, float)):
return datetime.datetime.fromtimestamp(int(value))
if isinstance(value, str):
if _is_int_string(value):
return datetime.datetime.fromtimestamp(int(value))
# try parse common format; fallback to raw string
for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M:%S.%f"):
try:
return datetime.datetime.strptime(value, fmt)
except Exception:
pass
return value
return value
if dtype in {"BOOL", "BOOLEAN"}:
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str) and value.lower() in {"0", "1", "true", "false"}:
return value.lower() in {"1", "true"}
return bool(value)
# BIGINT/INT: leave as-is; asyncpg will coerce ints
return value
def _get_sqlite_tables(conn: sqlite3.Connection) -> List[Tuple[str, str]]:
cur = conn.cursor()
# exclude sqlite internal tables
rows = cur.execute(
"SELECT name, type FROM sqlite_master WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' ORDER BY type, name"
).fetchall()
return [(r[0], r[1]) for r in rows]
def _get_table_columns(
conn: sqlite3.Connection, table: str
) -> List[Tuple[str, Optional[str]]]:
cur = conn.cursor()
rows = cur.execute(f"PRAGMA table_info({table})").fetchall()
# rows: cid, name, type, notnull, dflt_value, pk
return [(r[1], r[2]) for r in rows]
def _iter_sqlite_rows(
conn: sqlite3.Connection, table: str, batch_size: int
) -> Iterable[List[sqlite3.Row]]:
cur = conn.cursor()
cur.execute(f"SELECT * FROM {table}")
while True:
rows = cur.fetchmany(batch_size)
if not rows:
break
yield rows
def _prepare_insert_sql(table: str, columns: List[str]) -> str:
cols = ", ".join(columns)
params = ", ".join(f":{c}" for c in columns)
# Use ON CONFLICT DO NOTHING to make script idempotent on empty DBs
return f"INSERT INTO {table} ({cols}) VALUES ({params}) ON CONFLICT DO NOTHING"
async def _ensure_target_schema(pg_url: str) -> Database:
db = Database("mint", pg_url)
await migrate_databases(db, mint_migrations)
return db
async def _pg_table_row_count(db: Database, table: str) -> int:
try:
async with db.connect() as conn:
r = await conn.fetchone(f"SELECT COUNT(*) AS c FROM {table}")
return int(r["c"]) if r else 0
except Exception:
return 0
def _sqlite_table_row_count(conn: sqlite3.Connection, table: str) -> int:
try:
cur = conn.cursor()
return int(cur.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0])
except Exception:
return 0
async def _precheck_postgres_populated(
pg_url: str, candidate_tables: List[str]
) -> Optional[str]:
db = Database("mint", pg_url)
populated: List[Tuple[str, int]] = []
for t in candidate_tables:
cnt = await _pg_table_row_count(db, t)
if cnt > 0:
populated.append((t, cnt))
if populated:
url = urlparse(pg_url.replace("postgresql+asyncpg://", "postgres://"))
user = url.username or "<user>"
host = url.hostname or "localhost"
port = url.port or 5432
dbname = (url.path or "/").lstrip("/") or "<database>"
details = ", ".join(f"{t}={c}" for t, c in populated)
info = (
"Target Postgres database appears to be populated; aborting migration to avoid corruption.\n"
f"Detected rows: {details}.\n"
"To reset the database, connect as the proper user and run:\n"
f'psql -U {user} -h {host} -p {port} -d {dbname} -c "DROP SCHEMA public CASCADE; CREATE SCHEMA public; GRANT ALL PRIVILEGES ON SCHEMA public TO {user};"'
)
return info
return None
async def _compare_balance_views(
sqlite_conn: sqlite3.Connection, pg_db: Database
) -> Tuple[bool, str]:
# Read SQLite balance view
try:
s_rows = sqlite_conn.execute("SELECT keyset, balance FROM balance").fetchall()
sqlite_map = {str(r[0]): int(r[1]) for r in s_rows}
except Exception as e:
return False, f"Failed reading SQLite balance view: {e}"
# Read Postgres balance view
try:
async with pg_db.connect() as conn:
p_rows = await conn.fetchall("SELECT keyset, balance FROM balance")
pg_map = {str(r["keyset"]): int(r["balance"]) for r in p_rows}
except Exception as e:
return False, f"Failed reading Postgres balance view: {e}"
if sqlite_map == pg_map:
return True, "Balance views match"
# Summarize differences
diffs = []
keys = set(sqlite_map) | set(pg_map)
for k in sorted(keys):
sv = sqlite_map.get(k)
pv = pg_map.get(k)
if sv != pv:
diffs.append(f"{k}: sqlite={sv} postgres={pv}")
if len(diffs) >= 10:
diffs.append("")
break
return False, "Balance view differs: " + "; ".join(diffs)
async def _copy_table(
sqlite_conn: sqlite3.Connection,
pg_db: Database,
table: str,
batch_size: int,
) -> int:
# views are skipped; ensure table exists on target
columns_with_types = _get_table_columns(sqlite_conn, table)
if not columns_with_types:
return 0
columns = [name for name, _ in columns_with_types]
insert_sql = _prepare_insert_sql(table, columns)
total = 0
total_rows = _sqlite_table_row_count(sqlite_conn, table)
printed_done = False
# commit per batch to avoid gigantic transactions
for batch in _iter_sqlite_rows(sqlite_conn, table, batch_size):
payload: List[Dict[str, Any]] = []
for row in batch:
row_dict = {columns[i]: row[i] for i in range(len(columns))}
normalized: Dict[str, Any] = {}
for col, decl_type in columns_with_types:
normalized[col] = _convert_value(row_dict.get(col), decl_type)
payload.append(normalized)
if not payload:
continue
async with pg_db.connect() as conn: # new txn per batch
await conn.execute(insert_sql, payload)
total += len(payload)
if total_rows:
pct = int(total * 100 / total_rows)
print(f"[{table}] {total}/{total_rows} ({pct}%)", end="\r", flush=True)
printed_done = True
if printed_done:
print("")
return total
def _ordered_tables(existing: Dict[str, str]) -> List[str]:
desired_order = [
"keysets",
"mint_pubkeys",
"mint_quotes",
"melt_quotes",
"promises",
"proofs_used",
"proofs_pending",
"balance_log",
]
# Filter desired order by presence
present_ordered = [
t for t in desired_order if t in existing and existing[t] == "table"
]
# Append any other base tables not covered yet
rest = [
t
for t, typ in existing.items()
if typ == "table" and t not in present_ordered and t not in {"dbversions"}
]
return present_ordered + rest
async def migrate_sqlite_to_postgres(
sqlite_path: str, pg_url: str, batch_size: int
) -> None:
if not os.path.exists(sqlite_path):
raise FileNotFoundError(f"SQLite file not found: {sqlite_path}")
# 1) open sqlite
sqlite_conn = sqlite3.connect(sqlite_path)
sqlite_conn.row_factory = sqlite3.Row
# decide which tables to check/copy
all_tables = _get_sqlite_tables(sqlite_conn)
table_map = {name: typ for name, typ in all_tables}
skip = {"dbversions", "balance", "balance_issued", "balance_redeemed"}
candidate_tables = [
t for t, typ in table_map.items() if typ == "table" and t not in skip
]
# 2) precheck Postgres not populated
info = await _precheck_postgres_populated(pg_url, candidate_tables)
if info:
print(info)
sqlite_conn.close()
return
# 3) ensure target schema on postgres
pg_db = await _ensure_target_schema(pg_url)
# 4) inspect sqlite schema
ordered = _ordered_tables(table_map)
ordered = [t for t in ordered if t not in skip]
# 5) copy data
for tbl in ordered:
print(f"Copying table: {tbl}")
count = await _copy_table(sqlite_conn, pg_db, tbl, batch_size)
print(f"Copied {count} rows from {tbl}")
# 6) verification: compare table row counts and balance view
print("Verifying data integrity …")
mismatches: List[str] = []
for tbl in ordered:
s_cnt = _sqlite_table_row_count(sqlite_conn, tbl)
p_cnt = await _pg_table_row_count(pg_db, tbl)
if s_cnt != p_cnt:
mismatches.append(f"{tbl}: sqlite={s_cnt} postgres={p_cnt}")
ok_balance, balance_msg = await _compare_balance_views(sqlite_conn, pg_db)
# 7) finalize
await pg_db.engine.dispose() # close connections cleanly
sqlite_conn.close()
if mismatches:
print("WARNING: Row count mismatches detected:")
for m in mismatches:
print(f" - {m}")
if not ok_balance:
print(f"WARNING: {balance_msg}")
if not mismatches and ok_balance:
total_rows_copied = sum(
_sqlite_table_row_count(sqlite3.connect(sqlite_path), t) for t in ordered
)
print(
"Migration successful: all row counts match and balance view is identical.\n"
f"Tables migrated: {len(ordered)}, total rows: {total_rows_copied}."
)
else:
print("Migration completed with warnings. Review the messages above.")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Migrate Cashu mint SQLite DB to Postgres"
)
parser.add_argument("--sqlite", required=True, help="Path to mint.sqlite3 file")
parser.add_argument(
"--postgres",
required=True,
help="Postgres connection string, e.g. postgres://user:pass@host:5432/dbname",
)
parser.add_argument(
"--batch-size",
type=int,
default=DEFAULT_BATCH_SIZE,
help=f"Batch size for inserts (default {DEFAULT_BATCH_SIZE})",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
asyncio.run(migrate_sqlite_to_postgres(args.sqlite, args.postgres, args.batch_size))
if __name__ == "__main__":
main()

View File

@@ -37,6 +37,8 @@ settings.fakewallet_brr = True
settings.fakewallet_delay_outgoing_payment = 0 settings.fakewallet_delay_outgoing_payment = 0
settings.fakewallet_delay_incoming_payment = 1 settings.fakewallet_delay_incoming_payment = 1
settings.fakewallet_stochastic_invoice = False settings.fakewallet_stochastic_invoice = False
settings.lightning_fee_percent = 2.0
settings.lightning_reserve_fee_min = 2000 # msat
assert ( assert (
settings.mint_test_database != settings.mint_database settings.mint_test_database != settings.mint_database
), "Test database is the same as the main database" ), "Test database is the same as the main database"

View File

@@ -1,235 +1,365 @@
from typing import List from typing import List
import pytest import pytest
from cashu.core.base import BlindedMessage, Proof, Unit from cashu.core.base import BlindedMessage, Proof, Unit
from cashu.core.crypto.b_dhke import step1_alice from cashu.core.crypto.b_dhke import step1_alice
from cashu.core.helpers import calculate_number_of_blank_outputs from cashu.core.helpers import calculate_number_of_blank_outputs
from cashu.core.models import PostMintQuoteRequest from cashu.core.models import PostMeltQuoteRequest, PostMintQuoteRequest
from cashu.core.settings import settings from cashu.core.settings import settings
from cashu.mint.ledger import Ledger from cashu.mint.ledger import Ledger
from tests.helpers import pay_if_regtest from tests.helpers import pay_if_regtest
async def assert_err(f, msg): async def assert_err(f, msg):
"""Compute f() and expect an error message 'msg'.""" """Compute f() and expect an error message 'msg'."""
try: try:
await f await f
except Exception as exc: except Exception as exc:
assert exc.args[0] == msg, Exception( assert exc.args[0] == msg, Exception(
f"Expected error: {msg}, got: {exc.args[0]}" f"Expected error: {msg}, got: {exc.args[0]}"
) )
def assert_amt(proofs: List[Proof], expected: int): def assert_amt(proofs: List[Proof], expected: int):
"""Assert amounts the proofs contain.""" """Assert amounts the proofs contain."""
assert [p.amount for p in proofs] == expected assert [p.amount for p in proofs] == expected
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pubkeys(ledger: Ledger): async def test_pubkeys(ledger: Ledger):
assert ledger.keyset.public_keys assert ledger.keyset.public_keys
assert ( assert (
ledger.keyset.public_keys[1].serialize().hex() ledger.keyset.public_keys[1].serialize().hex()
== "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" == "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104"
) )
assert ( assert (
ledger.keyset.public_keys[2 ** (settings.max_order - 1)].serialize().hex() ledger.keyset.public_keys[2 ** (settings.max_order - 1)].serialize().hex()
== "023c84c0895cc0e827b348ea0a62951ca489a5e436f3ea7545f3c1d5f1bea1c866" == "023c84c0895cc0e827b348ea0a62951ca489a5e436f3ea7545f3c1d5f1bea1c866"
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_privatekeys(ledger: Ledger): async def test_privatekeys(ledger: Ledger):
assert ledger.keyset.private_keys assert ledger.keyset.private_keys
assert ( assert (
ledger.keyset.private_keys[1].serialize() ledger.keyset.private_keys[1].serialize()
== "8300050453f08e6ead1296bb864e905bd46761beed22b81110fae0751d84604d" == "8300050453f08e6ead1296bb864e905bd46761beed22b81110fae0751d84604d"
) )
assert ( assert (
ledger.keyset.private_keys[2 ** (settings.max_order - 1)].serialize() ledger.keyset.private_keys[2 ** (settings.max_order - 1)].serialize()
== "b0477644cb3d82ffcc170bc0a76e0409727232e87c5ae51d64a259936228c7be" == "b0477644cb3d82ffcc170bc0a76e0409727232e87c5ae51d64a259936228c7be"
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_keysets(ledger: Ledger): async def test_keysets(ledger: Ledger):
assert len(ledger.keysets) assert len(ledger.keysets)
assert len(list(ledger.keysets.keys())) assert len(list(ledger.keysets.keys()))
assert ledger.keyset.id == "009a1f293253e41e" assert ledger.keyset.id == "009a1f293253e41e"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_keyset(ledger: Ledger): async def test_get_keyset(ledger: Ledger):
keyset = ledger.get_keyset() keyset = ledger.get_keyset()
assert isinstance(keyset, dict) assert isinstance(keyset, dict)
assert len(keyset) == settings.max_order assert len(keyset) == settings.max_order
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mint(ledger: Ledger): async def test_mint(ledger: Ledger):
quote = await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat")) quote = await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat"))
await pay_if_regtest(quote.request) await pay_if_regtest(quote.request)
blinded_messages_mock = [ blinded_messages_mock = [
BlindedMessage( BlindedMessage(
amount=8, amount=8,
B_="02634a2c2b34bec9e8a4aba4361f6bf202d7fa2365379b0840afe249a7a9d71239", B_="02634a2c2b34bec9e8a4aba4361f6bf202d7fa2365379b0840afe249a7a9d71239",
id="009a1f293253e41e", id="009a1f293253e41e",
) )
] ]
promises = await ledger.mint(outputs=blinded_messages_mock, quote_id=quote.quote) promises = await ledger.mint(outputs=blinded_messages_mock, quote_id=quote.quote)
assert len(promises) assert len(promises)
assert promises[0].amount == 8 assert promises[0].amount == 8
assert ( assert (
promises[0].C_ promises[0].C_
== "031422eeffb25319e519c68de000effb294cb362ef713a7cf4832cea7b0452ba6e" == "031422eeffb25319e519c68de000effb294cb362ef713a7cf4832cea7b0452ba6e"
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mint_invalid_quote(ledger: Ledger): async def test_mint_invalid_quote(ledger: Ledger):
await assert_err( await assert_err(
ledger.get_mint_quote(quote_id="invalid_quote_id"), ledger.get_mint_quote(quote_id="invalid_quote_id"),
"quote not found", "quote not found",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_melt_invalid_quote(ledger: Ledger): async def test_melt_invalid_quote(ledger: Ledger):
await assert_err( await assert_err(
ledger.get_melt_quote(quote_id="invalid_quote_id"), ledger.get_melt_quote(quote_id="invalid_quote_id"),
"quote not found", "quote not found",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mint_invalid_blinded_message(ledger: Ledger): async def test_mint_invalid_blinded_message(ledger: Ledger):
quote = await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat")) quote = await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat"))
await pay_if_regtest(quote.request) await pay_if_regtest(quote.request)
blinded_messages_mock_invalid_key = [ blinded_messages_mock_invalid_key = [
BlindedMessage( BlindedMessage(
amount=8, amount=8,
B_="02634a2c2b34bec9e8a4aba4361f6bff02d7fa2365379b0840afe249a7a9d71237", B_="02634a2c2b34bec9e8a4aba4361f6bff02d7fa2365379b0840afe249a7a9d71237",
id="009a1f293253e41e", id="009a1f293253e41e",
) )
] ]
await assert_err( await assert_err(
ledger.mint(outputs=blinded_messages_mock_invalid_key, quote_id=quote.quote), ledger.mint(outputs=blinded_messages_mock_invalid_key, quote_id=quote.quote),
"invalid public key", "invalid public key",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_promises(ledger: Ledger): async def test_generate_promises(ledger: Ledger):
blinded_messages_mock = [ blinded_messages_mock = [
BlindedMessage( BlindedMessage(
amount=8, amount=8,
B_="02634a2c2b34bec9e8a4aba4361f6bf202d7fa2365379b0840afe249a7a9d71239", B_="02634a2c2b34bec9e8a4aba4361f6bf202d7fa2365379b0840afe249a7a9d71239",
id="009a1f293253e41e", id="009a1f293253e41e",
) )
] ]
promises = await ledger._generate_promises(blinded_messages_mock) await ledger._store_blinded_messages(blinded_messages_mock)
assert ( promises = await ledger._sign_blinded_messages(blinded_messages_mock)
promises[0].C_ assert (
== "031422eeffb25319e519c68de000effb294cb362ef713a7cf4832cea7b0452ba6e" promises[0].C_
) == "031422eeffb25319e519c68de000effb294cb362ef713a7cf4832cea7b0452ba6e"
assert promises[0].amount == 8 )
assert promises[0].id == "009a1f293253e41e" assert promises[0].amount == 8
assert promises[0].id == "009a1f293253e41e"
# DLEQ proof present
assert promises[0].dleq # DLEQ proof present
assert promises[0].dleq.s assert promises[0].dleq
assert promises[0].dleq.e assert promises[0].dleq.s
assert promises[0].dleq.e
@pytest.mark.asyncio
async def test_generate_change_promises(ledger: Ledger): @pytest.mark.asyncio
# Example slightly adapted from NUT-08 because we want to ensure the dynamic change async def test_generate_change_promises(ledger: Ledger):
# token amount works: `n_blank_outputs != n_returned_promises != 4`. # Example slightly adapted from NUT-08 because we want to ensure the dynamic change
# invoice_amount = 100_000 # token amount works: `n_blank_outputs != n_returned_promises != 4`.
fee_reserve = 2_000 # invoice_amount = 100_000
# total_provided = invoice_amount + fee_reserve fee_reserve = 2_000
actual_fee = 100 # total_provided = invoice_amount + fee_reserve
actual_fee = 100
expected_returned_promises = 7 # Amounts = [4, 8, 32, 64, 256, 512, 1024]
expected_returned_fees = 1900 expected_returned_promises = 7 # Amounts = [4, 8, 32, 64, 256, 512, 1024]
expected_returned_fees = 1900
n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve)
blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)] n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve)
outputs = [ blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)]
BlindedMessage( outputs = [
amount=1, BlindedMessage(
B_=b.serialize().hex(), amount=1,
id="009a1f293253e41e", B_=b.serialize().hex(),
) id="009a1f293253e41e",
for b, _ in blinded_msgs )
] for b, _ in blinded_msgs
]
promises = await ledger._generate_change_promises( await ledger._store_blinded_messages(outputs)
fee_provided=fee_reserve, fee_paid=actual_fee, outputs=outputs promises = await ledger._generate_change_promises(
) fee_provided=fee_reserve, fee_paid=actual_fee, outputs=outputs
)
assert len(promises) == expected_returned_promises
assert sum([promise.amount for promise in promises]) == expected_returned_fees assert len(promises) == expected_returned_promises
assert sum([promise.amount for promise in promises]) == expected_returned_fees
@pytest.mark.asyncio
async def test_generate_change_promises_legacy_wallet(ledger: Ledger): @pytest.mark.asyncio
# Check if mint handles a legacy wallet implementation (always sends 4 blank async def test_generate_change_promises_legacy_wallet(ledger: Ledger):
# outputs) as well. # Check if mint handles a legacy wallet implementation (always sends 4 blank
# invoice_amount = 100_000 # outputs) as well.
fee_reserve = 2_000 # invoice_amount = 100_000
# total_provided = invoice_amount + fee_reserve fee_reserve = 2_000
actual_fee = 100 # total_provided = invoice_amount + fee_reserve
actual_fee = 100
expected_returned_promises = 4 # Amounts = [64, 256, 512, 1024]
expected_returned_fees = 1856 expected_returned_promises = 4 # Amounts = [64, 256, 512, 1024]
expected_returned_fees = 1856
n_blank_outputs = 4
blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)] n_blank_outputs = 4
outputs = [ blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)]
BlindedMessage( outputs = [
amount=1, BlindedMessage(
B_=b.serialize().hex(), amount=1,
id="009a1f293253e41e", B_=b.serialize().hex(),
) id="009a1f293253e41e",
for b, _ in blinded_msgs )
] for b, _ in blinded_msgs
]
promises = await ledger._generate_change_promises(fee_reserve, actual_fee, outputs)
await ledger._store_blinded_messages(outputs)
assert len(promises) == expected_returned_promises promises = await ledger._generate_change_promises(fee_reserve, actual_fee, outputs)
assert sum([promise.amount for promise in promises]) == expected_returned_fees
assert len(promises) == expected_returned_promises
assert sum([promise.amount for promise in promises]) == expected_returned_fees
@pytest.mark.asyncio
async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledger):
# invoice_amount = 100_000 @pytest.mark.asyncio
fee_reserve = 1_000 async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledger):
# total_provided = invoice_amount + fee_reserve # invoice_amount = 100_000
actual_fee_msat = 100_000 fee_reserve = 1_000
outputs = None # total_provided = invoice_amount + fee_reserve
actual_fee_msat = 100_000
promises = await ledger._generate_change_promises( outputs = None
fee_reserve, actual_fee_msat, outputs
) promises = await ledger._generate_change_promises(
assert len(promises) == 0 fee_reserve, actual_fee_msat, outputs
)
assert len(promises) == 0
@pytest.mark.asyncio
async def test_get_balance(ledger: Ledger):
unit = Unit["sat"] @pytest.mark.asyncio
balance, fees_paid = await ledger.get_balance(unit) async def test_get_balance(ledger: Ledger):
assert balance == 0 unit = Unit["sat"]
assert fees_paid == 0 balance, fees_paid = await ledger.get_balance(unit)
assert balance == 0
assert fees_paid == 0
@pytest.mark.asyncio
async def test_maximum_balance(ledger: Ledger):
settings.mint_max_balance = 1000 @pytest.mark.asyncio
await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat")) async def test_maximum_balance(ledger: Ledger):
await assert_err( settings.mint_max_balance = 1000
ledger.mint_quote(PostMintQuoteRequest(amount=8000, unit="sat")), await ledger.mint_quote(PostMintQuoteRequest(amount=8, unit="sat"))
"Mint has reached maximum balance.", await assert_err(
) ledger.mint_quote(PostMintQuoteRequest(amount=8000, unit="sat")),
settings.mint_max_balance = 0 "Mint has reached maximum balance.",
)
settings.mint_max_balance = 0
@pytest.mark.asyncio
async def test_generate_change_promises_signs_subset_and_deletes_rest(ledger: Ledger):
from cashu.core.base import BlindedMessage
from cashu.core.crypto.b_dhke import step1_alice
from cashu.core.split import amount_split
# Create a real melt quote to satisfy FK on promises.melt_quote
mint_quote_resp = await ledger.mint_quote(
PostMintQuoteRequest(amount=64, unit="sat")
)
melt_quote_resp = await ledger.melt_quote(
PostMeltQuoteRequest(request=mint_quote_resp.request, unit="sat")
)
melt_id = melt_quote_resp.quote
fee_provided = 2_000
fee_paid = 100
overpaid_fee = fee_provided - fee_paid
return_amounts = amount_split(overpaid_fee)
# Store more blank outputs than needed for the change.
extra_blanks = 3
n_blank = len(return_amounts) + extra_blanks
blank_outputs = [
BlindedMessage(
amount=1,
B_=step1_alice(f"change_blank_{i}")[0].serialize().hex(),
id=ledger.keyset.id,
)
for i in range(n_blank)
]
await ledger._store_blinded_messages(blank_outputs, melt_id=melt_id)
# Fetch the stored unsigned blanks (same as melt flow) and run change generation.
stored_outputs = await ledger.crud.get_blinded_messages_melt_id(
db=ledger.db, melt_id=melt_id
)
assert len(stored_outputs) == n_blank
promises = await ledger._generate_change_promises(
fee_provided=fee_provided,
fee_paid=fee_paid,
outputs=stored_outputs,
melt_id=melt_id,
keyset=ledger.keyset,
)
assert len(promises) == len(return_amounts)
assert sorted(p.amount for p in promises) == sorted(return_amounts)
# All unsigned blanks should be deleted after signing the subset.
remaining_unsigned = await ledger.crud.get_blinded_messages_melt_id(
db=ledger.db, melt_id=melt_id
)
assert remaining_unsigned == []
# The signed promises should remain in the DB with c_ set.
async with ledger.db.connect() as conn:
rows = await conn.fetchall(
f"""
SELECT amount, c_ FROM {ledger.db.table_with_schema('promises')}
WHERE melt_quote = :melt_id
""",
{"melt_id": melt_id},
)
assert len(rows) == len(return_amounts)
assert all(row["c_"] for row in rows)
assert sorted(int(row["amount"]) for row in rows) == sorted(return_amounts)
@pytest.mark.asyncio
async def test_generate_change_promises_zero_fee_deletes_all_blanks(ledger: Ledger):
from cashu.core.base import BlindedMessage
from cashu.core.crypto.b_dhke import step1_alice
# Create a real melt quote to satisfy FK on promises.melt_quote
mint_quote_resp = await ledger.mint_quote(
PostMintQuoteRequest(amount=64, unit="sat")
)
melt_quote_resp = await ledger.melt_quote(
PostMeltQuoteRequest(request=mint_quote_resp.request, unit="sat")
)
melt_id = melt_quote_resp.quote
fee_provided = 1_000
fee_paid = 1_000 # no overpaid fee
n_blank = 4
blank_outputs = [
BlindedMessage(
amount=1,
B_=step1_alice(f"no_fee_blank_{i}")[0].serialize().hex(),
id=ledger.keyset.id,
)
for i in range(n_blank)
]
await ledger._store_blinded_messages(blank_outputs, melt_id=melt_id)
stored_outputs = await ledger.crud.get_blinded_messages_melt_id(
db=ledger.db, melt_id=melt_id
)
assert len(stored_outputs) == n_blank
promises = await ledger._generate_change_promises(
fee_provided=fee_provided,
fee_paid=fee_paid,
outputs=stored_outputs,
melt_id=melt_id,
keyset=ledger.keyset,
)
assert promises == []
remaining_unsigned = await ledger.crud.get_blinded_messages_melt_id(
db=ledger.db, melt_id=melt_id
)
# With zero fee nothing is signed or deleted; blanks stay pending.
assert len(remaining_unsigned) == n_blank
async with ledger.db.connect() as conn:
rows = await conn.fetchall(
f"""
SELECT amount, c_ FROM {ledger.db.table_with_schema('promises')}
WHERE melt_quote = :melt_id
""",
{"melt_id": melt_id},
)
assert len(rows) == n_blank
assert all(row["c_"] is None for row in rows)

View File

@@ -23,9 +23,7 @@ from tests.helpers import (
pay_if_regtest, pay_if_regtest,
) )
payment_request = ( payment_request = "lnbc1u1p5qeft3sp5jn5cqclnxvucfqtjm8qnlar2vhevcuudpccv7tsuglruj3qm579spp5ygdhy0t7xu53myke8z3z024xhz4kzgk9fcqk64sp0fyeqzhmaswqdqqcqpjrzjq0euzzxv65mts5ngg8c2t3vzz2aeuevy5845jvyqulqucd8c9kkhzrtp55qq63qqqqqqqqqqqqqzwyqqyg9qxpqysgqscprcpnk8whs3askqhgu6z5a4hupyn8du2aahdcf00s5pxrs4g94sv9f95xdn4tu0wec7kfyzj439wu9z27k6m6e3q4ysjquf5agx7gp0eeye4"
"lnbc1u1p5qeft3sp5jn5cqclnxvucfqtjm8qnlar2vhevcuudpccv7tsuglruj3qm579spp5ygdhy0t7xu53myke8z3z024xhz4kzgk9fcqk64sp0fyeqzhmaswqdqqcqpjrzjq0euzzxv65mts5ngg8c2t3vzz2aeuevy5845jvyqulqucd8c9kkhzrtp55qq63qqqqqqqqqqqqqzwyqqyg9qxpqysgqscprcpnk8whs3askqhgu6z5a4hupyn8du2aahdcf00s5pxrs4g94sv9f95xdn4tu0wec7kfyzj439wu9z27k6m6e3q4ysjquf5agx7gp0eeye4"
)
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
@@ -295,30 +293,51 @@ async def test_db_events_add_client(wallet: Wallet, ledger: Ledger):
# remove subscription # remove subscription
client.remove_subscription("subId") client.remove_subscription("subId")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_db_update_mint_quote_state(wallet: Wallet, ledger: Ledger): async def test_db_update_mint_quote_state(wallet: Wallet, ledger: Ledger):
mint_quote = await wallet.request_mint(128) mint_quote = await wallet.request_mint(128)
await ledger.db_write._update_mint_quote_state(mint_quote.quote, MintQuoteState.paid) await ledger.db_write._update_mint_quote_state(
mint_quote.quote, MintQuoteState.paid
mint_quote_db = await ledger.crud.get_mint_quote(quote_id=mint_quote.quote, db=ledger.db) )
mint_quote_db = await ledger.crud.get_mint_quote(
quote_id=mint_quote.quote, db=ledger.db
)
assert mint_quote_db
assert mint_quote_db.state == MintQuoteState.paid assert mint_quote_db.state == MintQuoteState.paid
# Update it to issued # Update it to issued
await ledger.db_write._update_mint_quote_state(mint_quote_db.quote, MintQuoteState.issued) await ledger.db_write._update_mint_quote_state(
mint_quote_db.quote, MintQuoteState.issued
)
# Try and revert it back to unpaid # Try and revert it back to unpaid
await assert_err(ledger.db_write._update_mint_quote_state(mint_quote_db.quote, MintQuoteState.unpaid), "Cannot change state of an issued mint quote.") await assert_err(
ledger.db_write._update_mint_quote_state(
mint_quote_db.quote, MintQuoteState.unpaid
),
"Cannot change state of an issued mint quote.",
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skipif( @pytest.mark.skipif(is_deprecated_api_only, reason=("Deprecated API"))
is_deprecated_api_only,
reason=("Deprecated API")
)
async def test_db_update_melt_quote_state(wallet: Wallet, ledger: Ledger): async def test_db_update_melt_quote_state(wallet: Wallet, ledger: Ledger):
melt_quote = await wallet.melt_quote(payment_request) melt_quote = await wallet.melt_quote(payment_request)
await ledger.db_write._update_melt_quote_state(melt_quote.quote, MeltQuoteState.paid) await ledger.db_write._update_melt_quote_state(
melt_quote.quote, MeltQuoteState.paid
)
melt_quote_db = await ledger.crud.get_melt_quote(quote_id=melt_quote.quote, db=ledger.db) melt_quote_db = await ledger.crud.get_melt_quote(
quote_id=melt_quote.quote, db=ledger.db
)
assert melt_quote_db
assert melt_quote_db.state == MeltQuoteState.paid assert melt_quote_db.state == MeltQuoteState.paid
await assert_err(ledger.db_write._update_melt_quote_state(melt_quote.quote, MeltQuoteState.unpaid), "Cannot change state of a paid melt quote.") await assert_err(
ledger.db_write._update_melt_quote_state(
melt_quote.quote, MeltQuoteState.unpaid
),
"Cannot change state of a paid melt quote.",
)

View File

@@ -10,6 +10,7 @@ import pytest_asyncio
from cashu.core import db from cashu.core import db
from cashu.core.db import Connection from cashu.core.db import Connection
from cashu.core.migrations import backup_database from cashu.core.migrations import backup_database
from cashu.core.models import PostMeltQuoteRequest
from cashu.core.settings import settings from cashu.core.settings import settings
from cashu.mint.ledger import Ledger from cashu.mint.ledger import Ledger
from cashu.wallet.wallet import Wallet from cashu.wallet.wallet import Wallet
@@ -353,3 +354,402 @@ async def test_db_lock_table(wallet: Wallet, ledger: Ledger):
), ),
"failed to acquire database lock", "failed to acquire database lock",
) )
@pytest.mark.asyncio
async def test_store_and_sign_blinded_message(ledger: Ledger):
# Localized imports to avoid polluting module scope
from cashu.core.crypto.b_dhke import step1_alice, step2_bob
from cashu.core.crypto.secp import PublicKey
# Arrange: prepare a blinded message tied to current active keyset
amount = 8
keyset_id = ledger.keyset.id
B_pubkey, _ = step1_alice("test_store_and_sign_blinded_message")
B_hex = B_pubkey.serialize().hex()
# Act: store the blinded message (unsinged promise row)
await ledger.crud.store_blinded_message(
db=ledger.db,
amount=amount,
b_=B_hex,
id=keyset_id,
)
# Act: compute a valid blind signature for the stored row and persist it
private_key_amount = ledger.keyset.private_keys[amount]
B_point = PublicKey(bytes.fromhex(B_hex), raw=True)
C_point, e, s = step2_bob(B_point, private_key_amount)
await ledger.crud.update_blinded_message_signature(
db=ledger.db,
amount=amount,
b_=B_hex,
c_=C_point.serialize().hex(),
e=e.serialize(),
s=s.serialize(),
)
# Assert: row is now a full promise and can be read back via get_promise
promise = await ledger.crud.get_promise(db=ledger.db, b_=B_hex)
assert promise is not None
assert promise.amount == amount
assert promise.C_ == C_point.serialize().hex()
assert promise.id == keyset_id
@pytest.mark.asyncio
async def test_get_blinded_messages_by_melt_id(wallet: Wallet, ledger: Ledger):
# Arrange
from cashu.core.crypto.b_dhke import step1_alice
amount = 8
keyset_id = ledger.keyset.id
# Create a real melt quote to satisfy FK on promises.melt_quote
mint_quote = await wallet.request_mint(64)
melt_quote = await ledger.melt_quote(
PostMeltQuoteRequest(request=mint_quote.request, unit="sat")
)
melt_id = melt_quote.quote
# Create two blinded messages
B1, _ = step1_alice("get_by_melt_id_1")
B2, _ = step1_alice("get_by_melt_id_2")
b1_hex = B1.serialize().hex()
b2_hex = B2.serialize().hex()
# Persist as unsigned messages with proper melt_id FK
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b1_hex, id=keyset_id, melt_id=melt_id
)
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b2_hex, id=keyset_id, melt_id=melt_id
)
# Act
rows = await ledger.crud.get_blinded_messages_melt_id(db=ledger.db, melt_id=melt_id)
# Assert
assert len(rows) == 2
assert {r.B_ for r in rows} == {b1_hex, b2_hex}
assert all(r.id == keyset_id for r in rows)
@pytest.mark.asyncio
async def test_delete_blinded_messages_by_melt_id(wallet: Wallet, ledger: Ledger):
from cashu.core.crypto.b_dhke import step1_alice
amount = 4
keyset_id = ledger.keyset.id
# Create a real melt quote to satisfy FK on promises.melt_quote
mint_quote = await wallet.request_mint(64)
melt_quote = await ledger.melt_quote(
PostMeltQuoteRequest(request=mint_quote.request, unit="sat")
)
melt_id = melt_quote.quote
# Create two blinded messages
B1, _ = step1_alice("delete_by_melt_id_1")
B2, _ = step1_alice("delete_by_melt_id_2")
b1_hex = B1.serialize().hex()
b2_hex = B2.serialize().hex()
# Persist as unsigned messages
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b1_hex, id=keyset_id, melt_id=melt_id
)
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b2_hex, id=keyset_id, melt_id=melt_id
)
rows_before = await ledger.crud.get_blinded_messages_melt_id(
db=ledger.db, melt_id=melt_id
)
assert len(rows_before) == 2
# Act: delete all unsigned messages for this melt_id
await ledger.crud.delete_blinded_messages_melt_id(db=ledger.db, melt_id=melt_id)
# Assert: now none left for that melt_id
rows_after = await ledger.crud.get_blinded_messages_melt_id(
db=ledger.db, melt_id=melt_id
)
assert rows_after == []
@pytest.mark.asyncio
async def test_get_blinded_messages_by_melt_id_filters_signed(
wallet: Wallet, ledger: Ledger
):
from cashu.core.crypto.b_dhke import step1_alice, step2_bob
from cashu.core.crypto.secp import PublicKey
amount = 2
keyset_id = ledger.keyset.id
# Create a real melt quote to satisfy FK on promises.melt_quote
mint_quote = await wallet.request_mint(64)
melt_quote = await ledger.melt_quote(
PostMeltQuoteRequest(request=mint_quote.request, unit="sat")
)
melt_id = melt_quote.quote
B1, _ = step1_alice("filter_by_melt_id_1")
B2, _ = step1_alice("filter_by_melt_id_2")
b1_hex = B1.serialize().hex()
b2_hex = B2.serialize().hex()
# Persist two unsigned messages
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b1_hex, id=keyset_id, melt_id=melt_id
)
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b2_hex, id=keyset_id, melt_id=melt_id
)
# Sign one of them (it should no longer be returned by get_blinded_messages_melt_id which filters c_ IS NULL)
priv = ledger.keyset.private_keys[amount]
C_point, e, s = step2_bob(PublicKey(bytes.fromhex(b1_hex), raw=True), priv)
await ledger.crud.update_blinded_message_signature(
db=ledger.db,
amount=amount,
b_=b1_hex,
c_=C_point.serialize().hex(),
e=e.serialize(),
s=s.serialize(),
)
# Act
rows = await ledger.crud.get_blinded_messages_melt_id(db=ledger.db, melt_id=melt_id)
# Assert: only the unsigned one remains (b2_hex)
assert len(rows) == 1
assert rows[0].B_ == b2_hex
assert rows[0].id == keyset_id
@pytest.mark.asyncio
async def test_store_blinded_message(ledger: Ledger):
from cashu.core.crypto.b_dhke import step1_alice
amount = 8
keyset_id = ledger.keyset.id
B_pub, _ = step1_alice("test_store_blinded_message")
b_hex = B_pub.serialize().hex()
# Act: store unsigned blinded message
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b_hex, id=keyset_id
)
# Assert: row exists and is unsigned (c_ IS NULL)
async with ledger.db.connect() as conn:
row = await conn.fetchone(
f"SELECT amount, id, b_, c_, created FROM {ledger.db.table_with_schema('promises')} WHERE b_ = :b_",
{"b_": b_hex},
)
assert row is not None
assert int(row["amount"]) == amount
assert row["id"] == keyset_id
assert row["b_"] == b_hex
assert row["c_"] is None
assert row["created"] is not None
@pytest.mark.asyncio
async def test_update_blinded_message_signature_before_store_blinded_message_errors(
ledger: Ledger,
):
from cashu.core.crypto.b_dhke import step1_alice, step2_bob
from cashu.core.crypto.secp import PublicKey
amount = 8
# Generate a blinded message that we will NOT store
B_pub, _ = step1_alice("test_sign_before_store_blinded_message")
b_hex = B_pub.serialize().hex()
# Create a valid signature tuple for that blinded message
priv = ledger.keyset.private_keys[amount]
C_point, e, s = step2_bob(PublicKey(bytes.fromhex(b_hex), raw=True), priv)
# Expect a DB-level error; on SQLite/Postgres this is typically a no-op update, so this test is xfail.
await assert_err(
ledger.crud.update_blinded_message_signature(
db=ledger.db,
amount=amount,
b_=b_hex,
c_=C_point.serialize().hex(),
e=e.serialize(),
s=s.serialize(),
),
"blinded message does not exist",
)
@pytest.mark.asyncio
async def test_store_blinded_message_duplicate_b_(ledger: Ledger):
from cashu.core.crypto.b_dhke import step1_alice
amount = 2
keyset_id = ledger.keyset.id
B_pub, _ = step1_alice("test_duplicate_b_")
b_hex = B_pub.serialize().hex()
# First insert should succeed
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b_hex, id=keyset_id
)
@pytest.mark.asyncio
async def test_get_blind_signatures_by_melt_id_returns_signed(
wallet: Wallet, ledger: Ledger
):
from cashu.core.crypto.b_dhke import step1_alice, step2_bob
from cashu.core.crypto.secp import PublicKey
amount = 4
keyset_id = ledger.keyset.id
# Create a real melt quote to satisfy FK on promises.melt_quote
mint_quote = await wallet.request_mint(64)
melt_quote = await ledger.melt_quote(
PostMeltQuoteRequest(request=mint_quote.request, unit="sat")
)
melt_id = melt_quote.quote
# Prepare two blinded messages under the same melt_id
B1, _ = step1_alice("signed_promises_by_melt_id_1")
B2, _ = step1_alice("signed_promises_by_melt_id_2")
b1_hex = B1.serialize().hex()
b2_hex = B2.serialize().hex()
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b1_hex, id=keyset_id, melt_id=melt_id
)
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b2_hex, id=keyset_id, melt_id=melt_id
)
# Sign only one of them -> should be returned by get_blind_signatures_melt_id
priv = ledger.keyset.private_keys[amount]
C_point, e, s = step2_bob(PublicKey(bytes.fromhex(b1_hex), raw=True), priv)
await ledger.crud.update_blinded_message_signature(
db=ledger.db,
amount=amount,
b_=b1_hex,
c_=C_point.serialize().hex(),
e=e.serialize(),
s=s.serialize(),
)
# Act
signed = await ledger.crud.get_blind_signatures_melt_id(
db=ledger.db, melt_id=melt_id
)
# Assert: only the signed one is returned
assert len(signed) == 1
assert signed[0].amount == amount
assert signed[0].id == keyset_id
@pytest.mark.asyncio
async def test_get_melt_quote_includes_change_signatures(
wallet: Wallet, ledger: Ledger
):
from cashu.core.crypto.b_dhke import step1_alice, step2_bob
from cashu.core.crypto.secp import PublicKey
amount = 8
keyset_id = ledger.keyset.id
# Create melt quote and attach outputs/promises under its melt_id
mint_quote = await wallet.request_mint(64)
melt_quote = await ledger.melt_quote(
PostMeltQuoteRequest(request=mint_quote.request, unit="sat")
)
melt_id = melt_quote.quote
# Create two blinded messages, sign one -> becomes change
B1, _ = step1_alice("melt_quote_change_1")
B2, _ = step1_alice("melt_quote_change_2")
b1_hex = B1.serialize().hex()
b2_hex = B2.serialize().hex()
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b1_hex, id=keyset_id, melt_id=melt_id
)
await ledger.crud.store_blinded_message(
db=ledger.db, amount=amount, b_=b2_hex, id=keyset_id, melt_id=melt_id
)
# Sign one -> should appear in change loaded by get_melt_quote
priv = ledger.keyset.private_keys[amount]
C_point, e, s = step2_bob(PublicKey(bytes.fromhex(b1_hex), raw=True), priv)
await ledger.crud.update_blinded_message_signature(
db=ledger.db,
amount=amount,
b_=b1_hex,
c_=C_point.serialize().hex(),
e=e.serialize(),
s=s.serialize(),
)
# Act
quote_db = await ledger.crud.get_melt_quote(quote_id=melt_id, db=ledger.db)
# Assert: change contains the signed promise(s)
assert quote_db is not None
assert quote_db.quote == melt_id
assert quote_db.change is not None
assert len(quote_db.change) == 1
assert quote_db.change[0].amount == amount
assert quote_db.change[0].id == keyset_id
@pytest.mark.asyncio
async def test_promises_fk_constraints_enforced(ledger: Ledger):
from cashu.core.crypto.b_dhke import step1_alice
keyset_id = ledger.keyset.id
B1, _ = step1_alice("fk_check_melt")
B2, _ = step1_alice("fk_check_mint")
b1_hex = B1.serialize().hex()
b2_hex = B2.serialize().hex()
# Use a single connection and enable FK enforcement on SQLite
async with ledger.db.connect() as conn:
# Fake melt_id should violate FK on promises.melt_quote
await assert_err_multiple(
ledger.crud.store_blinded_message(
db=ledger.db,
amount=1,
b_=b1_hex,
id=keyset_id,
melt_id="nonexistent-melt-id",
conn=conn,
),
[
"FOREIGN KEY", # SQLite
"violates foreign key constraint", # Postgres
],
)
async with ledger.db.connect() as conn:
# Fake mint_id should violate FK on promises.mint_quote
await assert_err_multiple(
ledger.crud.store_blinded_message(
db=ledger.db,
amount=1,
b_=b2_hex,
id=keyset_id,
mint_id="nonexistent-mint-id",
conn=conn,
),
[
"FOREIGN KEY", # SQLite
"violates foreign key constraint", # Postgres
],
)
# Done. This test only checks FK enforcement paths.

View File

@@ -51,13 +51,13 @@ async def test_wallet_subscription_mint(wallet: Wallet):
await asyncio.sleep(wait + 2) await asyncio.sleep(wait + 2)
assert triggered assert triggered
assert len(msg_stack) == 3 assert len(msg_stack) >= 3
assert msg_stack[0].payload["state"] == MintQuoteState.unpaid.value assert msg_stack[0].payload["state"] == MintQuoteState.unpaid.value
assert msg_stack[1].payload["state"] == MintQuoteState.paid.value assert msg_stack[1].payload["state"] == MintQuoteState.paid.value
assert msg_stack[2].payload["state"] == MintQuoteState.issued.value assert msg_stack[-1].payload["state"] == MintQuoteState.issued.value
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -133,7 +133,9 @@ async def test_wallet_subscription_multiple_listeners_receive_updates(wallet: Wa
from cashu.wallet.subscriptions import SubscriptionManager from cashu.wallet.subscriptions import SubscriptionManager
subs = SubscriptionManager(wallet.url) subs = SubscriptionManager(wallet.url)
threading.Thread(target=subs.connect, name="SubscriptionManager", daemon=True).start() threading.Thread(
target=subs.connect, name="SubscriptionManager", daemon=True
).start()
stack1: list[JSONRPCNotficationParams] = [] stack1: list[JSONRPCNotficationParams] = []
stack2: list[JSONRPCNotficationParams] = [] stack2: list[JSONRPCNotficationParams] = []