mirror of
https://github.com/aljazceru/nutshell.git
synced 2025-12-24 03:54:21 +01:00
Add HTTP compression middleware (#676)
* Add HTTP compression middleware * Apply fixes from `make format`
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
from fastapi import FastAPI
|
||||
import gzip
|
||||
import zlib
|
||||
|
||||
import brotli
|
||||
import zstandard as zstd
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.exception_handlers import (
|
||||
request_validation_exception_handler as _request_validation_exception_handler,
|
||||
)
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
|
||||
from ..core.settings import settings
|
||||
from .limit import _rate_limit_exceeded_handler, limiter_global
|
||||
@@ -26,6 +31,7 @@ def add_middlewares(app: FastAPI):
|
||||
allow_headers=["*"],
|
||||
expose_headers=["*"],
|
||||
)
|
||||
app.add_middleware(CompressionMiddleware)
|
||||
|
||||
if settings.debug_profiling:
|
||||
assert PyInstrumentProfilerMiddleware is not None
|
||||
@@ -53,3 +59,43 @@ async def request_validation_exception_handler(
|
||||
logger.error(detail)
|
||||
# pass on
|
||||
return await _request_validation_exception_handler(request, exc)
|
||||
|
||||
|
||||
class CompressionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
|
||||
# Handle streaming responses differently
|
||||
if response.__class__.__name__ == 'StreamingResponse':
|
||||
return response
|
||||
|
||||
response_body = b''
|
||||
async for chunk in response.body_iterator:
|
||||
response_body += chunk
|
||||
|
||||
accept_encoding = request.headers.get("Accept-Encoding", "")
|
||||
content = response_body
|
||||
|
||||
if "br" in accept_encoding:
|
||||
content = brotli.compress(content)
|
||||
response.headers["Content-Encoding"] = "br"
|
||||
elif "zstd" in accept_encoding:
|
||||
compressor = zstd.ZstdCompressor()
|
||||
content = compressor.compress(content)
|
||||
response.headers["Content-Encoding"] = "zstd"
|
||||
elif "gzip" in accept_encoding:
|
||||
content = gzip.compress(content)
|
||||
response.headers["Content-Encoding"] = "gzip"
|
||||
elif "deflate" in accept_encoding:
|
||||
content = zlib.compress(content)
|
||||
response.headers["Content-Encoding"] = "deflate"
|
||||
|
||||
response.headers["Content-Length"] = str(len(content))
|
||||
response.headers["Vary"] = "Accept-Encoding"
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user