mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
226 lines
7.5 KiB
Python
226 lines
7.5 KiB
Python
"""
|
|
Base LLM Provider Interface
|
|
|
|
Abstract base class for all LLM providers.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
|
import logging
|
|
|
|
from ..models import (
|
|
ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
|
|
ModelInfo, ProviderStatus
|
|
)
|
|
from ..config import ProviderConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseLLMProvider(ABC):
|
|
"""Abstract base class for LLM providers"""
|
|
|
|
def __init__(self, config: ProviderConfig, api_key: str):
|
|
"""
|
|
Initialize provider
|
|
|
|
Args:
|
|
config: Provider configuration
|
|
api_key: Decrypted API key for the provider
|
|
"""
|
|
self.config = config
|
|
self.api_key = api_key
|
|
self.name = config.name
|
|
self._session = None
|
|
|
|
logger.info(f"Initializing {self.name} provider")
|
|
|
|
@property
|
|
@abstractmethod
|
|
def provider_name(self) -> str:
|
|
"""Get provider name"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def health_check(self) -> ProviderStatus:
|
|
"""
|
|
Check provider health status
|
|
|
|
Returns:
|
|
ProviderStatus with current health information
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_models(self) -> List[ModelInfo]:
|
|
"""
|
|
Get list of available models
|
|
|
|
Returns:
|
|
List of available models with their capabilities
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
|
|
"""
|
|
Create chat completion
|
|
|
|
Args:
|
|
request: Chat completion request
|
|
|
|
Returns:
|
|
Chat completion response
|
|
|
|
Raises:
|
|
ProviderError: If provider-specific error occurs
|
|
SecurityError: If security validation fails
|
|
ValidationError: If request validation fails
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""
|
|
Create streaming chat completion
|
|
|
|
Args:
|
|
request: Chat completion request with stream=True
|
|
|
|
Yields:
|
|
Streaming response chunks
|
|
|
|
Raises:
|
|
ProviderError: If provider-specific error occurs
|
|
SecurityError: If security validation fails
|
|
ValidationError: If request validation fails
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
|
"""
|
|
Create embeddings
|
|
|
|
Args:
|
|
request: Embedding generation request
|
|
|
|
Returns:
|
|
Embedding response
|
|
|
|
Raises:
|
|
ProviderError: If provider-specific error occurs
|
|
SecurityError: If security validation fails
|
|
ValidationError: If request validation fails
|
|
"""
|
|
pass
|
|
|
|
async def initialize(self):
|
|
"""Initialize provider resources (override if needed)"""
|
|
pass
|
|
|
|
async def cleanup(self):
|
|
"""Cleanup provider resources"""
|
|
if self._session and hasattr(self._session, 'close'):
|
|
await self._session.close()
|
|
logger.debug(f"Cleaned up session for {self.name} provider")
|
|
|
|
def supports_model(self, model_name: str) -> bool:
|
|
"""Check if provider supports a specific model"""
|
|
return model_name in self.config.supported_models
|
|
|
|
def supports_capability(self, capability: str) -> bool:
|
|
"""Check if provider supports a specific capability"""
|
|
return capability in self.config.capabilities
|
|
|
|
def get_model_info(self, model_name: str) -> Optional[ModelInfo]:
|
|
"""Get information about a specific model (override for provider-specific info)"""
|
|
if not self.supports_model(model_name):
|
|
return None
|
|
|
|
return ModelInfo(
|
|
id=model_name,
|
|
object="model",
|
|
owned_by=self.name,
|
|
provider=self.name,
|
|
capabilities=self.config.capabilities,
|
|
context_window=self.config.max_context_window,
|
|
max_output_tokens=self.config.max_output_tokens,
|
|
supports_streaming=self.config.supports_streaming,
|
|
supports_function_calling=self.config.supports_function_calling
|
|
)
|
|
|
|
def _validate_request(self, request: Any):
|
|
"""Base request validation (override for provider-specific validation)"""
|
|
if hasattr(request, 'model') and not self.supports_model(request.model):
|
|
from ..exceptions import ValidationError
|
|
raise ValidationError(
|
|
f"Model '{request.model}' not supported by provider '{self.name}'",
|
|
field="model"
|
|
)
|
|
|
|
def _create_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
|
"""Create HTTP headers for requests"""
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"User-Agent": f"Enclava-LLM-Service/{self.name}"
|
|
}
|
|
|
|
if additional_headers:
|
|
headers.update(additional_headers)
|
|
|
|
return headers
|
|
|
|
def _handle_http_error(self, status_code: int, response_text: str, provider_context: str = ""):
|
|
"""Handle HTTP errors consistently across providers"""
|
|
from ..exceptions import ProviderError, RateLimitError, ValidationError
|
|
|
|
context = f"{self.name} {provider_context}".strip()
|
|
|
|
if status_code == 401:
|
|
raise ProviderError(
|
|
f"Authentication failed for {context}",
|
|
provider=self.name,
|
|
error_code="AUTHENTICATION_ERROR",
|
|
details={"status_code": status_code, "response": response_text}
|
|
)
|
|
elif status_code == 403:
|
|
raise ProviderError(
|
|
f"Access forbidden for {context}",
|
|
provider=self.name,
|
|
error_code="AUTHORIZATION_ERROR",
|
|
details={"status_code": status_code, "response": response_text}
|
|
)
|
|
elif status_code == 429:
|
|
raise RateLimitError(
|
|
f"Rate limit exceeded for {context}",
|
|
error_code="RATE_LIMIT_ERROR",
|
|
details={"status_code": status_code, "response": response_text, "provider": self.name}
|
|
)
|
|
elif status_code == 400:
|
|
raise ValidationError(
|
|
f"Bad request for {context}: {response_text}",
|
|
error_code="BAD_REQUEST",
|
|
details={"status_code": status_code, "response": response_text}
|
|
)
|
|
elif 500 <= status_code < 600:
|
|
raise ProviderError(
|
|
f"Server error for {context}: {response_text}",
|
|
provider=self.name,
|
|
error_code="SERVER_ERROR",
|
|
details={"status_code": status_code, "response": response_text}
|
|
)
|
|
else:
|
|
raise ProviderError(
|
|
f"HTTP error {status_code} for {context}: {response_text}",
|
|
provider=self.name,
|
|
error_code="HTTP_ERROR",
|
|
details={"status_code": status_code, "response": response_text}
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.__class__.__name__}(name={self.name})"
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(name={self.name}, enabled={self.config.enabled})" |