Add get quote API to wallet + check proof states in batches (#637)

* add get quote api to wallet

* wrong string

* test before pushing

* fix tests for deprecated api only

* sigh
This commit is contained in:
callebtc
2024-10-08 18:12:10 +02:00
committed by GitHub
parent cd39e18916
commit 4490cc6fce
6 changed files with 102 additions and 22 deletions

View File

@@ -381,8 +381,12 @@ async def invoice(
while time.time() < check_until and not paid: while time.time() < check_until and not paid:
await asyncio.sleep(5) await asyncio.sleep(5)
try: try:
await wallet.mint(amount, split=optional_split, id=invoice.id) mint_quote_resp = await wallet.get_mint_quote(invoice.id)
paid = True if mint_quote_resp.state == MintQuoteState.paid.value:
await wallet.mint(amount, split=optional_split, id=invoice.id)
paid = True
else:
print(".", end="", flush=True)
except Exception as e: except Exception as e:
# TODO: user error codes! # TODO: user error codes!
if "not paid" in str(e): if "not paid" in str(e):
@@ -710,12 +714,7 @@ async def burn(ctx: Context, token: str, all: bool, force: bool, delete: str):
if delete: if delete:
await wallet.invalidate(proofs) await wallet.invalidate(proofs)
else: else:
# invalidate proofs in batches await wallet.invalidate(proofs, check_spendable=True)
for _proofs in [
proofs[i : i + settings.proofs_batch_size]
for i in range(0, len(proofs), settings.proofs_batch_size)
]:
await wallet.invalidate(_proofs, check_spendable=True)
await print_balance(ctx) await print_balance(ctx)
@@ -1024,7 +1023,9 @@ async def info(ctx: Context, mint: bool, mnemonic: bool):
if mint_info.get("time"): if mint_info.get("time"):
print(f" - Server time: {mint_info['time']}") print(f" - Server time: {mint_info['time']}")
if mint_info.get("nuts"): if mint_info.get("nuts"):
nuts_str = ', '.join([f"NUT-{k}" for k in mint_info['nuts'].keys()]) nuts_str = ", ".join(
[f"NUT-{k}" for k in mint_info["nuts"].keys()]
)
print(f" - Supported NUTS: {nuts_str}") print(f" - Supported NUTS: {nuts_str}")
print("") print("")
except Exception as e: except Exception as e:

View File

@@ -170,7 +170,7 @@ class LedgerAPI(LedgerAPIDeprecated):
keys_dict: dict = resp.json() keys_dict: dict = resp.json()
assert len(keys_dict), Exception("did not receive any keys") assert len(keys_dict), Exception("did not receive any keys")
keys = KeysResponse.parse_obj(keys_dict) keys = KeysResponse.parse_obj(keys_dict)
keysets_str = ' '.join([f"{k.id} ({k.unit})" for k in keys.keysets]) keysets_str = " ".join([f"{k.id} ({k.unit})" for k in keys.keysets])
logger.debug(f"Received {len(keys.keysets)} keysets from mint: {keysets_str}.") logger.debug(f"Received {len(keys.keysets)} keysets from mint: {keysets_str}.")
ret = [ ret = [
WalletKeyset( WalletKeyset(
@@ -312,6 +312,24 @@ class LedgerAPI(LedgerAPIDeprecated):
return_dict = resp.json() return_dict = resp.json()
return PostMintQuoteResponse.parse_obj(return_dict) return PostMintQuoteResponse.parse_obj(return_dict)
@async_set_httpx_client
@async_ensure_mint_loaded
async def get_mint_quote(self, quote: str) -> PostMintQuoteResponse:
"""Returns an existing mint quote from the server.
Args:
quote (str): Quote ID
Returns:
PostMintQuoteResponse: Mint Quote Response
"""
resp = await self.httpx.get(
join(self.url, f"/v1/mint/quote/bolt11/{quote}"),
)
self.raise_on_error_request(resp)
return_dict = resp.json()
return PostMintQuoteResponse.parse_obj(return_dict)
@async_set_httpx_client @async_set_httpx_client
@async_ensure_mint_loaded @async_ensure_mint_loaded
async def mint( async def mint(
@@ -400,6 +418,24 @@ class LedgerAPI(LedgerAPIDeprecated):
return_dict = resp.json() return_dict = resp.json()
return PostMeltQuoteResponse.parse_obj(return_dict) return PostMeltQuoteResponse.parse_obj(return_dict)
@async_set_httpx_client
@async_ensure_mint_loaded
async def get_melt_quote(self, quote: str) -> PostMeltQuoteResponse:
"""Returns an existing melt quote from the server.
Args:
quote (str): Quote ID
Returns:
PostMeltQuoteResponse: Melt Quote Response
"""
resp = await self.httpx.get(
join(self.url, f"/v1/melt/quote/bolt11/{quote}"),
)
self.raise_on_error_request(resp)
return_dict = resp.json()
return PostMeltQuoteResponse.parse_obj(return_dict)
@async_set_httpx_client @async_set_httpx_client
@async_ensure_mint_loaded @async_ensure_mint_loaded
async def melt( async def melt(

View File

@@ -162,7 +162,7 @@ class Wallet(
self.keysets = {k.id: k for k in keysets_active_unit} self.keysets = {k.id: k for k in keysets_active_unit}
else: else:
self.keysets = {k.id: k for k in keysets_list} self.keysets = {k.id: k for k in keysets_list}
keysets_str = ' '.join([f"{i} {k.unit}" for i, k in self.keysets.items()]) keysets_str = " ".join([f"{i} {k.unit}" for i, k in self.keysets.items()])
logger.debug(f"Loaded keysets: {keysets_str}") logger.debug(f"Loaded keysets: {keysets_str}")
return self return self
@@ -351,10 +351,9 @@ class Wallet(
for keyset_id in self.keysets: for keyset_id in self.keysets:
proofs = await get_proofs(db=self.db, id=keyset_id, conn=conn) proofs = await get_proofs(db=self.db, id=keyset_id, conn=conn)
self.proofs.extend(proofs) self.proofs.extend(proofs)
keysets_str = ' '.join([f"{k.id} ({k.unit})" for k in self.keysets.values()]) keysets_str = " ".join([f"{k.id} ({k.unit})" for k in self.keysets.values()])
logger.trace(f"Proofs loaded for keysets: {keysets_str}") logger.trace(f"Proofs loaded for keysets: {keysets_str}")
async def load_keysets_from_db( async def load_keysets_from_db(
self, url: Union[str, None] = "", unit: Union[str, None] = "" self, url: Union[str, None] = "", unit: Union[str, None] = ""
): ):
@@ -1020,10 +1019,15 @@ class Wallet(
""" """
invalidated_proofs: List[Proof] = [] invalidated_proofs: List[Proof] = []
if check_spendable: if check_spendable:
proof_states = await self.check_proof_state(proofs) # checks proofs in batches
for i, state in enumerate(proof_states.states): for _proofs in [
if state.spent: proofs[i : i + settings.proofs_batch_size]
invalidated_proofs.append(proofs[i]) for i in range(0, len(proofs), settings.proofs_batch_size)
]:
proof_states = await self.check_proof_state(proofs)
for i, state in enumerate(proof_states.states):
if state.spent:
invalidated_proofs.append(proofs[i])
else: else:
invalidated_proofs = proofs invalidated_proofs = proofs
@@ -1033,9 +1037,12 @@ class Wallet(
f" {self.unit.str(sum_proofs(invalidated_proofs))}." f" {self.unit.str(sum_proofs(invalidated_proofs))}."
) )
async with self.db.connect() as conn: for p in invalidated_proofs:
for p in invalidated_proofs: try:
await invalidate_proof(p, db=self.db, conn=conn) # mark proof as spent
await invalidate_proof(p, db=self.db)
except Exception as e:
logger.error(f"DB error while invalidating proof: {e}")
invalidate_secrets = [p.secret for p in invalidated_proofs] invalidate_secrets = [p.secret for p in invalidated_proofs]
self.proofs = list( self.proofs = list(

View File

@@ -4,6 +4,7 @@ import pytest_asyncio
from cashu.core.base import MeltQuoteState from cashu.core.base import MeltQuoteState
from cashu.core.helpers import sum_proofs from cashu.core.helpers import sum_proofs
from cashu.core.models import PostMeltQuoteRequest, PostMintQuoteRequest from cashu.core.models import PostMeltQuoteRequest, PostMintQuoteRequest
from cashu.core.settings import settings
from cashu.mint.ledger import Ledger from cashu.mint.ledger import Ledger
from cashu.wallet.wallet import Wallet from cashu.wallet.wallet import Wallet
from cashu.wallet.wallet import Wallet as Wallet1 from cashu.wallet.wallet import Wallet as Wallet1
@@ -55,6 +56,13 @@ async def test_melt_internal(wallet1: Wallet, ledger: Ledger):
assert melt_quote.amount == 64 assert melt_quote.amount == 64
assert melt_quote.fee_reserve == 0 assert melt_quote.fee_reserve == 0
if not settings.debug_mint_only_deprecated:
melt_quote_response_pre_payment = await wallet1.get_melt_quote(melt_quote.quote)
assert (
not melt_quote_response_pre_payment.state == MeltQuoteState.paid.value
), "melt quote should not be paid"
assert melt_quote_response_pre_payment.amount == 64
melt_quote_pre_payment = await ledger.get_melt_quote(melt_quote.quote) melt_quote_pre_payment = await ledger.get_melt_quote(melt_quote.quote)
assert not melt_quote_pre_payment.paid, "melt quote should not be paid" assert not melt_quote_pre_payment.paid, "melt quote should not be paid"
assert melt_quote_pre_payment.unpaid assert melt_quote_pre_payment.unpaid
@@ -89,6 +97,13 @@ async def test_melt_external(wallet1: Wallet, ledger: Ledger):
PostMeltQuoteRequest(request=invoice_payment_request, unit="sat") PostMeltQuoteRequest(request=invoice_payment_request, unit="sat")
) )
if not settings.debug_mint_only_deprecated:
melt_quote_response_pre_payment = await wallet1.get_melt_quote(melt_quote.quote)
assert (
melt_quote_response_pre_payment.state == MeltQuoteState.unpaid.value
), "melt quote should not be paid"
assert melt_quote_response_pre_payment.amount == melt_quote.amount
melt_quote_pre_payment = await ledger.get_melt_quote(melt_quote.quote) melt_quote_pre_payment = await ledger.get_melt_quote(melt_quote.quote)
assert not melt_quote_pre_payment.paid, "melt quote should not be paid" assert not melt_quote_pre_payment.paid, "melt quote should not be paid"
assert melt_quote_pre_payment.unpaid assert melt_quote_pre_payment.unpaid
@@ -109,7 +124,12 @@ async def test_mint_internal(wallet1: Wallet, ledger: Ledger):
mint_quote = await ledger.get_mint_quote(invoice.id) mint_quote = await ledger.get_mint_quote(invoice.id)
assert mint_quote.paid, "mint quote should be paid" assert mint_quote.paid, "mint quote should be paid"
assert mint_quote.paid
if not settings.debug_mint_only_deprecated:
mint_quote_resp = await wallet1.get_mint_quote(invoice.id)
assert (
mint_quote_resp.state == MeltQuoteState.paid.value
), "mint quote should be paid"
output_amounts = [128] output_amounts = [128]
secrets, rs, derivation_paths = await wallet1.generate_n_secrets( secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
@@ -139,6 +159,10 @@ async def test_mint_external(wallet1: Wallet, ledger: Ledger):
assert not mint_quote.paid, "mint quote already paid" assert not mint_quote.paid, "mint quote already paid"
assert mint_quote.unpaid assert mint_quote.unpaid
if not settings.debug_mint_only_deprecated:
mint_quote_resp = await wallet1.get_mint_quote(quote.quote)
assert not mint_quote_resp.paid, "mint quote should not be paid"
await assert_err( await assert_err(
wallet1.mint(128, id=quote.quote), wallet1.mint(128, id=quote.quote),
"quote not paid", "quote not paid",

View File

@@ -4,7 +4,7 @@ from typing import List, Union
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from cashu.core.base import Proof from cashu.core.base import MintQuoteState, Proof
from cashu.core.errors import CashuError, KeysetNotFoundError from cashu.core.errors import CashuError, KeysetNotFoundError
from cashu.core.helpers import sum_proofs from cashu.core.helpers import sum_proofs
from cashu.core.settings import settings from cashu.core.settings import settings
@@ -168,6 +168,11 @@ async def test_request_mint(wallet1: Wallet):
async def test_mint(wallet1: Wallet): async def test_mint(wallet1: Wallet):
invoice = await wallet1.request_mint(64) invoice = await wallet1.request_mint(64)
await pay_if_regtest(invoice.bolt11) await pay_if_regtest(invoice.bolt11)
if not settings.debug_mint_only_deprecated:
quote_resp = await wallet1.get_mint_quote(invoice.id)
assert quote_resp.request == invoice.bolt11
assert quote_resp.state == MintQuoteState.paid.value
expected_proof_amounts = wallet1.split_wallet_state(64) expected_proof_amounts = wallet1.split_wallet_state(64)
await wallet1.mint(64, id=invoice.id) await wallet1.mint(64, id=invoice.id)
assert wallet1.balance == 64 assert wallet1.balance == 64
@@ -307,6 +312,10 @@ async def test_melt(wallet1: Wallet):
assert total_amount == 64 assert total_amount == 64
assert quote.fee_reserve == 0 assert quote.fee_reserve == 0
if not settings.debug_mint_only_deprecated:
quote_resp = await wallet1.get_melt_quote(quote.quote)
assert quote_resp.amount == quote.amount
_, send_proofs = await wallet1.swap_to_send(wallet1.proofs, total_amount) _, send_proofs = await wallet1.swap_to_send(wallet1.proofs, total_amount)
melt_response = await wallet1.melt( melt_response = await wallet1.melt(

View File

@@ -110,6 +110,9 @@ def test_balance(cli_prefix):
@pytest.mark.skipif(is_regtest, reason="only works with FakeWallet") @pytest.mark.skipif(is_regtest, reason="only works with FakeWallet")
def test_invoice(mint, cli_prefix): def test_invoice(mint, cli_prefix):
if settings.debug_mint_only_deprecated:
pytest.skip("only works with v1 API")
runner = CliRunner() runner = CliRunner()
result = runner.invoke( result = runner.invoke(
cli, cli,