diff --git a/cashu/mint/crud.py b/cashu/mint/crud.py index 64cc5a1..6ff2a63 100644 --- a/cashu/mint/crud.py +++ b/cashu/mint/crud.py @@ -119,6 +119,7 @@ class LedgerCrud(ABC): @abstractmethod async def get_balance( self, + keyset: MintKeyset, db: Database, conn: Optional[Connection] = None, ) -> int: ... @@ -668,15 +669,22 @@ class LedgerCrudSqlite(LedgerCrud): async def get_balance( self, + keyset: MintKeyset, db: Database, conn: Optional[Connection] = None, ) -> int: row = await (conn or db).fetchone( f""" - SELECT * from {db.table_with_schema('balance')} - """ + SELECT balance FROM {db.table_with_schema('balance')} + WHERE keyset = :keyset + """, + { + "keyset": keyset.id, + }, ) - assert row, "Balance not found" + + if row is None: + return 0 # sqlalchemy index of first element key = next(iter(row)) diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index 8a89ab6..27e55e1 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -322,9 +322,9 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe raise KeysetError("no public keys for this keyset") return {a: p.serialize().hex() for a, p in keyset.public_keys.items()} - async def get_balance(self) -> int: + async def get_balance(self, keyset: MintKeyset) -> int: """Returns the balance of the mint.""" - return await self.crud.get_balance(db=self.db) + return await self.crud.get_balance(keyset=keyset, db=self.db) # ------- ECASH ------- @@ -451,8 +451,13 @@ class Ledger(LedgerVerification, LedgerSpendingConditions, LedgerTasks, LedgerFe ): raise NotAllowedError("Backend does not support descriptions.") - if settings.mint_max_balance: - balance = await self.get_balance() + # MINT_MAX_BALANCE refers to sat (for now) + if settings.mint_max_balance and unit == Unit.sat: + # get next active keyset for unit + active_keyset: MintKeyset = next( + filter(lambda k: k.active and k.unit == unit, self.keysets.values()) + ) + balance = await self.get_balance(active_keyset) if balance + quote_request.amount > settings.mint_max_balance: raise NotAllowedError("Mint has reached maximum balance.") diff --git a/cashu/mint/migrations.py b/cashu/mint/migrations.py index b5bedd8..79f21dc 100644 --- a/cashu/mint/migrations.py +++ b/cashu/mint/migrations.py @@ -868,3 +868,43 @@ async def m025_add_amounts_to_keysets(db: Database): await conn.execute( f"UPDATE {db.table_with_schema('keysets')} SET amounts = '[]'" ) + + +async def m026_keyset_specific_balance_views(db: Database): + async with db.connect() as conn: + await drop_balance_views(db, conn) + await conn.execute( + f""" + CREATE VIEW {db.table_with_schema('balance_issued')} AS + SELECT id AS keyset, COALESCE(s, 0) AS balance FROM ( + SELECT id, SUM(amount) AS s + FROM {db.table_with_schema('promises')} + WHERE amount > 0 + GROUP BY id + ); + """ + ) + await conn.execute( + f""" + CREATE VIEW {db.table_with_schema('balance_redeemed')} AS + SELECT id AS keyset, COALESCE(s, 0) AS balance FROM ( + SELECT id, SUM(amount) AS s + FROM {db.table_with_schema('proofs_used')} + WHERE amount > 0 + GROUP BY id + ); + """ + ) + await conn.execute( + f""" + CREATE VIEW {db.table_with_schema('balance')} AS + SELECT keyset, s_issued - s_used AS balance FROM ( + SELECT bi.keyset AS keyset, + bi.balance AS s_issued, + COALESCE(bu.balance, 0) AS s_used + FROM {db.table_with_schema('balance_issued')} bi + LEFT OUTER JOIN {db.table_with_schema('balance_redeemed')} bu + ON bi.keyset = bu.keyset + ); + """ + ) diff --git a/tests/test_mint.py b/tests/test_mint.py index 9d09729..dc0b837 100644 --- a/tests/test_mint.py +++ b/tests/test_mint.py @@ -2,7 +2,7 @@ from typing import List import pytest -from cashu.core.base import BlindedMessage, Proof +from cashu.core.base import BlindedMessage, MintKeyset, Proof, Unit from cashu.core.crypto.b_dhke import step1_alice from cashu.core.helpers import calculate_number_of_blank_outputs from cashu.core.models import PostMintQuoteRequest @@ -218,7 +218,11 @@ async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledg @pytest.mark.asyncio async def test_get_balance(ledger: Ledger): - balance = await ledger.get_balance() + unit = Unit["sat"] + active_keyset: MintKeyset = next( + filter(lambda k: k.active and k.unit == unit, ledger.keysets.values()) + ) + balance = await ledger.get_balance(active_keyset) assert balance == 0