set context version in http middleware

This commit is contained in:
callebtc
2022-10-10 21:36:29 +02:00
parent 0cd24fcf31
commit 13b9cd17bd
3 changed files with 18 additions and 5 deletions

View File

@@ -8,12 +8,28 @@ from cashu.core.settings import DEBUG, VERSION
from starlette_context.middleware import RawContextMiddleware
from starlette.middleware import Middleware
from starlette_context import context
from .router import router
from .startup import load_ledger
from starlette.middleware.base import BaseHTTPMiddleware
class CustomHeaderMiddleware(BaseHTTPMiddleware):
"""
Middleware for starlette that can set the context from request headers
"""
async def dispatch(self, request, call_next):
context["version"] = request.headers.get("Client-version")
response = await call_next(request)
response.headers["Custom"] = "Example"
return response
def create_app(config_object="core.settings") -> FastAPI:
def configure_logger() -> None:
class Formatter:
@@ -53,6 +69,7 @@ def create_app(config_object="core.settings") -> FastAPI:
Middleware(
RawContextMiddleware,
),
Middleware(CustomHeaderMiddleware),
]
app = FastAPI(

View File

@@ -118,7 +118,7 @@ class Ledger:
# backwards compatibility with old hash_to_curve
# old clients do not send a version
if not context.get("version"):
if not context.get("client-version"):
return legacy.verify_pre_0_3_3(secret_key, C, proof.secret)
return b_dhke.verify(secret_key, C, proof.secret)

View File

@@ -77,8 +77,6 @@ async def melt(request: Request, payload: MeltRequest):
"""
Requests tokens to be destroyed and sent out via Lightning.
"""
context["version"] = request.headers.get("Client-version")
print(context["version"])
ok, preimage = await ledger.melt(payload.proofs, payload.invoice)
resp = GetMeltResponse(paid=ok, preimage=preimage)
return resp
@@ -107,8 +105,6 @@ async def split(request: Request, payload: SplitRequest):
Requetst a set of tokens with amount "total" to be split into two
newly minted sets with amount "split" and "total-split".
"""
context["version"] = request.headers.get("Client-version")
print(context["version"])
proofs = payload.proofs
amount = payload.amount
outputs = payload.outputs.blinded_messages if payload.outputs else None