diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 33455d0..f32ad3d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,6 +25,7 @@ jobs: - name: Run mint env: LIGHTNING: False + MINT_PRIVATE_KEY: "testingkey" MINT_SERVER_HOST: 0.0.0.0 MINT_SERVER_PORT: 3338 run: | diff --git a/cashu/core/base.py b/cashu/core/base.py index 0aa0244..d54913a 100644 --- a/cashu/core/base.py +++ b/cashu/core/base.py @@ -1,12 +1,10 @@ from sqlite3 import Row -from typing import List, Union +from typing import Any, Dict, List, Union from pydantic import BaseModel - -class CashuError(BaseModel): - code = "000" - error = "CashuError" +from cashu.core.crypto import derive_keys, derive_keyset_id, derive_pubkeys +from cashu.core.secp import PrivateKey, PublicKey class P2SHScript(BaseModel): @@ -25,6 +23,7 @@ class P2SHScript(BaseModel): class Proof(BaseModel): + id: str = "" amount: int secret: str = "" C: str @@ -44,6 +43,7 @@ class Proof(BaseModel): send_id=row[4] or "", time_created=row[5] or "", time_reserved=row[6] or "", + id=row[7] or "", ) @classmethod @@ -60,10 +60,10 @@ class Proof(BaseModel): ) def to_dict(self): - return dict(amount=self.amount, secret=self.secret, C=self.C) + return dict(id=self.id, amount=self.amount, secret=self.secret, C=self.C) def to_dict_no_secret(self): - return dict(amount=self.amount, C=self.C) + return dict(id=self.id, amount=self.amount, C=self.C) def __getitem__(self, key): return self.__getattribute__(key) @@ -95,17 +95,20 @@ class Invoice(BaseModel): class BlindedMessage(BaseModel): + id: str = "" amount: int B_: str class BlindedSignature(BaseModel): + id: str = "" amount: int C_: str @classmethod def from_dict(cls, d: dict): return cls( + id=d["id"], amount=d["amount"], C_=d["C_"], ) @@ -165,3 +168,129 @@ class MeltRequest(BaseModel): proofs: List[Proof] amount: int = None # deprecated invoice: str + + +class KeyBase(BaseModel): + id: str + amount: int + pubkey: str + + @classmethod + def from_row(cls, row: Row): + if row is None: + return cls + return cls( + id=row[0], + amount=int(row[1]), + pubkey=row[2], + ) + + +class WalletKeyset: + id: str + public_keys: Dict[int, PublicKey] + mint_url: Union[str, None] = None + valid_from: Union[str, None] = None + valid_to: Union[str, None] = None + first_seen: Union[str, None] = None + active: bool = True + + def __init__( + self, + pubkeys: Dict[int, PublicKey] = None, + mint_url=None, + id=None, + valid_from=None, + valid_to=None, + first_seen=None, + active=None, + ): + self.id = id + self.valid_from = valid_from + self.valid_to = valid_to + self.first_seen = first_seen + self.active = active + self.mint_url = mint_url + if pubkeys: + self.public_keys = pubkeys + self.id = derive_keyset_id(self.public_keys) + + @classmethod + def from_row(cls, row: Row): + if row is None: + return cls + return cls( + id=row[0], + mint_url=row[1], + valid_from=row[2], + valid_to=row[3], + first_seen=row[4], + active=row[5], + ) + + +class MintKeyset: + id: str + derivation_path: str + private_keys: Dict[int, PrivateKey] + public_keys: Dict[int, PublicKey] = None + valid_from: Union[str, None] = None + valid_to: Union[str, None] = None + first_seen: Union[str, None] = None + active: bool = True + + def __init__( + self, + id=None, + valid_from=None, + valid_to=None, + first_seen=None, + active=None, + seed: Union[None, str] = None, + derivation_path: str = "0", + ): + self.derivation_path = derivation_path + self.id = id + self.valid_from = valid_from + self.valid_to = valid_to + self.first_seen = first_seen + self.active = active + # generate keys from seed + if seed: + self.generate_keys(seed) + + def generate_keys(self, seed): + self.private_keys = derive_keys(seed, self.derivation_path) + self.public_keys = derive_pubkeys(self.private_keys) + self.id = derive_keyset_id(self.public_keys) + + @classmethod + def from_row(cls, row: Row): + if row is None: + return cls + # fix to convert byte to string, unclear why this is necessary + id = row[0].decode("ascii") if type(row[0]) == bytes else row[0] + return cls( + id=id, + derivation_path=row[1], + valid_from=row[2], + valid_to=row[3], + first_seen=row[4], + active=row[5], + ) + + def get_keybase(self): + return { + k: KeyBase(id=self.id, amount=k, pubkey=v.serialize().hex()) + for k, v in self.public_keys.items() + } + + +class MintKeysets: + keysets: Dict[str, MintKeyset] + + def __init__(self, keysets: List[MintKeyset]): + self.keysets: Dict[str, MintKeyset] = {k.id: k for k in keysets} + + def get_ids(self): + return [k for k, _ in self.keysets.items()] diff --git a/cashu/core/crypto.py b/cashu/core/crypto.py new file mode 100644 index 0000000..379d3ed --- /dev/null +++ b/cashu/core/crypto.py @@ -0,0 +1,35 @@ +import base64 +import hashlib +from typing import Dict, List + +from cashu.core.secp import PrivateKey, PublicKey +from cashu.core.settings import MAX_ORDER + + +def derive_keys(master_key: str, derivation_path: str = ""): + """ + Deterministic derivation of keys for 2^n values. + TODO: Implement BIP32. + """ + return { + 2 + ** i: PrivateKey( + hashlib.sha256((str(master_key) + derivation_path + str(i)).encode("utf-8")) + .hexdigest() + .encode("utf-8")[:32], + raw=True, + ) + for i in range(MAX_ORDER) + } + + +def derive_pubkeys(keys: Dict[int, PrivateKey]): + return {amt: keys[amt].pubkey for amt in [2**i for i in range(MAX_ORDER)]} + + +def derive_keyset_id(keys: Dict[str, PublicKey]): + """Deterministic derivation keyset_id from set of public keys.""" + pubkeys_concat = "".join([p.serialize().hex() for _, p in keys.items()]) + return base64.b64encode(hashlib.sha256((pubkeys_concat).encode("utf-8")).digest())[ + :12 + ] diff --git a/cashu/core/errors.py b/cashu/core/errors.py new file mode 100644 index 0000000..f770896 --- /dev/null +++ b/cashu/core/errors.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel + + +class CashuError(BaseModel): + code = "000" + error = "CashuError" + + +# class CashuError(Exception, BaseModel): +# code = "000" +# error = "CashuError" + + +# class MintException(CashuError): +# code = 100 +# error = "Mint" + + +# class LightningException(MintException): +# code = 200 +# error = "Lightning" + + +# class InvoiceNotPaidException(LightningException): +# code = 201 +# error = "invoice not paid." diff --git a/cashu/core/migrations.py b/cashu/core/migrations.py index 6beaa91..cc5041f 100644 --- a/cashu/core/migrations.py +++ b/cashu/core/migrations.py @@ -1,7 +1,5 @@ import re -from loguru import logger - from cashu.core.db import COCKROACH, POSTGRES, SQLITE, Database diff --git a/cashu/core/settings.py b/cashu/core/settings.py index 68a74ba..1f23b89 100644 --- a/cashu/core/settings.py +++ b/cashu/core/settings.py @@ -48,4 +48,4 @@ LNBITS_ENDPOINT = env.str("LNBITS_ENDPOINT", default=None) LNBITS_KEY = env.str("LNBITS_KEY", default=None) MAX_ORDER = 64 -VERSION = "0.2.6" +VERSION = "0.3.0" diff --git a/cashu/mint/crud.py b/cashu/mint/crud.py index a608967..19d0f0c 100644 --- a/cashu/mint/crud.py +++ b/cashu/mint/crud.py @@ -1,7 +1,6 @@ -import secrets from typing import Optional -from cashu.core.base import Invoice, Proof +from cashu.core.base import Invoice, MintKeyset, Proof from cashu.core.db import Connection, Database @@ -111,3 +110,57 @@ async def update_lightning_invoice( hash, ), ) + + +async def store_keyset( + keyset: MintKeyset, + mint_url: str = None, + db: Database = None, + conn: Optional[Connection] = None, +): + + await (conn or db).execute( + """ + INSERT INTO keysets + (id, derivation_path, valid_from, valid_to, first_seen, active) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + keyset.id, + keyset.derivation_path, + keyset.valid_from, + keyset.valid_to, + keyset.first_seen, + True, + ), + ) + + +async def get_keyset( + id: str = None, + derivation_path: str = None, + db: Database = None, + conn: Optional[Connection] = None, +): + clauses = [] + values = [] + clauses.append("active = ?") + values.append(True) + if id: + clauses.append("id = ?") + values.append(id) + if derivation_path: + clauses.append("derivation_path = ?") + values.append(derivation_path) + where = "" + if clauses: + where = f"WHERE {' AND '.join(clauses)}" + + rows = await (conn or db).fetchall( + f""" + SELECT * from keysets + {where} + """, + tuple(values), + ) + return [MintKeyset.from_row(row) for row in rows] diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index 6895bf5..8d3c558 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -2,26 +2,35 @@ Implementation of https://gist.github.com/phyro/935badc682057f418842c72961cf096c """ -import hashlib import math -from inspect import signature -from signal import signal -from typing import List, Set +from typing import Dict, List, Set + +from loguru import logger import cashu.core.b_dhke as b_dhke import cashu.core.bolt11 as bolt11 -from cashu.core.base import BlindedMessage, BlindedSignature, Invoice, Proof +from cashu.core.base import ( + BlindedMessage, + BlindedSignature, + Invoice, + MintKeyset, + MintKeysets, + Proof, +) +from cashu.core.crypto import derive_keys, derive_keyset_id, derive_pubkeys from cashu.core.db import Database from cashu.core.helpers import fee_reserve from cashu.core.script import verify_script -from cashu.core.secp import PrivateKey, PublicKey +from cashu.core.secp import PublicKey from cashu.core.settings import LIGHTNING, MAX_ORDER from cashu.core.split import amount_split from cashu.lightning import WALLET from cashu.mint.crud import ( + get_keyset, get_lightning_invoice, get_proofs_used, invalidate_proof, + store_keyset, store_lightning_invoice, store_promise, update_lightning_invoice, @@ -29,34 +38,40 @@ from cashu.mint.crud import ( class Ledger: - def __init__(self, secret_key: str, db: str): + def __init__(self, secret_key: str, db: str, derivation_path=""): self.proofs_used: Set[str] = set() - self.master_key = secret_key - self.keys = self._derive_keys(self.master_key) - self.pub_keys = self._derive_pubkeys(self.keys) + self.derivation_path = derivation_path self.db: Database = Database("mint", db) async def load_used_proofs(self): self.proofs_used = set(await get_proofs_used(db=self.db)) - @staticmethod - def _derive_keys(master_key: str): - """Deterministic derivation of keys for 2^n values.""" - return { - 2 - ** i: PrivateKey( - hashlib.sha256((str(master_key) + str(i)).encode("utf-8")) - .hexdigest() - .encode("utf-8")[:32], - raw=True, - ) - for i in range(MAX_ORDER) - } + async def init_keysets(self): + """Loads all past keysets and stores the active one if not already in db""" + # generate current keyset from seed and current derivation path + self.keyset = MintKeyset( + seed=self.master_key, derivation_path=self.derivation_path + ) + # check if current keyset is stored in db and store if not + current_keyset_local: List[MintKeyset] = await get_keyset( + id=self.keyset.id, db=self.db + ) + if not len(current_keyset_local): + logger.debug(f"Storing keyset {self.keyset.id}") + await store_keyset(keyset=self.keyset, db=self.db) - @staticmethod - def _derive_pubkeys(keys: List[PrivateKey]): - return {amt: keys[amt].pubkey for amt in [2**i for i in range(MAX_ORDER)]} + # load all past keysets from db + # this needs two steps because the types of tmp_keysets and the argument of MintKeysets() are different + tmp_keysets: List[MintKeyset] = await get_keyset(db=self.db) + self.keysets = MintKeysets(tmp_keysets) + logger.debug(f"Keysets {self.keysets.keysets}") + # generate all derived keys from stored derivation paths of past keysets + for _, v in self.keysets.keysets.items(): + v.generate_keys(self.master_key) + + if len(self.keysets.keysets): + logger.debug(f"Loaded {len(self.keysets.keysets)} keysets from db.") async def _generate_promises(self, amounts: List[int], B_s: List[str]): """Generates promises that sum to the given amount.""" @@ -67,7 +82,7 @@ class Ledger: async def _generate_promise(self, amount: int, B_: PublicKey): """Generates a promise for given amount and returns a pair (amount, C').""" - secret_key = self.keys[amount] # Get the correct key + secret_key = self.keyset.private_keys[amount] # Get the correct key C_ = b_dhke.step2_bob(B_, secret_key) await store_promise( amount, B_=B_.serialize().hex(), C_=C_.serialize().hex(), db=self.db @@ -92,7 +107,13 @@ class Ledger: """Verifies that the proof of promise was issued by this ledger.""" if not self._check_spendable(proof): raise Exception(f"tokens already spent. Secret: {proof.secret}") - secret_key = self.keys[proof.amount] # Get the correct key to check against + # if no keyset id is given in proof, assume the current one + if not proof.id: + secret_key = self.keyset.private_keys[proof.amount] + else: + # use the appropriate active keyset for this proof.id + secret_key = self.keysets.keysets[proof.id].private_keys[proof.amount] + C = PublicKey(bytes.fromhex(proof.C), raw=True) return b_dhke.verify(secret_key, C, proof.secret) @@ -123,7 +144,7 @@ class Ledger: assert len(proof.secret.split(":")) == 3, "secret format wrong." assert proof.secret.split(":")[1] == str( txin_p2sh_address - ), f"secret does not contain correct P2SH address: {proof.secret.split(':')[1]}!={txin_p2sh_address}." + ), f"secret does not contain correct P2SH address: {proof.secret.split(':')[1]} is not {txin_p2sh_address}." return valid def _verify_outputs(self, total: int, amount: int, outputs: List[BlindedMessage]): @@ -221,10 +242,13 @@ class Ledger: for p in proofs: await invalidate_proof(p, db=self.db) - # Public methods - def get_pubkeys(self): + def _serialize_pubkeys(self): """Returns public keys for possible amounts.""" - return {a: p.serialize().hex() for a, p in self.pub_keys.items()} + return {a: p.serialize().hex() for a, p in self.keyset.public_keys.items()} + + # Public methods + def get_keyset(self): + return self._serialize_pubkeys() async def request_mint(self, amount): """Returns Lightning invoice and stores it in the db.""" diff --git a/cashu/mint/migrations.py b/cashu/mint/migrations.py index 967b9d3..15c8233 100644 --- a/cashu/mint/migrations.py +++ b/cashu/mint/migrations.py @@ -85,3 +85,36 @@ async def m001_initial(db: Database): ); """ ) + + +async def m003_mint_keysets(db: Database): + """ + Stores mint keysets from different mints and epochs. + """ + await db.execute( + f""" + CREATE TABLE IF NOT EXISTS keysets ( + id TEXT NOT NULL, + derivation_path TEXT, + valid_from TIMESTAMP DEFAULT {db.timestamp_now}, + valid_to TIMESTAMP DEFAULT {db.timestamp_now}, + first_seen TIMESTAMP DEFAULT {db.timestamp_now}, + active BOOL DEFAULT TRUE, + + UNIQUE (derivation_path) + + ); + """ + ) + await db.execute( + f""" + CREATE TABLE IF NOT EXISTS mint_pubkeys ( + id TEXT NOT NULL, + amount INTEGER NOT NULL, + pubkey TEXT NOT NULL, + + UNIQUE (id, pubkey) + + ); + """ + ) diff --git a/cashu/mint/router.py b/cashu/mint/router.py index 9974826..ae78859 100644 --- a/cashu/mint/router.py +++ b/cashu/mint/router.py @@ -4,7 +4,6 @@ from fastapi import APIRouter from secp256k1 import PublicKey from cashu.core.base import ( - CashuError, CheckFeesRequest, CheckFeesResponse, CheckRequest, @@ -15,6 +14,7 @@ from cashu.core.base import ( PostSplitResponse, SplitRequest, ) +from cashu.core.errors import CashuError from cashu.mint import ledger router: APIRouter = APIRouter() @@ -23,7 +23,13 @@ router: APIRouter = APIRouter() @router.get("/keys") def keys(): """Get the public keys of the mint""" - return ledger.get_pubkeys() + return ledger.get_keyset() + + +@router.get("/keysets") +def keysets(): + """Get all active keysets of the mint""" + return {"keysets": ledger.keysets.get_ids()} @router.get("/mint") diff --git a/cashu/mint/startup.py b/cashu/mint/startup.py index 1207091..f148e2c 100644 --- a/cashu/mint/startup.py +++ b/cashu/mint/startup.py @@ -2,16 +2,19 @@ import asyncio from loguru import logger +from cashu.core.migrations import migrate_databases from cashu.core.settings import CASHU_DIR, LIGHTNING from cashu.lightning import WALLET -from cashu.mint.migrations import m001_initial +from cashu.mint import migrations from . import ledger async def load_ledger(): - await asyncio.wait([m001_initial(ledger.db)]) + await migrate_databases(ledger.db, migrations) + # await asyncio.wait([m001_initial(ledger.db)]) await ledger.load_used_proofs() + await ledger.init_keysets() if LIGHTNING: error_message, balance = await WALLET.status() diff --git a/cashu/wallet/cli.py b/cashu/wallet/cli.py index 8ddd0e3..2ce2272 100755 --- a/cashu/wallet/cli.py +++ b/cashu/wallet/cli.py @@ -78,7 +78,7 @@ def coro(f): @coro async def mint(ctx, amount: int, hash: str): wallet: Wallet = ctx.obj["WALLET"] - wallet.load_mint() + await wallet.load_mint() wallet.status() if not LIGHTNING: r = await wallet.mint(amount) @@ -123,7 +123,7 @@ async def mint(ctx, amount: int, hash: str): @coro async def pay(ctx, invoice: str): wallet: Wallet = ctx.obj["WALLET"] - wallet.load_mint() + await wallet.load_mint() wallet.status() decoded_invoice: Invoice = bolt11.decode(invoice) # check if it's an internal payment @@ -148,7 +148,18 @@ async def pay(ctx, invoice: str): @coro async def balance(ctx): wallet: Wallet = ctx.obj["WALLET"] - wallet.status() + keyset_balances = wallet.balance_per_keyset() + if len(keyset_balances) > 1: + print(f"You have balances in {len(keyset_balances)} keysets:") + print("") + for k, v in keyset_balances.items(): + print( + f"Keyset: {k or 'undefined'} Balance: {v['balance']} sat (available: {v['available']})" + ) + print("") + print( + f"Balance: {wallet.balance} sat (available: {wallet.available_balance} sat in {len([p for p in wallet.proofs if not p.reserved])} tokens)" + ) @cli.command("send", help="Send coins.") @@ -164,7 +175,7 @@ async def send(ctx, amount: int, lock: str): if lock and len(lock.split("P2SH:")) == 2: p2sh = True wallet: Wallet = ctx.obj["WALLET"] - wallet.load_mint() + await wallet.load_mint() wallet.status() _, send_proofs = await wallet.split_to_send(wallet.proofs, amount, lock) await wallet.set_reserved(send_proofs, reserved=True) @@ -182,7 +193,7 @@ async def send(ctx, amount: int, lock: str): @coro async def receive(ctx, coin: str, lock: str): wallet: Wallet = ctx.obj["WALLET"] - wallet.load_mint() + await wallet.load_mint() wallet.status() if lock: # load the script and signature of this address from the database @@ -192,7 +203,7 @@ async def receive(ctx, coin: str, lock: str): address_split = lock.split("P2SH:")[1] p2shscripts = await get_unused_locks(address_split, db=wallet.db) - assert len(p2shscripts) == 1 + assert len(p2shscripts) == 1, Exception("lock not found.") script = p2shscripts[0].script signature = p2shscripts[0].signature else: @@ -212,7 +223,7 @@ async def receive(ctx, coin: str, lock: str): @coro async def burn(ctx, coin: str, all: bool, force: bool): wallet: Wallet = ctx.obj["WALLET"] - wallet.load_mint() + await wallet.load_mint() if not (all or coin or force) or (coin and all): print( "Error: enter a coin or use --all to burn all pending coins or --force to check all coins." @@ -239,7 +250,7 @@ async def burn(ctx, coin: str, all: bool, force: bool): @coro async def pending(ctx): wallet: Wallet = ctx.obj["WALLET"] - wallet.load_mint() + await wallet.load_mint() reserved_proofs = await get_reserved_proofs(wallet.db) if len(reserved_proofs): print(f"--------------------------\n") diff --git a/cashu/wallet/crud.py b/cashu/wallet/crud.py index da29451..12d8401 100644 --- a/cashu/wallet/crud.py +++ b/cashu/wallet/crud.py @@ -1,7 +1,7 @@ import time from typing import Any, List, Optional -from cashu.core.base import P2SHScript, Proof +from cashu.core.base import KeyBase, P2SHScript, Proof, WalletKeyset from cashu.core.db import Connection, Database @@ -14,10 +14,10 @@ async def store_proof( await (conn or db).execute( """ INSERT INTO proofs - (amount, C, secret, time_created) - VALUES (?, ?, ?, ?) + (id, amount, C, secret, time_created) + VALUES (?, ?, ?, ?, ?) """, - (proof.amount, str(proof.C), str(proof.secret), int(time.time())), + (proof.id, proof.amount, str(proof.C), str(proof.secret), int(time.time())), ) @@ -65,10 +65,10 @@ async def invalidate_proof( await (conn or db).execute( """ INSERT INTO proofs_used - (amount, C, secret, time_used) - VALUES (?, ?, ?, ?) + (amount, C, secret, time_used, id) + VALUES (?, ?, ?, ?, ?) """, - (proof.amount, str(proof.C), str(proof.secret), int(time.time())), + (proof.amount, str(proof.C), str(proof.secret), int(time.time()), proof.id), ) @@ -180,3 +180,57 @@ async def update_p2sh_used( f"UPDATE proofs SET {', '.join(clauses)} WHERE address = ?", (*values, str(p2sh.address)), ) + + +async def store_keyset( + keyset: WalletKeyset, + mint_url: str = None, + db: Database = None, + conn: Optional[Connection] = None, +): + + await (conn or db).execute( + """ + INSERT INTO keysets + (id, mint_url, valid_from, valid_to, first_seen, active) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + keyset.id, + mint_url or keyset.mint_url, + keyset.valid_from, + keyset.valid_to, + keyset.first_seen, + True, + ), + ) + + +async def get_keyset( + id: str = None, + mint_url: str = None, + db: Database = None, + conn: Optional[Connection] = None, +): + clauses = [] + values = [] + clauses.append("active = ?") + values.append(True) + if id: + clauses.append("id = ?") + values.append(id) + if mint_url: + clauses.append("mint_url = ?") + values.append(mint_url) + where = "" + if clauses: + where = f"WHERE {' AND '.join(clauses)}" + + row = await (conn or db).fetchone( + f""" + SELECT * from keysets + {where} + """, + tuple(values), + ) + return WalletKeyset.from_row(row) if row is not None else None diff --git a/cashu/wallet/migrations.py b/cashu/wallet/migrations.py index 68b4b64..0a0a73e 100644 --- a/cashu/wallet/migrations.py +++ b/cashu/wallet/migrations.py @@ -98,3 +98,39 @@ async def m004_p2sh_locks(db: Database): ); """ ) + + +async def m005_wallet_keysets(db: Database): + """ + Stores mint keysets from different mints and epochs. + """ + await db.execute( + f""" + CREATE TABLE IF NOT EXISTS keysets ( + id TEXT NOT NULL, + mint_url TEXT NOT NULL, + valid_from TIMESTAMP DEFAULT {db.timestamp_now}, + valid_to TIMESTAMP DEFAULT {db.timestamp_now}, + first_seen TIMESTAMP DEFAULT {db.timestamp_now}, + active BOOL DEFAULT TRUE, + + UNIQUE (id, mint_url) + + ); + """ + ) + # await db.execute( + # f""" + # CREATE TABLE IF NOT EXISTS mint_pubkeys ( + # id TEXT NOT NULL, + # amount INTEGER NOT NULL, + # pubkey TEXT NOT NULL, + + # UNIQUE (id, pubkey) + + # ); + # """ + # ) + + await db.execute("ALTER TABLE proofs ADD COLUMN id TEXT") + await db.execute("ALTER TABLE proofs_used ADD COLUMN id TEXT") diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index fc1fb84..4c150e0 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -2,7 +2,8 @@ import base64 import json import secrets as scrts import uuid -from typing import List +from itertools import groupby +from typing import Dict, List import requests from loguru import logger @@ -18,6 +19,7 @@ from cashu.core.base import ( P2SHScript, Proof, SplitRequest, + WalletKeyset, ) from cashu.core.db import Database from cashu.core.script import ( @@ -30,9 +32,11 @@ from cashu.core.secp import PublicKey from cashu.core.settings import DEBUG from cashu.core.split import amount_split from cashu.wallet.crud import ( + get_keyset, get_proofs, invalidate_proof, secret_used, + store_keyset, store_p2sh, store_proof, update_proof_reserved, @@ -40,18 +44,27 @@ from cashu.wallet.crud import ( class LedgerAPI: + keys: Dict[int, str] + keyset: str + def __init__(self, url): self.url = url - @staticmethod - def _get_keys(url): - resp = requests.get(url + "/keys") - resp.raise_for_status() - data = resp.json() - return { + async def _get_keys(self, url): + resp = requests.get(url + "/keys").json() + keys = resp + assert len(keys), Exception("did not receive any keys") + keyset_keys = { int(amt): PublicKey(bytes.fromhex(val), raw=True) - for amt, val in data.items() + for amt, val in keys.items() } + keyset = WalletKeyset(pubkeys=keyset_keys, mint_url=url) + return keyset + + async def _get_keysets(self, url): + keysets = requests.get(url + "/keysets").json() + assert len(keysets), Exception("did not receive any keysets") + return keysets @staticmethod def _get_output_split(amount): @@ -71,7 +84,12 @@ class LedgerAPI: for promise, secret, r in zip(promises, secrets, rs): C_ = PublicKey(bytes.fromhex(promise.C_), raw=True) C = b_dhke.step3_alice(C_, r, self.keys[promise.amount]) - proof = Proof(amount=promise.amount, C=C.serialize().hex(), secret=secret) + proof = Proof( + id=self.keyset_id, + amount=promise.amount, + C=C.serialize().hex(), + secret=secret, + ) proofs.append(proof) return proofs @@ -80,12 +98,32 @@ class LedgerAPI: """Returns base64 encoded random string.""" return scrts.token_urlsafe(randombits // 8) - def _load_mint(self): + async def _load_mint(self): + """ + Loads the current keys and the active keyset of the map. + """ assert len( self.url ), "Ledger not initialized correctly: mint URL not specified yet. " - self.keys = self._get_keys(self.url) - assert len(self.keys) > 0, "did not receive keys from mint." + # get current keyset + keyset = await self._get_keys(self.url) + logger.debug(f"Current mint keyset: {keyset.id}") + # get all active keysets + keysets = await self._get_keysets(self.url) + logger.debug(f"Mint keysets: {keysets}") + + # check if current keyset is in db + keyset_local: WalletKeyset = await get_keyset(keyset.id, db=self.db) + if keyset_local is None: + await store_keyset(keyset=keyset, db=self.db) + + # store current keyset + assert len(keyset.public_keys) > 0, "did not receive keys from mint." + self.keys = keyset.public_keys + self.keyset_id = keyset.id + + # store active keysets + self.keysets = keysets["keysets"] def request_mint(self, amount): """Requests a mint from the server and returns Lightning invoice.""" @@ -177,9 +215,19 @@ class LedgerAPI: await self._check_used_secrets(secrets) payloads, rs = self._construct_outputs(amounts, secrets) split_payload = SplitRequest(proofs=proofs, amount=amount, outputs=payloads) + + def _splitrequest_include_fields(proofs): + """strips away fields from the model that aren't necessary for the /split""" + proofs_include = {"id", "amount", "secret", "C", "script"} + return { + "amount": ..., + "outputs": ..., + "proofs": {i: proofs_include for i in range(len(proofs))}, + } + resp = requests.post( self.url + "/split", - json=split_payload.dict(), + json=split_payload.dict(include=_splitrequest_include_fields(proofs)), ) resp.raise_for_status() try: @@ -225,9 +273,19 @@ class LedgerAPI: async def pay_lightning(self, proofs: List[Proof], invoice: str): payload = MeltRequest(proofs=proofs, invoice=invoice) + + def _meltequest_include_fields(proofs): + """strips away fields from the model that aren't necessary for the /melt""" + proofs_include = {"id", "amount", "secret", "C", "script"} + return { + "amount": ..., + "invoice": ..., + "proofs": {i: proofs_include for i in range(len(proofs))}, + } + resp = requests.post( self.url + "/melt", - json=payload.dict(), + json=payload.dict(include=_meltequest_include_fields(proofs)), ) resp.raise_for_status() @@ -244,8 +302,8 @@ class Wallet(LedgerAPI): self.proofs: List[Proof] = [] self.name = name - def load_mint(self): - super()._load_mint() + async def load_mint(self): + await super()._load_mint() async def load_proofs(self): self.proofs = await get_proofs(db=self.db) @@ -254,6 +312,16 @@ class Wallet(LedgerAPI): for proof in proofs: await store_proof(proof, db=self.db) + @staticmethod + def _sum_proofs(proofs: List[Proof], available_only=False): + if available_only: + return sum([p.amount for p in proofs if not p.reserved]) + return sum([p.amount for p in proofs]) + + @staticmethod + def _get_proofs_per_keyset(proofs: List[Proof]): + return {key: list(group) for key, group in groupby(proofs, lambda p: p.id)} + async def request_mint(self, amount): return super().request_mint(amount) @@ -319,14 +387,28 @@ class Wallet(LedgerAPI): ).decode() return token + async def _get_spendable_proofs(self, proofs: List[Proof]): + """ + Selects proofs that can be used with the current mint. + Chooses: + 1) Proofs that are not marked as reserved + 2) Proofs that have a keyset id that is in self.keysets (active keysets of mint) - !!! optional for backwards compatibility with legacy clients + """ + proofs = [ + p for p in proofs if p.id in self.keysets or not p.id + ] # "or not p.id" is for backwards compatibility with proofs without a keyset id + proofs = [p for p in proofs if not p.reserved] + return proofs + async def split_to_send(self, proofs: List[Proof], amount, scnd_secret: str = None): """Like self.split but only considers non-reserved tokens.""" if scnd_secret: logger.debug(f"Spending conditions: {scnd_secret}") - if len([p for p in proofs if not p.reserved]) <= 0: + spendable_proofs = await self._get_spendable_proofs(proofs) + if sum([p.amount for p in spendable_proofs]) < amount: raise Exception("balance too low.") return await self.split( - [p for p in proofs if not p.reserved], amount, scnd_secret + [p for p in spendable_proofs if not p.reserved], amount, scnd_secret ) async def set_reserved(self, proofs: List[Proof], reserved: bool): @@ -382,5 +464,14 @@ class Wallet(LedgerAPI): f"Balance: {self.balance} sat (available: {self.available_balance} sat in {len([p for p in self.proofs if not p.reserved])} tokens)" ) + def balance_per_keyset(self): + return { + key: { + "balance": self._sum_proofs(proofs), + "available": self._sum_proofs(proofs, available_only=True), + } + for key, proofs in self._get_proofs_per_keyset(self.proofs).items() + } + def proof_amounts(self): return [p["amount"] for p in sorted(self.proofs, key=lambda p: p["amount"])] diff --git a/pyproject.toml b/pyproject.toml index 62eb8f5..1c4d8a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cashu" -version = "0.2.6" +version = "0.3.0" description = "Ecash wallet and mint." authors = ["calle "] license = "MIT" diff --git a/setup.py b/setup.py index 2cc3078..d06a8c9 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ entry_points = {"console_scripts": ["cashu = cashu.wallet.cli:cli"]} setuptools.setup( name="cashu", - version="0.2.6", + version="0.3.0", description="Ecash wallet and mint with Bitcoin Lightning support", long_description=long_description, long_description_content_type="text/markdown", diff --git a/tests/test_wallet.py b/tests/test_wallet.py index 36bb3c5..c14cc7c 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -31,12 +31,12 @@ def assert_amt(proofs, expected): async def run_test(): wallet1 = Wallet1(SERVER_ENDPOINT, "data/wallet1", "wallet1") await migrate_databases(wallet1.db, migrations) - wallet1.load_mint() + await wallet1.load_mint() wallet1.status() wallet2 = Wallet2(SERVER_ENDPOINT, "data/wallet2", "wallet2") await migrate_databases(wallet2.db, migrations) - wallet2.load_mint() + await wallet2.load_mint() wallet2.status() proofs = []