[FIX] Wallet sort outputs before swapping (#648)

* sort proofs

* outputs-ordering

* mypy fix

* clean up

* test if output amounts are sorted

* clean up test

---------

Co-authored-by: callebtc <93376500+callebtc@users.noreply.github.com>
This commit is contained in:
lollerfirst
2024-11-05 15:00:37 +01:00
committed by GitHub
parent 9cdfba52a3
commit ed0d25dec7
2 changed files with 75 additions and 4 deletions

View File

@@ -456,7 +456,7 @@ class Wallet(
# sort by increasing amount
amounts_we_want.sort()
logger.debug(
logger.trace(
f"Amounts we have: {[(a, amounts_we_have.count(a)) for a in set(amounts_we_have)]}"
)
amounts: list[int] = []
@@ -470,7 +470,7 @@ class Wallet(
amounts += amount_split(remaining_amount)
amounts.sort()
logger.debug(f"Amounts we want: {amounts}")
logger.trace(f"Amounts we want: {amounts}")
if sum(amounts) != amount:
raise Exception(f"Amounts do not sum to {amount}.")
@@ -643,7 +643,7 @@ class Wallet(
proofs = self.add_witnesses_to_proofs(proofs)
input_fees = self.get_fees_for_proofs(proofs)
logger.debug(f"Input fees: {input_fees}")
logger.trace(f"Input fees: {input_fees}")
# create a suitable amounts to keep and send.
keep_outputs, send_outputs = self.determine_output_amounts(
proofs,
@@ -674,8 +674,22 @@ class Wallet(
# potentially add witnesses to outputs based on what requirement the proofs indicate
outputs = self.add_witnesses_to_outputs(proofs, outputs)
# sort outputs by amount, remember original order
sorted_outputs_with_indices = sorted(
enumerate(outputs), key=lambda p: p[1].amount
)
original_indices, sorted_outputs = zip(*sorted_outputs_with_indices)
# Call swap API
promises = await super().split(proofs, outputs)
sorted_promises = await super().split(proofs, sorted_outputs)
# sort promises back to original order
promises = [
promise
for _, promise in sorted(
zip(original_indices, sorted_promises), key=lambda x: x[0]
)
]
# Construct proofs from returned promises (i.e., unblind the signatures)
new_proofs = await self._construct_proofs(

View File

@@ -0,0 +1,57 @@
import json
import pytest
import pytest_asyncio
import respx
from httpx import Request, Response
from cashu.core.base import BlindedSignature
from cashu.core.crypto.b_dhke import hash_to_curve
from cashu.wallet.wallet import Wallet
from cashu.wallet.wallet import Wallet as Wallet1
from tests.conftest import SERVER_ENDPOINT
from tests.helpers import pay_if_regtest
@pytest_asyncio.fixture(scope="function")
async def wallet1(mint):
wallet1 = await Wallet1.with_db(
url=SERVER_ENDPOINT,
db="test_data/wallet1",
name="wallet1",
)
await wallet1.load_mint()
yield wallet1
@pytest.mark.asyncio
async def test_swap_outputs_are_sorted(wallet1: Wallet):
await wallet1.load_mint()
mint_quote = await wallet1.request_mint(16)
await pay_if_regtest(mint_quote.request)
await wallet1.mint(16, quote_id=mint_quote.quote, split=[16])
assert wallet1.balance == 16
test_url = f"{wallet1.url}/v1/swap"
key = hash_to_curve("test".encode("utf-8"))
mock_blind_signature = BlindedSignature(
id=wallet1.keyset_id,
amount=8,
C_=key.serialize().hex(),
)
mock_response_data = {"signatures": [mock_blind_signature.dict()]}
with respx.mock() as mock:
route = mock.post(test_url).mock(
return_value=Response(200, json=mock_response_data)
)
await wallet1.select_to_send(wallet1.proofs, 5)
assert route.called
assert route.call_count == 1
request: Request = route.calls[0].request
assert request.method == "POST"
assert request.url == test_url
request_data = json.loads(request.content.decode("utf-8"))
output_amounts = [o["amount"] for o in request_data["outputs"]]
# assert that output amounts are sorted
assert output_amounts == sorted(output_amounts)