Files
enclava/backend/app/services/llm/providers/base.py
2025-08-22 18:02:37 +02:00

226 lines
7.5 KiB
Python

"""
Base LLM Provider Interface
Abstract base class for all LLM providers.
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, AsyncGenerator
import logging
from ..models import (
ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
ModelInfo, ProviderStatus
)
from ..config import ProviderConfig
logger = logging.getLogger(__name__)
class BaseLLMProvider(ABC):
"""Abstract base class for LLM providers"""
def __init__(self, config: ProviderConfig, api_key: str):
"""
Initialize provider
Args:
config: Provider configuration
api_key: Decrypted API key for the provider
"""
self.config = config
self.api_key = api_key
self.name = config.name
self._session = None
logger.info(f"Initializing {self.name} provider")
@property
@abstractmethod
def provider_name(self) -> str:
"""Get provider name"""
pass
@abstractmethod
async def health_check(self) -> ProviderStatus:
"""
Check provider health status
Returns:
ProviderStatus with current health information
"""
pass
@abstractmethod
async def get_models(self) -> List[ModelInfo]:
"""
Get list of available models
Returns:
List of available models with their capabilities
"""
pass
@abstractmethod
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
"""
Create chat completion
Args:
request: Chat completion request
Returns:
Chat completion response
Raises:
ProviderError: If provider-specific error occurs
SecurityError: If security validation fails
ValidationError: If request validation fails
"""
pass
@abstractmethod
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
"""
Create streaming chat completion
Args:
request: Chat completion request with stream=True
Yields:
Streaming response chunks
Raises:
ProviderError: If provider-specific error occurs
SecurityError: If security validation fails
ValidationError: If request validation fails
"""
pass
@abstractmethod
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
"""
Create embeddings
Args:
request: Embedding generation request
Returns:
Embedding response
Raises:
ProviderError: If provider-specific error occurs
SecurityError: If security validation fails
ValidationError: If request validation fails
"""
pass
async def initialize(self):
"""Initialize provider resources (override if needed)"""
pass
async def cleanup(self):
"""Cleanup provider resources"""
if self._session and hasattr(self._session, 'close'):
await self._session.close()
logger.debug(f"Cleaned up session for {self.name} provider")
def supports_model(self, model_name: str) -> bool:
"""Check if provider supports a specific model"""
return model_name in self.config.supported_models
def supports_capability(self, capability: str) -> bool:
"""Check if provider supports a specific capability"""
return capability in self.config.capabilities
def get_model_info(self, model_name: str) -> Optional[ModelInfo]:
"""Get information about a specific model (override for provider-specific info)"""
if not self.supports_model(model_name):
return None
return ModelInfo(
id=model_name,
object="model",
owned_by=self.name,
provider=self.name,
capabilities=self.config.capabilities,
context_window=self.config.max_context_window,
max_output_tokens=self.config.max_output_tokens,
supports_streaming=self.config.supports_streaming,
supports_function_calling=self.config.supports_function_calling
)
def _validate_request(self, request: Any):
"""Base request validation (override for provider-specific validation)"""
if hasattr(request, 'model') and not self.supports_model(request.model):
from ..exceptions import ValidationError
raise ValidationError(
f"Model '{request.model}' not supported by provider '{self.name}'",
field="model"
)
def _create_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""Create HTTP headers for requests"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
"User-Agent": f"Enclava-LLM-Service/{self.name}"
}
if additional_headers:
headers.update(additional_headers)
return headers
def _handle_http_error(self, status_code: int, response_text: str, provider_context: str = ""):
"""Handle HTTP errors consistently across providers"""
from ..exceptions import ProviderError, RateLimitError, ValidationError
context = f"{self.name} {provider_context}".strip()
if status_code == 401:
raise ProviderError(
f"Authentication failed for {context}",
provider=self.name,
error_code="AUTHENTICATION_ERROR",
details={"status_code": status_code, "response": response_text}
)
elif status_code == 403:
raise ProviderError(
f"Access forbidden for {context}",
provider=self.name,
error_code="AUTHORIZATION_ERROR",
details={"status_code": status_code, "response": response_text}
)
elif status_code == 429:
raise RateLimitError(
f"Rate limit exceeded for {context}",
error_code="RATE_LIMIT_ERROR",
details={"status_code": status_code, "response": response_text, "provider": self.name}
)
elif status_code == 400:
raise ValidationError(
f"Bad request for {context}: {response_text}",
error_code="BAD_REQUEST",
details={"status_code": status_code, "response": response_text}
)
elif 500 <= status_code < 600:
raise ProviderError(
f"Server error for {context}: {response_text}",
provider=self.name,
error_code="SERVER_ERROR",
details={"status_code": status_code, "response": response_text}
)
else:
raise ProviderError(
f"HTTP error {status_code} for {context}: {response_text}",
provider=self.name,
error_code="HTTP_ERROR",
details={"status_code": status_code, "response": response_text}
)
def __str__(self) -> str:
return f"{self.__class__.__name__}(name={self.name})"
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name}, enabled={self.config.enabled})"