Files
enclava/backend/app/services/llm/config.py
2025-08-26 15:09:13 +02:00

275 lines
11 KiB
Python

"""
LLM Service Configuration
Configuration management for LLM providers and service settings.
"""
import os
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field, validator
from dataclasses import dataclass
from app.core.config import settings
from .models import ResilienceConfig
class ProviderConfig(BaseModel):
"""Configuration for an LLM provider"""
name: str = Field(..., description="Provider name")
enabled: bool = Field(True, description="Whether provider is enabled")
base_url: str = Field(..., description="Provider base URL")
api_key_env_var: str = Field(..., description="Environment variable for API key")
default_model: Optional[str] = Field(None, description="Default model for this provider")
supported_models: List[str] = Field(default_factory=list, description="List of supported models")
capabilities: List[str] = Field(default_factory=list, description="Provider capabilities")
priority: int = Field(1, description="Provider priority (lower = higher priority)")
# Rate limiting
max_requests_per_minute: Optional[int] = Field(None, description="Max requests per minute")
max_requests_per_hour: Optional[int] = Field(None, description="Max requests per hour")
# Model-specific settings
supports_streaming: bool = Field(False, description="Whether provider supports streaming")
supports_function_calling: bool = Field(False, description="Whether provider supports function calling")
max_context_window: Optional[int] = Field(None, description="Maximum context window size")
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
# Resilience configuration
resilience: ResilienceConfig = Field(default_factory=ResilienceConfig, description="Resilience settings")
@validator('priority')
def validate_priority(cls, v):
if v < 1:
raise ValueError("Priority must be >= 1")
return v
class LLMServiceConfig(BaseModel):
"""Main LLM service configuration"""
# Global settings
default_provider: str = Field("privatemode", description="Default provider to use")
enable_detailed_logging: bool = Field(False, description="Enable detailed request/response logging")
enable_security_checks: bool = Field(True, description="Enable security validation")
enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
# Security settings
security_risk_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Risk threshold for blocking")
security_warning_threshold: float = Field(0.6, ge=0.0, le=1.0, description="Risk threshold for warnings")
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
# Performance settings
default_timeout_ms: int = Field(30000, ge=1000, le=300000, description="Default request timeout")
max_concurrent_requests: int = Field(100, ge=1, le=1000, description="Maximum concurrent requests")
# Provider configurations
providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations")
# Model routing (model_name -> provider_name)
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
@validator('security_risk_threshold')
def validate_risk_threshold(cls, v, values):
warning_threshold = values.get('security_warning_threshold', 0.6)
if v <= warning_threshold:
raise ValueError("Risk threshold must be greater than warning threshold")
return v
def create_default_config() -> LLMServiceConfig:
"""Create default LLM service configuration"""
# PrivateMode.ai configuration (via proxy)
# Models will be fetched dynamically from proxy /models endpoint
privatemode_config = ProviderConfig(
name="privatemode",
enabled=True,
base_url=settings.PRIVATEMODE_PROXY_URL,
api_key_env_var="PRIVATEMODE_API_KEY",
default_model="privatemode-latest",
supported_models=[], # Will be populated dynamically from proxy
capabilities=["chat", "embeddings", "tee"],
priority=1,
max_requests_per_minute=100,
max_requests_per_hour=2000,
supports_streaming=True,
supports_function_calling=True,
max_context_window=128000,
max_output_tokens=8192,
resilience=ResilienceConfig(
max_retries=3,
retry_delay_ms=1000,
timeout_ms=60000, # PrivateMode may be slower due to TEE
circuit_breaker_threshold=5,
circuit_breaker_reset_timeout_ms=120000
)
)
# Create main configuration
config = LLMServiceConfig(
default_provider="privatemode",
enable_detailed_logging=settings.LOG_LLM_PROMPTS,
enable_security_checks=settings.API_SECURITY_ENABLED,
security_risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
security_warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
providers={
"privatemode": privatemode_config
},
model_routing={} # Will be populated dynamically from provider models
)
return config
@dataclass
class EnvironmentVariables:
"""Environment variables used by LLM service"""
# Provider API keys
PRIVATEMODE_API_KEY: Optional[str] = None
OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
GOOGLE_API_KEY: Optional[str] = None
# Service settings
LOG_LLM_PROMPTS: bool = False
def __post_init__(self):
"""Load values from environment"""
self.PRIVATEMODE_API_KEY = os.getenv("PRIVATEMODE_API_KEY")
self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
self.ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
self.GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
self.LOG_LLM_PROMPTS = os.getenv("LOG_LLM_PROMPTS", "false").lower() == "true"
def get_api_key(self, provider_name: str) -> Optional[str]:
"""Get API key for a specific provider"""
key_mapping = {
"privatemode": self.PRIVATEMODE_API_KEY,
"openai": self.OPENAI_API_KEY,
"anthropic": self.ANTHROPIC_API_KEY,
"google": self.GOOGLE_API_KEY
}
return key_mapping.get(provider_name.lower())
def validate_required_keys(self, enabled_providers: List[str]) -> List[str]:
"""Validate that required API keys are present"""
missing_keys = []
for provider in enabled_providers:
if not self.get_api_key(provider):
missing_keys.append(f"{provider.upper()}_API_KEY")
return missing_keys
class ConfigurationManager:
"""Manages LLM service configuration"""
def __init__(self):
self._config: Optional[LLMServiceConfig] = None
self._env_vars = EnvironmentVariables()
def get_config(self) -> LLMServiceConfig:
"""Get current configuration"""
if self._config is None:
self._config = create_default_config()
self._validate_configuration()
return self._config
def update_config(self, config: LLMServiceConfig):
"""Update configuration"""
self._config = config
self._validate_configuration()
def get_provider_config(self, provider_name: str) -> Optional[ProviderConfig]:
"""Get configuration for a specific provider"""
config = self.get_config()
return config.providers.get(provider_name)
def get_provider_for_model(self, model_name: str) -> Optional[str]:
"""Get provider name for a specific model"""
config = self.get_config()
return config.model_routing.get(model_name)
def get_enabled_providers(self) -> List[str]:
"""Get list of enabled providers"""
config = self.get_config()
return [name for name, provider in config.providers.items() if provider.enabled]
def get_api_key(self, provider_name: str) -> Optional[str]:
"""Get API key for provider"""
return self._env_vars.get_api_key(provider_name)
def _validate_configuration(self):
"""Validate current configuration"""
if not self._config:
return
# Check for enabled providers without API keys
enabled_providers = self.get_enabled_providers()
missing_keys = self._env_vars.validate_required_keys(enabled_providers)
if missing_keys:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Missing API keys for enabled providers: {', '.join(missing_keys)}")
# Validate default provider is enabled
default_provider = self._config.default_provider
if default_provider not in enabled_providers:
raise ValueError(f"Default provider '{default_provider}' is not enabled")
# Validate model routing points to enabled providers
invalid_routes = []
for model, provider in self._config.model_routing.items():
if provider not in enabled_providers:
invalid_routes.append(f"{model} -> {provider}")
if invalid_routes:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Model routes point to disabled providers: {', '.join(invalid_routes)}")
async def refresh_provider_models(self, provider_name: str, models: List[str]):
"""Update supported models for a provider dynamically"""
if not self._config:
return
provider_config = self._config.providers.get(provider_name)
if not provider_config:
return
# Update supported models
provider_config.supported_models = models
# Update model routing - map all models to this provider
for model in models:
self._config.model_routing[model] = provider_name
import logging
logger = logging.getLogger(__name__)
logger.info(f"Updated {provider_name} with {len(models)} models: {models}")
async def get_all_available_models(self) -> Dict[str, List[str]]:
"""Get all available models grouped by provider"""
config = self.get_config()
models_by_provider = {}
for provider_name, provider_config in config.providers.items():
if provider_config.enabled:
models_by_provider[provider_name] = provider_config.supported_models
return models_by_provider
def get_model_provider_mapping(self) -> Dict[str, str]:
"""Get current model to provider mapping"""
config = self.get_config()
return config.model_routing.copy()
# Global configuration manager
config_manager = ConfigurationManager()