Fix: Cast keyset keys (amount) to int (#368)

* load keys as integers

* add tests

* make format

* revert format from newer branch
This commit is contained in:
callebtc
2023-11-24 14:47:36 -03:00
committed by GitHub
parent b519c7db34
commit f7d0126805
3 changed files with 37 additions and 7 deletions

View File

@@ -348,6 +348,11 @@ class WalletKeyset:
self.public_keys = public_keys self.public_keys = public_keys
# overwrite id by deriving it from the public keys # overwrite id by deriving it from the public keys
self.id = derive_keyset_id(self.public_keys) self.id = derive_keyset_id(self.public_keys)
logger.trace(f"Derived keyset id {self.id} from public keys.")
if id and id != self.id:
logger.warning(
f"WARNING: Keyset id {self.id} does not match the given id {id}."
)
def serialize(self): def serialize(self):
return json.dumps( return json.dumps(
@@ -356,9 +361,9 @@ class WalletKeyset:
@classmethod @classmethod
def from_row(cls, row: Row): def from_row(cls, row: Row):
def deserialize(serialized: str): def deserialize(serialized: str) -> Dict[int, PublicKey]:
return { return {
amount: PublicKey(bytes.fromhex(hex_key), raw=True) int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
for amount, hex_key in dict(json.loads(serialized)).items() for amount, hex_key in dict(json.loads(serialized)).items()
} }

View File

@@ -171,21 +171,24 @@ class LedgerAPI(object):
keyset_local: Union[WalletKeyset, None] = None keyset_local: Union[WalletKeyset, None] = None
if keyset_id: if keyset_id:
# check if current keyset is in db # check if current keyset is in db
logger.trace(f"Checking if keyset {keyset_id} is in database.")
keyset_local = await get_keyset(keyset_id, db=self.db) keyset_local = await get_keyset(keyset_id, db=self.db)
if keyset_local: if keyset_local:
logger.debug(f"Found keyset {keyset_id} in database.") logger.trace(f"Found keyset {keyset_id} in database.")
else: else:
logger.debug( logger.trace(
f"Cannot find keyset {keyset_id} in database. Loading keyset from" f"Could not find keyset {keyset_id} in database. Loading keyset"
" mint." " from mint."
) )
keyset = keyset_local keyset = keyset_local
if keyset_local is None and keyset_id: if keyset_local is None and keyset_id:
# get requested keyset from mint # get requested keyset from mint
logger.trace(f"Getting keyset {keyset_id} from mint.")
keyset = await self._get_keys_of_keyset(self.url, keyset_id) keyset = await self._get_keys_of_keyset(self.url, keyset_id)
else: else:
# get current keyset # get current keyset
logger.trace("Getting current keyset from mint.")
keyset = await self._get_keys(self.url) keyset = await self._get_keys(self.url)
assert keyset assert keyset

View File

@@ -1,3 +1,4 @@
import copy
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
@@ -9,7 +10,7 @@ from cashu.core.base import Proof
from cashu.core.errors import CashuError, KeysetNotFoundError from cashu.core.errors import CashuError, KeysetNotFoundError
from cashu.core.helpers import sum_proofs from cashu.core.helpers import sum_proofs
from cashu.core.settings import settings from cashu.core.settings import settings
from cashu.wallet.crud import get_lightning_invoice, get_proofs from cashu.wallet.crud import get_keyset, get_lightning_invoice, get_proofs
from cashu.wallet.wallet import Wallet from cashu.wallet.wallet import Wallet
from cashu.wallet.wallet import Wallet as Wallet1 from cashu.wallet.wallet import Wallet as Wallet1
from cashu.wallet.wallet import Wallet as Wallet2 from cashu.wallet.wallet import Wallet as Wallet2
@@ -114,6 +115,27 @@ async def test_get_keyset(wallet1: Wallet):
assert len(keys1.public_keys) == len(keys2.public_keys) assert len(keys1.public_keys) == len(keys2.public_keys)
@pytest.mark.asyncio
async def test_get_keyset_from_db(wallet1: Wallet):
# first load it from the mint
# await wallet1._load_mint_keys()
# NOTE: conftest already called wallet.load_mint() which got the keys from the mint
keyset1 = copy.copy(wallet1.keysets[wallet1.keyset_id])
# then load it from the db
await wallet1._load_mint_keys()
keyset2 = copy.copy(wallet1.keysets[wallet1.keyset_id])
assert keyset1.public_keys == keyset2.public_keys
assert keyset1.id == keyset2.id
# load it directly from the db
keyset3 = await get_keyset(db=wallet1.db, id=keyset1.id)
assert keyset3
assert keyset1.public_keys == keyset3.public_keys
assert keyset1.id == keyset3.id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_info(wallet1: Wallet): async def test_get_info(wallet1: Wallet):
info = await wallet1._get_info(wallet1.url) info = await wallet1._get_info(wallet1.url)