mirror of
https://github.com/aljazceru/nutshell.git
synced 2025-12-20 18:44:20 +01:00
[Wallet] DB optimization for faster payments (#250)
* get rid of redundant proof loads * fix test? * fix one test? * api: load_mint for invoice * clean up tests
This commit is contained in:
@@ -2,6 +2,8 @@ import logging
|
||||
import sys
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
# from fastapi_profiler import PyInstrumentProfilerMiddleware
|
||||
from loguru import logger
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
@@ -90,6 +92,9 @@ def create_app(config_object="core.settings") -> FastAPI:
|
||||
},
|
||||
middleware=middleware,
|
||||
)
|
||||
|
||||
# app.add_middleware(PyInstrumentProfilerMiddleware)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from fastapi.responses import JSONResponse
|
||||
from ...core.settings import settings
|
||||
from .router import router
|
||||
|
||||
# from fastapi_profiler import PyInstrumentProfilerMiddleware
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
@@ -15,6 +17,8 @@ def create_app() -> FastAPI:
|
||||
"url": "https://raw.githubusercontent.com/cashubtc/cashu/main/LICENSE",
|
||||
},
|
||||
)
|
||||
# app.add_middleware(PyInstrumentProfilerMiddleware)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -9,16 +9,12 @@ class PayResponse(BaseModel):
|
||||
amount: int
|
||||
fee: int
|
||||
amount_with_fee: int
|
||||
initial_balance: int
|
||||
balance: int
|
||||
|
||||
|
||||
class InvoiceResponse(BaseModel):
|
||||
amount: Optional[int] = None
|
||||
invoice: Optional[Invoice] = None
|
||||
hash: Optional[str] = None
|
||||
initial_balance: int
|
||||
balance: int
|
||||
|
||||
|
||||
class BalanceResponse(BaseModel):
|
||||
|
||||
@@ -73,20 +73,15 @@ async def pay(
|
||||
global wallet
|
||||
wallet = await load_mint(wallet, mint)
|
||||
|
||||
await wallet.load_proofs()
|
||||
initial_balance = wallet.available_balance
|
||||
total_amount, fee_reserve_sat = await wallet.get_pay_amount_with_fees(invoice)
|
||||
assert total_amount > 0, "amount has to be larger than zero."
|
||||
assert wallet.available_balance >= total_amount, "balance is too low."
|
||||
_, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount)
|
||||
await wallet.pay_lightning(send_proofs, invoice, fee_reserve_sat)
|
||||
await wallet.load_proofs()
|
||||
return PayResponse(
|
||||
amount=total_amount - fee_reserve_sat,
|
||||
fee=fee_reserve_sat,
|
||||
amount_with_fee=total_amount,
|
||||
initial_balance=initial_balance,
|
||||
balance=wallet.available_balance,
|
||||
)
|
||||
|
||||
|
||||
@@ -115,29 +110,22 @@ async def invoice(
|
||||
|
||||
global wallet
|
||||
wallet = await load_mint(wallet, mint)
|
||||
initial_balance = wallet.available_balance
|
||||
if not settings.lightning:
|
||||
r = await wallet.mint(amount, split=optional_split)
|
||||
return InvoiceResponse(
|
||||
amount=amount,
|
||||
balance=wallet.available_balance,
|
||||
initial_balance=initial_balance,
|
||||
)
|
||||
elif amount and not hash:
|
||||
invoice = await wallet.request_mint(amount)
|
||||
return InvoiceResponse(
|
||||
amount=amount,
|
||||
invoice=invoice,
|
||||
balance=wallet.available_balance,
|
||||
initial_balance=initial_balance,
|
||||
)
|
||||
elif amount and hash:
|
||||
await wallet.mint(amount, split=optional_split, hash=hash)
|
||||
return InvoiceResponse(
|
||||
amount=amount,
|
||||
hash=hash,
|
||||
balance=wallet.available_balance,
|
||||
initial_balance=initial_balance,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -171,7 +159,6 @@ async def send_command(
|
||||
),
|
||||
):
|
||||
global wallet
|
||||
await wallet.load_proofs()
|
||||
if not nostr:
|
||||
balance, token = await send(
|
||||
wallet, amount, lock, legacy=False, split=not nosplit
|
||||
|
||||
@@ -85,7 +85,7 @@ def cli(ctx: Context, host: str, walletname: str):
|
||||
ctx.obj["HOST"], os.path.join(settings.cashu_dir, walletname), name=walletname
|
||||
)
|
||||
ctx.obj["WALLET"] = wallet
|
||||
asyncio.run(init_wallet(ctx.obj["WALLET"]))
|
||||
asyncio.run(init_wallet(ctx.obj["WALLET"], load_proofs=False))
|
||||
|
||||
# MUTLIMINT: Select a wallet
|
||||
# only if a command is one of a subset that needs to specify a mint host
|
||||
@@ -96,7 +96,7 @@ def cli(ctx: Context, host: str, walletname: str):
|
||||
ctx.obj["WALLET"] = asyncio.run(
|
||||
get_mint_wallet(ctx)
|
||||
) # select a specific wallet by CLI input
|
||||
asyncio.run(init_wallet(ctx.obj["WALLET"]))
|
||||
asyncio.run(init_wallet(ctx.obj["WALLET"], load_proofs=False))
|
||||
|
||||
|
||||
# https://github.com/pallets/click/issues/85#issuecomment-503464628
|
||||
@@ -134,7 +134,6 @@ async def pay(ctx: Context, invoice: str, yes: bool):
|
||||
return
|
||||
_, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount)
|
||||
await wallet.pay_lightning(send_proofs, invoice, fee_reserve_sat)
|
||||
await wallet.load_proofs()
|
||||
wallet.status()
|
||||
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ async def get_mint_wallet(ctx: Context):
|
||||
name=wallet.name,
|
||||
)
|
||||
# await mint_wallet.load_mint()
|
||||
await mint_wallet.load_proofs()
|
||||
await mint_wallet.load_proofs(reload=True)
|
||||
|
||||
return mint_wallet
|
||||
|
||||
|
||||
@@ -14,10 +14,11 @@ from ..wallet.crud import get_keyset, get_unused_locks
|
||||
from ..wallet.wallet import Wallet as Wallet
|
||||
|
||||
|
||||
async def init_wallet(wallet: Wallet):
|
||||
async def init_wallet(wallet: Wallet, load_proofs: bool = True):
|
||||
"""Performs migrations and loads proofs from db."""
|
||||
await migrate_databases(wallet.db, migrations)
|
||||
await wallet.load_proofs()
|
||||
if load_proofs:
|
||||
await wallet.load_proofs(reload=True)
|
||||
|
||||
|
||||
async def redeem_TokenV3_multimint(
|
||||
@@ -158,7 +159,7 @@ async def receive(
|
||||
print(f"Received {sum_proofs(proofs)} sats")
|
||||
|
||||
# reload main wallet so the balance updates
|
||||
await wallet.load_proofs()
|
||||
await wallet.load_proofs(reload=True)
|
||||
wallet.status()
|
||||
return wallet.available_balance
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
@@ -499,7 +500,12 @@ class Wallet(LedgerAPI):
|
||||
"""
|
||||
await super()._load_mint(keyset_id)
|
||||
|
||||
async def load_proofs(self):
|
||||
async def load_proofs(self, reload: bool = False):
|
||||
"""Load all proofs from the database."""
|
||||
|
||||
if self.proofs and not reload:
|
||||
logger.debug("Proofs already loaded.")
|
||||
return
|
||||
self.proofs = await get_proofs(db=self.db)
|
||||
|
||||
async def request_mint(self, amount):
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from cashu.core.migrations import migrate_databases
|
||||
@@ -7,41 +9,45 @@ from cashu.core.settings import settings
|
||||
from cashu.wallet import migrations
|
||||
from cashu.wallet.api.app import app
|
||||
from cashu.wallet.wallet import Wallet
|
||||
from tests.conftest import SERVER_ENDPOINT, mint
|
||||
|
||||
|
||||
async def init_wallet():
|
||||
wallet = Wallet(settings.mint_host, "data/wallet", "wallet")
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def wallet(mint):
|
||||
wallet = Wallet(SERVER_ENDPOINT, "data/test_wallet_api", "wallet_api")
|
||||
await migrate_databases(wallet.db, migrations)
|
||||
await wallet.load_proofs()
|
||||
return wallet
|
||||
await wallet.load_mint()
|
||||
wallet.status()
|
||||
yield wallet
|
||||
|
||||
|
||||
def test_invoice(mint):
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoice(wallet: Wallet):
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/invoice?amount=100")
|
||||
assert response.status_code == 200
|
||||
if settings.lightning:
|
||||
assert response.json()["invoice"]
|
||||
else:
|
||||
assert response.json()["balance"]
|
||||
assert response.json()["amount"]
|
||||
|
||||
|
||||
def test_invoice_with_split(mint):
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoice_with_split(wallet: Wallet):
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/invoice?amount=10&split=1")
|
||||
assert response.status_code == 200
|
||||
if settings.lightning:
|
||||
assert response.json()["invoice"]
|
||||
else:
|
||||
assert response.json()["balance"]
|
||||
assert response.json()["amount"]
|
||||
# wallet = asyncio.run(init_wallet())
|
||||
# asyncio.run(wallet.load_proofs())
|
||||
# assert wallet.proof_amounts.count(1) >= 10
|
||||
|
||||
|
||||
def test_balance():
|
||||
@pytest.mark.asyncio
|
||||
async def test_balance():
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/balance")
|
||||
assert response.status_code == 200
|
||||
@@ -50,33 +56,38 @@ def test_balance():
|
||||
assert response.json()["mints"]
|
||||
|
||||
|
||||
def test_send(mint):
|
||||
@pytest.mark.asyncio
|
||||
async def test_send(wallet: Wallet):
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/send?amount=10")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["balance"]
|
||||
|
||||
|
||||
def test_send_without_split(mint):
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_without_split(wallet: Wallet):
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/send?amount=1&nosplit=true")
|
||||
response = client.post("/send?amount=2&nosplit=true")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["balance"]
|
||||
|
||||
|
||||
def test_send_without_split_but_wrong_amount(mint):
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_without_split_but_wrong_amount(wallet: Wallet):
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/send?amount=10&nosplit=true")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_pending():
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending():
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/pending")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_receive_all(mint):
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_all(wallet: Wallet):
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/receive?all=true")
|
||||
assert response.status_code == 200
|
||||
@@ -84,7 +95,8 @@ def test_receive_all(mint):
|
||||
assert response.json()["balance"]
|
||||
|
||||
|
||||
def test_burn_all(mint):
|
||||
@pytest.mark.asyncio
|
||||
async def test_burn_all(wallet: Wallet):
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/send?amount=20")
|
||||
assert response.status_code == 200
|
||||
@@ -93,7 +105,8 @@ def test_burn_all(mint):
|
||||
assert response.json()["balance"]
|
||||
|
||||
|
||||
def test_pay():
|
||||
@pytest.mark.asyncio
|
||||
async def test_pay():
|
||||
with TestClient(app) as client:
|
||||
invoice = (
|
||||
"lnbc100n1pjzp22cpp58xvjxvagzywky9xz3vurue822aaax"
|
||||
@@ -109,50 +122,60 @@ def test_pay():
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_lock():
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock():
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/lock")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_locks():
|
||||
@pytest.mark.asyncio
|
||||
async def test_locks():
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/locks")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_invoices():
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoices():
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/invoices")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_wallets():
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallets():
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/wallets")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_info():
|
||||
@pytest.mark.asyncio
|
||||
async def test_info():
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/info")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["version"]
|
||||
|
||||
|
||||
def test_flow(mint):
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow(wallet: Wallet):
|
||||
with TestClient(app) as client:
|
||||
if not settings.lightning:
|
||||
response = client.get("/balance")
|
||||
initial_balance = response.json()["balance"]
|
||||
response = client.post("/invoice?amount=100")
|
||||
response = client.get("/balance")
|
||||
assert response.json()["balance"] == initial_balance + 100
|
||||
response = client.post("/send?amount=50")
|
||||
response = client.get("/balance")
|
||||
assert response.json()["balance"] == initial_balance + 50
|
||||
response = client.post("/send?amount=50")
|
||||
response = client.get("/balance")
|
||||
assert response.json()["balance"] == initial_balance
|
||||
response = client.get("/pending")
|
||||
token = response.json()["pending_token"]["0"]["token"]
|
||||
amount = response.json()["pending_token"]["0"]["amount"]
|
||||
response = client.post(f"/receive?token={token}")
|
||||
response = client.get("/balance")
|
||||
assert response.json()["balance"] == initial_balance + amount
|
||||
|
||||
Reference in New Issue
Block a user