Mint: watchdog balance log and killswitch (#705)

* wip store balance

* store balances in watchdog worker

* move mint_auth_database setting

* auth db

* balances returned as Amount (instead of int)

* add test for balance change on invoice receive

* fix 1 test

* cancel tasks on shutdown

* watchdog can now abort

* remove wallet api server

* fix lndgrpc

* fix lnbits balance

* disable watchdog

* balance lnbits msat

* test db watcher with its own database connection

* init superclass only once

* wip: log balance in keysets table

* check max balance using new keyset balance

* fix test

* fix another test

* store fees in keysets

* format

* cleanup

* shorter

* add keyset migration to auth server

* fix fakewallet

* fix db tests

* fix postgres problems during migration 26 (mint)

* fix cln

* ledger

* working with pending

* super fast watchdog, errors

* test new pipeline

* delete walletapi

* delete unneeded files

* revert workflows
This commit is contained in:
callebtc
2025-05-11 20:29:13 +02:00
committed by GitHub
parent 38bdb9ce76
commit fc0e3fe663
41 changed files with 938 additions and 960 deletions

View File

@@ -71,3 +71,6 @@ docker-build:
cd docker-build cd docker-build
docker buildx build -f Dockerfile -t cashubtc/nutshell:0.15.0 --platform linux/amd64 . docker buildx build -f Dockerfile -t cashubtc/nutshell:0.15.0 --platform linux/amd64 .
# docker push cashubtc/nutshell:0.15.0 # docker push cashubtc/nutshell:0.15.0
clear-postgres:
psql cashu -c "DROP SCHEMA public CASCADE;" -c "CREATE SCHEMA public;" -c "GRANT ALL PRIVILEGES ON SCHEMA public TO cashu;"

View File

@@ -1,4 +1,5 @@
import base64 import base64
import datetime
import json import json
import math import math
import time import time
@@ -11,6 +12,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union
import cbor2 import cbor2
from loguru import logger from loguru import logger
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from sqlalchemy import RowMapping
from cashu.core.json_rpc.base import JSONRPCSubscriptionKinds from cashu.core.json_rpc.base import JSONRPCSubscriptionKinds
@@ -551,7 +553,7 @@ class Unit(Enum):
btc = 4 btc = 4
auth = 999 auth = 999
def str(self, amount: int) -> str: def str(self, amount: int | float) -> str:
if self == Unit.sat: if self == Unit.sat:
return f"{amount} sat" return f"{amount} sat"
elif self == Unit.msat: elif self == Unit.msat:
@@ -631,6 +633,62 @@ class Amount:
def __repr__(self): def __repr__(self):
return self.unit.str(self.amount) return self.unit.str(self.amount)
def __add__(self, other: "Amount | int") -> "Amount":
if isinstance(other, int):
return Amount(self.unit, self.amount + other)
if self.unit != other.unit:
raise Exception("Units must be the same")
return Amount(self.unit, self.amount + other.amount)
def __sub__(self, other: "Amount | int") -> "Amount":
if isinstance(other, int):
return Amount(self.unit, self.amount - other)
if self.unit != other.unit:
raise Exception("Units must be the same")
return Amount(self.unit, self.amount - other.amount)
def __mul__(self, other: int) -> "Amount":
return Amount(self.unit, self.amount * other)
def __eq__(self, other: object) -> bool:
if isinstance(other, int):
return self.amount == other
if isinstance(other, Amount):
if self.unit != other.unit:
raise Exception("Units must be the same")
return self.amount == other.amount
return False
def __lt__(self, other: "Amount | int") -> bool:
if isinstance(other, int):
return self.amount < other
if self.unit != other.unit:
raise Exception("Units must be the same")
return self.amount < other.amount
def __le__(self, other: "Amount | int") -> bool:
if isinstance(other, int):
return self.amount <= other
if self.unit != other.unit:
raise Exception("Units must be the same")
return self.amount <= other.amount
def __gt__(self, other: "Amount | int") -> bool:
if isinstance(other, int):
return self.amount > other
if self.unit != other.unit:
raise Exception("Units must be the same")
return self.amount > other.amount
def __ge__(self, other: "Amount | int") -> bool:
if isinstance(other, int):
return self.amount >= other
if self.unit != other.unit:
raise Exception("Units must be the same")
return self.amount >= other.amount
class Method(Enum): class Method(Enum):
bolt11 = 0 bolt11 = 0
@@ -736,6 +794,7 @@ class MintKeyset:
first_seen: Optional[str] = None first_seen: Optional[str] = None
version: Optional[str] = None version: Optional[str] = None
amounts: List[int] amounts: List[int]
balance: int
duplicate_keyset_id: Optional[str] = None # BACKWARDS COMPATIBILITY < 0.15.0 duplicate_keyset_id: Optional[str] = None # BACKWARDS COMPATIBILITY < 0.15.0
@@ -755,6 +814,8 @@ class MintKeyset:
version: Optional[str] = None, version: Optional[str] = None,
input_fee_ppk: Optional[int] = None, input_fee_ppk: Optional[int] = None,
id: str = "", id: str = "",
balance: int = 0,
fees_paid: int = 0,
): ):
DEFAULT_SEED = "supersecretprivatekey" DEFAULT_SEED = "supersecretprivatekey"
if seed == DEFAULT_SEED: if seed == DEFAULT_SEED:
@@ -787,6 +848,8 @@ class MintKeyset:
self.first_seen = first_seen self.first_seen = first_seen
self.active = bool(active) if active is not None else False self.active = bool(active) if active is not None else False
self.version = version or settings.version self.version = version or settings.version
self.balance = balance
self.fees_paid = fees_paid
self.input_fee_ppk = input_fee_ppk or 0 self.input_fee_ppk = input_fee_ppk or 0
if self.input_fee_ppk < 0: if self.input_fee_ppk < 0:
@@ -840,6 +903,8 @@ class MintKeyset:
version=row["version"], version=row["version"],
input_fee_ppk=row["input_fee_ppk"], input_fee_ppk=row["input_fee_ppk"],
amounts=json.loads(row["amounts"]), amounts=json.loads(row["amounts"]),
balance=row["balance"],
fees_paid=row["fees_paid"],
) )
@property @property
@@ -1343,3 +1408,24 @@ class WalletMint(BaseModel):
refresh_token: Optional[str] = None refresh_token: Optional[str] = None
username: Optional[str] = None username: Optional[str] = None
password: Optional[str] = None password: Optional[str] = None
class MintBalanceLogEntry(BaseModel):
unit: Unit
backend_balance: Amount
keyset_balance: Amount
keyset_fees_paid: Amount
time: datetime.datetime
@classmethod
def from_row(cls, row: RowMapping):
return cls(
unit=Unit[row["unit"]],
backend_balance=Amount(
Unit[row["unit"]],
row["backend_balance"],
),
keyset_balance=Amount(Unit[row["unit"]], row["keyset_balance"]),
keyset_fees_paid=Amount(Unit[row["unit"]], row["keyset_fees_paid"]),
time=row["time"],
)

View File

