Use PostRestoreRequest for all restore operations (#483)

* use PostRestoreRequest for all restore operations

* refactor: unit method verification
This commit is contained in:
callebtc
2024-03-21 22:59:47 +01:00
committed by GitHub
parent df2c81ee89
commit f4621345f3
8 changed files with 71 additions and 31 deletions

View File

@@ -535,6 +535,12 @@ class CheckFeesResponse_deprecated(BaseModel):
# ------- API: RESTORE ------- # ------- API: RESTORE -------
class PostRestoreRequest(BaseModel):
outputs: List[BlindedMessage] = Field(
..., max_items=settings.mint_max_request_length
)
class PostRestoreResponse(BaseModel): class PostRestoreResponse(BaseModel):
outputs: List[BlindedMessage] = [] outputs: List[BlindedMessage] = []
signatures: List[BlindedSignature] = [] signatures: List[BlindedSignature] = []

View File

@@ -325,8 +325,11 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
) )
if settings.mint_peg_out_only: if settings.mint_peg_out_only:
raise NotAllowedError("Mint does not allow minting new tokens.") raise NotAllowedError("Mint does not allow minting new tokens.")
unit = Unit[quote_request.unit]
method = Method.bolt11 unit, method = self._verify_and_get_unit_method(
quote_request.unit, Method.bolt11.name
)
if settings.mint_max_balance: if settings.mint_max_balance:
balance = await self.get_balance() balance = await self.get_balance()
if balance + quote_request.amount > settings.mint_max_balance: if balance + quote_request.amount > settings.mint_max_balance:
@@ -387,10 +390,10 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
MintQuote: Mint quote object. MintQuote: Mint quote object.
""" """
quote = await self.crud.get_mint_quote(quote_id=quote_id, db=self.db) quote = await self.crud.get_mint_quote(quote_id=quote_id, db=self.db)
assert quote, "quote not found" if not quote:
assert quote.method == Method.bolt11.name, "only bolt11 supported" raise Exception("quote not found")
unit = Unit[quote.unit]
method = Method[quote.method] unit, method = self._verify_and_get_unit_method(quote.unit, quote.method)
if not quote.paid: if not quote.paid:
assert quote.checking_id, "quote has no checking id" assert quote.checking_id, "quote has no checking id"
@@ -471,8 +474,10 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
Returns: Returns:
PostMeltQuoteResponse: Melt quote response. PostMeltQuoteResponse: Melt quote response.
""" """
unit = Unit[melt_quote.unit] unit, method = self._verify_and_get_unit_method(
method = Method.bolt11 melt_quote.unit, Method.bolt11.name
)
# NOTE: we normalize the request to lowercase to avoid case sensitivity # NOTE: we normalize the request to lowercase to avoid case sensitivity
# This works with Lightning but might not work with other methods # This works with Lightning but might not work with other methods
request = melt_quote.request.lower() request = melt_quote.request.lower()
@@ -557,10 +562,12 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
MeltQuote: Melt quote object. MeltQuote: Melt quote object.
""" """
melt_quote = await self.crud.get_melt_quote(quote_id=quote_id, db=self.db) melt_quote = await self.crud.get_melt_quote(quote_id=quote_id, db=self.db)
assert melt_quote, "quote not found" if not melt_quote:
assert melt_quote.method == Method.bolt11.name, "only bolt11 supported" raise Exception("quote not found")
unit = Unit[melt_quote.unit]
method = Method[melt_quote.method] unit, method = self._verify_and_get_unit_method(
melt_quote.unit, melt_quote.method
)
# we only check the state with the backend if there is no associated internal # we only check the state with the backend if there is no associated internal
# mint quote for this melt quote # mint quote for this melt quote
@@ -664,8 +671,11 @@ class Ledger(LedgerVerification, LedgerSpendingConditions):
""" """
# get melt quote and check if it was already paid # get melt quote and check if it was already paid
melt_quote = await self.get_melt_quote(quote_id=quote) melt_quote = await self.get_melt_quote(quote_id=quote)
method = Method[melt_quote.method]
unit = Unit[melt_quote.unit] unit, method = self._verify_and_get_unit_method(
melt_quote.unit, melt_quote.method
)
assert not melt_quote.paid, "melt quote already paid" assert not melt_quote.paid, "melt quote already paid"
# make sure that the outputs (for fee return) are in the same unit as the quote # make sure that the outputs (for fee return) are in the same unit as the quote

