mirror of
https://github.com/aljazceru/nutshell.git
synced 2026-02-03 07:44:21 +01:00
[FIX] Wallet sort outputs before swapping (#648)
* sort proofs * outputs-ordering * mypy fix * clean up * test if output amounts are sorted * clean up test --------- Co-authored-by: callebtc <93376500+callebtc@users.noreply.github.com>
This commit is contained in:
@@ -456,7 +456,7 @@ class Wallet(
|
||||
# sort by increasing amount
|
||||
amounts_we_want.sort()
|
||||
|
||||
logger.debug(
|
||||
logger.trace(
|
||||
f"Amounts we have: {[(a, amounts_we_have.count(a)) for a in set(amounts_we_have)]}"
|
||||
)
|
||||
amounts: list[int] = []
|
||||
@@ -470,7 +470,7 @@ class Wallet(
|
||||
amounts += amount_split(remaining_amount)
|
||||
amounts.sort()
|
||||
|
||||
logger.debug(f"Amounts we want: {amounts}")
|
||||
logger.trace(f"Amounts we want: {amounts}")
|
||||
if sum(amounts) != amount:
|
||||
raise Exception(f"Amounts do not sum to {amount}.")
|
||||
|
||||
@@ -643,7 +643,7 @@ class Wallet(
|
||||
proofs = self.add_witnesses_to_proofs(proofs)
|
||||
|
||||
input_fees = self.get_fees_for_proofs(proofs)
|
||||
logger.debug(f"Input fees: {input_fees}")
|
||||
logger.trace(f"Input fees: {input_fees}")
|
||||
# create a suitable amounts to keep and send.
|
||||
keep_outputs, send_outputs = self.determine_output_amounts(
|
||||
proofs,
|
||||
@@ -674,8 +674,22 @@ class Wallet(
|
||||
# potentially add witnesses to outputs based on what requirement the proofs indicate
|
||||
outputs = self.add_witnesses_to_outputs(proofs, outputs)
|
||||
|
||||
# sort outputs by amount, remember original order
|
||||
sorted_outputs_with_indices = sorted(
|
||||
enumerate(outputs), key=lambda p: p[1].amount
|
||||
)
|
||||
original_indices, sorted_outputs = zip(*sorted_outputs_with_indices)
|
||||
|
||||
# Call swap API
|
||||
promises = await super().split(proofs, outputs)
|
||||
sorted_promises = await super().split(proofs, sorted_outputs)
|
||||
|
||||
# sort promises back to original order
|
||||
promises = [
|
||||
promise
|
||||
for _, promise in sorted(
|
||||
zip(original_indices, sorted_promises), key=lambda x: x[0]
|
||||
)
|
||||
]
|
||||
|
||||
# Construct proofs from returned promises (i.e., unblind the signatures)
|
||||
new_proofs = await self._construct_proofs(
|
||||
|
||||
57
tests/test_wallet_requests.py
Normal file
57
tests/test_wallet_requests.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import respx
|
||||
from httpx import Request, Response
|
||||
|
||||
from cashu.core.base import BlindedSignature
|
||||
from cashu.core.crypto.b_dhke import hash_to_curve
|
||||
from cashu.wallet.wallet import Wallet
|
||||
from cashu.wallet.wallet import Wallet as Wallet1
|
||||
from tests.conftest import SERVER_ENDPOINT
|
||||
from tests.helpers import pay_if_regtest
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def wallet1(mint):
|
||||
wallet1 = await Wallet1.with_db(
|
||||
url=SERVER_ENDPOINT,
|
||||
db="test_data/wallet1",
|
||||
name="wallet1",
|
||||
)
|
||||
await wallet1.load_mint()
|
||||
yield wallet1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swap_outputs_are_sorted(wallet1: Wallet):
|
||||
await wallet1.load_mint()
|
||||
mint_quote = await wallet1.request_mint(16)
|
||||
await pay_if_regtest(mint_quote.request)
|
||||
await wallet1.mint(16, quote_id=mint_quote.quote, split=[16])
|
||||
assert wallet1.balance == 16
|
||||
|
||||
test_url = f"{wallet1.url}/v1/swap"
|
||||
key = hash_to_curve("test".encode("utf-8"))
|
||||
mock_blind_signature = BlindedSignature(
|
||||
id=wallet1.keyset_id,
|
||||
amount=8,
|
||||
C_=key.serialize().hex(),
|
||||
)
|
||||
mock_response_data = {"signatures": [mock_blind_signature.dict()]}
|
||||
with respx.mock() as mock:
|
||||
route = mock.post(test_url).mock(
|
||||
return_value=Response(200, json=mock_response_data)
|
||||
)
|
||||
await wallet1.select_to_send(wallet1.proofs, 5)
|
||||
|
||||
assert route.called
|
||||
assert route.call_count == 1
|
||||
request: Request = route.calls[0].request
|
||||
assert request.method == "POST"
|
||||
assert request.url == test_url
|
||||
request_data = json.loads(request.content.decode("utf-8"))
|
||||
output_amounts = [o["amount"] for o in request_data["outputs"]]
|
||||
# assert that output amounts are sorted
|
||||
assert output_amounts == sorted(output_amounts)
|
||||
Reference in New Issue
Block a user