Mint: store Y in db (#412)

* storage y db

* for proofs_pending as well

* pending check with Y

* fix pending table

* test_race_pending

* skip race condition test on github

* skip test on github actions

* move test_cli.py -> test_wallet_cli.py

* get full proof from memory

* add domain separation wallet
This commit is contained in:
callebtc
2024-02-10 22:52:55 +01:00
committed by GitHub
parent 1de7abf032
commit 6db4604f99
12 changed files with 264 additions and 48 deletions

View File

@@ -9,6 +9,8 @@ from typing import Any, Dict, List, Optional, Union
from loguru import logger
from pydantic import BaseModel, Field
from cashu.core.crypto.b_dhke import hash_to_curve
from .crypto.aes import AESCipher
from .crypto.keys import (
derive_keys,
@@ -88,8 +90,9 @@ class Proof(BaseModel):
id: Union[None, str] = ""
amount: int = 0
secret: str = "" # secret or message to be blinded and signed
Y: str = "" # hash_to_curve(secret)
C: str = "" # signature on secret, unblinded by wallet
dleq: Union[DLEQWallet, None] = None # DLEQ proof
dleq: Optional[DLEQWallet] = None # DLEQ proof
witness: Union[None, str] = "" # witness for spending condition
# whether this proof is reserved for sending, used for coin management in the wallet
@@ -106,6 +109,11 @@ class Proof(BaseModel):
None # holds the id of the melt operation that destroyed this proof
)
def __init__(self, **data):
super().__init__(**data)
if not self.Y:
self.Y = hash_to_curve(self.secret.encode("utf-8")).serialize().hex()
@classmethod
def from_dict(cls, proof_dict: dict):
if proof_dict.get("dleq") and isinstance(proof_dict["dleq"], str):

View File

@@ -71,6 +71,26 @@ def hash_to_curve(message: bytes) -> PublicKey:
return point
DOMAIN_SEPARATOR = b"Secp256k1_HashToCurve_"
def hash_to_curve_domain_separated(message: bytes) -> PublicKey:
"""Generates a point from the message hash and checks if the point lies on the curve.
If it does not, iteratively tries to compute a new point from the hash."""
point = None
msg_to_hash = DOMAIN_SEPARATOR + message
counter = 0
while point is None:
_hash = hashlib.sha256(msg_to_hash + str(counter).encode()).digest()
try:
# will error if point does not lie on curve
point = PublicKey(b"\x02" + _hash, raw=True)
except Exception:
msg_to_hash = _hash
counter += 1
return point
def step1_alice(
secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
@@ -80,6 +100,15 @@ def step1_alice(
return B_, r
def step1_alice_domain_separated(
secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
r = blinding_factor or PrivateKey()
B_: PublicKey = Y + r.pubkey # type: ignore
return B_, r
def step2_bob(B_: PublicKey, a: PrivateKey) -> Tuple[PublicKey, PrivateKey, PrivateKey]:
C_: PublicKey = B_.mult(a) # type: ignore
# produce dleq proof
@@ -94,7 +123,13 @@ def step3_alice(C_: PublicKey, r: PrivateKey, A: PublicKey) -> PublicKey:
def verify(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8"))
return C == Y.mult(a) # type: ignore
valid = C == Y.mult(a) # type: ignore
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not valid:
Y1: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
return C == Y1.mult(a) # type: ignore
# END: BACKWARDS COMPATIBILITY < 0.15.1
return valid
def hash_e(*publickeys: PublicKey) -> bytes:
@@ -149,7 +184,14 @@ def carol_verify_dleq(
Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8"))
C_: PublicKey = C + A.mult(r) # type: ignore
B_: PublicKey = Y + r.pubkey # type: ignore
return alice_verify_dleq(B_, C_, e, s, A)
valid = alice_verify_dleq(B_, C_, e, s, A)
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not valid:
Y1: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
B_1: PublicKey = Y1 + r.pubkey # type: ignore
return alice_verify_dleq(B_1, C_, e, s, A)
# END: BACKWARDS COMPATIBILITY < 0.15.1
return valid
# Below is a test of a simple positive and negative case

View File

@@ -121,7 +121,7 @@ class WalletSettings(CashuSettings):
mint_port: int = Field(default=3338)
wallet_name: str = Field(default="wallet")
wallet_unit: str = Field(default="sat")
wallet_domain_separation: bool = Field(default=False)
api_port: int = Field(default=4448)
api_host: str = Field(default="127.0.0.1")

View File

@@ -47,8 +47,8 @@ class LedgerCrud(ABC):
async def get_proof_used(
self,
*,
Y: str,
db: Database,
secret: str,
conn: Optional[Connection] = None,
) -> Optional[Proof]: ...
@@ -65,6 +65,7 @@ class LedgerCrud(ABC):
async def get_proofs_pending(
self,
*,
proofs: List[Proof],
db: Database,
conn: Optional[Connection] = None,
) -> List[Proof]: ...
@@ -271,13 +272,14 @@ class LedgerCrudSqlite(LedgerCrud):
await (conn or db).execute(
f"""
INSERT INTO {table_with_schema(db, 'proofs_used')}
(amount, C, secret, id, witness, created)
VALUES (?, ?, ?, ?, ?, ?)
(amount, C, secret, Y, id, witness, created)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
proof.amount,
proof.C,
proof.secret,
proof.Y,
proof.id,
proof.witness,
timestamp_now(db),
@@ -287,12 +289,17 @@ class LedgerCrudSqlite(LedgerCrud):
async def get_proofs_pending(
self,
*,
proofs: List[Proof],
db: Database,
conn: Optional[Connection] = None,
) -> List[Proof]:
rows = await (conn or db).fetchall(f"""
rows = await (conn or db).fetchall(
f"""
SELECT * from {table_with_schema(db, 'proofs_pending')}
""")
WHERE Y IN ({','.join(['?']*len(proofs))})
""",
tuple(proof.Y for proof in proofs),
)
return [Proof(**r) for r in rows]
async def set_proof_pending(
@@ -306,13 +313,14 @@ class LedgerCrudSqlite(LedgerCrud):
await (conn or db).execute(
f"""
INSERT INTO {table_with_schema(db, 'proofs_pending')}
(amount, C, secret, created)
VALUES (?, ?, ?, ?)
(amount, C, secret, Y, created)
VALUES (?, ?, ?, ?, ?)
""",
(
proof.amount,
str(proof.C),
str(proof.secret),
proof.C,
proof.secret,
proof.Y,
timestamp_now(db),
),
)
@@ -590,15 +598,16 @@ class LedgerCrudSqlite(LedgerCrud):
async def get_proof_used(
self,
*,
Y: str,
db: Database,
secret: str,
conn: Optional[Connection] = None,
) -> Optional[Proof]:
row = await (conn or db).fetchone(
f"""
SELECT * from {table_with_schema(db, 'proofs_used')}
WHERE secret = ?
WHERE Y = ?
""",
(secret,),
(Y,),
)
return Proof(**row) if row else None

View File

@@ -233,7 +233,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
proofs (List[Proof]): Proofs to add to known secret table.
conn: (Optional[Connection], optional): Database connection to reuse. Will create a new one if not given. Defaults to None.
"""
self.spent_proofs.update({p.secret: p for p in proofs})
self.spent_proofs.update({p.Y: p for p in proofs})
async with get_db_connection(self.db, conn) as conn:
# store in db
for p in proofs:
@@ -873,7 +873,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
logger.debug("Loading used proofs into memory")
spent_proofs_list = await self.crud.get_spent_proofs(db=self.db) or []
logger.debug(f"Loaded {len(spent_proofs_list)} used proofs")
self.spent_proofs = {p.secret: p for p in spent_proofs_list}
self.spent_proofs = {p.Y: p for p in spent_proofs_list}
async def check_proofs_state(self, secrets: List[str]) -> List[ProofState]:
"""Checks if provided proofs are spend or are pending.
@@ -891,19 +891,25 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
List[bool]: List of which proof are pending (True if pending, else False)
"""
states: List[ProofState] = []
proofs_spent = await self._get_proofs_spent(secrets)
proofs_pending = await self._get_proofs_pending(secrets)
proofs_spent_idx_secret = await self._get_proofs_spent_idx_secret(secrets)
proofs_pending_idx_secret = await self._get_proofs_pending_idx_secret(secrets)
for secret in secrets:
if secret not in proofs_spent and secret not in proofs_pending:
if (
secret not in proofs_spent_idx_secret
and secret not in proofs_pending_idx_secret
):
states.append(ProofState(secret=secret, state=SpentState.unspent))
elif secret not in proofs_spent and secret in proofs_pending:
elif (
secret not in proofs_spent_idx_secret
and secret in proofs_pending_idx_secret
):
states.append(ProofState(secret=secret, state=SpentState.pending))
else:
states.append(
ProofState(
secret=secret,
state=SpentState.spent,
witness=proofs_spent[secret].witness,
witness=proofs_spent_idx_secret[secret].witness,
)
)
return states
@@ -922,13 +928,13 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
async with self.proofs_pending_lock:
async with self.db.connect() as conn:
await self._validate_proofs_pending(proofs, conn)
for p in proofs:
try:
try:
for p in proofs:
await self.crud.set_proof_pending(
proof=p, db=self.db, conn=conn
)
except Exception:
raise TransactionError("proofs already pending.")
except Exception:
raise TransactionError("Failed to set proofs pending.")
async def _unset_proofs_pending(self, proofs: List[Proof]) -> None:
"""Deletes proofs from pending table.
@@ -952,8 +958,9 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
Raises:
Exception: At least one of the proofs is in the pending table.
"""
proofs_pending = await self.crud.get_proofs_pending(db=self.db, conn=conn)
for p in proofs:
for pp in proofs_pending:
if p.secret == pp.secret:
raise TransactionError("proofs are pending.")
assert (
len(
await self.crud.get_proofs_pending(proofs=proofs, db=self.db, conn=conn)
)
== 0
), TransactionError("proofs are pending.")

View File

@@ -1,3 +1,5 @@
from cashu.core.base import Proof
from ..core.db import Connection, Database, table_with_schema, timestamp_now
from ..core.settings import settings
@@ -392,3 +394,102 @@ async def m013_keysets_add_encrypted_seed(db: Database):
f"ALTER TABLE {table_with_schema(db, 'keysets')} ADD COLUMN"
" seed_encryption_method TEXT"
)
async def m014_proofs_add_Y_column(db: Database):
# get all proofs_used and proofs_pending from the database and compute Y for each of them
async with db.connect() as conn:
rows = await conn.fetchall(
f"SELECT * FROM {table_with_schema(db, 'proofs_used')}"
)
# Proof() will compute Y from secret upon initialization
proofs_used = [Proof(**r) for r in rows]
rows = await conn.fetchall(
f"SELECT * FROM {table_with_schema(db, 'proofs_pending')}"
)
proofs_pending = [Proof(**r) for r in rows]
async with db.connect() as conn:
await conn.execute(
f"ALTER TABLE {table_with_schema(db, 'proofs_used')} ADD COLUMN Y TEXT"
)
for proof in proofs_used:
await conn.execute(
f"UPDATE {table_with_schema(db, 'proofs_used')} SET Y = '{proof.Y}'"
f" WHERE secret = '{proof.secret}'"
)
# Copy proofs_used to proofs_used_old and create a new table proofs_used
# with the same columns but with a unique constraint on (Y)
# and copy the data from proofs_used_old to proofs_used, then drop proofs_used_old
await conn.execute(
f"DROP TABLE IF EXISTS {table_with_schema(db, 'proofs_used_old')}"
)
await conn.execute(
f"CREATE TABLE {table_with_schema(db, 'proofs_used_old')} AS"
f" SELECT * FROM {table_with_schema(db, 'proofs_used')}"
)
await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_used')}")
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_used')} (
amount INTEGER NOT NULL,
C TEXT NOT NULL,
secret TEXT NOT NULL,
id TEXT,
Y TEXT,
created TIMESTAMP,
witness TEXT,
UNIQUE (Y)
);
""")
await conn.execute(
f"INSERT INTO {table_with_schema(db, 'proofs_used')} (amount, C, "
"secret, id, Y, created, witness) SELECT amount, C, secret, id, Y,"
f" created, witness FROM {table_with_schema(db, 'proofs_used_old')}"
)
await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_used_old')}")
# add column Y to proofs_pending
await conn.execute(
f"ALTER TABLE {table_with_schema(db, 'proofs_pending')} ADD COLUMN Y TEXT"
)
for proof in proofs_pending:
await conn.execute(
f"UPDATE {table_with_schema(db, 'proofs_pending')} SET Y = '{proof.Y}'"
f" WHERE secret = '{proof.secret}'"
)
# Copy proofs_pending to proofs_pending_old and create a new table proofs_pending
# with the same columns but with a unique constraint on (Y)
# and copy the data from proofs_pending_old to proofs_pending, then drop proofs_pending_old
await conn.execute(
f"DROP TABLE IF EXISTS {table_with_schema(db, 'proofs_pending_old')}"
)
await conn.execute(
f"CREATE TABLE {table_with_schema(db, 'proofs_pending_old')} AS"
f" SELECT * FROM {table_with_schema(db, 'proofs_pending')}"
)
await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_pending')}")
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_pending')} (
amount INTEGER NOT NULL,
C TEXT NOT NULL,
secret TEXT NOT NULL,
Y TEXT,
id TEXT,
created TIMESTAMP,
UNIQUE (Y)
);
""")
await conn.execute(
f"INSERT INTO {table_with_schema(db, 'proofs_pending')} (amount, C, "
"secret, Y, id, created) SELECT amount, C, secret, Y, id, created"
f" FROM {table_with_schema(db, 'proofs_pending_old')}"
)
await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_pending_old')}")

View File

@@ -51,8 +51,10 @@ class LedgerVerification(LedgerSpendingConditions, SupportsKeysets, SupportsDb):
"""
# Verify inputs
# Verify proofs are spendable
spent_proofs = await self._get_proofs_spent([p.secret for p in proofs])
if not len(spent_proofs) == 0:
if (
not len(await self._get_proofs_spent_idx_secret([p.secret for p in proofs]))
== 0
):
raise TokenAlreadySpentError()
# Verify amounts of inputs
if not all([self._verify_amount(p.amount) for p in proofs]):
@@ -141,27 +143,35 @@ class LedgerVerification(LedgerSpendingConditions, SupportsKeysets, SupportsDb):
result.append(False if promise is None else True)
return result
async def _get_proofs_pending(self, secrets: List[str]) -> Dict[str, Proof]:
async def _get_proofs_pending_idx_secret(
self, secrets: List[str]
) -> Dict[str, Proof]:
"""Returns only those proofs that are pending."""
all_proofs_pending = await self.crud.get_proofs_pending(db=self.db)
all_proofs_pending = await self.crud.get_proofs_pending(
proofs=[Proof(secret=s) for s in secrets], db=self.db
)
proofs_pending = list(filter(lambda p: p.secret in secrets, all_proofs_pending))
proofs_pending_dict = {p.secret: p for p in proofs_pending}
return proofs_pending_dict
async def _get_proofs_spent(self, secrets: List[str]) -> Dict[str, Proof]:
async def _get_proofs_spent_idx_secret(
self, secrets: List[str]
) -> Dict[str, Proof]:
"""Returns all proofs that are spent."""
proofs = [Proof(secret=s) for s in secrets]
proofs_spent: List[Proof] = []
if settings.mint_cache_secrets:
# check used secrets in memory
for secret in secrets:
if secret in self.spent_proofs:
proofs_spent.append(self.spent_proofs[secret])
for proof in proofs:
spent_proof = self.spent_proofs.get(proof.Y)
if spent_proof:
proofs_spent.append(spent_proof)
else:
# check used secrets in database
async with self.db.connect() as conn:
for secret in secrets:
for proof in proofs:
spent_proof = await self.crud.get_proof_used(
db=self.db, secret=secret, conn=conn
db=self.db, Y=proof.Y, conn=conn
)
if spent_proof:
proofs_spent.append(spent_proof)

View File

@@ -1120,7 +1120,14 @@ class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets):
C = b_dhke.step3_alice(
C_, r, self.keysets[promise.id].public_keys[promise.amount]
)
B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not settings.wallet_domain_separation:
B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs
# END: BACKWARDS COMPATIBILITY < 0.15.1
else:
B_, r = b_dhke.step1_alice_domain_separated(
secret, r
) # recompute B_ for dleq proofs
proof = Proof(
id=promise.id,
@@ -1183,7 +1190,12 @@ class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets):
rs_ = [None] * len(amounts) if not rs else rs
rs_return: List[PrivateKey] = []
for secret, amount, r in zip(secrets, amounts, rs_):
B_, r = b_dhke.step1_alice(secret, r or None)
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not settings.wallet_domain_separation:
B_, r = b_dhke.step1_alice(secret, r or None)
# END: BACKWARDS COMPATIBILITY < 0.15.1
else:
B_, r = b_dhke.step1_alice_domain_separated(secret, r or None)
rs_return.append(r)
output = BlindedMessage(
amount=amount, B_=B_.serialize().hex(), id=self.keyset_id

View File

@@ -1,4 +1,5 @@
import asyncio
import importlib
import multiprocessing
import os
import shutil
@@ -45,7 +46,7 @@ assert "test" in settings.cashu_dir
shutil.rmtree(settings.cashu_dir, ignore_errors=True)
Path(settings.cashu_dir).mkdir(parents=True, exist_ok=True)
from cashu.mint.startup import lightning_backend # noqa
# from cashu.mint.startup import lightning_backend # noqa
@pytest.fixture(scope="session")
@@ -99,7 +100,8 @@ async def ledger():
db_file = os.path.join(settings.mint_database, "mint.sqlite3")
if os.path.exists(db_file):
os.remove(db_file)
wallets_module = importlib.import_module("cashu.lightning")
lightning_backend = getattr(wallets_module, settings.mint_lightning_backend)()
backends = {
Method.bolt11: {Unit.sat: lightning_backend},
}

View File

@@ -30,6 +30,7 @@ WALLET = wallet_class()
is_fake: bool = WALLET.__class__.__name__ == "FakeWallet"
is_regtest: bool = not is_fake
is_deprecated_api_only = settings.debug_mint_only_deprecated
is_github_actions = os.getenv("GITHUB_ACTIONS") == "true"
docker_lightning_cli = [
"docker",

View File

@@ -13,7 +13,13 @@ from cashu.wallet.wallet import Wallet
from cashu.wallet.wallet import Wallet as Wallet1
from cashu.wallet.wallet import Wallet as Wallet2
from tests.conftest import SERVER_ENDPOINT
from tests.helpers import get_real_invoice, is_fake, is_regtest, pay_if_regtest
from tests.helpers import (
get_real_invoice,
is_fake,
is_github_actions,
is_regtest,
pay_if_regtest,
)
async def assert_err(f, msg: Union[str, CashuError]):
@@ -349,12 +355,30 @@ async def test_duplicate_proofs_double_spent(wallet1: Wallet):
doublespend = await wallet1.mint(64, id=invoice.id)
await assert_err(
wallet1.split(wallet1.proofs + doublespend, 20),
"Mint Error: proofs already pending.",
"Mint Error: Failed to set proofs pending.",
)
assert wallet1.balance == 64
assert wallet1.available_balance == 64
@pytest.mark.asyncio
@pytest.mark.skipif(is_github_actions, reason="GITHUB_ACTIONS")
async def test_split_race_condition(wallet1: Wallet):
invoice = await wallet1.request_mint(64)
pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)
# run two splits in parallel
import asyncio
await assert_err(
asyncio.gather(
wallet1.split(wallet1.proofs, 20),
wallet1.split(wallet1.proofs, 20),
),
"proofs are pending.",
)
@pytest.mark.asyncio
async def test_send_and_redeem(wallet1: Wallet, wallet2: Wallet):
invoice = await wallet1.request_mint(64)