View File

@@ -1,6 +1,6 @@
from typing import Dict, Protocol from typing import Dict, Mapping, Protocol
from ..core.base import MintKeyset, Unit from ..core.base import Method, MintKeyset, Unit
from ..core.db import Database from ..core.db import Database
from ..lightning.base import LightningBackend from ..lightning.base import LightningBackend
from ..mint.crud import LedgerCrud from ..mint.crud import LedgerCrud
@@ -11,8 +11,8 @@ class SupportsKeysets(Protocol):
keysets: Dict[str, MintKeyset] keysets: Dict[str, MintKeyset]
class SupportLightning(Protocol): class SupportsBackends(Protocol):
lightning: Dict[Unit, LightningBackend] backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {}
class SupportsDb(Protocol): class SupportsDb(Protocol):

View File

@@ -20,6 +20,7 @@ from ..core.base import (
PostMintQuoteResponse, PostMintQuoteResponse,
PostMintRequest, PostMintRequest,
PostMintResponse, PostMintResponse,
PostRestoreRequest,
PostRestoreResponse, PostRestoreResponse,
PostSplitRequest, PostSplitRequest,
PostSplitResponse, PostSplitResponse,
@@ -358,14 +359,14 @@ async def check_state(
@router.post( @router.post(
"/v1/restore", "/v1/restore",
name="Restore", name="Restore",
summary="Restores a blinded signature from a secret", summary="Restores blind signature for a set of outputs.",
response_model=PostRestoreResponse, response_model=PostRestoreResponse,
response_description=( response_description=(
"Two lists with the first being the list of the provided outputs that " "Two lists with the first being the list of the provided outputs that "
"have an associated blinded signature which is given in the second list." "have an associated blinded signature which is given in the second list."
), ),
) )
async def restore(payload: PostMintRequest) -> PostRestoreResponse: async def restore(payload: PostRestoreRequest) -> PostRestoreResponse:
assert payload.outputs, Exception("no outputs provided.") assert payload.outputs, Exception("no outputs provided.")
outputs, signatures = await ledger.restore(payload.outputs) outputs, signatures = await ledger.restore(payload.outputs)
return PostRestoreResponse(outputs=outputs, signatures=signatures) return PostRestoreResponse(outputs=outputs, signatures=signatures)

View File

@@ -19,6 +19,7 @@ from ..core.base import (
PostMintQuoteRequest, PostMintQuoteRequest,
PostMintRequest_deprecated, PostMintRequest_deprecated,
PostMintResponse_deprecated, PostMintResponse_deprecated,
PostRestoreRequest,
PostRestoreResponse, PostRestoreResponse,
PostSplitRequest_Deprecated, PostSplitRequest_Deprecated,
PostSplitResponse_Deprecated, PostSplitResponse_Deprecated,
@@ -357,7 +358,7 @@ async def check_spendable_deprecated(
), ),
deprecated=True, deprecated=True,
) )
async def restore(payload: PostMintRequest_deprecated) -> PostRestoreResponse: async def restore(payload: PostRestoreRequest) -> PostRestoreResponse:
assert payload.outputs, Exception("no outputs provided.") assert payload.outputs, Exception("no outputs provided.")
outputs, promises = await ledger.restore(payload.outputs) outputs, promises = await ledger.restore(payload.outputs)
return PostRestoreResponse(outputs=outputs, signatures=promises) return PostRestoreResponse(outputs=outputs, signatures=promises)

View File

