Add HTTP compression middleware (#676)

* Add HTTP compression middleware

* Apply fixes from `make format`
This commit is contained in:
ok300
2024-11-25 12:08:46 +00:00
committed by GitHub
parent 2b233fd67e
commit ee90d840ab
3 changed files with 298 additions and 4 deletions

View File

@@ -1,12 +1,17 @@
from fastapi import FastAPI
import gzip
import zlib
import brotli
import zstandard as zstd
from fastapi import FastAPI, Request, Response
from fastapi.exception_handlers import (
request_validation_exception_handler as _request_validation_exception_handler,
)
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from ..core.settings import settings
from .limit import _rate_limit_exceeded_handler, limiter_global
@@ -26,6 +31,7 @@ def add_middlewares(app: FastAPI):
allow_headers=["*"],
expose_headers=["*"],
)
app.add_middleware(CompressionMiddleware)
if settings.debug_profiling:
assert PyInstrumentProfilerMiddleware is not None
@@ -53,3 +59,43 @@ async def request_validation_exception_handler(
logger.error(detail)
# pass on
return await _request_validation_exception_handler(request, exc)
class CompressionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Handle streaming responses differently
if response.__class__.__name__ == 'StreamingResponse':
return response
response_body = b''
async for chunk in response.body_iterator:
response_body += chunk
accept_encoding = request.headers.get("Accept-Encoding", "")
content = response_body
if "br" in accept_encoding:
content = brotli.compress(content)
response.headers["Content-Encoding"] = "br"
elif "zstd" in accept_encoding:
compressor = zstd.ZstdCompressor()
content = compressor.compress(content)
response.headers["Content-Encoding"] = "zstd"
elif "gzip" in accept_encoding:
content = gzip.compress(content)
response.headers["Content-Encoding"] = "gzip"
elif "deflate" in accept_encoding:
content = zlib.compress(content)
response.headers["Content-Encoding"] = "deflate"
response.headers["Content-Length"] = str(len(content))
response.headers["Vary"] = "Accept-Encoding"
return Response(
content=content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)