mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
clean commit
This commit is contained in:
160
backend/app/api/v1/openai_compat.py
Normal file
160
backend/app/api/v1/openai_compat.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
OpenAI-compatible API endpoints
|
||||
Following the exact OpenAI API specification for compatibility with OpenAI clients
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.api.v1.llm import (
|
||||
get_auth_context, get_cached_models, ModelsResponse, ModelInfo,
|
||||
ChatCompletionRequest, EmbeddingRequest, create_chat_completion as llm_chat_completion,
|
||||
create_embedding as llm_create_embedding
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def openai_error_response(message: str, error_type: str = "invalid_request_error",
|
||||
status_code: int = 400, code: str = None):
|
||||
"""Create OpenAI-compatible error response"""
|
||||
error_data = {
|
||||
"error": {
|
||||
"message": message,
|
||||
"type": error_type,
|
||||
}
|
||||
}
|
||||
if code:
|
||||
error_data["error"]["code"] = code
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=error_data
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=ModelsResponse)
|
||||
async def list_models(
|
||||
context: Dict[str, Any] = Depends(get_auth_context),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Lists the currently available models, and provides basic information about each one
|
||||
such as the owner and availability.
|
||||
|
||||
This endpoint follows the exact OpenAI API specification:
|
||||
GET /v1/models
|
||||
"""
|
||||
try:
|
||||
# Delegate to the existing LLM models endpoint
|
||||
from app.api.v1.llm import list_models as llm_list_models
|
||||
return await llm_list_models(context, db)
|
||||
except HTTPException as e:
|
||||
# Convert FastAPI HTTPException to OpenAI format
|
||||
if e.status_code == 401:
|
||||
return openai_error_response(
|
||||
"Invalid authentication credentials",
|
||||
"authentication_error",
|
||||
401
|
||||
)
|
||||
elif e.status_code == 403:
|
||||
return openai_error_response(
|
||||
"Insufficient permissions",
|
||||
"permission_error",
|
||||
403
|
||||
)
|
||||
else:
|
||||
return openai_error_response(str(e.detail), "api_error", e.status_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in OpenAI models endpoint: {e}")
|
||||
return openai_error_response(
|
||||
"Internal server error",
|
||||
"api_error",
|
||||
500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
async def create_chat_completion(
|
||||
request_body: Request,
|
||||
chat_request: ChatCompletionRequest,
|
||||
context: Dict[str, Any] = Depends(get_auth_context),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Create chat completion - OpenAI compatible endpoint
|
||||
|
||||
This endpoint follows the exact OpenAI API specification:
|
||||
POST /v1/chat/completions
|
||||
"""
|
||||
# Delegate to the existing LLM chat completions endpoint
|
||||
return await llm_chat_completion(request_body, chat_request, context, db)
|
||||
|
||||
|
||||
@router.post("/embeddings")
|
||||
async def create_embedding(
|
||||
request: EmbeddingRequest,
|
||||
context: Dict[str, Any] = Depends(get_auth_context),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Create embedding - OpenAI compatible endpoint
|
||||
|
||||
This endpoint follows the exact OpenAI API specification:
|
||||
POST /v1/embeddings
|
||||
"""
|
||||
# Delegate to the existing LLM embeddings endpoint
|
||||
return await llm_create_embedding(request, context, db)
|
||||
|
||||
|
||||
@router.get("/models/{model_id}")
|
||||
async def retrieve_model(
|
||||
model_id: str,
|
||||
context: Dict[str, Any] = Depends(get_auth_context),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve model information - OpenAI compatible endpoint
|
||||
|
||||
This endpoint follows the exact OpenAI API specification:
|
||||
GET /v1/models/{model}
|
||||
"""
|
||||
try:
|
||||
# Get all models and find the specific one
|
||||
models = await get_cached_models()
|
||||
|
||||
# Filter models based on API key permissions
|
||||
api_key = context.get("api_key")
|
||||
if api_key and api_key.allowed_models:
|
||||
models = [model for model in models if model.get("id") in api_key.allowed_models]
|
||||
|
||||
# Find the specific model
|
||||
model = next((m for m in models if m.get("id") == model_id), None)
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model '{model_id}' not found"
|
||||
)
|
||||
|
||||
return ModelInfo(
|
||||
id=model.get("id", model_id),
|
||||
object="model",
|
||||
created=model.get("created", 0),
|
||||
owned_by=model.get("owned_by", "system")
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving model {model_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve model information"
|
||||
)
|
||||
Reference in New Issue
Block a user