keyset working

This commit is contained in:
callebtc
2022-10-08 13:56:01 +02:00
parent b411fcc79b
commit 6ce15da527
12 changed files with 299 additions and 138 deletions

View File

@@ -1,5 +1,9 @@
from sqlite3 import Row from sqlite3 import Row
from typing import List, Union, Dict from typing import List, Union, Dict, Any
from cashu.core.crypto import derive_keyset_id, derive_keys, derive_pubkeys
from cashu.core.secp import PrivateKey, PublicKey
from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
@@ -9,12 +13,10 @@ class CashuError(BaseModel):
error = "CashuError" error = "CashuError"
class Keyset(BaseModel): class KeyBase(BaseModel):
id: str id: str
keys: Dict amount: int
mint_url: Union[str, None] = None pubkey: str
first_seen: Union[str, None] = None
active: bool = True
@classmethod @classmethod
def from_row(cls, row: Row): def from_row(cls, row: Row):
@@ -22,13 +24,55 @@ class Keyset(BaseModel):
return cls return cls
return cls( return cls(
id=row[0], id=row[0],
keys=row[1], amount=int(row[1]),
mint_url=row[2], pubkey=row[2],
first_seen=row[3],
active=row[4],
) )
class Keyset:
id: str
private_keys: Dict[int, PrivateKey]
public_keys: Dict[int, PublicKey]
mint_url: Union[str, None] = None
valid_from: Union[str, None] = None
valid_to: Union[str, None] = None
first_seen: Union[str, None] = None
active: bool = True
def __init__(
self,
seed: Union[None, str] = None,
derivation_path: str = "0",
pubkeys: Union[None, Dict[int, PublicKey]] = None,
):
if seed:
self.private_keys = derive_keys(seed, derivation_path)
self.public_keys = derive_pubkeys(self.private_keys)
if pubkeys:
self.public_keys = pubkeys
self.id = derive_keyset_id(self.public_keys)
logger.debug(f"Mint keyset id: {self.id}")
@classmethod
def from_row(cls, row: Row):
if row is None:
return cls
return cls(
id=row[0],
mint_url=row[1],
valid_from=row[2],
valid_to=row[3],
first_seen=row[4],
active=row[5],
)
def get_keybase(self):
return {
k: KeyBase(id=self.id, amount=k, pubkey=v.serialize().hex())
for k, v in self.public_keys.items()
}
class P2SHScript(BaseModel): class P2SHScript(BaseModel):
script: str script: str
signature: str signature: str
@@ -65,6 +109,7 @@ class Proof(BaseModel):
send_id=row[4] or "", send_id=row[4] or "",
time_created=row[5] or "", time_created=row[5] or "",
time_reserved=row[6] or "", time_reserved=row[6] or "",
id=row[7] or "",
) )
@classmethod @classmethod
@@ -116,17 +161,20 @@ class Invoice(BaseModel):
class BlindedMessage(BaseModel): class BlindedMessage(BaseModel):
id: str = ""
amount: int amount: int
B_: str B_: str
class BlindedSignature(BaseModel): class BlindedSignature(BaseModel):
id: str = ""
amount: int amount: int
C_: str C_: str
@classmethod @classmethod
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
return cls( return cls(
id=d["id"],
amount=d["amount"], amount=d["amount"],
C_=d["C_"], C_=d["C_"],
) )

74
cashu/core/crud.py Normal file
View File

