Files
nutshell/cashu/mint/middleware.py
ok300 ee90d840ab Add HTTP compression middleware (#676)
* Add HTTP compression middleware

* Apply fixes from `make format`
2024-11-25 13:08:46 +01:00

102 lines
3.3 KiB
Python

import gzip
import zlib
import brotli
import zstandard as zstd
from fastapi import FastAPI, Request, Response
from fastapi.exception_handlers import (
request_validation_exception_handler as _request_validation_exception_handler,
)
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from ..core.settings import settings
from .limit import _rate_limit_exceeded_handler, limiter_global
if settings.debug_profiling:
from fastapi_profiler import PyInstrumentProfilerMiddleware
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
def add_middlewares(app: FastAPI):
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)
app.add_middleware(CompressionMiddleware)
if settings.debug_profiling:
assert PyInstrumentProfilerMiddleware is not None
app.add_middleware(PyInstrumentProfilerMiddleware)
if settings.mint_rate_limit:
app.state.limiter = limiter_global
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
async def request_validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
"""
This is a wrapper to the default RequestValidationException handler of FastAPI.
This function will be called when client input is not valid.
"""
query_params = request.query_params._dict
detail = {
"errors": exc.errors(),
"query_params": query_params,
}
# log the error
logger.error(detail)
# pass on
return await _request_validation_exception_handler(request, exc)
class CompressionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Handle streaming responses differently
if response.__class__.__name__ == 'StreamingResponse':
return response
response_body = b''
async for chunk in response.body_iterator:
response_body += chunk
accept_encoding = request.headers.get("Accept-Encoding", "")
content = response_body
if "br" in accept_encoding:
content = brotli.compress(content)
response.headers["Content-Encoding"] = "br"
elif "zstd" in accept_encoding:
compressor = zstd.ZstdCompressor()
content = compressor.compress(content)
response.headers["Content-Encoding"] = "zstd"
elif "gzip" in accept_encoding:
content = gzip.compress(content)
response.headers["Content-Encoding"] = "gzip"
elif "deflate" in accept_encoding:
content = zlib.compress(content)
response.headers["Content-Encoding"] = "deflate"
response.headers["Content-Length"] = str(len(content))
response.headers["Vary"] = "Accept-Encoding"
return Response(
content=content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)