diff --git a/cashu/core/base.py b/cashu/core/base.py index 2814bb2..e530065 100644 --- a/cashu/core/base.py +++ b/cashu/core/base.py @@ -833,7 +833,7 @@ class MintKeyset: valid_from=row["valid_from"], valid_to=row["valid_to"], first_seen=row["first_seen"], - active=row["active"], + active=bool(row["active"]), unit=row["unit"], version=row["version"], input_fee_ppk=row["input_fee_ppk"], diff --git a/cashu/mint/crud.py b/cashu/mint/crud.py index 6ff2a63..4c0d20b 100644 --- a/cashu/mint/crud.py +++ b/cashu/mint/crud.py @@ -2,6 +2,8 @@ import json from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional +from loguru import logger + from ..core.base import ( BlindedSignature, MeltQuote, @@ -738,6 +740,7 @@ class LedgerCrudSqlite(LedgerCrud): keyset: MintKeyset, conn: Optional[Connection] = None, ) -> None: + logger.debug(f"Updating keyset {keyset.id}, which has {keyset.active = }") await (conn or db).execute( f""" UPDATE {db.table_with_schema('keysets')} diff --git a/cashu/mint/keysets.py b/cashu/mint/keysets.py new file mode 100644 index 0000000..224f744 --- /dev/null +++ b/cashu/mint/keysets.py @@ -0,0 +1,251 @@ +import base64 +from typing import Dict, List, Optional + +from loguru import logger + +from ..core.base import MintKeyset, Unit +from ..core.crypto.keys import derive_keyset_id +from ..core.errors import KeysetError, KeysetNotFoundError +from ..core.settings import settings +from .protocols import SupportsDb, SupportsKeysets, SupportsSeed + + +class LedgerKeysets(SupportsKeysets, SupportsSeed, SupportsDb): + + # ------- KEYS ------- + + def maybe_update_derivation_path(self, derivation_path: str) -> str: + """ + Check whether `self.derivation_path` was superseded by any of the active keysets loaded into this instance + upon initialization. The superseding derivation must have a greater count (last portion of the derivation path). + If this condition is true, update `self.derivation_path` to match the highest count derivation. + """ + derivation: List[str] = derivation_path.split("/") # type: ignore + counter = int(derivation[-1].replace("'", "")) + for keyset in self.keysets.values(): + if keyset.active: + keyset_derivation_path = keyset.derivation_path.split("/") + keyset_derivation_counter = int(keyset_derivation_path[-1].replace("'", "")) + if ( + keyset_derivation_path[:-1] == derivation[:-1] + and keyset_derivation_counter > counter + ): + derivation_path = keyset.derivation_path + return derivation_path + + async def rotate_next_keyset( + self, + unit: Unit, + max_order: Optional[int], + input_fee_ppk: Optional[int] + ) -> MintKeyset: + """ + This function: + 1. finds the highest counter keyset for `unit` + 2. creates a new derivation path from the old one, increasing the counter by one + 3. creates a new active keyset for the new derivation path + 4. de-activates the old keyset + 5. stores the new keyset to DB + + Args: + unit (Unit): Unit of the keyset. + max_order (Optional[int], optional): The number of keys to generate, which correspond to powers of 2. + input_fee_ppk (Optional[int], optional): The new keyset's fee + Returns: + MintKeyset: Resulting keyset of the rotation + """ + + logger.info(f"Attempting keyset rotation for unit {str(unit)}") + + # Select keyset with the greatest counter + selected_keyset = None + selected_keyset_counter = -1 + for keyset in self.keysets.values(): + if keyset.active and keyset.unit == unit: + keyset_derivation_path = keyset.derivation_path.split("/") + keyset_derivation_counter = int(keyset_derivation_path[-1].replace("'", "")) + if keyset_derivation_counter > selected_keyset_counter: + selected_keyset = keyset + + # If no selected keyset, then there is no keyset for this unit + if not selected_keyset: + logger.error(f"Couldn't find suitable keyset for rotation with unit {str(unit)}") + raise Exception(f"Couldn't find suitable keyset for rotation with unit {str(unit)}") + + logger.info(f"Rotating keyset {selected_keyset.id}") + + # New derivation path is just old derivation path with increased counter + new_derivation_path = selected_keyset.derivation_path.split("/") + new_derivation_path[-1] = str(int(new_derivation_path[-1].replace("'", "")) + 1) + "'" + + # keys amounts for this keyset: if amounts is None we use `self.amounts` + amounts = [2**i for i in range(max_order)] if max_order else self.amounts + + # Generate the keyset + new_keyset = MintKeyset( + derivation_path="/".join(new_derivation_path), + seed=self.seed, + amounts=amounts, + input_fee_ppk=input_fee_ppk + ) + + logger.debug(f"New keyset was generated with Id {new_keyset.id}. Saving...") + await self.crud.store_keyset(keyset=new_keyset, db=self.db) + self.keysets[new_keyset.id] = new_keyset + + logger.debug(f"De-activating keyset {selected_keyset.id}...") + selected_keyset.active = False + await self.crud.update_keyset(keyset=selected_keyset, db=self.db) + self.keysets[selected_keyset.id] = selected_keyset + + logger.debug(f"Keyset {keyset.id} was de-activated") + return new_keyset + + async def activate_keyset( + self, + *, + derivation_path: str, + seed: Optional[str] = None, + version: Optional[str] = None, + autosave=True, + ) -> MintKeyset: + """ + Load an existing keyset for the specified derivation path or generate a new one if it doesn't exist. + Optionally store the newly created keyset in the database. + + Args: + derivation_path (str): Derivation path for keyset generation. + seed (Optional[str], optional): Seed value. Defaults to None. + version (Optional[str], optional): Version identifier. Defaults to None. + autosave (bool, optional): Whether to store the keyset if newly created. Defaults to True. + + Returns: + MintKeyset: The activated keyset. + """ + if not derivation_path: + raise ValueError("Derivation path must be provided.") + + seed = seed or self.seed + version = version or settings.version + # Initialize a temporary keyset to derive the ID + temp_keyset = MintKeyset( + seed=seed, + derivation_path=derivation_path, + version=version, + amounts=self.amounts, + ) + logger.debug( + f"Activating keyset for derivation path '{derivation_path}' with ID '{temp_keyset.id}'." + ) + + # Attempt to retrieve existing keysets from the database + existing_keysets: List[MintKeyset] = await self.crud.get_keyset( + id=temp_keyset.id, db=self.db + ) + logger.trace( + f"Retrieved {len(existing_keysets)} keyset(s) for derivation path '{derivation_path}'." + ) + + if existing_keysets: + keyset = existing_keysets[0] + else: + # Create a new keyset if none exists + keyset = MintKeyset( + seed=seed, + derivation_path=derivation_path, + amounts=self.amounts, + version=version, + input_fee_ppk=settings.mint_input_fee_ppk, + ) + logger.debug(f"Generated new keyset with ID '{keyset.id}'.") + + if autosave: + logger.debug(f"Storing new keyset with ID '{keyset.id}'.") + await self.crud.store_keyset(keyset=keyset, db=self.db) + + # Activate the keyset + keyset.active = True + self.keysets[keyset.id] = keyset + logger.debug(f"Keyset with ID '{keyset.id}' is now active.") + + return keyset + + async def init_keysets(self, autosave: bool = True) -> None: + """Initializes all keysets of the mint from the db. Loads all past keysets from db + and generate their keys. Then activate the current keyset set by self.derivation_path. + + Args: + autosave (bool, optional): Whether the current keyset should be saved if it is + not in the database yet. Will be passed to `self.activate_keyset` where it is + generated from `self.derivation_path`. Defaults to True. + """ + # load all past keysets from db, the keys will be generated at instantiation + tmp_keysets: List[MintKeyset] = await self.crud.get_keyset(db=self.db) + + # add keysets from db to memory + for k in tmp_keysets: + self.keysets[k.id] = k + + logger.info(f"Loaded {len(self.keysets)} keysets from database.") + + # Check if any of the loaded keysets marked as active + # do supersede the one specified in the derivation settings. + # If this is the case update to latest count derivation. + self.derivation_path = self.maybe_update_derivation_path(self.derivation_path) # type: ignore + + # activate the current keyset set by self.derivation_path + # and self.derivation_path is not superseded by any other + # active keyset with same derivation path but higher count + if self.derivation_path: + self.keyset = await self.activate_keyset( + derivation_path=self.derivation_path, autosave=autosave + ) + logger.info(f"Current keyset: {self.keyset.id}") + + # check that we have a least one active keyset + if not any([k.active for k in self.keysets.values()]): + raise KeysetError("No active keyset found.") + + # DEPRECATION 0.16.1 – disable base64 keysets if hex equivalent exists + if settings.mint_inactivate_base64_keysets: + await self.inactivate_base64_keysets() + + async def inactivate_base64_keysets(self) -> None: + """Inactivates all base64 keysets that have a hex equivalent.""" + for keyset in self.keysets.values(): + if not keyset.active or not keyset.public_keys: + continue + # test if the keyset id is a hex string, if not it's base64 + try: + int(keyset.id, 16) + except ValueError: + # verify that it's base64 + try: + _ = base64.b64decode(keyset.id) + except ValueError: + logger.error("Unexpected: keyset id is neither hex nor base64.") + continue + + # verify that we have a hex version of the same keyset by comparing public keys + hex_keyset_id = derive_keyset_id(keys=keyset.public_keys) + if hex_keyset_id not in [k.id for k in self.keysets.values()]: + logger.warning( + f"Keyset {keyset.id} is base64 but we don't have a hex version. Ignoring." + ) + continue + + logger.warning( + f"Keyset {keyset.id} is base64 and has a hex counterpart, setting inactive." + ) + keyset.active = False + self.keysets[keyset.id] = keyset + await self.crud.update_keyset(keyset=keyset, db=self.db) + + def get_keyset(self, keyset_id: Optional[str] = None) -> Dict[int, str]: + """Returns a dictionary of hex public keys of a specific keyset for each supported amount""" + if keyset_id and keyset_id not in self.keysets: + raise KeysetNotFoundError() + keyset = self.keysets[keyset_id] if keyset_id else self.keyset + if not keyset.public_keys: + raise KeysetError("no public keys for this keyset") + return {a: p.serialize().hex() for a, p in keyset.public_keys.items()} \ No newline at end of file diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index b3f21c6..09d81b7 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -1,5 +1,4 @@ import asyncio -import base64 import time from typing import Dict, List, Mapping, Optional, Tuple @@ -25,7 +24,6 @@ from ..core.base import ( from ..core.crypto import b_dhke from ..core.crypto.aes import AESCipher from ..core.crypto.keys import ( - derive_keyset_id, derive_pubkey, random_hash, ) @@ -33,8 +31,6 @@ from ..core.crypto.secp import PrivateKey, PublicKey from ..core.db import Connection, Database from ..core.errors import ( CashuError, - KeysetError, - KeysetNotFoundError, LightningError, LightningPaymentFailedError, NotAllowedError, @@ -65,11 +61,12 @@ from .db.read import DbReadHelper from .db.write import DbWriteHelper from .events.events import LedgerEventManager from .features import LedgerFeatures +from .keysets import LedgerKeysets from .tasks import LedgerTasks from .verification import LedgerVerification -class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFeatures): +class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFeatures, LedgerKeysets): backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {} keysets: Dict[str, MintKeyset] = {} events = LedgerEventManager() @@ -139,6 +136,7 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe async def _startup_keysets(self) -> None: await self.init_keysets() for derivation_path in settings.mint_derivation_path_list: + derivation_path = self.maybe_update_derivation_path(derivation_path) await self.activate_keyset(derivation_path=derivation_path) async def _run_regular_tasks(self) -> None: @@ -193,150 +191,6 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe quote = await self.get_melt_quote(quote_id=quote.quote) logger.info(f"Melt quote {quote.quote} state: {quote.state}") - # ------- KEYS ------- - - async def activate_keyset( - self, - *, - derivation_path: str, - seed: Optional[str] = None, - version: Optional[str] = None, - autosave=True, - ) -> MintKeyset: - """ - Load an existing keyset for the specified derivation path or generate a new one if it doesn't exist. - Optionally store the newly created keyset in the database. - - Args: - derivation_path (str): Derivation path for keyset generation. - seed (Optional[str], optional): Seed value. Defaults to None. - version (Optional[str], optional): Version identifier. Defaults to None. - autosave (bool, optional): Whether to store the keyset if newly created. Defaults to True. - - Returns: - MintKeyset: The activated keyset. - """ - if not derivation_path: - raise ValueError("Derivation path must be provided.") - - seed = seed or self.seed - version = version or settings.version - # Initialize a temporary keyset to derive the ID - temp_keyset = MintKeyset( - seed=seed, - derivation_path=derivation_path, - version=version, - amounts=self.amounts, - ) - logger.debug( - f"Activating keyset for derivation path '{derivation_path}' with ID '{temp_keyset.id}'." - ) - - # Attempt to retrieve existing keysets from the database - existing_keysets: List[MintKeyset] = await self.crud.get_keyset( - id=temp_keyset.id, db=self.db - ) - logger.trace( - f"Retrieved {len(existing_keysets)} keyset(s) for derivation path '{derivation_path}'." - ) - - if existing_keysets: - keyset = existing_keysets[0] - else: - # Create a new keyset if none exists - keyset = MintKeyset( - seed=seed, - derivation_path=derivation_path, - amounts=self.amounts, - version=version, - input_fee_ppk=settings.mint_input_fee_ppk, - ) - logger.debug(f"Generated new keyset with ID '{keyset.id}'.") - - if autosave: - logger.debug(f"Storing new keyset with ID '{keyset.id}'.") - await self.crud.store_keyset(keyset=keyset, db=self.db) - - # Activate the keyset - keyset.active = True - self.keysets[keyset.id] = keyset - logger.debug(f"Keyset with ID '{keyset.id}' is now active.") - - return keyset - - async def init_keysets(self, autosave: bool = True) -> None: - """Initializes all keysets of the mint from the db. Loads all past keysets from db - and generate their keys. Then activate the current keyset set by self.derivation_path. - - Args: - autosave (bool, optional): Whether the current keyset should be saved if it is - not in the database yet. Will be passed to `self.activate_keyset` where it is - generated from `self.derivation_path`. Defaults to True. - """ - # load all past keysets from db, the keys will be generated at instantiation - tmp_keysets: List[MintKeyset] = await self.crud.get_keyset(db=self.db) - - # add keysets from db to memory - for k in tmp_keysets: - self.keysets[k.id] = k - - logger.info(f"Loaded {len(self.keysets)} keysets from database.") - - # activate the current keyset set by self.derivation_path - if self.derivation_path: - self.keyset = await self.activate_keyset( - derivation_path=self.derivation_path, autosave=autosave - ) - logger.info(f"Current keyset: {self.keyset.id}") - - # check that we have a least one active keyset - if not any([k.active for k in self.keysets.values()]): - raise KeysetError("No active keyset found.") - - # DEPRECATION 0.16.1 – disable base64 keysets if hex equivalent exists - if settings.mint_inactivate_base64_keysets: - await self.inactivate_base64_keysets() - - async def inactivate_base64_keysets(self) -> None: - """Inactivates all base64 keysets that have a hex equivalent.""" - for keyset in self.keysets.values(): - if not keyset.active or not keyset.public_keys: - continue - # test if the keyset id is a hex string, if not it's base64 - try: - int(keyset.id, 16) - except ValueError: - # verify that it's base64 - try: - _ = base64.b64decode(keyset.id) - except ValueError: - logger.error("Unexpected: keyset id is neither hex nor base64.") - continue - - # verify that we have a hex version of the same keyset by comparing public keys - hex_keyset_id = derive_keyset_id(keys=keyset.public_keys) - if hex_keyset_id not in [k.id for k in self.keysets.values()]: - logger.warning( - f"Keyset {keyset.id} is base64 but we don't have a hex version. Ignoring." - ) - continue - - logger.warning( - f"Keyset {keyset.id} is base64 and has a hex counterpart, setting inactive." - ) - keyset.active = False - self.keysets[keyset.id] = keyset - await self.crud.update_keyset(keyset=keyset, db=self.db) - - def get_keyset(self, keyset_id: Optional[str] = None) -> Dict[int, str]: - """Returns a dictionary of hex public keys of a specific keyset for each supported amount""" - if keyset_id and keyset_id not in self.keysets: - raise KeysetNotFoundError() - keyset = self.keysets[keyset_id] if keyset_id else self.keyset - if not keyset.public_keys: - raise KeysetError("no public keys for this keyset") - return {a: p.serialize().hex() for a, p in keyset.public_keys.items()} - async def get_balance(self, keyset: MintKeyset) -> int: """Returns the balance of the mint.""" return await self.crud.get_balance(keyset=keyset, db=self.db) diff --git a/cashu/mint/protocols.py b/cashu/mint/protocols.py index 3661b63..1fcd379 100644 --- a/cashu/mint/protocols.py +++ b/cashu/mint/protocols.py @@ -1,4 +1,4 @@ -from typing import Dict, Mapping, Protocol +from typing import Dict, List, Mapping, Protocol from ..core.base import Method, MintKeyset, Unit from ..core.crypto.secp import PublicKey @@ -10,9 +10,14 @@ from .db.write import DbWriteHelper from .events.events import LedgerEventManager +class SupportsSeed(Protocol): + seed: str + class SupportsKeysets(Protocol): + amounts: List[int] keyset: MintKeyset keysets: Dict[str, MintKeyset] + derivation_path: str class SupportsBackends(Protocol): diff --git a/tests/test_mint_keysets.py b/tests/test_mint_keysets.py index 6085ea7..897eb9c 100644 --- a/tests/test_mint_keysets.py +++ b/tests/test_mint_keysets.py @@ -1,7 +1,8 @@ import pytest -from cashu.core.base import MintKeyset +from cashu.core.base import MintKeyset, Unit from cashu.core.settings import settings +from cashu.mint.ledger import Ledger from tests.test_mint_init import DECRYPTON_KEY, DERIVATION_PATH, ENCRYPTED_SEED, SEED @@ -71,3 +72,20 @@ async def test_keyset_0_15_0_encrypted(): == "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" ) assert keyset.id == "009a1f293253e41e" + +@pytest.mark.asyncio +async def test_keyset_rotation(ledger: Ledger): + keyset_sat = next(filter(lambda k: k.unit == Unit["sat"] and k.active, ledger.keysets.values())) + new_keyset_sat = await ledger.rotate_next_keyset(unit=Unit["sat"], max_order=20, input_fee_ppk=1) + + keyset_sat_derivation = keyset_sat.derivation_path.split("/") + new_keyset_sat_derivation = keyset_sat.derivation_path.split("/") + + assert keyset_sat_derivation[:-1] == new_keyset_sat_derivation[:-1], "keyset derivation does not match up to the counter branch" + assert int(new_keyset_sat_derivation[-1].replace("'", "")) - int(keyset_sat_derivation[-1].replace("'", "")) == 0, "counters should differ by exactly 1" + + assert new_keyset_sat.input_fee_ppk == 1 + assert len(new_keyset_sat.private_keys.values()) == 20 + + old_keyset = (await ledger.crud.get_keyset(db=ledger.db, id=keyset_sat.id))[0] + assert not old_keyset.active, "old keyset is still active" \ No newline at end of file