@@ -0,0 +1,74 @@
from typing import Optional
from cashu.core.base import Keyset, KeyBase
from cashu.core.db import Connection, Database
async def store_keyset(
keyset: Keyset,
mint_url: str,
db: Database,
conn: Optional[Connection] = None,
):
await (conn or db).execute(
"""
INSERT INTO keysets
(id, mint_url, valid_from, valid_to, first_seen, active)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
keyset.id,
mint_url,
keyset.valid_from,
keyset.valid_to,
keyset.first_seen,
True,
),
)
async def get_keyset(
id: str = None,
mint_url: str = None,
db: Database = None,
conn: Optional[Connection] = None,
):
clauses = []
values = []
clauses.append("active = ?")
values.append(True)
if id:
clauses.append("id = ?")
values.append(id)
if mint_url:
clauses.append("mint_url = ?")
values.append(mint_url)
where = ""
if clauses:
where = f"WHERE {' AND '.join(clauses)}"
row = await (conn or db).fetchone(
f"""
SELECT * from keysets
{where}
""",
tuple(values),
)
return Keyset.from_row(row)
async def store_mint_pubkey(
key: KeyBase,
db: Database,
conn: Optional[Connection] = None,
):
await (conn or db).execute(
"""
INSERT INTO mint_pubkeys
(id, amount, pubkey)
VALUES (?, ?, ?)
""",
(key.id, key.amount, key.pubkey),
)

31
cashu/core/crypto.py Normal file
View File

@@ -0,0 +1,31 @@
import hashlib
from typing import Dict, List
from cashu.core.secp import PrivateKey, PublicKey
from cashu.core.settings import MAX_ORDER
def derive_keys(master_key: str, derivation_path: str = ""):
"""
Deterministic derivation of keys for 2^n values.
TODO: Implement BIP32.
"""
return {
2
** i: PrivateKey(
hashlib.sha256((str(master_key) + derivation_path + str(i)).encode("utf-8"))
.hexdigest()
.encode("utf-8")[:32],
raw=True,
)
for i in range(MAX_ORDER)
}
def derive_pubkeys(keys: Dict[int, PrivateKey]):
return {amt: keys[amt].pubkey for amt in [2**i for i in range(MAX_ORDER)]}
def derive_keyset_id(keys: Dict[str, PublicKey]):
"""Deterministic derivation keyset_id from set of public keys."""
pubkeys_concat = "".join([p.serialize().hex() for _, p in keys.items()])
return hashlib.sha256((pubkeys_concat).encode("utf-8")).hexdigest()[:16]

21
cashu/core/errors.py Normal file
View File

@@ -0,0 +1,21 @@
from pydantic import BaseModel
class CashuError(Exception, BaseModel):
code = "000"
error = "CashuError"
class MintException(CashuError):
code = 100
error = "Mint"
class LightningException(MintException):
code = 200
error = "Lightning"
class InvoiceNotPaidException(LightningException):
code = 201
error = "invoice not paid."

View File

@@ -1,7 +1,5 @@
import re import re
from loguru import logger
from cashu.core.db import COCKROACH, POSTGRES, SQLITE, Database from cashu.core.db import COCKROACH, POSTGRES, SQLITE, Database

View File

@@ -1,4 +1,3 @@
import secrets
from typing import Optional from typing import Optional
from cashu.core.base import Invoice, Proof from cashu.core.base import Invoice, Proof

View File

@@ -2,20 +2,17 @@
Implementation of https://gist.github.com/phyro/935badc682057f418842c72961cf096c Implementation of https://gist.github.com/phyro/935badc682057f418842c72961cf096c
""" """
import hashlib
import math import math
from inspect import signature
from signal import signal
from typing import List, Set from typing import List, Set
import cashu.core.b_dhke as b_dhke import cashu.core.b_dhke as b_dhke
import cashu.core.bolt11 as bolt11 import cashu.core.bolt11 as bolt11
from cashu.core.base import BlindedMessage, BlindedSignature, Invoice, Proof from cashu.core.base import BlindedMessage, BlindedSignature, Invoice, Proof, Keyset
from cashu.core.crypto import derive_keyset_id from cashu.core.crypto import derive_keyset_id, derive_keys, derive_pubkeys
from cashu.core.db import Database from cashu.core.db import Database
from cashu.core.helpers import fee_reserve from cashu.core.helpers import fee_reserve
from cashu.core.script import verify_script from cashu.core.script import verify_script
from cashu.core.secp import PrivateKey, PublicKey from cashu.core.secp import PublicKey
from cashu.core.settings import LIGHTNING, MAX_ORDER from cashu.core.settings import LIGHTNING, MAX_ORDER
from cashu.core.split import amount_split from cashu.core.split import amount_split
from cashu.lightning import WALLET from cashu.lightning import WALLET
@@ -32,34 +29,13 @@ from cashu.mint.crud import (
class Ledger: class Ledger:
def __init__(self, secret_key: str, db: str): def __init__(self, secret_key: str, db: str):
self.proofs_used: Set[str] = set() self.proofs_used: Set[str] = set()
self.master_key = secret_key self.master_key = secret_key
self.keys = self._derive_keys(self.master_key) self.keyset = Keyset(self.master_key)
self.keyset_id = derive_keyset_id(self.keys)
self.pub_keys = self._derive_pubkeys(self.keys)
self.db: Database = Database("mint", db) self.db: Database = Database("mint", db)
async def load_used_proofs(self): async def load_used_proofs(self):
self.proofs_used = set(await get_proofs_used(db=self.db)) self.proofs_used = set(await get_proofs_used(db=self.db))
@staticmethod
def _derive_keys(master_key: str, keyset_id: str = ""):
"""Deterministic derivation of keys for 2^n values."""
return {
2
** i: PrivateKey(
hashlib.sha256((str(master_key) + str(i) + keyset_id).encode("utf-8"))
.hexdigest()
.encode("utf-8")[:32],
raw=True,
)
for i in range(MAX_ORDER)
}
@staticmethod
def _derive_pubkeys(keys: List[PrivateKey]):
return {amt: keys[amt].pubkey for amt in [2**i for i in range(MAX_ORDER)]}
async def _generate_promises(self, amounts: List[int], B_s: List[str]): async def _generate_promises(self, amounts: List[int], B_s: List[str]):
"""Generates promises that sum to the given amount.""" """Generates promises that sum to the given amount."""
return [ return [
@@ -69,7 +45,7 @@ class Ledger:
async def _generate_promise(self, amount: int, B_: PublicKey): async def _generate_promise(self, amount: int, B_: PublicKey):
"""Generates a promise for given amount and returns a pair (amount, C').""" """Generates a promise for given amount and returns a pair (amount, C')."""
secret_key = self.keys[amount] # Get the correct key secret_key = self.keyset.private_keys[amount] # Get the correct key
C_ = b_dhke.step2_bob(B_, secret_key) C_ = b_dhke.step2_bob(B_, secret_key)
await store_promise( await store_promise(
amount, B_=B_.serialize().hex(), C_=C_.serialize().hex(), db=self.db amount, B_=B_.serialize().hex(), C_=C_.serialize().hex(), db=self.db
@@ -94,7 +70,9 @@ class Ledger:
"""Verifies that the proof of promise was issued by this ledger.""" """Verifies that the proof of promise was issued by this ledger."""
if not self._check_spendable(proof): if not self._check_spendable(proof):
raise Exception(f"tokens already spent. Secret: {proof.secret}") raise Exception(f"tokens already spent. Secret: {proof.secret}")
secret_key = self.keys[proof.amount] # Get the correct key to check against secret_key = self.keyset.private_keys[
proof.amount
] # Get the correct key to check against
C = PublicKey(bytes.fromhex(proof.C), raw=True) C = PublicKey(bytes.fromhex(proof.C), raw=True)
return b_dhke.verify(secret_key, C, proof.secret) return b_dhke.verify(secret_key, C, proof.secret)
@@ -125,7 +103,7 @@ class Ledger:
assert len(proof.secret.split(":")) == 3, "secret format wrong." assert len(proof.secret.split(":")) == 3, "secret format wrong."
assert proof.secret.split(":")[1] == str( assert proof.secret.split(":")[1] == str(
txin_p2sh_address txin_p2sh_address
), f"secret does not contain correct P2SH address: {proof.secret.split(':')[1]}!={txin_p2sh_address}." ), f"secret does not contain correct P2SH address: {proof.secret.split(':')[1]} is not {txin_p2sh_address}."
return valid return valid
def _verify_outputs(self, total: int, amount: int, outputs: List[BlindedMessage]): def _verify_outputs(self, total: int, amount: int, outputs: List[BlindedMessage]):
@@ -223,13 +201,13 @@ class Ledger:
for p in proofs: for p in proofs:
await invalidate_proof(p, db=self.db) await invalidate_proof(p, db=self.db)
# Public methods def _serialize_pubkeys(self):
def get_pubkeys(self):
"""Returns public keys for possible amounts.""" """Returns public keys for possible amounts."""
return {a: p.serialize().hex() for a, p in self.pub_keys.items()} return {a: p.serialize().hex() for a, p in self.keyset.public_keys.items()}
# Public methods
def get_keyset(self): def get_keyset(self):
return {"id": self.keyset_id, "keys": self.get_pubkeys()} return {"id": self.keyset.id, "keys": self._serialize_pubkeys()}
async def request_mint(self, amount): async def request_mint(self, amount):
"""Returns Lightning invoice and stores it in the db.""" """Returns Lightning invoice and stores it in the db."""

View File

@@ -85,3 +85,36 @@ async def m001_initial(db: Database):
); );
""" """
) )
async def m003_mint_keysets(db: Database):
"""
Stores mint keysets from different mints and epochs.
"""
await db.execute(
f"""
CREATE TABLE IF NOT EXISTS keysets (
id TEXT NOT NULL,
mint_url TEXT NOT NULL,
valid_from TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
valid_to TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
first_seen TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
active BOOL NOT NULL DEFAULT TRUE,
UNIQUE (id, mint_url)
);
"""
)
await db.execute(
f"""
CREATE TABLE IF NOT EXISTS mint_pubkeys (
id TEXT NOT NULL,
amount INTEGER NOT NULL,
pubkey TEXT NOT NULL,
UNIQUE (id, pubkey)
);
"""
)

View File

@@ -4,13 +4,15 @@ from loguru import logger
from cashu.core.settings import CASHU_DIR, LIGHTNING from cashu.core.settings import CASHU_DIR, LIGHTNING
from cashu.lightning import WALLET from cashu.lightning import WALLET
from cashu.mint.migrations import m001_initial from cashu.mint import migrations
from cashu.core.migrations import migrate_databases
from . import ledger from . import ledger
async def load_ledger(): async def load_ledger():
await asyncio.wait([m001_initial(ledger.db)]) await migrate_databases(ledger.db, migrations)
# await asyncio.wait([m001_initial(ledger.db)])
await ledger.load_used_proofs() await ledger.load_used_proofs()
if LIGHTNING: if LIGHTNING:

View File

@@ -1,7 +1,7 @@
import time import time
from typing import Any, List, Optional, Dict from typing import Any, List, Optional
from cashu.core.base import P2SHScript, Proof, Keyset from cashu.core.base import P2SHScript, Proof
from cashu.core.db import Connection, Database from cashu.core.db import Connection, Database
@@ -14,10 +14,10 @@ async def store_proof(
await (conn or db).execute( await (conn or db).execute(
""" """
INSERT INTO proofs INSERT INTO proofs
(amount, C, secret, time_created) (id, amount, C, secret, time_created)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
""", """,
(proof.amount, str(proof.C), str(proof.secret), int(time.time())), (proof.id, proof.amount, str(proof.C), str(proof.secret), int(time.time())),
) )
@@ -65,10 +65,10 @@ async def invalidate_proof(
await (conn or db).execute( await (conn or db).execute(
""" """
INSERT INTO proofs_used INSERT INTO proofs_used
(amount, C, secret, time_used) (amount, C, secret, time_used, id)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
""", """,
(proof.amount, str(proof.C), str(proof.secret), int(time.time())), (proof.amount, str(proof.C), str(proof.secret), int(time.time()), proof.id),
) )
@@ -180,50 +180,3 @@ async def update_p2sh_used(
f"UPDATE proofs SET {', '.join(clauses)} WHERE address = ?", f"UPDATE proofs SET {', '.join(clauses)} WHERE address = ?",
(*values, str(p2sh.address)), (*values, str(p2sh.address)),
) )
async def store_keyset(
keyset: Keyset,
mint_url: str,
db: Database,
conn: Optional[Connection] = None,
):
await (conn or db).execute(
"""
INSERT INTO keysets
(id, mint_url, keys, first_seen, active)
VALUES (?, ?, ?, ?, ?)
""",
(keyset.id, mint_url, keyset.keys, keyset.first_seen, True),
)
async def get_keyset(
id: str = None,
mint_url: str = None,
db: Database = None,
conn: Optional[Connection] = None,
):
clauses = []
values = []
clauses.append("active = ?")
values.append(True)
if id:
clauses.append("id = ?")
values.append(id)
if mint_url:
clauses.append("mint_url = ?")
values.append(mint_url)
where = ""
if clauses:
where = f"WHERE {' AND '.join(clauses)}"
row = await (conn or db).fetchone(
f"""
SELECT * from keysets
{where}
""",
tuple(values),
)
return Keyset.from_row(row)

View File

@@ -108,8 +108,9 @@ async def m005_mint_keysets(db: Database):
f""" f"""
CREATE TABLE IF NOT EXISTS keysets ( CREATE TABLE IF NOT EXISTS keysets (
id TEXT NOT NULL, id TEXT NOT NULL,
keys TEXT NOT NULL,
mint_url TEXT NOT NULL, mint_url TEXT NOT NULL,
valid_from TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
valid_to TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
first_seen TIMESTAMP NOT NULL DEFAULT {db.timestamp_now}, first_seen TIMESTAMP NOT NULL DEFAULT {db.timestamp_now},
active BOOL NOT NULL DEFAULT TRUE, active BOOL NOT NULL DEFAULT TRUE,
@@ -118,3 +119,18 @@ async def m005_mint_keysets(db: Database):
); );
""" """
) )
await db.execute(
f"""
CREATE TABLE IF NOT EXISTS mint_pubkeys (
id TEXT NOT NULL,
amount INTEGER NOT NULL,
pubkey TEXT NOT NULL,
UNIQUE (id, pubkey)
);
"""
)
await db.execute("ALTER TABLE proofs ADD COLUMN id TEXT")
await db.execute("ALTER TABLE proofs_used ADD COLUMN id TEXT")

View File

@@ -37,9 +37,8 @@ from cashu.wallet.crud import (
store_p2sh, store_p2sh,
store_proof, store_proof,
update_proof_reserved, update_proof_reserved,
store_keyset,
get_keyset,
) )
from cashu.core.crud import store_keyset, get_keyset
class LedgerAPI: class LedgerAPI:
@@ -49,37 +48,18 @@ class LedgerAPI:
def __init__(self, url): def __init__(self, url):
self.url = url self.url = url
<<<<<<< HEAD
def _get_keys(self, url): def _get_keys(self, url):
resp = requests.get(url + "/keys").json() resp = requests.get(url + "/keys").json()
keyset_id = resp["id"] keyset_id = resp["id"]
keys = resp["keys"] keys = resp["keys"]
# return { assert len(keys), Exception("did not receive any keys")
# "id": keyset_id, keyset_keys = {
# "keys": {
# int(amt): PublicKey(bytes.fromhex(val), raw=True)
# for amt, val in keys.items()
# },
# }
keyset_keys = (
{
int(amt): PublicKey(bytes.fromhex(val), raw=True) int(amt): PublicKey(bytes.fromhex(val), raw=True)
for amt, val in keys.items() for amt, val in keys.items()
},
)
print(resp)
return Keyset(id=keyset_id, keys=keyset_keys, mint_url=self.url)
=======
@staticmethod
def _get_keys(url):
resp = requests.get(url + "/keys")
resp.raise_for_status()
data = resp.json()
return {
int(amt): PublicKey(bytes.fromhex(val), raw=True)
for amt, val in data.items()
} }
>>>>>>> main keyset = Keyset(pubkeys=keyset_keys)
assert keyset_id == keyset.id, Exception("mint keyset id not valid.")
return keyset
@staticmethod @staticmethod
def _get_output_split(amount): def _get_output_split(amount):
@@ -121,7 +101,7 @@ class LedgerAPI:
keyset_local: Keyset = await get_keyset(keyset.id, self.url, db=self.db) keyset_local: Keyset = await get_keyset(keyset.id, self.url, db=self.db)
if keyset_local is None: if keyset_local is None:
await store_keyset(keyset=keyset, db=self.db) await store_keyset(keyset=keyset, db=self.db)
self.keys = keyset.keys self.keys = keyset.public_keys
self.keyset_id = keyset.id self.keyset_id = keyset.id
assert len(self.keys) > 0, "did not receive keys from mint." assert len(self.keys) > 0, "did not receive keys from mint."
@@ -215,9 +195,19 @@ class LedgerAPI:
await self._check_used_secrets(secrets) await self._check_used_secrets(secrets)
payloads, rs = self._construct_outputs(amounts, secrets) payloads, rs = self._construct_outputs(amounts, secrets)
split_payload = SplitRequest(proofs=proofs, amount=amount, outputs=payloads) split_payload = SplitRequest(proofs=proofs, amount=amount, outputs=payloads)
def _splitrequest_include_fields(proofs):
"""strips away fields from the model that aren't necessary for the /split"""
proofs_include = {"id", "amount", "secret", "C", "script"}
return {
"amount": ...,
"outputs": ...,
"proofs": {i: proofs_include for i in range(len(proofs))},
}
resp = requests.post( resp = requests.post(
self.url + "/split", self.url + "/split",
json=split_payload.dict(), json=split_payload.dict(include=_splitrequest_include_fields(proofs)),
) )
resp.raise_for_status() resp.raise_for_status()
try: try:
@@ -263,9 +253,19 @@ class LedgerAPI:
async def pay_lightning(self, proofs: List[Proof], invoice: str): async def pay_lightning(self, proofs: List[Proof], invoice: str):
payload = MeltRequest(proofs=proofs, invoice=invoice) payload = MeltRequest(proofs=proofs, invoice=invoice)
def _meltequest_include_fields(proofs):
"""strips away fields from the model that aren't necessary for the /melt"""
proofs_include = {"id", "amount", "secret", "C", "script"}
return {
"amount": ...,
"invoice": ...,
"proofs": {i: proofs_include for i in range(len(proofs))},
}
resp = requests.post( resp = requests.post(
self.url + "/melt", self.url + "/melt",
json=payload.dict(), json=payload.dict(include=_meltequest_include_fields(proofs)),
) )
resp.raise_for_status() resp.raise_for_status()
@@ -357,14 +357,22 @@ class Wallet(LedgerAPI):
).decode() ).decode()
return token return token
async def _get_spendable_proofs(self, proofs: List[Proof]):
print(f"Debug: only loading proofs with id: {self.keyset_id}")
proofs = [p for p in proofs if p.id == self.keyset_id or not p.id]
proofs = [p for p in proofs if not p.reserved]
return proofs
async def split_to_send(self, proofs: List[Proof], amount, scnd_secret: str = None): async def split_to_send(self, proofs: List[Proof], amount, scnd_secret: str = None):
"""Like self.split but only considers non-reserved tokens.""" """Like self.split but only considers non-reserved tokens."""
if scnd_secret: if scnd_secret:
logger.debug(f"Spending conditions: {scnd_secret}") logger.debug(f"Spending conditions: {scnd_secret}")
if len([p for p in proofs if not p.reserved]) <= 0: spendable_proofs = await self._get_spendable_proofs(proofs)
print(f"Balance: {sum([p.amount for p in spendable_proofs])}")
if sum([p.amount for p in spendable_proofs]) < amount:
raise Exception("balance too low.") raise Exception("balance too low.")
return await self.split( return await self.split(
[p for p in proofs if not p.reserved], amount, scnd_secret [p for p in spendable_proofs if not p.reserved], amount, scnd_secret
) )
async def set_reserved(self, proofs: List[Proof], reserved: bool): async def set_reserved(self, proofs: List[Proof], reserved: bool):