diff --git a/cashu/mint/app.py b/cashu/mint/app.py index 900d216..940da73 100644 --- a/cashu/mint/app.py +++ b/cashu/mint/app.py @@ -8,12 +8,28 @@ from cashu.core.settings import DEBUG, VERSION from starlette_context.middleware import RawContextMiddleware from starlette.middleware import Middleware +from starlette_context import context from .router import router from .startup import load_ledger +from starlette.middleware.base import BaseHTTPMiddleware + + +class CustomHeaderMiddleware(BaseHTTPMiddleware): + """ + Middleware for starlette that can set the context from request headers + """ + + async def dispatch(self, request, call_next): + context["version"] = request.headers.get("Client-version") + response = await call_next(request) + response.headers["Custom"] = "Example" + return response + + def create_app(config_object="core.settings") -> FastAPI: def configure_logger() -> None: class Formatter: @@ -53,6 +69,7 @@ def create_app(config_object="core.settings") -> FastAPI: Middleware( RawContextMiddleware, ), + Middleware(CustomHeaderMiddleware), ] app = FastAPI( diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index cd77872..874d64f 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -118,7 +118,7 @@ class Ledger: # backwards compatibility with old hash_to_curve # old clients do not send a version - if not context.get("version"): + if not context.get("client-version"): return legacy.verify_pre_0_3_3(secret_key, C, proof.secret) return b_dhke.verify(secret_key, C, proof.secret) diff --git a/cashu/mint/router.py b/cashu/mint/router.py index 8d95062..db64ea0 100644 --- a/cashu/mint/router.py +++ b/cashu/mint/router.py @@ -77,8 +77,6 @@ async def melt(request: Request, payload: MeltRequest): """ Requests tokens to be destroyed and sent out via Lightning. """ - context["version"] = request.headers.get("Client-version") - print(context["version"]) ok, preimage = await ledger.melt(payload.proofs, payload.invoice) resp = GetMeltResponse(paid=ok, preimage=preimage) return resp @@ -107,8 +105,6 @@ async def split(request: Request, payload: SplitRequest): Requetst a set of tokens with amount "total" to be split into two newly minted sets with amount "split" and "total-split". """ - context["version"] = request.headers.get("Client-version") - print(context["version"]) proofs = payload.proofs amount = payload.amount outputs = payload.outputs.blinded_messages if payload.outputs else None