@@ -72,11 +72,22 @@ class MintSettings(CashuSettings):
) )
class MintWatchdogSettings(MintSettings):
mint_watchdog_enabled: bool = Field(
default=False,
title="Balance watchdog",
description="The watchdog shuts down the mint if the balance of the mint and the backend do not match.",
)
mint_watchdog_balance_check_interval_seconds: float = Field(default=0.1)
mint_watchdog_ignore_mismatch: bool = Field(
default=False,
description="Ignore watchdog errors and continue running. Use this to recover from a watchdog error.",
)
class MintDeprecationFlags(MintSettings): class MintDeprecationFlags(MintSettings):
mint_inactivate_base64_keysets: bool = Field(default=False) mint_inactivate_base64_keysets: bool = Field(default=False)
auth_database: str = Field(default="data/mint")
class MintBackends(MintSettings): class MintBackends(MintSettings):
mint_lightning_backend: str = Field(default="") # deprecated mint_lightning_backend: str = Field(default="") # deprecated
@@ -153,6 +164,9 @@ class FakeWalletSettings(MintSettings):
fakewallet_payment_state_exception: Optional[bool] = Field(default=False) fakewallet_payment_state_exception: Optional[bool] = Field(default=False)
fakewallet_pay_invoice_state: Optional[str] = Field(default="SETTLED") fakewallet_pay_invoice_state: Optional[str] = Field(default="SETTLED")
fakewallet_pay_invoice_state_exception: Optional[bool] = Field(default=False) fakewallet_pay_invoice_state_exception: Optional[bool] = Field(default=False)
fakewallet_balance_sat: int = Field(default=1337)
fakewallet_balance_usd: int = Field(default=1337)
fakewallet_balance_eur: int = Field(default=1337)
class MintInformation(CashuSettings): class MintInformation(CashuSettings):
@@ -242,6 +256,7 @@ class CoreLightningRestFundingSource(MintSettings):
class AuthSettings(MintSettings): class AuthSettings(MintSettings):
mint_auth_database: str = Field(default="data/mint")
mint_require_auth: bool = Field(default=False) mint_require_auth: bool = Field(default=False)
mint_auth_oicd_discovery_url: Optional[str] = Field(default=None) mint_auth_oicd_discovery_url: Optional[str] = Field(default=None)
mint_auth_oicd_client_id: str = Field(default="cashu-client") mint_auth_oicd_client_id: str = Field(default="cashu-client")
@@ -280,6 +295,7 @@ class Settings(
AuthSettings, AuthSettings,
MintRedisCache, MintRedisCache,
MintDeprecationFlags, MintDeprecationFlags,
MintWatchdogSettings,
MintSettings, MintSettings,
MintInformation, MintInformation,
WalletSettings, WalletSettings,

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum, auto from enum import Enum, auto
from typing import AsyncGenerator, Coroutine, Optional, Union from typing import AsyncGenerator, Coroutine, Optional
from pydantic import BaseModel from pydantic import BaseModel
@@ -13,7 +13,7 @@ from ..core.models import PostMeltQuoteRequest
class StatusResponse(BaseModel): class StatusResponse(BaseModel):
balance: Union[int, float] balance: Amount
error_message: Optional[str] = None error_message: Optional[str] = None

View File

@@ -90,7 +90,7 @@ class BlinkWallet(LightningBackend):
logger.error(f"Blink API error: {exc}") logger.error(f"Blink API error: {exc}")
return StatusResponse( return StatusResponse(
error_message=f"Failed to connect to {self.endpoint} due to: {exc}", error_message=f"Failed to connect to {self.endpoint} due to: {exc}",
balance=0, balance=Amount(self.unit, 0),
) )
try: try:
@@ -100,7 +100,7 @@ class BlinkWallet(LightningBackend):
error_message=( error_message=(
f"Received invalid response from {self.endpoint}: {r.text}" f"Received invalid response from {self.endpoint}: {r.text}"
), ),
balance=0, balance=Amount(self.unit, 0),
) )
balance = 0 balance = 0
@@ -113,7 +113,7 @@ class BlinkWallet(LightningBackend):
self.wallet_ids[Unit.sat] = wallet_dict["id"] # type: ignore self.wallet_ids[Unit.sat] = wallet_dict["id"] # type: ignore
balance = wallet_dict["balance"] # type: ignore balance = wallet_dict["balance"] # type: ignore
return StatusResponse(error_message=None, balance=balance) return StatusResponse(error_message=None, balance=Amount(self.unit, balance))
async def create_invoice( async def create_invoice(
self, self,

View File

@@ -103,14 +103,14 @@ class CLNRestWallet(LightningBackend):
error_message=( error_message=(
f"Failed to connect to {self.url}, got: '{error_message}...'" f"Failed to connect to {self.url}, got: '{error_message}...'"
), ),
balance=0, balance=Amount(self.unit, 0),
) )
data = r.json() data = r.json()
if len(data) == 0: if len(data) == 0:
return StatusResponse(error_message="no data", balance=0) return StatusResponse(error_message="no data", balance=Amount(self.unit, 0))
balance_msat = int(sum([c["our_amount_msat"] for c in data["channels"]])) balance_msat = int(sum([c["our_amount_msat"] for c in data["channels"]]))
return StatusResponse(balance=balance_msat) return StatusResponse(balance=Amount(self.unit, balance_msat // 1000))
async def create_invoice( async def create_invoice(
self, self,
@@ -289,7 +289,15 @@ class CLNRestWallet(LightningBackend):
data = r.json() data = r.json()
if r.is_error or "message" in data: if r.is_error or "message" in data:
raise Exception("error in cln response") raise Exception("error in cln response")
self.last_pay_index = data["invoices"][-1]["pay_index"] last_invoice_paid_invoice = next(
(i for i in reversed(data["invoices"]) if i["status"] == "paid"), None
)
last_pay_index = (
last_invoice_paid_invoice.get("pay_index")
if last_invoice_paid_invoice
else 0
)
self.last_pay_index = last_pay_index
while True: while True:
try: try:
url = "/v1/waitanyinvoice" url = "/v1/waitanyinvoice"
@@ -308,9 +316,13 @@ class CLNRestWallet(LightningBackend):
raise Exception(inv["message"]) raise Exception(inv["message"])
try: try:
paid = inv["status"] == "paid" paid = inv["status"] == "paid"
self.last_pay_index = inv["pay_index"]
if not paid: if not paid:
continue continue
last_pay_index = inv.get("pay_index")
if not last_pay_index:
logger.error(f"missing pay_index in invoice: {inv}")
raise Exception("missing pay_index in invoice")
self.last_pay_index = last_pay_index
except Exception as e: except Exception as e:
logger.error(f"Error in paid_invoices_stream: {e}") logger.error(f"Error in paid_invoices_stream: {e}")
continue continue
@@ -332,8 +344,8 @@ class CLNRestWallet(LightningBackend):
invoice_obj = decode(melt_quote.request) invoice_obj = decode(melt_quote.request)
assert invoice_obj.amount_msat, "invoice has no amount." assert invoice_obj.amount_msat, "invoice has no amount."
assert invoice_obj.amount_msat > 0, "invoice has 0 amount." assert invoice_obj.amount_msat > 0, "invoice has 0 amount."
amount_msat = melt_quote.mpp_amount if melt_quote.is_mpp else ( amount_msat = (
invoice_obj.amount_msat melt_quote.mpp_amount if melt_quote.is_mpp else (invoice_obj.amount_msat)
) )
fees_msat = fee_reserve(amount_msat) fees_msat = fee_reserve(amount_msat)
fees = Amount(unit=Unit.msat, amount=fees_msat) fees = Amount(unit=Unit.msat, amount=fees_msat)

View File

@@ -96,14 +96,16 @@ class CoreLightningRestWallet(LightningBackend):
error_message=( error_message=(
f"Failed to connect to {self.url}, got: '{error_message}...'" f"Failed to connect to {self.url}, got: '{error_message}...'"
), ),
balance=0, balance=Amount(self.unit, 0),
) )
data = r.json() data = r.json()
if len(data) == 0: if len(data) == 0:
return StatusResponse(error_message="no data", balance=0) return StatusResponse(error_message="no data", balance=Amount(self.unit, 0))
balance_msat = int(sum([c["our_amount_msat"] for c in data["channels"]])) balance_msat = int(sum([c["our_amount_msat"] for c in data["channels"]]))
return StatusResponse(error_message=None, balance=balance_msat) return StatusResponse(
error_message=None, balance=Amount(self.unit, balance_msat // 1000)
)
async def create_invoice( async def create_invoice(
self, self,
@@ -271,9 +273,15 @@ class CoreLightningRestWallet(LightningBackend):
data = r.json() data = r.json()
if r.is_error or "error" in data: if r.is_error or "error" in data:
raise Exception("error in cln response") raise Exception("error in cln response")
if data.get("invoices"): last_invoice_paid_invoice = next(
self.last_pay_index = data["invoices"][-1]["pay_index"] (i for i in reversed(data["invoices"]) if i["status"] == "paid"), None
)
last_pay_index = (
last_invoice_paid_invoice.get("pay_index")
if last_invoice_paid_invoice
else 0
)
self.last_pay_index = last_pay_index
while True: while True:
try: try:
url = f"/v1/invoice/waitAnyInvoice/{self.last_pay_index}" url = f"/v1/invoice/waitAnyInvoice/{self.last_pay_index}"
@@ -285,9 +293,9 @@ class CoreLightningRestWallet(LightningBackend):
raise Exception(inv["error"]["message"]) raise Exception(inv["error"]["message"])
try: try:
paid = inv["status"] == "paid" paid = inv["status"] == "paid"
self.last_pay_index = inv["pay_index"]
if not paid: if not paid:
continue continue
self.last_pay_index = inv["pay_index"]
except Exception: except Exception:
continue continue
logger.trace(f"paid invoice: {inv}") logger.trace(f"paid invoice: {inv}")

View File

@@ -50,6 +50,12 @@ class FakeWallet(LightningBackend):
).hex() ).hex()
supported_units = {Unit.sat, Unit.msat, Unit.usd, Unit.eur} supported_units = {Unit.sat, Unit.msat, Unit.usd, Unit.eur}
balance: Dict[Unit, Amount] = {
Unit.sat: Amount(Unit.sat, settings.fakewallet_balance_sat),
Unit.msat: Amount(Unit.msat, settings.fakewallet_balance_sat * 1000),
Unit.usd: Amount(Unit.usd, settings.fakewallet_balance_usd),
Unit.eur: Amount(Unit.eur, settings.fakewallet_balance_eur),
}
supports_incoming_payment_stream: bool = True supports_incoming_payment_stream: bool = True
supports_description: bool = True supports_description: bool = True
@@ -59,7 +65,10 @@ class FakeWallet(LightningBackend):
self.unit = unit self.unit = unit
async def status(self) -> StatusResponse: async def status(self) -> StatusResponse:
return StatusResponse(error_message=None, balance=1337) return StatusResponse(
error_message=None,
balance=Amount(self.unit, self.balance[self.unit].amount),
)
async def mark_invoice_paid(self, invoice: Bolt11, delay=True) -> None: async def mark_invoice_paid(self, invoice: Bolt11, delay=True) -> None:
if invoice in self.paid_invoices_incoming: if invoice in self.paid_invoices_incoming:
@@ -70,6 +79,25 @@ class FakeWallet(LightningBackend):
await asyncio.sleep(settings.fakewallet_delay_incoming_payment) await asyncio.sleep(settings.fakewallet_delay_incoming_payment)
self.paid_invoices_incoming.append(invoice) self.paid_invoices_incoming.append(invoice)
await self.paid_invoices_queue.put(invoice) await self.paid_invoices_queue.put(invoice)
self.update_balance(invoice, incoming=True)
def update_balance(self, invoice: Bolt11, incoming: bool) -> None:
amount_bolt11 = invoice.amount_msat
assert amount_bolt11, "invoice has no amount."
amount = int(amount_bolt11)
if self.unit == Unit.sat:
amount = amount // 1000
elif self.unit == Unit.usd or self.unit == Unit.eur:
amount = math.ceil(amount / 1e9 * self.fake_btc_price)
elif self.unit == Unit.msat:
amount = amount
else:
raise NotImplementedError()
if incoming:
self.balance[self.unit] += Amount(self.unit, amount)
else:
self.balance[self.unit] -= Amount(self.unit, amount)
def create_dummy_bolt11(self, payment_hash: str) -> Bolt11: def create_dummy_bolt11(self, payment_hash: str) -> Bolt11:
tags = Tags() tags = Tags()
@@ -165,6 +193,8 @@ class FakeWallet(LightningBackend):
await asyncio.sleep(settings.fakewallet_delay_outgoing_payment) await asyncio.sleep(settings.fakewallet_delay_outgoing_payment)
if settings.fakewallet_pay_invoice_state: if settings.fakewallet_pay_invoice_state:
if settings.fakewallet_pay_invoice_state == "SETTLED":
self.update_balance(invoice, incoming=False)
return PaymentResponse( return PaymentResponse(
result=PaymentResult[settings.fakewallet_pay_invoice_state], result=PaymentResult[settings.fakewallet_pay_invoice_state],
checking_id=invoice.payment_hash, checking_id=invoice.payment_hash,
@@ -178,6 +208,7 @@ class FakeWallet(LightningBackend):
else: else:
raise ValueError("Invoice already paid") raise ValueError("Invoice already paid")
self.update_balance(invoice, incoming=False)
return PaymentResponse( return PaymentResponse(
result=PaymentResult.SETTLED, result=PaymentResult.SETTLED,
checking_id=invoice.payment_hash, checking_id=invoice.payment_hash,
@@ -191,9 +222,13 @@ class FakeWallet(LightningBackend):
) )
async def get_invoice_status(self, checking_id: str) -> PaymentStatus: async def get_invoice_status(self, checking_id: str) -> PaymentStatus:
await self.mark_invoice_paid(self.create_dummy_bolt11(checking_id), delay=False) invoice = next(
(i for i in self.created_invoices if i.payment_hash == checking_id), None
) or self.create_dummy_bolt11(checking_id)
paid_chceking_ids = [i.payment_hash for i in self.paid_invoices_incoming] paid_chceking_ids = [i.payment_hash for i in self.paid_invoices_incoming]
if checking_id in paid_chceking_ids: if checking_id in paid_chceking_ids or settings.fakewallet_brr:
await self.mark_invoice_paid(invoice, delay=False)
return PaymentStatus(result=PaymentResult.SETTLED) return PaymentStatus(result=PaymentResult.SETTLED)
else: else:
return PaymentStatus( return PaymentStatus(

View File

@@ -48,14 +48,17 @@ class LNbitsWallet(LightningBackend):
except Exception as exc: except Exception as exc:
return StatusResponse( return StatusResponse(
error_message=f"Failed to connect to {self.endpoint} due to: {exc}", error_message=f"Failed to connect to {self.endpoint} due to: {exc}",
balance=0, balance=Amount(self.unit, 0),
) )
if data.get("detail"): if data.get("detail"):
return StatusResponse( return StatusResponse(
error_message=f"LNbits error: {data['detail']}", balance=0 error_message=f"LNbits error: {data['detail']}",
balance=Amount(self.unit, 0),
) )
return StatusResponse(error_message=None, balance=data["balance"]) return StatusResponse(
error_message=None, balance=Amount(Unit.sat, data["balance"] // 1000)
)
async def create_invoice( async def create_invoice(
self, self,

View File

@@ -103,10 +103,10 @@ class LndRPCWallet(LightningBackend):
r = await lnstub.ChannelBalance(lnrpc.ChannelBalanceRequest()) r = await lnstub.ChannelBalance(lnrpc.ChannelBalanceRequest())
except AioRpcError as e: except AioRpcError as e:
return StatusResponse( return StatusResponse(
error_message=f"Error calling Lnd gRPC: {e}", balance=0 error_message=f"Error calling Lnd gRPC: {e}",
balance=Amount(self.unit, 0),
) )
# NOTE: `balance` field is deprecated. Change this. return StatusResponse(error_message=None, balance=Amount(self.unit, r.balance))
return StatusResponse(error_message=None, balance=r.balance * 1000)
async def create_invoice( async def create_invoice(
self, self,

View File

@@ -112,7 +112,7 @@ class LndRestWallet(LightningBackend):
except (httpx.ConnectError, httpx.RequestError) as exc: except (httpx.ConnectError, httpx.RequestError) as exc:
return StatusResponse( return StatusResponse(
error_message=f"Unable to connect to {self.endpoint}. {exc}", error_message=f"Unable to connect to {self.endpoint}. {exc}",
balance=0, balance=Amount(self.unit, 0),
) )
try: try:
@@ -120,9 +120,13 @@ class LndRestWallet(LightningBackend):
if r.is_error: if r.is_error:
raise Exception raise Exception
except Exception: except Exception:
return StatusResponse(error_message=r.text[:200], balance=0) return StatusResponse(
error_message=r.text[:200], balance=Amount(self.unit, 0)
)
return StatusResponse(error_message=None, balance=int(data["balance"]) * 1000) return StatusResponse(
error_message=None, balance=Amount(self.unit, int(data["balance"]))
)
async def create_invoice( async def create_invoice(
self, self,

View File

@@ -128,7 +128,7 @@ class StrikeWallet(LightningBackend):
except Exception as exc: except Exception as exc:
return StatusResponse( return StatusResponse(
error_message=f"Failed to connect to {self.endpoint} due to: {exc}", error_message=f"Failed to connect to {self.endpoint} due to: {exc}",
balance=0, balance=Amount(self.unit, 0),
) )
try: try:
@@ -138,16 +138,14 @@ class StrikeWallet(LightningBackend):
error_message=( error_message=(
f"Failed to connect to {self.endpoint}, got: '{r.text[:200]}...'" f"Failed to connect to {self.endpoint}, got: '{r.text[:200]}...'"
), ),
balance=0, balance=Amount(self.unit, 0),
) )
for balance in data: for balance in data:
if balance["currency"] == self.currency: if balance["currency"] == self.currency:
return StatusResponse( return StatusResponse(
error_message=None, error_message=None,
balance=Amount.from_float( balance=Amount.from_float(float(balance["total"]), self.unit),
float(balance["total"]), self.unit
).amount,
) )
# if no the unit is USD but no USD balance was found, we try USDT # if no the unit is USD but no USD balance was found, we try USDT
@@ -157,14 +155,12 @@ class StrikeWallet(LightningBackend):
self.currency = USDT self.currency = USDT
return StatusResponse( return StatusResponse(
error_message=None, error_message=None,
balance=Amount.from_float( balance=Amount.from_float(float(balance["total"]), self.unit),
float(balance["total"]), self.unit
).amount,
) )
return StatusResponse( return StatusResponse(
error_message=f"Could not find balance for currency {self.currency}", error_message=f"Could not find balance for currency {self.currency}",
balance=0, balance=Amount(self.unit, 0),
) )
async def create_invoice( async def create_invoice(

View File

@@ -98,3 +98,19 @@ async def m001_initial(db: Database):
); );
""" """
) )
async def m002_add_balance_to_keysets_and_log_table(db: Database):
async with db.connect() as conn:
await conn.execute(
f"""
ALTER TABLE {db.table_with_schema('keysets')}
ADD COLUMN balance INTEGER NOT NULL DEFAULT 0
"""
)
await conn.execute(
f"""
ALTER TABLE {db.table_with_schema('keysets')}
ADD COLUMN fees_paid INTEGER NOT NULL DEFAULT 0
"""
)

View File

@@ -62,6 +62,9 @@ class AuthLedger(Ledger):
logger.info(f"Initialized OpenID Connect: {self.issuer}") logger.info(f"Initialized OpenID Connect: {self.issuer}")
def _get_oicd_discovery_json(self) -> dict: def _get_oicd_discovery_json(self) -> dict:
logger.debug(
f"Getting OpenID Connect discovery JSON from: {self.oicd_discovery_url}"
)
resp = httpx.get(self.oicd_discovery_url) resp = httpx.get(self.oicd_discovery_url)
resp.raise_for_status() resp.raise_for_status()
return resp.json() return resp.json()
@@ -220,7 +223,9 @@ class AuthLedger(Ledger):
try: try:
proof = AuthProof.from_base64(blind_auth_token).to_proof() proof = AuthProof.from_base64(blind_auth_token).to_proof()
await self.verify_inputs_and_outputs(proofs=[proof]) await self.verify_inputs_and_outputs(proofs=[proof])
await self.db_write._verify_spent_proofs_and_set_pending([proof]) await self.db_write._verify_spent_proofs_and_set_pending(
[proof], self.keysets
)
except Exception as e: except Exception as e:
logger.error(f"Blind auth error: {e}") logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError() raise BlindAuthFailedError()
@@ -232,4 +237,4 @@ class AuthLedger(Ledger):
logger.error(f"Blind auth error: {e}") logger.error(f"Blind auth error: {e}")
raise BlindAuthFailedError() raise BlindAuthFailedError()
finally: finally:
await self.db_write._unset_proofs_pending([proof]) await self.db_write._unset_proofs_pending([proof], self.keysets)

View File

@@ -1,15 +1,18 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
from loguru import logger from loguru import logger
from ..core.base import ( from ..core.base import (
Amount,
BlindedSignature, BlindedSignature,
MeltQuote, MeltQuote,
MintBalanceLogEntry,
MintKeyset, MintKeyset,
MintQuote, MintQuote,
Proof, Proof,
Unit,
) )
from ..core.db import ( from ..core.db import (
Connection, Connection,
@@ -31,6 +34,7 @@ class LedgerCrud(ABC):
*, *,
db: Database, db: Database,
id: str = "", id: str = "",
unit: str = "",
derivation_path: str = "", derivation_path: str = "",
seed: str = "", seed: str = "",
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
@@ -118,13 +122,33 @@ class LedgerCrud(ABC):
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None: ...
@abstractmethod
async def bump_keyset_balance(
self,
*,
db: Database,
keyset: MintKeyset,
amount: int,
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod
async def bump_keyset_fees_paid(
self,
*,
db: Database,
keyset: MintKeyset,
amount: int,
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod @abstractmethod
async def get_balance( async def get_balance(
self, self,
keyset: MintKeyset, keyset: MintKeyset,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> int: ... ) -> Tuple[Amount, Amount]: ...
@abstractmethod @abstractmethod
async def store_promise( async def store_promise(
@@ -234,6 +258,25 @@ class LedgerCrud(ABC):
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ... ) -> None: ...
@abstractmethod
async def store_balance_log(
self,
backend_balance: Amount,
keyset_balance: Amount,
keyset_fees_paid: Amount,
db: Database,
conn: Optional[Connection] = None,
) -> None: ...
@abstractmethod
async def get_last_balance_log_entry(
self,
*,
unit: Unit,
db: Database,
conn: Optional[Connection] = None,
) -> MintBalanceLogEntry | None: ...
class LedgerCrudSqlite(LedgerCrud): class LedgerCrudSqlite(LedgerCrud):
"""Implementation of LedgerCrud for sqlite. """Implementation of LedgerCrud for sqlite.
@@ -645,8 +688,8 @@ class LedgerCrudSqlite(LedgerCrud):
await (conn or db).execute( await (conn or db).execute(
f""" f"""
INSERT INTO {db.table_with_schema('keysets')} INSERT INTO {db.table_with_schema('keysets')}
(id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk, amounts) (id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk, amounts, balance)
VALUES (:id, :seed, :encrypted_seed, :seed_encryption_method, :derivation_path, :valid_from, :valid_to, :first_seen, :active, :version, :unit, :input_fee_ppk, :amounts) VALUES (:id, :seed, :encrypted_seed, :seed_encryption_method, :derivation_path, :valid_from, :valid_to, :first_seen, :active, :version, :unit, :input_fee_ppk, :amounts, :balance)
""", """,
{ {
"id": keyset.id, "id": keyset.id,
@@ -666,31 +709,66 @@ class LedgerCrudSqlite(LedgerCrud):
"unit": keyset.unit.name, "unit": keyset.unit.name,
"input_fee_ppk": keyset.input_fee_ppk, "input_fee_ppk": keyset.input_fee_ppk,
"amounts": json.dumps(keyset.amounts), "amounts": json.dumps(keyset.amounts),
"balance": keyset.balance,
}, },
) )
async def bump_keyset_balance(
self,
*,
db: Database,
keyset: MintKeyset,
amount: int,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
UPDATE {db.table_with_schema('keysets')}
SET balance = balance + :amount
WHERE id = :id
""",
{"amount": amount, "id": keyset.id},
)
async def bump_keyset_fees_paid(
self,
*,
db: Database,
keyset: MintKeyset,
amount: int,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
f"""
UPDATE {db.table_with_schema('keysets')}
SET fees_paid = fees_paid + :amount
WHERE id = :id
""",
{"amount": amount, "id": keyset.id},
)
async def get_balance( async def get_balance(
self, self,
keyset: MintKeyset, keyset: MintKeyset,
db: Database, db: Database,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> int: ) -> Tuple[Amount, Amount]:
row = await (conn or db).fetchone( row = await (conn or db).fetchone(
f""" f"""
SELECT balance FROM {db.table_with_schema('balance')} SELECT balance, fees_paid FROM {db.table_with_schema('keysets')}
WHERE keyset = :keyset WHERE id = :id
""", """,
{ {
"keyset": keyset.id, "id": keyset.id,
}, },
) )
if row is None: if row is None:
return 0 return Amount(keyset.unit, 0), Amount(keyset.unit, 0)
# sqlalchemy index of first element return Amount(keyset.unit, int(row["balance"])), Amount(
key = next(iter(row)) keyset.unit, int(row["fees_paid"])
return int(row[key]) )
async def get_keyset( async def get_keyset(
self, self,
@@ -764,6 +842,7 @@ class LedgerCrudSqlite(LedgerCrud):
"version": keyset.version, "version": keyset.version,
"unit": keyset.unit.name, "unit": keyset.unit.name,
"input_fee_ppk": keyset.input_fee_ppk, "input_fee_ppk": keyset.input_fee_ppk,
"balance": keyset.balance,
}, },
) )
@@ -781,3 +860,48 @@ class LedgerCrudSqlite(LedgerCrud):
values = {f"y_{i}": Ys[i] for i in range(len(Ys))} values = {f"y_{i}": Ys[i] for i in range(len(Ys))}
rows = await (conn or db).fetchall(query, values) rows = await (conn or db).fetchall(query, values)
return [Proof(**r) for r in rows] if rows else [] return [Proof(**r) for r in rows] if rows else []
async def store_balance_log(
self,
backend_balance: Amount,
keyset_balance: Amount,
keyset_fees_paid: Amount,
db: Database,
conn: Optional[Connection] = None,
):
if backend_balance.unit != keyset_balance.unit:
raise ValueError("Units do not match")
await (conn or db).execute(
f"""
INSERT INTO {db.table_with_schema('balance_log')}
(unit, backend_balance, keyset_balance, keyset_fees_paid, time)
VALUES (:unit, :backend_balance, :keyset_balance, :keyset_fees_paid, :time)
""",
{
"unit": backend_balance.unit.name,
"backend_balance": backend_balance.amount,
"keyset_balance": keyset_balance.amount,
"keyset_fees_paid": keyset_fees_paid.amount,
"time": db.to_timestamp(db.timestamp_now_str()),
},
)
async def get_last_balance_log_entry(
self,
*,
unit: Unit,
db: Database,
conn: Optional[Connection] = None,
) -> MintBalanceLogEntry | None:
row = await (conn or db).fetchone(
f"""
SELECT * from {db.table_with_schema('balance_log')}
WHERE unit = :unit
ORDER BY time DESC
LIMIT 1
""",
{"unit": unit.name},
)
return MintBalanceLogEntry.from_row(row) if row else None

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Union from typing import Dict, List, Optional, Union
from loguru import logger from loguru import logger
@@ -6,6 +6,7 @@ from ...core.base import (
BlindedMessage, BlindedMessage,
MeltQuote, MeltQuote,
MeltQuoteState, MeltQuoteState,
MintKeyset,
MintQuote, MintQuote,
MintQuoteState, MintQuoteState,
Proof, Proof,
@@ -40,13 +41,17 @@ class DbWriteHelper:
self.db_read = db_read self.db_read = db_read
async def _verify_spent_proofs_and_set_pending( async def _verify_spent_proofs_and_set_pending(
self, proofs: List[Proof], quote_id: Optional[str] = None self,
proofs: List[Proof],
keysets: Dict[str, MintKeyset],
quote_id: Optional[str] = None,
) -> None: ) -> None:
""" """
Method to check if proofs are already spent. If they are not spent, we check if they are pending. Method to check if proofs are already spent. If they are not spent, we check if they are pending.
If they are not pending, we set them as pending. If they are not pending, we set them as pending.
Args: Args:
proofs (List[Proof]): Proofs to add to pending table. proofs (List[Proof]): Proofs to add to pending table.
keysets (Dict[str, MintKeyset]): Keysets of the mint (needed to update keyset balances)
quote_id (Optional[str]): Melt quote ID. If it is not set, we assume the pending tokens to be from a swap. quote_id (Optional[str]): Melt quote ID. If it is not set, we assume the pending tokens to be from a swap.
Raises: Raises:
TransactionError: If any one of the proofs is already spent or pending. TransactionError: If any one of the proofs is already spent or pending.
@@ -67,6 +72,12 @@ class DbWriteHelper:
await self.crud.set_proof_pending( await self.crud.set_proof_pending(
proof=p, db=self.db, quote_id=quote_id, conn=conn proof=p, db=self.db, quote_id=quote_id, conn=conn
) )
await self.crud.bump_keyset_balance(
db=self.db,
keyset=keysets[p.id],
amount=-p.amount,
conn=conn,
)
logger.trace(f"crud: set proof {p.Y} as PENDING") logger.trace(f"crud: set proof {p.Y} as PENDING")
logger.trace("_verify_spent_proofs_and_set_pending released lock") logger.trace("_verify_spent_proofs_and_set_pending released lock")
except Exception as e: except Exception as e:
@@ -75,20 +86,34 @@ class DbWriteHelper:
for p in proofs: for p in proofs:
await self.events.submit(ProofState(Y=p.Y, state=ProofSpentState.pending)) await self.events.submit(ProofState(Y=p.Y, state=ProofSpentState.pending))
async def _unset_proofs_pending(self, proofs: List[Proof], spent=True) -> None: async def _unset_proofs_pending(
self,
proofs: List[Proof],
keysets: Dict[str, MintKeyset],
spent=True,
conn: Optional[Connection] = None,
) -> None:
"""Deletes proofs from pending table. """Deletes proofs from pending table.
Args: Args:
proofs (List[Proof]): Proofs to delete. proofs (List[Proof]): Proofs to delete.
keysets (Dict[str, MintKeyset]): Keysets of the mint (needed to update keyset balances)
spent (bool): Whether the proofs have been spent or not. Defaults to True. spent (bool): Whether the proofs have been spent or not. Defaults to True.
This should be False if the proofs were NOT invalidated before calling this function. This should be False if the proofs were NOT invalidated before calling this function.
It is used to emit the unspent state for the proofs (otherwise the spent state is emitted It is used to emit the unspent state for the proofs (otherwise the spent state is emitted
by the _invalidate_proofs function when the proofs are spent). by the _invalidate_proofs function when the proofs are spent).
conn (Optional[Connection]): Connection to use. If not set, a new connection will be created.
""" """
async with self.db.get_connection() as conn: async with self.db.get_connection(conn) as conn:
for p in proofs: for p in proofs:
logger.trace(f"crud: un-setting proof {p.Y} as PENDING") logger.trace(f"crud: un-setting proof {p.Y} as PENDING")
await self.crud.unset_proof_pending(proof=p, db=self.db, conn=conn) await self.crud.unset_proof_pending(proof=p, db=self.db, conn=conn)
await self.crud.bump_keyset_balance(
db=self.db,
keyset=keysets[p.id],
amount=p.amount,
conn=conn,
)
if not spent: if not spent:
for p in proofs: for p in proofs:

View File

@@ -11,7 +11,6 @@ from .protocols import SupportsDb, SupportsKeysets, SupportsSeed
class LedgerKeysets(SupportsKeysets, SupportsSeed, SupportsDb): class LedgerKeysets(SupportsKeysets, SupportsSeed, SupportsDb):
# ------- KEYS ------- # ------- KEYS -------
def maybe_update_derivation_path(self, derivation_path: str) -> str: def maybe_update_derivation_path(self, derivation_path: str) -> str:
@@ -20,12 +19,14 @@ class LedgerKeysets(SupportsKeysets, SupportsSeed, SupportsDb):
upon initialization. The superseding derivation must have a greater count (last portion of the derivation path). upon initialization. The superseding derivation must have a greater count (last portion of the derivation path).
If this condition is true, update `self.derivation_path` to match the highest count derivation. If this condition is true, update `self.derivation_path` to match the highest count derivation.
""" """
derivation: List[str] = derivation_path.split("/") # type: ignore derivation: List[str] = derivation_path.split("/") # type: ignore
counter = int(derivation[-1].replace("'", "")) counter = int(derivation[-1].replace("'", ""))
for keyset in self.keysets.values(): for keyset in self.keysets.values():
if keyset.active: if keyset.active:
keyset_derivation_path = keyset.derivation_path.split("/") keyset_derivation_path = keyset.derivation_path.split("/")
keyset_derivation_counter = int(keyset_derivation_path[-1].replace("'", "")) keyset_derivation_counter = int(
keyset_derivation_path[-1].replace("'", "")
)
if ( if (
keyset_derivation_path[:-1] == derivation[:-1] keyset_derivation_path[:-1] == derivation[:-1]
and keyset_derivation_counter > counter and keyset_derivation_counter > counter
@@ -34,10 +35,7 @@ class LedgerKeysets(SupportsKeysets, SupportsSeed, SupportsDb):
return derivation_path return derivation_path
async def rotate_next_keyset( async def rotate_next_keyset(
self, self, unit: Unit, max_order: Optional[int], input_fee_ppk: Optional[int]
unit: Unit,
max_order: Optional[int],
input_fee_ppk: Optional[int]
) -> MintKeyset: ) -> MintKeyset:
""" """
This function: This function:
@@ -46,7 +44,7 @@ class LedgerKeysets(SupportsKeysets, SupportsSeed, SupportsDb):
3. creates a new active keyset for the new derivation path 3. creates a new active keyset for the new derivation path
4. de-activates the old keyset 4. de-activates the old keyset
5. stores the new keyset to DB 5. stores the new keyset to DB
Args: Args:
unit (Unit): Unit of the keyset. unit (Unit): Unit of the keyset.
max_order (Optional[int], optional): The number of keys to generate, which correspond to powers of 2. max_order (Optional[int], optional): The number of keys to generate, which correspond to powers of 2.
@@ -63,21 +61,29 @@ class LedgerKeysets(SupportsKeysets, SupportsSeed, SupportsDb):
for keyset in self.keysets.values(): for keyset in self.keysets.values():
if keyset.active and keyset.unit == unit: if keyset.active and keyset.unit == unit:
keyset_derivation_path = keyset.derivation_path.split("/") keyset_derivation_path = keyset.derivation_path.split("/")
keyset_derivation_counter = int(keyset_derivation_path[-1].replace("'", "")) keyset_derivation_counter = int(
keyset_derivation_path[-1].replace("'", "")
)
if keyset_derivation_counter > selected_keyset_counter: if keyset_derivation_counter > selected_keyset_counter:
selected_keyset = keyset selected_keyset = keyset
# If no selected keyset, then there is no keyset for this unit # If no selected keyset, then there is no keyset for this unit
if not selected_keyset: if not selected_keyset:
logger.error(f"Couldn't find suitable keyset for rotation with unit {str(unit)}") logger.error(
raise Exception(f"Couldn't find suitable keyset for rotation with unit {str(unit)}") f"Couldn't find suitable keyset for rotation with unit {str(unit)}"
)
raise Exception(
f"Couldn't find suitable keyset for rotation with unit {str(unit)}"
)
logger.info(f"Rotating keyset {selected_keyset.id}") logger.info(f"Rotating keyset {selected_keyset.id}")
# New derivation path is just old derivation path with increased counter # New derivation path is just old derivation path with increased counter
new_derivation_path = selected_keyset.derivation_path.split("/") new_derivation_path = selected_keyset.derivation_path.split("/")
new_derivation_path[-1] = str(int(new_derivation_path[-1].replace("'", "")) + 1) + "'" new_derivation_path[-1] = (
str(int(new_derivation_path[-1].replace("'", "")) + 1) + "'"
)
# keys amounts for this keyset: if amounts is None we use `self.amounts` # keys amounts for this keyset: if amounts is None we use `self.amounts`
amounts = [2**i for i in range(max_order)] if max_order else self.amounts amounts = [2**i for i in range(max_order)] if max_order else self.amounts
@@ -86,7 +92,7 @@ class LedgerKeysets(SupportsKeysets, SupportsSeed, SupportsDb):
derivation_path="/".join(new_derivation_path), derivation_path="/".join(new_derivation_path),
seed=self.seed, seed=self.seed,
amounts=amounts, amounts=amounts,
input_fee_ppk=input_fee_ppk input_fee_ppk=input_fee_ppk,
) )
logger.debug(f"New keyset was generated with Id {new_keyset.id}. Saving...") logger.debug(f"New keyset was generated with Id {new_keyset.id}. Saving...")
@@ -191,7 +197,7 @@ class LedgerKeysets(SupportsKeysets, SupportsSeed, SupportsDb):
# Check if any of the loaded keysets marked as active # Check if any of the loaded keysets marked as active
# do supersede the one specified in the derivation settings. # do supersede the one specified in the derivation settings.
# If this is the case update to latest count derivation. # If this is the case update to latest count derivation.
self.derivation_path = self.maybe_update_derivation_path(self.derivation_path) # type: ignore self.derivation_path = self.maybe_update_derivation_path(self.derivation_path) # type: ignore
# activate the current keyset set by self.derivation_path # activate the current keyset set by self.derivation_path
# and self.derivation_path is not superseded by any other # and self.derivation_path is not superseded by any other
@@ -248,4 +254,4 @@ class LedgerKeysets(SupportsKeysets, SupportsSeed, SupportsDb):
keyset = self.keysets[keyset_id] if keyset_id else self.keyset keyset = self.keysets[keyset_id] if keyset_id else self.keyset
if not keyset.public_keys: if not keyset.public_keys:
raise KeysetError("no public keys for this keyset") raise KeysetError("no public keys for this keyset")
return {a: p.serialize().hex() for a, p in keyset.public_keys.items()} return {a: p.serialize().hex() for a, p in keyset.public_keys.items()}

View File

@@ -64,6 +64,7 @@ from .features import LedgerFeatures
from .keysets import LedgerKeysets from .keysets import LedgerKeysets
from .tasks import LedgerTasks from .tasks import LedgerTasks
from .verification import LedgerVerification from .verification import LedgerVerification
from .watchdog import LedgerWatchdog
class Ledger( class Ledger(
@@ -71,13 +72,17 @@ class Ledger(
LedgerSpendingConditions, LedgerSpendingConditions,
LedgerTasks, LedgerTasks,
LedgerFeatures, LedgerFeatures,
LedgerWatchdog,
LedgerKeysets, LedgerKeysets,
): ):
backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {} backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
keysets: Dict[str, MintKeyset] = {} keysets: Dict[str, MintKeyset] = {}
events = LedgerEventManager() events = LedgerEventManager()
db: Database
db_read: DbReadHelper db_read: DbReadHelper
db_write: DbWriteHelper
invoice_listener_tasks: List[asyncio.Task] = [] invoice_listener_tasks: List[asyncio.Task] = []
watchdog_tasks: List[asyncio.Task] = []
disable_melt: bool = False disable_melt: bool = False
pubkey: PublicKey pubkey: PublicKey
@@ -98,6 +103,7 @@ class Ledger(
self.db_read: DbReadHelper self.db_read: DbReadHelper
self.locks: Dict[str, asyncio.Lock] = {} # holds multiprocessing locks self.locks: Dict[str, asyncio.Lock] = {} # holds multiprocessing locks
self.invoice_listener_tasks: List[asyncio.Task] = [] self.invoice_listener_tasks: List[asyncio.Task] = []
self.watchdog_tasks: List[asyncio.Task] = []
self.regular_tasks: List[asyncio.Task] = [] self.regular_tasks: List[asyncio.Task] = []
if not seed: if not seed:
@@ -131,6 +137,8 @@ class Ledger(
self.db_read = DbReadHelper(self.db, self.crud) self.db_read = DbReadHelper(self.db, self.crud)
self.db_write = DbWriteHelper(self.db, self.crud, self.events, self.db_read) self.db_write = DbWriteHelper(self.db, self.crud, self.events, self.db_read)
LedgerWatchdog.__init__(self)
# ------- STARTUP ------- # ------- STARTUP -------
async def startup_ledger(self) -> None: async def startup_ledger(self) -> None:
@@ -138,6 +146,8 @@ class Ledger(
await self._check_backends() await self._check_backends()
self.regular_tasks.append(asyncio.create_task(self._run_regular_tasks())) self.regular_tasks.append(asyncio.create_task(self._run_regular_tasks()))
self.invoice_listener_tasks = await self.dispatch_listeners() self.invoice_listener_tasks = await self.dispatch_listeners()
if settings.mint_watchdog_enabled:
self.watchdog_tasks = await self.dispatch_watchdogs()
async def _startup_keysets(self) -> None: async def _startup_keysets(self) -> None:
await self.init_keysets() await self.init_keysets()
@@ -168,7 +178,7 @@ class Ledger(
f" working properly: '{status.error_message}'" f" working properly: '{status.error_message}'"
) )
exit(1) exit(1)
logger.info(f"Backend balance: {status.balance} {unit.name}") logger.info(f"Backend balance: {status.balance}")
logger.info(f"Data dir: {settings.cashu_dir}") logger.info(f"Data dir: {settings.cashu_dir}")
@@ -178,6 +188,8 @@ class Ledger(
logger.debug("Shutting down invoice listeners") logger.debug("Shutting down invoice listeners")
for task in self.invoice_listener_tasks: for task in self.invoice_listener_tasks:
task.cancel() task.cancel()
for task in self.watchdog_tasks:
task.cancel()
logger.debug("Shutting down regular tasks") logger.debug("Shutting down regular tasks")
for task in self.regular_tasks: for task in self.regular_tasks:
task.cancel() task.cancel()
@@ -197,10 +209,6 @@ class Ledger(
quote = await self.get_melt_quote(quote_id=quote.quote) quote = await self.get_melt_quote(quote_id=quote.quote)
logger.info(f"Melt quote {quote.quote} state: {quote.state}") logger.info(f"Melt quote {quote.quote} state: {quote.state}")
async def get_balance(self, keyset: MintKeyset) -> int:
"""Returns the balance of the mint."""
return await self.crud.get_balance(keyset=keyset, db=self.db)
# ------- ECASH ------- # ------- ECASH -------
async def _invalidate_proofs( async def _invalidate_proofs(
@@ -216,6 +224,8 @@ class Ledger(
proofs (List[Proof]): Proofs to add to known secret table. proofs (List[Proof]): Proofs to add to known secret table.
conn: (Optional[Connection], optional): Database connection to reuse. Will create a new one if not given. Defaults to None. conn: (Optional[Connection], optional): Database connection to reuse. Will create a new one if not given. Defaults to None.
""" """
# sum_proofs = sum([p.amount for p in proofs])
fees_proofs = self.get_fees_for_proofs(proofs)
async with self.db.get_connection(conn) as conn: async with self.db.get_connection(conn) as conn:
# store in db # store in db
for p in proofs: for p in proofs:
@@ -223,11 +233,17 @@ class Ledger(
await self.crud.invalidate_proof( await self.crud.invalidate_proof(
proof=p, db=self.db, quote_id=quote_id, conn=conn proof=p, db=self.db, quote_id=quote_id, conn=conn
) )
await self.crud.bump_keyset_balance(
keyset=self.keysets[p.id], amount=-p.amount, db=self.db, conn=conn
)
await self.events.submit( await self.events.submit(
ProofState( ProofState(
Y=p.Y, state=ProofSpentState.spent, witness=p.witness or None Y=p.Y, state=ProofSpentState.spent, witness=p.witness or None
) )
) )
await self.crud.bump_keyset_fees_paid(
keyset=self.keyset, amount=fees_proofs, db=self.db, conn=conn
)
async def _generate_change_promises( async def _generate_change_promises(
self, self,
@@ -326,13 +342,10 @@ class Ledger(
): ):
raise NotAllowedError("Backend does not support descriptions.") raise NotAllowedError("Backend does not support descriptions.")
# MINT_MAX_BALANCE refers to sat (for now) # Check maximum balance.
if settings.mint_max_balance and unit == Unit.sat: # TODO: Allow setting MINT_MAX_BALANCE per unit
# get next active keyset for unit if settings.mint_max_balance:
active_keyset: MintKeyset = next( balance, fees_paid = await self.get_unit_balance_and_fees(unit, db=self.db)
filter(lambda k: k.active and k.unit == unit, self.keysets.values())
)
balance = await self.get_balance(active_keyset)
if balance + quote_request.amount > settings.mint_max_balance: if balance + quote_request.amount > settings.mint_max_balance:
raise NotAllowedError("Mint has reached maximum balance.") raise NotAllowedError("Mint has reached maximum balance.")
@@ -545,7 +558,9 @@ class Ledger(
melt_quote.is_mpp melt_quote.is_mpp
and melt_quote.mpp_amount != payment_quote.amount.to(Unit.msat).amount and melt_quote.mpp_amount != payment_quote.amount.to(Unit.msat).amount
): ):
logger.error(f"expected {payment_quote.amount.to(Unit.msat).amount} msat but got {melt_quote.mpp_amount}") logger.error(
f"expected {payment_quote.amount.to(Unit.msat).amount} msat but got {melt_quote.mpp_amount}"
)
raise TransactionError("quote amount not as requested") raise TransactionError("quote amount not as requested")
# make sure the backend returned the amount with a correct unit # make sure the backend returned the amount with a correct unit
if not payment_quote.amount.unit == unit: if not payment_quote.amount.unit == unit:
@@ -697,8 +712,13 @@ class Ledger(
pending_proofs = await self.crud.get_pending_proofs_for_quote( pending_proofs = await self.crud.get_pending_proofs_for_quote(
quote_id=quote_id, db=self.db quote_id=quote_id, db=self.db
) )
await self._invalidate_proofs(proofs=pending_proofs, quote_id=quote_id) async with self.db.get_connection() as conn:
await self.db_write._unset_proofs_pending(pending_proofs) await self._invalidate_proofs(
proofs=pending_proofs, quote_id=quote_id, conn=conn
)
await self.db_write._unset_proofs_pending(
pending_proofs, keysets=self.keysets, conn=conn
)
# change to compensate wallet for overpaid fees # change to compensate wallet for overpaid fees
if melt_quote.outputs: if melt_quote.outputs:
total_provided = sum_proofs(pending_proofs) total_provided = sum_proofs(pending_proofs)
@@ -723,7 +743,9 @@ class Ledger(
pending_proofs = await self.crud.get_pending_proofs_for_quote( pending_proofs = await self.crud.get_pending_proofs_for_quote(
quote_id=quote_id, db=self.db quote_id=quote_id, db=self.db
) )
await self.db_write._unset_proofs_pending(pending_proofs) await self.db_write._unset_proofs_pending(
pending_proofs, keysets=self.keysets
)
return melt_quote return melt_quote
@@ -821,7 +843,7 @@ class Ledger(
e: Lightning payment unsuccessful e: Lightning payment unsuccessful
Returns: Returns:
Tuple[str, List[BlindedMessage]]: Proof of payment and signed outputs for returning overpaid fees to wallet. PostMeltQuoteResponse: Melt quote response.
""" """
# make sure we're allowed to melt # make sure we're allowed to melt
if self.disable_melt and settings.mint_disable_melt_on_error: if self.disable_melt and settings.mint_disable_melt_on_error:
@@ -880,7 +902,7 @@ class Ledger(
# set proofs to pending to avoid race conditions # set proofs to pending to avoid race conditions
await self.db_write._verify_spent_proofs_and_set_pending( await self.db_write._verify_spent_proofs_and_set_pending(
proofs, quote_id=melt_quote.quote proofs, keysets=self.keysets, quote_id=melt_quote.quote
) )
previous_state = melt_quote.state previous_state = melt_quote.state
melt_quote = await self.db_write._set_melt_quote_pending(melt_quote, outputs) melt_quote = await self.db_write._set_melt_quote_pending(melt_quote, outputs)
@@ -936,7 +958,9 @@ class Ledger(
match status.result: match status.result:
case PaymentResult.FAILED | PaymentResult.UNKNOWN: case PaymentResult.FAILED | PaymentResult.UNKNOWN:
# Everything as expected. Payment AND a status check both agree on a failure. We roll back the transaction. # Everything as expected. Payment AND a status check both agree on a failure. We roll back the transaction.
await self.db_write._unset_proofs_pending(proofs) await self.db_write._unset_proofs_pending(
proofs, keysets=self.keysets
)
await self.db_write._unset_melt_quote_pending( await self.db_write._unset_melt_quote_pending(
quote=melt_quote, state=previous_state quote=melt_quote, state=previous_state
) )
@@ -976,7 +1000,7 @@ class Ledger(
# melt was successful (either internal or via backend), invalidate proofs # melt was successful (either internal or via backend), invalidate proofs
await self._invalidate_proofs(proofs=proofs, quote_id=melt_quote.quote) await self._invalidate_proofs(proofs=proofs, quote_id=melt_quote.quote)
await self.db_write._unset_proofs_pending(proofs) await self.db_write._unset_proofs_pending(proofs, keysets=self.keysets)
# prepare change to compensate wallet for overpaid fees # prepare change to compensate wallet for overpaid fees
return_promises: List[BlindedSignature] = [] return_promises: List[BlindedSignature] = []
@@ -1019,7 +1043,9 @@ class Ledger(
logger.trace("swap called") logger.trace("swap called")
# verify spending inputs, outputs, and spending conditions # verify spending inputs, outputs, and spending conditions
await self.verify_inputs_and_outputs(proofs=proofs, outputs=outputs) await self.verify_inputs_and_outputs(proofs=proofs, outputs=outputs)
await self.db_write._verify_spent_proofs_and_set_pending(proofs) await self.db_write._verify_spent_proofs_and_set_pending(
proofs, keysets=self.keysets
)
try: try:
async with self.db.get_connection(lock_table="proofs_pending") as conn: async with self.db.get_connection(lock_table="proofs_pending") as conn:
await self._invalidate_proofs(proofs=proofs, conn=conn) await self._invalidate_proofs(proofs=proofs, conn=conn)
@@ -1029,7 +1055,7 @@ class Ledger(
raise e raise e
finally: finally:
# delete proofs from pending list # delete proofs from pending list
await self.db_write._unset_proofs_pending(proofs) await self.db_write._unset_proofs_pending(proofs, keysets=self.keysets)
logger.trace("swap successful") logger.trace("swap successful")
return promises return promises
@@ -1117,4 +1143,10 @@ class Ledger(
dleq=DLEQ(e=e.serialize(), s=s.serialize()), dleq=DLEQ(e=e.serialize(), s=s.serialize()),
) )
signatures.append(signature) signatures.append(signature)
# bump keyset balance
await self.crud.bump_keyset_balance(
db=self.db, keyset=self.keysets[keyset_id], amount=amount, conn=conn
)
return signatures return signatures

View File

@@ -801,7 +801,7 @@ async def m020_add_state_to_mint_and_melt_quotes(db: Database):
async with db.connect() as conn: async with db.connect() as conn:
rows: List[RowMapping] = await conn.fetchall( rows: List[RowMapping] = await conn.fetchall(
f"SELECT * FROM {db.table_with_schema('mint_quotes')}" f"SELECT * FROM {db.table_with_schema('mint_quotes')}"
) ) # type: ignore
for row in rows: for row in rows:
if row.get("issued"): if row.get("issued"):
state = "issued" state = "issued"
@@ -817,7 +817,7 @@ async def m020_add_state_to_mint_and_melt_quotes(db: Database):
async with db.connect() as conn: async with db.connect() as conn:
rows2: List[RowMapping] = await conn.fetchall( rows2: List[RowMapping] = await conn.fetchall(
f"SELECT * FROM {db.table_with_schema('melt_quotes')}" f"SELECT * FROM {db.table_with_schema('melt_quotes')}"
) ) # type: ignore
for row in rows2: for row in rows2:
if row["paid"]: if row["paid"]:
state = "paid" state = "paid"
@@ -929,3 +929,42 @@ async def m026_keyset_specific_balance_views(db: Database):
await add_missing_id_to_proofs_and_promises(db, conn) await add_missing_id_to_proofs_and_promises(db, conn)
await drop_balance_views(db, conn) await drop_balance_views(db, conn)
await create_balance_views(db, conn) await create_balance_views(db, conn)
async def m027_add_balance_to_keysets_and_log_table(db: Database):
async with db.connect() as conn:
await conn.execute(
f"""
ALTER TABLE {db.table_with_schema('keysets')}
ADD COLUMN balance INTEGER NOT NULL DEFAULT 0
"""
)
await conn.execute(
f"""
ALTER TABLE {db.table_with_schema('keysets')}
ADD COLUMN fees_paid INTEGER NOT NULL DEFAULT 0
"""
)
# copy the balances from the balance view for each keyset
await conn.execute(
f"""
UPDATE {db.table_with_schema('keysets')}
SET balance = COALESCE(b.balance, 0)
FROM (
SELECT keyset, balance
FROM {db.table_with_schema('balance')}
) AS b
WHERE {db.table_with_schema('keysets')}.id = b.keyset
"""
)
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {db.table_with_schema('balance_log')} (
unit TEXT NOT NULL,
keyset_balance INTEGER NOT NULL,
keyset_fees_paid INTEGER NOT NULL,
backend_balance INTEGER NOT NULL,
time TIMESTAMP DEFAULT {db.timestamp_now}
);
"""
)

View File

@@ -80,7 +80,7 @@ ledger = Ledger(
# start auth ledger # start auth ledger
auth_ledger = AuthLedger( auth_ledger = AuthLedger(
db=Database("auth", settings.auth_database), db=Database("auth", settings.mint_auth_database),
seed="auth seed here", seed="auth seed here",
amounts=[1], amounts=[1],
derivation_path="m/0'/999'/0'", derivation_path="m/0'/999'/0'",

159
cashu/mint/watchdog.py Normal file
View File

@@ -0,0 +1,159 @@
import asyncio
from typing import List, Optional, Tuple
from loguru import logger
from cashu.core.db import Connection, Database
from ..core.base import Amount, MintBalanceLogEntry, Unit
from ..core.settings import settings
from ..lightning.base import LightningBackend
from .protocols import SupportsBackends, SupportsDb
class LedgerWatchdog(SupportsDb, SupportsBackends):
watcher_db: Database
abort_queue: asyncio.Queue = asyncio.Queue(0)
def __init__(self) -> None:
self.watcher_db = Database(self.db.name, self.db.db_location)
return
async def get_unit_balance_and_fees(
self,
unit: Unit,
db: Database,
conn: Optional[Connection] = None,
) -> Tuple[Amount, Amount]:
keysets = await self.crud.get_keyset(db=db, unit=unit.name, conn=conn)
balance = Amount(unit, 0)
fees_paid = Amount(unit, 0)
for keyset in keysets:
balance_update = await self.crud.get_balance(keyset, db=db, conn=conn)
balance += balance_update[0]
fees_paid += balance_update[1]
return balance, fees_paid
async def dispatch_watchdogs(self) -> List[asyncio.Task]:
tasks = []
for method, unitbackends in self.backends.items():
for unit, backend in unitbackends.items():
tasks.append(
asyncio.create_task(self.dispatch_backend_checker(unit, backend))
)
tasks.append(asyncio.create_task(self.monitor_abort_queue()))
return tasks
async def monitor_abort_queue(self):
while True:
await self.abort_queue.get()
if settings.mint_watchdog_ignore_mismatch:
logger.warning(
"Ignoring balance mismatch due to MINT_WATCHDOG_IGNORE_MISMATCH setting"
)
continue
logger.error(
"Shutting down the mint due to balance mismatch. Fix the balance mismatch and restart the mint or set MINT_WATCHDOG_IGNORE_MISMATCH=True to ignore the mismatch."
)
raise SystemExit
async def get_balance(self, unit: Unit) -> Tuple[Amount, Amount]:
"""Returns the balance of the mint for this unit."""
return await self.get_unit_balance_and_fees(unit=unit, db=self.db)
async def dispatch_backend_checker(
self, unit: Unit, backend: LightningBackend
) -> None:
logger.info(
f"Dispatching backend checker for unit: {unit.name} and backend: {backend.__class__.__name__}"
)
while True:
backend_status = await backend.status()
backend_balance = backend_status.balance
last_balance_log_entry: MintBalanceLogEntry | None = None
async with self.watcher_db.connect() as conn:
last_balance_log_entry = await self.crud.get_last_balance_log_entry(
unit=unit, db=self.watcher_db
)
keyset_balance, keyset_fees_paid = await self.get_unit_balance_and_fees(
unit, db=self.watcher_db, conn=conn
)
logger.debug(f"Last balance log entry: {last_balance_log_entry}")
logger.debug(
f"Backend balance {backend.__class__.__name__}: {backend_balance}"
)
logger.debug(
f"Unit balance {unit.name}: {keyset_balance}, fees paid: {keyset_fees_paid}"
)
ok = await self.check_balances_and_abort(
backend,
last_balance_log_entry,
backend_balance,
keyset_balance,
keyset_fees_paid,
)
if ok or settings.mint_watchdog_ignore_mismatch:
await self.crud.store_balance_log(
backend_balance,
keyset_balance,
keyset_fees_paid,
db=self.db,
conn=conn,
)
await asyncio.sleep(settings.mint_watchdog_balance_check_interval_seconds)
async def check_balances_and_abort(
self,
backend: LightningBackend,
last_balance_log_entry: MintBalanceLogEntry | None,
backend_balance: Amount,
keyset_balance: Amount,
keyset_fees_paid: Amount,
) -> bool:
"""Check if the backend balance and the mint balance match.
If they don't match, log a warning and raise an exception that will shut down the mint.
Returns True if the balances check succeeded, False otherwise.
Args:
backend (LightningBackend): Backend to check the balance against
last_balance_log_entry (MintBalanceLogEntry | None): Last balance log entry in the database
backend_balance (Amount): Balance of the backend
keyset_balance (Amount): Balance of the mint
Returns:
bool: True if the balances check succeeded, False otherwise
"""
if keyset_balance + keyset_fees_paid > backend_balance:
logger.warning(
f"Backend balance {backend.__class__.__name__}: {backend_balance} is smaller than issued unit balance {keyset_balance.unit}: {keyset_balance}"
)
await self.abort_queue.put(True)
return False
if last_balance_log_entry:
last_balance_delta = last_balance_log_entry.backend_balance - (
last_balance_log_entry.keyset_balance
+ last_balance_log_entry.keyset_fees_paid
)
current_balance_delta = backend_balance - (
keyset_balance + keyset_fees_paid
)
if last_balance_delta > current_balance_delta:
logger.warning(
f"Balance delta mismatch: before: {last_balance_delta} - now: {current_balance_delta}"
)
logger.warning(
f"Balances before: backend: {last_balance_log_entry.backend_balance}, issued ecash: {last_balance_log_entry.keyset_balance}, fees earned: {last_balance_log_entry.keyset_fees_paid}"
)
logger.warning(
f"Balances now: backend: {backend_balance}, issued ecash: {keyset_balance}, fees earned: {keyset_fees_paid}"
)
await self.abort_queue.put(True)
return False
return True

View File

@@ -1,9 +0,0 @@
from ...core.base import Token
from ...wallet.crud import get_keysets
async def verify_mints(wallet, tokenObj: Token):
# verify mints
mint = tokenObj.mint
mint_keysets = await get_keysets(mint_url=mint, db=wallet.db)
assert len(mint_keysets), "We don't know this mint."

View File

@@ -1,13 +0,0 @@
import uvicorn
from ...core.settings import settings
def start_api_server(port=settings.api_port, host=settings.api_host):
config = uvicorn.Config(
"cashu.wallet.api.app:app",
port=port,
host=host,
)
server = uvicorn.Server(config)
server.run()

View File

@@ -1,40 +0,0 @@
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from loguru import logger
from ...core.settings import settings
from .router import router
# from fastapi_profiler import PyInstrumentProfilerMiddleware
def create_app() -> FastAPI:
app = FastAPI(
title="Cashu Wallet REST API",
description="REST API for Cashu Nutshell",
version=settings.version,
license_info={
"name": "MIT License",
"url": "https://raw.githubusercontent.com/cashubtc/cashu/main/LICENSE",
},
)
# app.add_middleware(PyInstrumentProfilerMiddleware)
return app
app = create_app()
@app.middleware("http")
async def catch_exceptions(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
logger.error(f"Exception: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}
)
app.include_router(router=router)

View File

@@ -1,71 +0,0 @@
from typing import Dict, List, Optional
from pydantic import BaseModel
from ...core.base import MeltQuote, MintQuote
class SwapResponse(BaseModel):
outgoing_mint: str
incoming_mint: str
mint_quote: MintQuote
balances: Dict
class BalanceResponse(BaseModel):
balance: int
keysets: Optional[Dict] = None
mints: Optional[Dict] = None
class SendResponse(BaseModel):
balance: int
token: str
npub: Optional[str] = None
class ReceiveResponse(BaseModel):
initial_balance: int
balance: int
class BurnResponse(BaseModel):
balance: int
class PendingResponse(BaseModel):
pending_token: Dict
class LockResponse(BaseModel):
P2PK: Optional[str]
class LocksResponse(BaseModel):
locks: List[str]
class InvoicesResponse(BaseModel):
mint_quotes: List[MintQuote]
melt_quotes: List[MeltQuote]
class WalletsResponse(BaseModel):
wallets: Dict
class RestoreResponse(BaseModel):
balance: int
class InfoResponse(BaseModel):
version: str
wallet: str
debug: bool
cashu_dir: str
mint_urls: List[str] = []
settings: Optional[str]
tor: bool
nostr_public_key: Optional[str] = None
nostr_relays: List[str] = []
socks_proxy: Optional[str] = None

View File

@@ -1,471 +0,0 @@
import os
from datetime import datetime
from itertools import groupby, islice
from operator import itemgetter
from os import listdir
from os.path import isdir, join
from typing import Optional
from fastapi import APIRouter, Query
from ...core.base import Token, TokenV3
from ...core.helpers import sum_proofs
from ...core.settings import settings
from ...lightning.base import (
InvoiceResponse,
PaymentResponse,
PaymentStatus,
StatusResponse,
)
from ...nostr.client.client import NostrClient
from ...tor.tor import TorProxy
from ...wallet.crud import (
get_bolt11_melt_quotes,
get_bolt11_mint_quotes,
get_reserved_proofs,
)
from ...wallet.helpers import (
deserialize_token_from_string,
init_wallet,
list_mints,
receive,
send,
)
from ...wallet.nostr import receive_nostr, send_nostr
from ...wallet.wallet import Wallet as Wallet
from ..lightning.lightning import LightningWallet
from .api_helpers import verify_mints
from .responses import (
BalanceResponse,
BurnResponse,
InfoResponse,
InvoicesResponse,
LockResponse,
LocksResponse,
PendingResponse,
ReceiveResponse,
RestoreResponse,
SendResponse,
SwapResponse,
WalletsResponse,
)
router: APIRouter = APIRouter()
async def mint_wallet(
mint_url: Optional[str] = None, raise_connection_error: bool = True
) -> LightningWallet:
lightning_wallet = await LightningWallet.with_db(
mint_url or settings.mint_url,
db=os.path.join(settings.cashu_dir, settings.wallet_name),
name=settings.wallet_name,
)
await lightning_wallet.async_init(raise_connection_error=raise_connection_error)
return lightning_wallet
wallet = LightningWallet(
settings.mint_url,
db=os.path.join(settings.cashu_dir, settings.wallet_name),
name=settings.wallet_name,
)
@router.on_event("startup")
async def start_wallet():
global wallet
wallet = await mint_wallet(settings.mint_url, raise_connection_error=False)
if settings.tor and not TorProxy().check_platform():
raise Exception("tor not working.")
@router.post(
"/lightning/pay_invoice",
name="Pay lightning invoice",
response_model=PaymentResponse,
)
async def pay(
bolt11: str = Query(default=..., description="Lightning invoice to pay"),
mint: str = Query(
default=None,
description="Mint URL to pay from (None for default mint)",
),
) -> PaymentResponse:
global wallet
if mint:
wallet = await mint_wallet(mint)
payment_response = await wallet.pay_invoice(bolt11)
ret = PaymentResponse(**payment_response.dict())
ret.fee = None # TODO: we can't return an Amount object, overwriting
return ret
@router.get(
"/lightning/payment_state",
name="Request lightning invoice",
response_model=PaymentStatus,
)
async def payment_state(
payment_hash: str = Query(default=None, description="Id of paid invoice"),
mint: str = Query(
default=None,
description="Mint URL to create an invoice at (None for default mint)",
),
) -> PaymentStatus:
global wallet
if mint:
wallet = await mint_wallet(mint)
state = await wallet.get_payment_status(payment_hash)
return state
@router.post(
"/lightning/create_invoice",
name="Request lightning invoice",
response_model=InvoiceResponse,
)
async def create_invoice(
amount: int = Query(default=..., description="Amount to request in invoice"),
mint: str = Query(
default=None,
description="Mint URL to create an invoice at (None for default mint)",
),
) -> InvoiceResponse:
global wallet
if mint:
wallet = await mint_wallet(mint)
invoice = await wallet.create_invoice(amount)
return invoice
@router.get(
"/lightning/invoice_state",
name="Request lightning invoice",
response_model=PaymentStatus,
)
async def invoice_state(
payment_request: str = Query(default=None, description="Payment request to check"),
mint: str = Query(
default=None,
description="Mint URL to create an invoice at (None for default mint)",
),
) -> PaymentStatus:
global wallet
if mint:
wallet = await mint_wallet(mint)
state = await wallet.get_invoice_status(payment_request)
return state
@router.get(
"/lightning/balance",
name="Balance",
summary="Display balance.",
response_model=StatusResponse,
)
async def lightning_balance() -> StatusResponse:
try:
await wallet.load_proofs(reload=True)
except Exception as exc:
return StatusResponse(error_message=str(exc), balance=0)
return StatusResponse(error_message=None, balance=wallet.available_balance * 1000)
@router.post(
"/swap",
name="Multi-mint swaps",
summary="Swap funds between mints",
response_model=SwapResponse,
)
async def swap(
amount: int = Query(default=..., description="Amount to swap between mints"),
outgoing_mint: str = Query(default=..., description="URL of outgoing mint"),
incoming_mint: str = Query(default=..., description="URL of incoming mint"),
):
incoming_wallet = await mint_wallet(incoming_mint)
outgoing_wallet = await mint_wallet(outgoing_mint)
if incoming_wallet.url == outgoing_wallet.url:
raise Exception("mints for swap have to be different")
# request invoice from incoming mint
mint_quote = await incoming_wallet.request_mint(amount)
# pay invoice from outgoing mint
await outgoing_wallet.load_proofs(reload=True)
quote = await outgoing_wallet.melt_quote(mint_quote.request)
total_amount = quote.amount + quote.fee_reserve
if outgoing_wallet.available_balance < total_amount:
raise Exception("balance too low")
_, send_proofs = await outgoing_wallet.swap_to_send(
outgoing_wallet.proofs, total_amount, set_reserved=True
)
await outgoing_wallet.melt(
send_proofs, mint_quote.request, quote.fee_reserve, quote.quote
)
# mint token in incoming mint
await incoming_wallet.mint(amount, quote_id=mint_quote.quote)
await incoming_wallet.load_proofs(reload=True)
mint_balances = await incoming_wallet.balance_per_minturl()
return SwapResponse(
outgoing_mint=outgoing_mint,
incoming_mint=incoming_mint,
mint_quote=mint_quote,
balances=mint_balances,
)
@router.get(
"/balance",
name="Balance",
summary="Display balance.",
response_model=BalanceResponse,
)
async def balance():
await wallet.load_proofs(reload=True)
keyset_balances = wallet.balance_per_keyset()
mint_balances = await wallet.balance_per_minturl()
return BalanceResponse(
balance=wallet.available_balance, keysets=keyset_balances, mints=mint_balances
)
@router.post("/send", name="Send tokens", response_model=SendResponse)
async def send_command(
amount: int = Query(default=..., description="Amount to send"),
nostr: str = Query(default=None, description="Send to nostr pubkey"),
lock: str = Query(default=None, description="Lock tokens (P2PK)"),
mint: str = Query(
default=None,
description="Mint URL to send from (None for default mint)",
),
offline: bool = Query(default=False, description="Force offline send."),
):
global wallet
if mint:
wallet = await mint_wallet(mint)
if not nostr:
balance, token = await send(
wallet, amount=amount, lock=lock, legacy=False, offline=offline
)
return SendResponse(balance=balance, token=token)
else:
token, pubkey = await send_nostr(wallet, amount=amount, pubkey=nostr)
return SendResponse(balance=wallet.available_balance, token=token, npub=pubkey)
@router.post("/receive", name="Receive tokens", response_model=ReceiveResponse)
async def receive_command(
token: str = Query(default=None, description="Token to receive"),
nostr: bool = Query(default=False, description="Receive tokens via nostr"),
all: bool = Query(default=False, description="Receive all pending tokens"),
):
wallet = await mint_wallet()
initial_balance = wallet.available_balance
if token:
tokenObj: Token = deserialize_token_from_string(token)
await verify_mints(wallet, tokenObj)
await receive(wallet, tokenObj)
elif nostr:
await receive_nostr(wallet)
elif all:
reserved_proofs = await get_reserved_proofs(wallet.db)
balance = None
if len(reserved_proofs):
for _, value in groupby(reserved_proofs, key=itemgetter("send_id")): # type: ignore
proofs = list(value)
token = await wallet.serialize_proofs(proofs)
tokenObj = deserialize_token_from_string(token)
await verify_mints(wallet, tokenObj)
await receive(wallet, tokenObj)
else:
raise Exception("enter token or use either flag --nostr or --all.")
balance = wallet.available_balance
return ReceiveResponse(initial_balance=initial_balance, balance=balance)
@router.post("/burn", name="Burn spent tokens", response_model=BurnResponse)
async def burn(
token: str = Query(default=None, description="Token to burn"),
all: bool = Query(default=False, description="Burn all spent tokens"),
force: bool = Query(default=False, description="Force check on all tokens."),
delete: str = Query(
default=None,
description="Forcefully delete pending token by send ID if mint is unavailable",
),
mint: str = Query(
default=None,
description="Mint URL to burn from (None for default mint)",
),
):
global wallet
if not delete:
wallet = await mint_wallet(mint)
if not (all or token or force or delete) or (token and all):
raise Exception(
"enter a token or use --all to burn all pending tokens, --force to"
" check all tokens or --delete with send ID to force-delete pending"
" token from list if mint is unavailable.",
)
if all:
# check only those who are flagged as reserved
proofs = await get_reserved_proofs(wallet.db)
elif force:
# check all proofs in db
proofs = wallet.proofs
elif delete:
reserved_proofs = await get_reserved_proofs(wallet.db)
proofs = [proof for proof in reserved_proofs if proof["send_id"] == delete]
else:
# check only the specified ones
tokenObj = TokenV3.deserialize(token)
proofs = tokenObj.proofs
if delete:
await wallet.invalidate(proofs)
else:
await wallet.invalidate(proofs, check_spendable=True)
return BurnResponse(balance=wallet.available_balance)
@router.get("/pending", name="Show pending tokens", response_model=PendingResponse)
async def pending(
number: int = Query(default=None, description="Show only n pending tokens"),
offset: int = Query(
default=0, description="Show pending tokens only starting from offset"
),
):
reserved_proofs = await get_reserved_proofs(wallet.db)
result: dict = {}
if len(reserved_proofs):
sorted_proofs = sorted(reserved_proofs, key=itemgetter("send_id")) # type: ignore
if number:
number += offset
for i, (key, value) in islice(
enumerate(
groupby(
sorted_proofs,
key=itemgetter("send_id"), # type: ignore
)
),
offset,
number,
):
grouped_proofs = list(value)
token = await wallet.serialize_proofs(grouped_proofs)
tokenObj = deserialize_token_from_string(token)
mint = tokenObj.mint
reserved_date = datetime.utcfromtimestamp(
int(grouped_proofs[0].time_reserved) # type: ignore
).strftime("%Y-%m-%d %H:%M:%S")
result.update(
{
f"{i}": {
"amount": sum_proofs(grouped_proofs),
"time": reserved_date,
"ID": key,
"token": token,
"mint": mint,
}
}
)
return PendingResponse(pending_token=result)
@router.get("/lock", name="Generate receiving lock", response_model=LockResponse)
async def lock():
pubkey = await wallet.create_p2pk_pubkey()
return LockResponse(P2PK=pubkey)
@router.get("/locks", name="Show unused receiving locks", response_model=LocksResponse)
async def locks():
pubkey = await wallet.create_p2pk_pubkey()
return LocksResponse(locks=[pubkey])
@router.get(
"/invoices", name="List all pending invoices", response_model=InvoicesResponse
)
async def invoices():
mint_quotes = await get_bolt11_mint_quotes(db=wallet.db)
melt_quotes = await get_bolt11_melt_quotes(db=wallet.db)
return InvoicesResponse(mint_quotes=mint_quotes, melt_quotes=melt_quotes)
@router.get(
"/wallets", name="List all available wallets", response_model=WalletsResponse
)
async def wallets():
wallets = [
d for d in listdir(settings.cashu_dir) if isdir(join(settings.cashu_dir, d))
]
try:
wallets.remove("mint")
except ValueError:
pass
result = {}
for w in wallets:
wallet = Wallet(settings.mint_url, os.path.join(settings.cashu_dir, w), name=w)
try:
await init_wallet(wallet)
if wallet.proofs and len(wallet.proofs):
active_wallet = False
if w == wallet.name:
active_wallet = True
if active_wallet:
result.update(
{
f"{w}": {
"balance": sum_proofs(wallet.proofs),
"available": sum_proofs(
[p for p in wallet.proofs if not p.reserved]
),
}
}
)
except Exception:
pass
return WalletsResponse(wallets=result)
@router.post("/v1/restore", name="Restore wallet", response_model=RestoreResponse)
async def restore(
to: int = Query(default=..., description="Counter to which restore the wallet"),
):
if to < 0:
raise Exception("Counter must be positive")
await wallet.load_mint()
await wallet.restore_promises_from_to(wallet.keyset_id, 0, to)
await wallet.invalidate(wallet.proofs, check_spendable=True)
return RestoreResponse(balance=wallet.available_balance)
@router.get("/info", name="Information about Cashu wallet", response_model=InfoResponse)
async def info():
if settings.nostr_private_key:
try:
client = NostrClient(private_key=settings.nostr_private_key, connect=False)
nostr_public_key = client.private_key.bech32()
nostr_relays = settings.nostr_relays
except Exception:
nostr_public_key = "Invalid key"
nostr_relays = []
else:
nostr_public_key = None
nostr_relays = []
mint_list = await list_mints(wallet)
return InfoResponse(
version=settings.version,
wallet=wallet.name,
debug=settings.debug,
cashu_dir=settings.cashu_dir,
mint_urls=mint_list,
settings=settings.env_file,
tor=settings.tor,
nostr_public_key=nostr_public_key,
nostr_relays=nostr_relays,
socks_proxy=settings.socks_proxy,
)

View File

@@ -42,7 +42,6 @@ from ...wallet.crud import (
get_seed_and_mnemonic, get_seed_and_mnemonic,
) )
from ...wallet.wallet import Wallet as Wallet from ...wallet.wallet import Wallet as Wallet
from ..api.api_server import start_api_server
from ..auth.auth import WalletAuth from ..auth.auth import WalletAuth
from ..cli.cli_helpers import ( from ..cli.cli_helpers import (
get_mint_wallet, get_mint_wallet,
@@ -71,13 +70,6 @@ class NaturalOrderGroup(click.Group):
return self.commands.keys() return self.commands.keys()
def run_api_server(ctx, param, daemon):
if not daemon:
return
start_api_server()
ctx.exit()
# https://github.com/pallets/click/issues/85#issuecomment-503464628 # https://github.com/pallets/click/issues/85#issuecomment-503464628
def coro(f): def coro(f):
@wraps(f) @wraps(f)
@@ -121,9 +113,7 @@ def init_auth_wallet(func):
if settings.debug: if settings.debug:
await auth_wallet.load_proofs(reload=True) await auth_wallet.load_proofs(reload=True)
logger.debug( logger.debug(f"Auth balance: {auth_wallet.available_balance}")
f"Auth balance: {auth_wallet.unit.str(auth_wallet.available_balance)}"
)
return ret return ret
@@ -151,15 +141,6 @@ def init_auth_wallet(func):
default=None, default=None,
help=f"Wallet unit (default: {settings.wallet_unit}).", help=f"Wallet unit (default: {settings.wallet_unit}).",
) )
@click.option(
"--daemon",
"-d",
is_flag=True,
is_eager=True,
expose_value=False,
callback=run_api_server,
help="Start server for wallet REST API",
)
@click.option( @click.option(
"--tests", "--tests",
"-t", "-t",
@@ -263,10 +244,8 @@ async def pay(
await wallet.load_mint() await wallet.load_mint()
await print_balance(ctx) await print_balance(ctx)
payment_hash = bolt11.decode(invoice).payment_hash payment_hash = bolt11.decode(invoice).payment_hash
amount_mpp_msat = None # we assume `amount` to be in sats
if amount: amount_mpp_msat = amount * 1000 if amount else None
# we assume `amount` to be in sats
amount_mpp_msat = amount * 1000
quote = await wallet.melt_quote(invoice, amount_mpp_msat) quote = await wallet.melt_quote(invoice, amount_mpp_msat)
logger.debug(f"Quote: {quote}") logger.debug(f"Quote: {quote}")
total_amount = quote.amount + quote.fee_reserve total_amount = quote.amount + quote.fee_reserve
@@ -291,9 +270,17 @@ async def pay(
assert total_amount > 0, "amount is not positive" assert total_amount > 0, "amount is not positive"
# we need to include fees so we can use the proofs for melting the `total_amount` # we need to include fees so we can use the proofs for melting the `total_amount`
send_proofs, _ = await wallet.select_to_send( send_proofs, _ = await wallet.select_to_send(
wallet.proofs, total_amount, include_fees=True, set_reserved=True wallet.proofs, total_amount, include_fees=True, set_reserved=False
) )
print("Paying Lightning invoice ...", end="", flush=True) print("Paying Lightning invoice ...", end="", flush=True)
assert total_amount > 0, "amount is not positive"
logger.debug(
f"Total amount: {total_amount} available balance: {wallet.available_balance}"
)
if wallet.available_balance < total_amount:
print(" Error: Balance too low.")
return
try: try:
melt_response = await wallet.melt( melt_response = await wallet.melt(
send_proofs, invoice, quote.fee_reserve, quote.quote send_proofs, invoice, quote.fee_reserve, quote.quote
@@ -600,12 +587,12 @@ async def balance(ctx: Context, verbose):
if verbose: if verbose:
print( print(
f"Balance: {wallet.unit.str(wallet.available_balance)} (pending:" f"Balance: {wallet.available_balance} (pending:"
f" {wallet.unit.str(wallet.balance-wallet.available_balance)}) in" f" {wallet.balance-wallet.available_balance}) in"
f" {len([p for p in wallet.proofs if not p.reserved])} tokens" f" {len([p for p in wallet.proofs if not p.reserved])} tokens"
) )
else: else:
print(f"Balance: {wallet.unit.str(wallet.available_balance)}") print(f"Balance: {wallet.available_balance}")
@cli.command("send", help="Send tokens.") @cli.command("send", help="Send tokens.")
@@ -1319,4 +1306,4 @@ async def auth(ctx: Context, mint: bool, force: bool, password: bool):
new_proofs = await auth_wallet.mint_blind_auth() new_proofs = await auth_wallet.mint_blind_auth()
print(f"Minted {auth_wallet.unit.str(sum_proofs(new_proofs))} auth tokens.") print(f"Minted {auth_wallet.unit.str(sum_proofs(new_proofs))} auth tokens.")
print(f"Auth balance: {auth_wallet.unit.str(auth_wallet.available_balance)}") print(f"Auth balance: {auth_wallet.available_balance}")

View File

@@ -23,7 +23,7 @@ from ..helpers import (
async def print_balance(ctx: Context): async def print_balance(ctx: Context):
wallet: Wallet = ctx.obj["WALLET"] wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_proofs(reload=True) await wallet.load_proofs(reload=True)
print(f"Balance: {wallet.unit.str(wallet.available_balance)}") print(f"Balance: {wallet.available_balance}")
async def get_unit_wallet(ctx: Context, force_select: bool = False): async def get_unit_wallet(ctx: Context, force_select: bool = False):

View File

@@ -8,6 +8,7 @@ from bip32 import BIP32
from loguru import logger from loguru import logger
from ..core.base import ( from ..core.base import (
Amount,
BlindedMessage, BlindedMessage,
BlindedSignature, BlindedSignature,
DLEQWallet, DLEQWallet,
@@ -1273,12 +1274,12 @@ class Wallet(
# ---------- BALANCE CHECKS ---------- # ---------- BALANCE CHECKS ----------
@property @property
def balance(self): def balance(self) -> Amount:
return sum_proofs(self.proofs) return Amount(self.unit, sum_proofs(self.proofs))
@property @property
def available_balance(self): def available_balance(self) -> Amount:
return sum_proofs([p for p in self.proofs if not p.reserved]) return Amount(self.unit, sum_proofs([p for p in self.proofs if not p.reserved]))
@property @property
def proof_amounts(self): def proof_amounts(self):

View File

@@ -51,7 +51,8 @@ settings.mint_lnd_enable_mpp = True
settings.mint_clnrest_enable_mpp = True settings.mint_clnrest_enable_mpp = True
settings.mint_input_fee_ppk = 0 settings.mint_input_fee_ppk = 0
settings.db_connection_pool = True settings.db_connection_pool = True
# settings.mint_require_auth = False settings.mint_require_auth = False
settings.mint_watchdog_enabled = False
assert "test" in settings.cashu_dir assert "test" in settings.cashu_dir
shutil.rmtree(settings.cashu_dir, ignore_errors=True) shutil.rmtree(settings.cashu_dir, ignore_errors=True)

View File

@@ -220,4 +220,4 @@ async def pay_if_regtest(bolt11: str) -> None:
pay_real_invoice(bolt11) pay_real_invoice(bolt11)
if is_fake: if is_fake:
await asyncio.sleep(settings.fakewallet_delay_incoming_payment or 0) await asyncio.sleep(settings.fakewallet_delay_incoming_payment or 0)
await asyncio.sleep(0.1) await asyncio.sleep(0.5)

View File

@@ -2,7 +2,7 @@ from typing import List
import pytest import pytest
from cashu.core.base import BlindedMessage, MintKeyset, Proof, Unit from cashu.core.base import BlindedMessage, Proof, Unit
from cashu.core.crypto.b_dhke import step1_alice from cashu.core.crypto.b_dhke import step1_alice
from cashu.core.helpers import calculate_number_of_blank_outputs from cashu.core.helpers import calculate_number_of_blank_outputs
from cashu.core.models import PostMintQuoteRequest from cashu.core.models import PostMintQuoteRequest
@@ -219,11 +219,9 @@ async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledg
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_balance(ledger: Ledger): async def test_get_balance(ledger: Ledger):
unit = Unit["sat"] unit = Unit["sat"]
active_keyset: MintKeyset = next( balance, fees_paid = await ledger.get_balance(unit)
filter(lambda k: k.active and k.unit == unit, ledger.keysets.values())
)
balance = await ledger.get_balance(active_keyset)
assert balance == 0 assert balance == 0
assert fees_paid == 0
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -45,13 +45,13 @@ async def test_mint_proofs_pending(wallet: Wallet, ledger: Ledger):
proofs_states_before_split = await wallet.check_proof_state(proofs) proofs_states_before_split = await wallet.check_proof_state(proofs)
assert all([s.unspent for s in proofs_states_before_split.states]) assert all([s.unspent for s in proofs_states_before_split.states])
await ledger.db_write._verify_spent_proofs_and_set_pending(proofs) await ledger.db_write._verify_spent_proofs_and_set_pending(proofs, ledger.keysets)
proof_states = await wallet.check_proof_state(proofs) proof_states = await wallet.check_proof_state(proofs)
assert all([s.pending for s in proof_states.states]) assert all([s.pending for s in proof_states.states])
await assert_err(wallet.split(wallet.proofs, 20), "proofs are pending.") await assert_err(wallet.split(wallet.proofs, 20), "proofs are pending.")
await ledger.db_write._unset_proofs_pending(proofs) await ledger.db_write._unset_proofs_pending(proofs, ledger.keysets)
await wallet.split(proofs, 20) await wallet.split(proofs, 20)

View File

@@ -75,9 +75,20 @@ async def test_db_tables(ledger: Ledger):
"mint_quotes", "mint_quotes",
"mint_pubkeys", "mint_pubkeys",
"promises", "promises",
"balance_log",
"balance",
"balance_issued",
"balance_redeemed",
] ]
for table in tables_expected:
assert table in tables tables.sort()
tables_expected.sort()
if ledger.db.type == db.SQLITE:
# SQLite does not return views
tables_expected.remove("balance")
tables_expected.remove("balance_issued")
tables_expected.remove("balance_redeemed")
assert tables == tables_expected
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -202,8 +213,12 @@ async def test_db_verify_spent_proofs_and_set_pending_race_condition(
await assert_err_multiple( await assert_err_multiple(
asyncio.gather( asyncio.gather(
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs), ledger.db_write._verify_spent_proofs_and_set_pending(
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs), wallet.proofs, ledger.keysets
),
ledger.db_write._verify_spent_proofs_and_set_pending(
wallet.proofs, ledger.keysets
),
), ),
[ [
"failed to acquire database lock", "failed to acquire database lock",
@@ -228,11 +243,15 @@ async def test_db_verify_spent_proofs_and_set_pending_delayed_no_race_condition(
async def delayed_verify_spent_proofs_and_set_pending(): async def delayed_verify_spent_proofs_and_set_pending():
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
await ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs) await ledger.db_write._verify_spent_proofs_and_set_pending(
wallet.proofs, ledger.keysets
)
await assert_err( await assert_err(
asyncio.gather( asyncio.gather(
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs), ledger.db_write._verify_spent_proofs_and_set_pending(
wallet.proofs, ledger.keysets
),
delayed_verify_spent_proofs_and_set_pending(), delayed_verify_spent_proofs_and_set_pending(),
), ),
"proofs are pending", "proofs are pending",
@@ -255,8 +274,12 @@ async def test_db_verify_spent_proofs_and_set_pending_no_race_condition_differen
assert len(wallet.proofs) == 2 assert len(wallet.proofs) == 2
asyncio.gather( asyncio.gather(
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs[:1]), ledger.db_write._verify_spent_proofs_and_set_pending(
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs[1:]), wallet.proofs[:1], ledger.keysets
),
ledger.db_write._verify_spent_proofs_and_set_pending(
wallet.proofs[1:], ledger.keysets
),
) )
@@ -325,6 +348,8 @@ async def test_db_lock_table(wallet: Wallet, ledger: Ledger):
async with ledger.db.connect(lock_table="proofs_pending", lock_timeout=0.1) as conn: async with ledger.db.connect(lock_table="proofs_pending", lock_timeout=0.1) as conn:
assert isinstance(conn, Connection) assert isinstance(conn, Connection)
await assert_err( await assert_err(
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs), ledger.db_write._verify_spent_proofs_and_set_pending(
wallet.proofs, ledger.keysets
),
"failed to acquire database lock", "failed to acquire database lock",
) )

View File

@@ -153,7 +153,7 @@ async def create_pending_melts(
quote=quote, quote=quote,
db=ledger.db, db=ledger.db,
) )
pending_proof = Proof(amount=123, C="asdasd", secret="asdasd", id=quote_id) pending_proof = Proof(amount=123, C="asdasd", secret="asdasd", id=ledger.keyset.id)
await ledger.crud.set_proof_pending( await ledger.crud.set_proof_pending(
db=ledger.db, db=ledger.db,
proof=pending_proof, proof=pending_proof,

View File

@@ -68,7 +68,7 @@ async def create_pending_melts(
quote=quote, quote=quote,
db=ledger.db, db=ledger.db,
) )
pending_proof = Proof(amount=123, C="asdasd", secret="asdasd", id=quote_id) pending_proof = Proof(amount=123, C="asdasd", secret="asdasd", id=ledger.keyset.id)
await ledger.crud.set_proof_pending( await ledger.crud.set_proof_pending(
db=ledger.db, db=ledger.db,
proof=pending_proof, proof=pending_proof,

View File

@@ -59,6 +59,44 @@ async def test_lightning_create_invoice(ledger: Ledger):
assert status.settled assert status.settled
@pytest.mark.asyncio
@pytest.mark.skipif(is_fake, reason="only regtest")
async def test_lightning_create_invoice_balance_change(ledger: Ledger):
invoice_amount = 1000 # sat
invoice = await ledger.backends[Method.bolt11][Unit.sat].create_invoice(
Amount(Unit.sat, invoice_amount)
)
assert invoice.ok
assert invoice.payment_request
assert invoice.checking_id
# TEST 2: check the invoice status
status = await ledger.backends[Method.bolt11][Unit.sat].get_invoice_status(
invoice.checking_id
)
assert status.pending
status = await ledger.backends[Method.bolt11][Unit.sat].status()
balance_before = status.balance
# settle the invoice
await pay_if_regtest(invoice.payment_request)
# cln takes some time to update the balance
await asyncio.sleep(SLEEP_TIME)
# TEST 3: check the invoice status
status = await ledger.backends[Method.bolt11][Unit.sat].get_invoice_status(
invoice.checking_id
)
assert status.settled
status = await ledger.backends[Method.bolt11][Unit.sat].status()
balance_after = status.balance
assert balance_after == balance_before + invoice_amount
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skipif(is_fake, reason="only regtest") @pytest.mark.skipif(is_fake, reason="only regtest")
async def test_lightning_get_payment_quote(ledger: Ledger): async def test_lightning_get_payment_quote(ledger: Ledger):

162
tests/test_mint_watchdog.py Normal file
View File

@@ -0,0 +1,162 @@
import pytest
import pytest_asyncio
from cashu.core.base import Amount, MeltQuoteState, Method, Unit
from cashu.core.models import PostMeltQuoteRequest
from cashu.core.settings import settings
from cashu.mint.ledger import Ledger
from cashu.wallet.wallet import Wallet
from tests.conftest import SERVER_ENDPOINT
from tests.helpers import (
get_real_invoice,
is_fake,
pay_if_regtest,
)
@pytest_asyncio.fixture(scope="function")
async def wallet():
wallet = await Wallet.with_db(
url=SERVER_ENDPOINT,
db="test_data/wallet",
name="wallet",
)
await wallet.load_mint()
yield wallet
@pytest.mark.asyncio
async def test_check_balances_and_abort(ledger: Ledger):
ok = await ledger.check_balances_and_abort(
ledger.backends[Method.bolt11][Unit.sat],
None,
Amount(Unit.sat, 0),
Amount(Unit.sat, 0),
Amount(Unit.sat, 0),
)
assert ok
@pytest.mark.asyncio
async def test_balance_update_on_mint(wallet: Wallet, ledger: Ledger):
balance_before, fees_paid_before = await ledger.get_unit_balance_and_fees(
Unit.sat, ledger.db
)
mint_quote = await wallet.request_mint(64)
await pay_if_regtest(mint_quote.request)
await wallet.mint(64, quote_id=mint_quote.quote)
assert wallet.balance == 64
balance_after, fees_paid_after = await ledger.get_unit_balance_and_fees(
Unit.sat, ledger.db
)
assert balance_after == balance_before + 64
assert fees_paid_after == fees_paid_before
@pytest.mark.asyncio
@pytest.mark.skipif(is_fake, reason="only works with Regtest")
async def test_balance_update_on_test_melt_internal(wallet: Wallet, ledger: Ledger):
settings.fakewallet_brr = False
# mint twice so we have enough to pay the second invoice back
mint_quote = await wallet.request_mint(128)
await pay_if_regtest(mint_quote.request)
await wallet.mint(128, quote_id=mint_quote.quote)
assert wallet.balance == 128
balance_before, fees_paid_before = await ledger.get_unit_balance_and_fees(
Unit.sat, ledger.db
)
# create a mint quote so that we can melt to it internally
payment_amount = 64
mint_quote_to_pay = await wallet.request_mint(payment_amount)
invoice_payment_request = mint_quote_to_pay.request
melt_quote = await ledger.melt_quote(
PostMeltQuoteRequest(request=invoice_payment_request, unit="sat")
)
if not settings.debug_mint_only_deprecated:
melt_quote_response_pre_payment = await wallet.get_melt_quote(melt_quote.quote)
assert (
not melt_quote_response_pre_payment.state == MeltQuoteState.paid.value
), "melt quote should not be paid"
assert melt_quote_response_pre_payment.amount == payment_amount
melt_quote_pre_payment = await ledger.get_melt_quote(melt_quote.quote)
assert not melt_quote_pre_payment.paid, "melt quote should not be paid"
assert melt_quote_pre_payment.unpaid
_, send_proofs = await wallet.swap_to_send(wallet.proofs, payment_amount)
await ledger.melt(proofs=send_proofs, quote=melt_quote.quote)
await wallet.invalidate(send_proofs, check_spendable=True)
assert wallet.balance == 64
melt_quote_post_payment = await ledger.get_melt_quote(melt_quote.quote)
assert melt_quote_post_payment.paid, "melt quote should be paid"
balance_after, fees_paid_after = await ledger.get_unit_balance_and_fees(
Unit.sat, ledger.db
)
# balance should have dropped
assert balance_after == balance_before - payment_amount
assert fees_paid_after == fees_paid_before
# now mint
await wallet.mint(payment_amount, quote_id=mint_quote_to_pay.quote)
assert wallet.balance == 128
balance_after, fees_paid_after = await ledger.get_unit_balance_and_fees(
Unit.sat, ledger.db
)
# balance should be back
assert balance_after == balance_before
assert fees_paid_after == fees_paid_before
@pytest.mark.asyncio
@pytest.mark.skipif(is_fake, reason="only works with Regtest")
async def test_balance_update_on_melt_external(wallet: Wallet, ledger: Ledger):
# mint twice so we have enough to pay the second invoice back
mint_quote = await wallet.request_mint(128)
await pay_if_regtest(mint_quote.request)
await wallet.mint(128, quote_id=mint_quote.quote)
assert wallet.balance == 128
balance_before, fees_paid_before = await ledger.get_unit_balance_and_fees(
Unit.sat, ledger.db
)
invoice_dict = get_real_invoice(64)
invoice_payment_request = invoice_dict["payment_request"]
mint_quote = await wallet.melt_quote(invoice_payment_request)
total_amount = mint_quote.amount + mint_quote.fee_reserve
_, send_proofs = await wallet.swap_to_send(wallet.proofs, total_amount)
melt_quote = await ledger.melt_quote(
PostMeltQuoteRequest(request=invoice_payment_request, unit="sat")
)
if not settings.debug_mint_only_deprecated:
melt_quote_response_pre_payment = await wallet.get_melt_quote(melt_quote.quote)
assert (
melt_quote_response_pre_payment.state == MeltQuoteState.unpaid.value
), "melt quote should not be paid"
assert melt_quote_response_pre_payment.amount == melt_quote.amount
melt_quote_resp = await ledger.melt(proofs=send_proofs, quote=melt_quote.quote)
fees_paid = melt_quote.fee_reserve - (
sum([b.amount for b in melt_quote_resp.change]) if melt_quote_resp.change else 0
)
melt_quote_post_payment = await ledger.get_melt_quote(melt_quote.quote)
assert melt_quote_post_payment.paid, "melt quote should be paid"
balance_after, fees_paid_after = await ledger.get_unit_balance_and_fees(
Unit.sat, ledger.db
)
assert balance_after == balance_before - 64 - fees_paid
assert fees_paid_after == fees_paid_before

View File

@@ -1,199 +0,0 @@
import asyncio
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from cashu.lightning.base import InvoiceResponse, PaymentResult, PaymentStatus
from cashu.wallet.api.app import app
from cashu.wallet.wallet import Wallet
from tests.conftest import SERVER_ENDPOINT
from tests.helpers import is_regtest
@pytest_asyncio.fixture(scope="function")
async def wallet():
wallet = await Wallet.with_db(
url=SERVER_ENDPOINT,
db="test_data/wallet",
name="wallet",
)
await wallet.load_mint()
yield wallet
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_invoice(wallet: Wallet):
with TestClient(app) as client:
response = client.post("/lightning/create_invoice?amount=100")
assert response.status_code == 200
invoice_response = InvoiceResponse.parse_obj(response.json())
state = PaymentStatus(result=PaymentResult.PENDING)
while state.pending:
print("checking invoice state")
response2 = client.get(
f"/lightning/invoice_state?payment_request={invoice_response.payment_request}"
)
state = PaymentStatus.parse_obj(response2.json())
await asyncio.sleep(0.1)
print("state:", state)
print("paid")
await wallet.load_proofs()
assert wallet.available_balance >= 100
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_balance():
with TestClient(app) as client:
response = client.get("/balance")
assert response.status_code == 200
assert "balance" in response.json()
assert response.json()["keysets"]
assert response.json()["mints"]
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_send(wallet: Wallet):
with TestClient(app) as client:
response = client.post("/send?amount=10")
assert response.status_code == 200
assert response.json()["balance"]
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_send_without_split(wallet: Wallet):
with TestClient(app) as client:
response = client.post("/send?amount=2&offline=true")
assert response.status_code == 200
assert response.json()["balance"]
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_send_too_much(wallet: Wallet):
with TestClient(app) as client:
response = client.post("/send?amount=110000")
assert response.status_code == 400
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_pending():
with TestClient(app) as client:
response = client.get("/pending")
assert response.status_code == 200
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_receive_all(wallet: Wallet):
with TestClient(app) as client:
response = client.post("/receive?all=true")
assert response.status_code == 200
assert response.json()["initial_balance"]
assert response.json()["balance"]
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_burn_all(wallet: Wallet):
with TestClient(app) as client:
response = client.post("/send?amount=20")
assert response.status_code == 200
response = client.post("/burn?all=true")
assert response.status_code == 200
assert response.json()["balance"]
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_pay():
with TestClient(app) as client:
invoice = (
"lnbc100n1pjjcqzfdq4gdshx6r4ypjx2ur0wd5hgpp58xvj8yn00d5"
"7uhshwzcwgy9uj3vwf5y2lr5fjf78s4w9l4vhr6xssp5stezsyty9r"
"hv3lat69g4mhqxqun56jyehhkq3y8zufh83xyfkmmq4usaqwrt5q4f"
"adm44g6crckp0hzvuyv9sja7t65hxj0ucf9y46qstkay7gfnwhuxgr"
"krf7djs38rml39l8wpn5ug9shp3n55quxhdecqfwxg23"
)
response = client.post(f"/lightning/pay_invoice?bolt11={invoice}")
assert response.status_code == 200
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_lock():
with TestClient(app) as client:
response = client.get("/lock")
assert response.status_code == 200
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_locks():
with TestClient(app) as client:
response = client.get("/locks")
assert response.status_code == 200
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_invoices():
with TestClient(app) as client:
response = client.get("/invoices")
assert response.status_code == 200
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_wallets():
with TestClient(app) as client:
response = client.get("/wallets")
assert response.status_code == 200
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_info():
with TestClient(app) as client:
response = client.get("/info")
assert response.status_code == 200
assert response.json()["version"]
@pytest.mark.skipif(is_regtest, reason="regtest")
@pytest.mark.asyncio
async def test_flow(wallet: Wallet):
with TestClient(app) as client:
response = client.get("/balance")
initial_balance = response.json()["balance"]
response = client.post("/lightning/create_invoice?amount=100")
invoice_response = InvoiceResponse.parse_obj(response.json())
state = PaymentStatus(result=PaymentResult.PENDING)
while state.pending:
print("checking invoice state")
response2 = client.get(
f"/lightning/invoice_state?payment_request={invoice_response.payment_request}"
)
state = PaymentStatus.parse_obj(response2.json())
await asyncio.sleep(0.1)
print("state:", state)
response = client.get("/balance")
assert response.json()["balance"] == initial_balance + 100
response = client.post("/send?amount=50")
response = client.get("/balance")
assert response.json()["balance"] == initial_balance + 50
response = client.post("/send?amount=50")
response = client.get("/balance")
assert response.json()["balance"] == initial_balance
response = client.get("/pending")
token = response.json()["pending_token"]["0"]["token"]
amount = response.json()["pending_token"]["0"]["amount"]
response = client.post(f"/receive?token={token}")
response = client.get("/balance")
assert response.json()["balance"] == initial_balance + amount

View File

@@ -115,7 +115,7 @@ def test_balance(cli_prefix):
print("------ BALANCE ------") print("------ BALANCE ------")
print(result.output) print(result.output)
w = asyncio.run(init_wallet()) w = asyncio.run(init_wallet())
assert f"Balance: {w.available_balance} sat" in result.output assert f"Balance: {w.available_balance}" in result.output
assert result.exit_code == 0 assert result.exit_code == 0