@@ -1,12 +1,14 @@
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Tuple, Union
from loguru import logger from loguru import logger
from ..core.base import ( from ..core.base import (
BlindedMessage, BlindedMessage,
BlindedSignature, BlindedSignature,
Method,
MintKeyset, MintKeyset,
Proof, Proof,
Unit,
) )
from ..core.crypto import b_dhke from ..core.crypto import b_dhke
from ..core.crypto.secp import PublicKey from ..core.crypto.secp import PublicKey
@@ -19,12 +21,15 @@ from ..core.errors import (
TransactionError, TransactionError,
) )
from ..core.settings import settings from ..core.settings import settings
from ..lightning.base import LightningBackend
from ..mint.crud import LedgerCrud from ..mint.crud import LedgerCrud
from .conditions import LedgerSpendingConditions from .conditions import LedgerSpendingConditions
from .protocols import SupportsDb, SupportsKeysets from .protocols import SupportsBackends, SupportsDb, SupportsKeysets
class LedgerVerification(LedgerSpendingConditions, SupportsKeysets, SupportsDb): class LedgerVerification(
LedgerSpendingConditions, SupportsKeysets, SupportsDb, SupportsBackends
):
"""Verification functions for the ledger.""" """Verification functions for the ledger."""
keyset: MintKeyset keyset: MintKeyset
@@ -32,6 +37,7 @@ class LedgerVerification(LedgerSpendingConditions, SupportsKeysets, SupportsDb):
spent_proofs: Dict[str, Proof] spent_proofs: Dict[str, Proof]
crud: LedgerCrud crud: LedgerCrud
db: Database db: Database
lightning: Dict[Unit, LightningBackend]
async def verify_inputs_and_outputs( async def verify_inputs_and_outputs(
self, *, proofs: List[Proof], outputs: Optional[List[BlindedMessage]] = None self, *, proofs: List[Proof], outputs: Optional[List[BlindedMessage]] = None
@@ -240,6 +246,22 @@ class LedgerVerification(LedgerSpendingConditions, SupportsKeysets, SupportsDb):
""" """
sum_inputs = sum(self._verify_amount(p.amount) for p in proofs) sum_inputs = sum(self._verify_amount(p.amount) for p in proofs)
sum_outputs = sum(self._verify_amount(p.amount) for p in outs) sum_outputs = sum(self._verify_amount(p.amount) for p in outs)
assert sum_outputs - sum_inputs == 0, TransactionError( if not sum_outputs - sum_inputs == 0:
"inputs do not have same amount as outputs." raise TransactionError("inputs do not have same amount as outputs.")
)
def _verify_and_get_unit_method(
self, unit_str: str, method_str: str
) -> Tuple[Unit, Method]:
"""Verify that the unit is supported by the ledger."""
method = Method[method_str]
unit = Unit[unit_str]
if not any([unit == k.unit for k in self.keysets.values()]):
raise NotAllowedError(f"unit '{unit.name}' not supported in any keyset.")
if not self.backends.get(method) or unit not in self.backends[method]:
raise NotAllowedError(
f"no support for method '{method.name}' with unit '{unit.name}'."
)
return unit, method

View File

@@ -8,7 +8,7 @@ from cashu.core.base import (
MintMeltMethodSetting, MintMeltMethodSetting,
PostCheckStateRequest, PostCheckStateRequest,
PostCheckStateResponse, PostCheckStateResponse,
PostMintRequest, PostRestoreRequest,
PostRestoreResponse, PostRestoreResponse,
SpentState, SpentState,
) )
@@ -430,7 +430,7 @@ async def test_api_restore(ledger: Ledger, wallet: Wallet):
) )
outputs, rs = wallet._construct_outputs([64], secrets, rs) outputs, rs = wallet._construct_outputs([64], secrets, rs)
payload = PostMintRequest(outputs=outputs, quote="placeholder") payload = PostRestoreRequest(outputs=outputs)
response = httpx.post( response = httpx.post(
f"{BASE_URL}/v1/restore", f"{BASE_URL}/v1/restore",
json=payload.dict(), json=payload.dict(),

View File

@@ -5,7 +5,7 @@ import pytest_asyncio
from cashu.core.base import ( from cashu.core.base import (
CheckSpendableRequest_deprecated, CheckSpendableRequest_deprecated,
CheckSpendableResponse_deprecated, CheckSpendableResponse_deprecated,
PostMintRequest, PostRestoreRequest,
PostRestoreResponse, PostRestoreResponse,
Proof, Proof,
) )
@@ -340,7 +340,7 @@ async def test_api_restore(ledger: Ledger, wallet: Wallet):
) )
outputs, rs = wallet._construct_outputs([64], secrets, rs) outputs, rs = wallet._construct_outputs([64], secrets, rs)
payload = PostMintRequest(outputs=outputs, quote="placeholder") payload = PostRestoreRequest(outputs=outputs)
response = httpx.post( response = httpx.post(
f"{BASE_URL}/restore", f"{BASE_URL}/restore",
json=payload.dict(), json=payload.dict(),