Fix: Cast keyset keys (amount) to int (#368)

* load keys as integers

* add tests

* make format

* revert format from newer branch
This commit is contained in:
callebtc
2023-11-24 14:47:36 -03:00
committed by GitHub
parent b519c7db34
commit f7d0126805
3 changed files with 37 additions and 7 deletions

View File

@@ -348,6 +348,11 @@ class WalletKeyset:
self.public_keys = public_keys
# overwrite id by deriving it from the public keys
self.id = derive_keyset_id(self.public_keys)
logger.trace(f"Derived keyset id {self.id} from public keys.")
if id and id != self.id:
logger.warning(
f"WARNING: Keyset id {self.id} does not match the given id {id}."
)
def serialize(self):
return json.dumps(
@@ -356,9 +361,9 @@ class WalletKeyset:
@classmethod
def from_row(cls, row: Row):
def deserialize(serialized: str):
def deserialize(serialized: str) -> Dict[int, PublicKey]:
return {
amount: PublicKey(bytes.fromhex(hex_key), raw=True)
int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
for amount, hex_key in dict(json.loads(serialized)).items()
}

View File

@@ -171,21 +171,24 @@ class LedgerAPI(object):
keyset_local: Union[WalletKeyset, None] = None
if keyset_id:
# check if current keyset is in db
logger.trace(f"Checking if keyset {keyset_id} is in database.")
keyset_local = await get_keyset(keyset_id, db=self.db)
if keyset_local:
logger.debug(f"Found keyset {keyset_id} in database.")
logger.trace(f"Found keyset {keyset_id} in database.")
else:
logger.debug(
f"Cannot find keyset {keyset_id} in database. Loading keyset from"
" mint."
logger.trace(
f"Could not find keyset {keyset_id} in database. Loading keyset"
" from mint."
)
keyset = keyset_local
if keyset_local is None and keyset_id:
# get requested keyset from mint
logger.trace(f"Getting keyset {keyset_id} from mint.")
keyset = await self._get_keys_of_keyset(self.url, keyset_id)
else:
# get current keyset
logger.trace("Getting current keyset from mint.")
keyset = await self._get_keys(self.url)
assert keyset

View File

@@ -1,3 +1,4 @@
import copy
import shutil
from pathlib import Path
from typing import List, Union
@@ -9,7 +10,7 @@ from cashu.core.base import Proof
from cashu.core.errors import CashuError, KeysetNotFoundError
from cashu.core.helpers import sum_proofs
from cashu.core.settings import settings
from cashu.wallet.crud import get_lightning_invoice, get_proofs
from cashu.wallet.crud import get_keyset, get_lightning_invoice, get_proofs
from cashu.wallet.wallet import Wallet
from cashu.wallet.wallet import Wallet as Wallet1
from cashu.wallet.wallet import Wallet as Wallet2
@@ -114,6 +115,27 @@ async def test_get_keyset(wallet1: Wallet):
assert len(keys1.public_keys) == len(keys2.public_keys)
@pytest.mark.asyncio
async def test_get_keyset_from_db(wallet1: Wallet):
# first load it from the mint
# await wallet1._load_mint_keys()
# NOTE: conftest already called wallet.load_mint() which got the keys from the mint
keyset1 = copy.copy(wallet1.keysets[wallet1.keyset_id])
# then load it from the db
await wallet1._load_mint_keys()
keyset2 = copy.copy(wallet1.keysets[wallet1.keyset_id])
assert keyset1.public_keys == keyset2.public_keys
assert keyset1.id == keyset2.id
# load it directly from the db
keyset3 = await get_keyset(db=wallet1.db, id=keyset1.id)
assert keyset3
assert keyset1.public_keys == keyset3.public_keys
assert keyset1.id == keyset3.id
@pytest.mark.asyncio
async def test_get_info(wallet1: Wallet):
info = await wallet1._get_info(wallet1.url)