From c444a063c135c804b41c6270fe8dbc01a62c56d4 Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Fri, 13 Oct 2023 21:58:42 +0200 Subject: [PATCH] Wallet: Cache keysets (#333) * accept new keyset id calculation, cache keysets, remove duplicate representations * remove comments * check if keyset is present for DLEQ * load keys from db if available * store new keyset * make mypy happy * make mypy happy * precommit --- .pre-commit-config.yaml | 9 +++- cashu/wallet/crud.py | 2 +- cashu/wallet/wallet.py | 114 ++++++++++++++++++++++++++-------------- tests/test_wallet.py | 8 +-- 4 files changed, 88 insertions(+), 45 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 22f3090..00795a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: '^cashu/nostr/.*' +exclude: "^cashu/nostr/.*" repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 @@ -20,4 +20,9 @@ repos: rev: v0.0.283 hooks: - id: ruff - args: [ --fix, --exit-non-zero-on-fix ] + args: [--fix, --exit-non-zero-on-fix] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.6.0 + hooks: + - id: mypy + args: [--ignore-missing] diff --git a/cashu/wallet/crud.py b/cashu/wallet/crud.py index e37a9bc..9101aa4 100644 --- a/cashu/wallet/crud.py +++ b/cashu/wallet/crud.py @@ -151,7 +151,7 @@ async def get_keyset( mint_url: str = "", db: Optional[Database] = None, conn: Optional[Connection] = None, -): +) -> Optional[WalletKeyset]: clauses = [] values: List[Any] = [] clauses.append("active = ?") diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index d75fa8d..dab92e4 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -99,9 +99,10 @@ def async_set_requests(func): class LedgerAPI(object): - keys: WalletKeyset # holds current keys of mint - keyset_id: str # holds id of current keyset - public_keys: Dict[int, PublicKey] # holds public keys of + keyset_id: str # holds current keyset id + keysets: Dict[str, WalletKeyset] # holds keysets + mint_keyset_ids: List[str] # holds active keyset ids of the mint + mint_info: GetInfoResponse # holds info about mint tor: TorProxy s: requests.Session @@ -111,6 +112,7 @@ class LedgerAPI(object): self.url = url self.s = requests.Session() self.db = db + self.keysets = {} @async_set_requests async def _init_s(self): @@ -137,7 +139,7 @@ class LedgerAPI(object): # raise for status if no error resp.raise_for_status() - async def _load_mint_keys(self, keyset_id: str = "") -> WalletKeyset: + async def _load_mint_keys(self, keyset_id: Optional[str] = None) -> None: """Loads keys from mint and stores them in the database. Args: @@ -152,31 +154,54 @@ class LedgerAPI(object): self.url ), "Ledger not initialized correctly: mint URL not specified yet. " + keyset_local: Union[WalletKeyset, None] = None if keyset_id: - # get requested keyset + # check if current keyset is in db + keyset_local = await get_keyset(keyset_id, db=self.db) + if keyset_local: + logger.debug(f"Found keyset {keyset_id} in database.") + else: + logger.debug( + f"Cannot find keyset {keyset_id} in database. Loading keyset from" + " mint." + ) + keyset = keyset_local + + if keyset_local is None and keyset_id: + # get requested keyset from mint keyset = await self._get_keys_of_keyset(self.url, keyset_id) else: # get current keyset keyset = await self._get_keys(self.url) - assert keyset.public_keys + assert keyset assert keyset.id assert len(keyset.public_keys) > 0, "did not receive keys from mint." - # check if current keyset is in db - keyset_local: Optional[WalletKeyset] = await get_keyset(keyset.id, db=self.db) - # if not, store it - if keyset_local is None: - logger.debug(f"Storing new mint keyset: {keyset.id}") - await store_keyset(keyset=keyset, db=self.db) + if keyset_id and keyset_id != keyset.id: + # NOTE: Because of the upcoming change of how to calculate keyset ids + # with version 0.14.0, we overwrite the calculated keyset id with the + # requested one. This is a temporary fix and should be removed once all + # ecash is transitioned to 0.14.0. + logger.debug( + f"Keyset ID mismatch: {keyset_id} != {keyset.id}. This can happen due" + " to a version upgrade." + ) + keyset.id = keyset_id or keyset.id - self.keys = keyset - assert self.keys.public_keys - self.public_keys = self.keys.public_keys - assert self.keys.id - self.keyset_id = self.keys.id - logger.debug(f"Current mint keyset: {self.keys.id}") - return self.keys + # if the keyset is not in the database, store it + if keyset_local is None: + keyset_local_from_mint = await get_keyset(keyset.id, db=self.db) + if not keyset_local_from_mint: + logger.debug(f"Storing new mint keyset: {keyset.id}") + await store_keyset(keyset=keyset, db=self.db) + + # set current keyset id + self.keyset_id = keyset.id + logger.debug(f"Current mint keyset: {self.keyset_id}") + + # add keyset to keysets dict + self.keysets[keyset.id] = keyset async def _load_mint_keysets(self) -> List[str]: """Loads the keyset IDs of the mint. @@ -191,11 +216,13 @@ class LedgerAPI(object): try: mint_keysets = await self._get_keyset_ids(self.url) except Exception: - assert self.keys.id, "could not get keysets from mint, and do not have keys" + assert self.keysets[ + self.keyset_id + ].id, "could not get keysets from mint, and do not have keys" pass - self.keysets = mint_keysets or [self.keys.id] - logger.debug(f"Mint keysets: {self.keysets}") - return self.keysets + self.mint_keyset_ids = mint_keysets or [self.keysets[self.keyset_id].id] + logger.debug(f"Mint keysets: {self.mint_keyset_ids}") + return self.mint_keyset_ids async def _load_mint_info(self) -> GetInfoResponse: """Loads the mint info from the mint.""" @@ -207,7 +234,7 @@ class LedgerAPI(object): """ Loads the public keys of the mint. Either gets the keys for the specified `keyset_id` or gets the keys of the active keyset from the mint. - Gets the active keyset ids of the mint and stores in `self.keysets`. + Gets the active keyset ids of the mint and stores in `self.mint_keyset_ids`. """ await self._load_mint_keys(keyset_id) await self._load_mint_keysets() @@ -218,7 +245,9 @@ class LedgerAPI(object): pass if keyset_id: - assert keyset_id in self.keysets, f"keyset {keyset_id} not active on mint" + assert ( + keyset_id in self.mint_keyset_ids + ), f"keyset {keyset_id} not active on mint" async def _check_used_secrets(self, secrets): """Checks if any of the secrets have already been used""" @@ -565,7 +594,7 @@ class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets): async def load_mint(self, keyset_id: str = ""): """Load a mint's keys with a given keyset_id if specified or else loads the active keyset of the mint into self.keys. - Also loads all keyset ids into self.keysets. + Also loads all keyset ids into self.mint_keyset_ids. Args: keyset_id (str, optional): _description_. Defaults to "". @@ -815,14 +844,17 @@ class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets): logger.trace("No DLEQ proof in proof.") return logger.trace("Verifying DLEQ proof.") - assert self.keys.public_keys + assert proof.id + assert ( + proof.id in self.keysets + ), f"Keyset {proof.id} not known, can not verify DLEQ." if not b_dhke.carol_verify_dleq( secret_msg=proof.secret, C=PublicKey(bytes.fromhex(proof.C), raw=True), r=PrivateKey(bytes.fromhex(proof.dleq.r), raw=True), e=PrivateKey(bytes.fromhex(proof.dleq.e), raw=True), s=PrivateKey(bytes.fromhex(proof.dleq.s), raw=True), - A=self.keys.public_keys[proof.amount], + A=self.keysets[proof.id].public_keys[proof.amount], ): raise Exception("DLEQ proof invalid.") else: @@ -852,13 +884,15 @@ class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets): logger.trace("Constructing proofs.") proofs: List[Proof] = [] for promise, secret, r, path in zip(promises, secrets, rs, derivation_paths): - logger.trace(f"Creating proof with keyset {self.keyset_id} = {promise.id}") - assert ( - self.keyset_id == promise.id - ), "our keyset id does not match promise id." + if promise.id not in self.keysets: + # we don't have the keyset for this promise, so we load it + await self._load_mint_keys(promise.id) + assert promise.id in self.keysets, "Could not load keyset." C_ = PublicKey(bytes.fromhex(promise.C_), raw=True) - C = b_dhke.step3_alice(C_, r, self.public_keys[promise.amount]) + C = b_dhke.step3_alice( + C_, r, self.keysets[promise.id].public_keys[promise.amount] + ) B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs proof = Proof( @@ -950,7 +984,7 @@ class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets): if id is None: continue keyset_crud = await get_keyset(id=id, db=self.db) - assert keyset_crud is not None, "keyset not found" + assert keyset_crud is not None, f"keyset {id} not found" keyset: WalletKeyset = keyset_crud assert keyset.mint_url if keyset.mint_url not in ret: @@ -1101,7 +1135,7 @@ class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets): Rules: 1) Proofs that are not marked as reserved - 2) Proofs that have a keyset id that is in self.keysets (all active keysets of mint) + 2) Proofs that have a keyset id that is in self.mint_keyset_ids (all active keysets of mint) 3) Include all proofs that have an older keyset than the current keyset of the mint (to get rid of old epochs). 4) If the target amount is not reached, add proofs of the current keyset until it is. """ @@ -1111,19 +1145,23 @@ class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets): proofs = [p for p in proofs if not p.reserved] # select proofs that are in the active keysets of the mint - proofs = [p for p in proofs if p.id in self.keysets or not p.id] + proofs = [p for p in proofs if p.id in self.mint_keyset_ids or not p.id] # check that enough spendable proofs exist if sum_proofs(proofs) < amount_to_send: raise Exception("balance too low.") # add all proofs that have an older keyset than the current keyset of the mint - proofs_old_epochs = [p for p in proofs if p.id != self.keys.id] + proofs_old_epochs = [ + p for p in proofs if p.id != self.keysets[self.keyset_id].id + ] send_proofs += proofs_old_epochs # coinselect based on amount only from the current keyset # start with the proofs with the largest amount and add them until the target amount is reached - proofs_current_epoch = [p for p in proofs if p.id == self.keys.id] + proofs_current_epoch = [ + p for p in proofs if p.id == self.keysets[self.keyset_id].id + ] sorted_proofs_of_current_keyset = sorted( proofs_current_epoch, key=lambda p: p.amount ) diff --git a/tests/test_wallet.py b/tests/test_wallet.py index b851933..575e0bd 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -89,8 +89,8 @@ async def wallet3(mint): @pytest.mark.asyncio async def test_get_keys(wallet1: Wallet): - assert wallet1.keys.public_keys - assert len(wallet1.keys.public_keys) == settings.max_order + assert wallet1.keysets[wallet1.keyset_id].public_keys + assert len(wallet1.keysets[wallet1.keyset_id].public_keys) == settings.max_order keyset = await wallet1._get_keys(wallet1.url) assert keyset.id is not None assert keyset.id == "1cCNIAZ2X/w1" @@ -100,8 +100,8 @@ async def test_get_keys(wallet1: Wallet): @pytest.mark.asyncio async def test_get_keyset(wallet1: Wallet): - assert wallet1.keys.public_keys - assert len(wallet1.keys.public_keys) == settings.max_order + assert wallet1.keysets[wallet1.keyset_id].public_keys + assert len(wallet1.keysets[wallet1.keyset_id].public_keys) == settings.max_order # let's get the keys first so we can get a keyset ID that we use later keys1 = await wallet1._get_keys(wallet1.url) # gets the keys of a specific keyset