[Wallet] Refactor restore_promises_from_to (#307)

* refactor restore_promises_from_to

* fix mypy

* black

* fix tests
This commit is contained in:
callebtc
2023-08-25 23:50:16 +02:00
committed by GitHub
parent e374d32df7
commit f551624132
3 changed files with 48 additions and 22 deletions

View File

@@ -282,9 +282,9 @@ async def burn(
wallet = await mint_wallet(mint)
if not (all or token or force or delete) or (token and all):
raise Exception(
"enter a token or use --all to burn all pending tokens, --force to check"
" all tokensor --delete with send ID to force-delete pending token from"
" list if mint is unavailable.",
"enter a token or use --all to burn all pending tokens, --force to"
" check all tokensor --delete with send ID to force-delete pending"
" token from list if mint is unavailable.",
)
if all:
# check only those who are flagged as reserved
@@ -414,7 +414,7 @@ async def restore(
if to < 0:
raise Exception("Counter must be positive")
await wallet.load_mint()
await wallet.restore_promises(0, to)
await wallet.restore_promises_from_to(0, to)
await wallet.invalidate(wallet.proofs)
wallet.status()
return RestoreResponse(balance=wallet.available_balance)

View File

@@ -1655,7 +1655,7 @@ class Wallet(LedgerAPI):
n_last_restored_proofs = 0
while stop_counter < to:
print(f"Restoring token {i} to {i + batch}...")
restored_proofs = await self.restore_promises(i, i + batch - 1)
restored_proofs = await self.restore_promises_from_to(i, i + batch - 1)
if len(restored_proofs) == 0:
stop_counter += 1
spendable_proofs = await self.invalidate(restored_proofs)
@@ -1679,7 +1679,9 @@ class Wallet(LedgerAPI):
print("No tokens restored.")
return
async def restore_promises(self, from_counter: int, to_counter: int) -> List[Proof]:
async def restore_promises_from_to(
self, from_counter: int, to_counter: int
) -> List[Proof]:
"""Restores promises from a given range of counters. This is for restoring a wallet from a mnemonic.
Args:
@@ -1698,14 +1700,42 @@ class Wallet(LedgerAPI):
# we generate outptus from deterministic secrets and rs
regenerated_outputs, _ = self._construct_outputs(amounts_dummy, secrets, rs)
# we ask the mint to reissue the promises
# restored_outputs is there so we can match the promises to the secrets and rs
restored_outputs, restored_promises = await super().restore_promises(
regenerated_outputs
proofs = await self.restore_promises(
outputs=regenerated_outputs,
secrets=secrets,
rs=rs,
derivation_paths=derivation_paths,
)
await set_secret_derivation(
db=self.db, keyset_id=self.keyset_id, counter=to_counter + 1
)
return proofs
async def restore_promises(
self,
outputs: List[BlindedMessage],
secrets: List[str],
rs: List[PrivateKey],
derivation_paths: List[str],
) -> List[Proof]:
"""Restores proofs from a list of outputs, secrets, rs and derivation paths.
Args:
outputs (List[BlindedMessage]): Outputs for which we request promises
secrets (List[str]): Secrets generated for the outputs
rs (List[PrivateKey]): Random blinding factors generated for the outputs
derivation_paths (List[str]): Derivation paths for the secrets
Returns:
List[Proof]: List of restored proofs
"""
# restored_outputs is there so we can match the promises to the secrets and rs
restored_outputs, restored_promises = await super().restore_promises(outputs)
# now we need to filter out the secrets and rs that had a match
matching_indices = [
idx
for idx, val in enumerate(regenerated_outputs)
for idx, val in enumerate(outputs)
if val.B_ in [o.B_ for o in restored_outputs]
]
secrets = [secrets[i] for i in matching_indices]
@@ -1721,8 +1751,4 @@ class Wallet(LedgerAPI):
for proof in proofs:
if proof.secret not in [p.secret for p in self.proofs]:
self.proofs.append(proof)
await set_secret_derivation(
db=self.db, keyset_id=self.keyset_id, counter=to_counter + 1
)
return proofs

View File

@@ -387,7 +387,7 @@ async def test_restore_wallet_after_mint(wallet3: Wallet):
await wallet3.load_proofs()
wallet3.proofs = []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 20)
await wallet3.restore_promises_from_to(0, 20)
assert wallet3.balance == 64
@@ -419,7 +419,7 @@ async def test_restore_wallet_after_split_to_send(wallet3: Wallet):
await wallet3.load_proofs()
wallet3.proofs = []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 100)
await wallet3.restore_promises_from_to(0, 100)
assert wallet3.balance == 64 * 2
await wallet3.invalidate(wallet3.proofs)
assert wallet3.balance == 64
@@ -443,7 +443,7 @@ async def test_restore_wallet_after_send_and_receive(wallet3: Wallet, wallet2: W
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 100)
await wallet3.restore_promises_from_to(0, 100)
assert wallet3.balance == 64 + 2 * 32
await wallet3.invalidate(wallet3.proofs)
assert wallet3.balance == 32
@@ -482,7 +482,7 @@ async def test_restore_wallet_after_send_and_self_receive(wallet3: Wallet):
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 100)
await wallet3.restore_promises_from_to(0, 100)
assert wallet3.balance == 64 + 2 * 32 + 32
await wallet3.invalidate(wallet3.proofs)
assert wallet3.balance == 64
@@ -512,7 +512,7 @@ async def test_restore_wallet_after_send_twice(
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 10)
await wallet3.restore_promises_from_to(0, 10)
box.add(wallet3.proofs)
assert wallet3.balance == 5
await wallet3.invalidate(wallet3.proofs)
@@ -532,7 +532,7 @@ async def test_restore_wallet_after_send_twice(
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 15)
await wallet3.restore_promises_from_to(0, 15)
box.add(wallet3.proofs)
assert wallet3.balance == 7
await wallet3.invalidate(wallet3.proofs)
@@ -565,7 +565,7 @@ async def test_restore_wallet_after_send_and_self_receive_nonquadratic_value(
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 20)
await wallet3.restore_promises_from_to(0, 20)
box.add(wallet3.proofs)
assert wallet3.balance == 138
await wallet3.invalidate(wallet3.proofs)
@@ -583,7 +583,7 @@ async def test_restore_wallet_after_send_and_self_receive_nonquadratic_value(
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 50)
await wallet3.restore_promises_from_to(0, 50)
assert wallet3.balance == 182
await wallet3.invalidate(wallet3.proofs)
assert wallet3.balance == 64