""" LLM Service Configuration Configuration management for LLM providers and service settings. """ import os from typing import Dict, List, Optional, Any from pydantic import BaseModel, Field, validator from dataclasses import dataclass from app.core.config import settings from .models import ResilienceConfig class ProviderConfig(BaseModel): """Configuration for an LLM provider""" name: str = Field(..., description="Provider name") enabled: bool = Field(True, description="Whether provider is enabled") base_url: str = Field(..., description="Provider base URL") api_key_env_var: str = Field(..., description="Environment variable for API key") default_model: Optional[str] = Field(None, description="Default model for this provider") supported_models: List[str] = Field(default_factory=list, description="List of supported models") capabilities: List[str] = Field(default_factory=list, description="Provider capabilities") priority: int = Field(1, description="Provider priority (lower = higher priority)") # Rate limiting max_requests_per_minute: Optional[int] = Field(None, description="Max requests per minute") max_requests_per_hour: Optional[int] = Field(None, description="Max requests per hour") # Model-specific settings supports_streaming: bool = Field(False, description="Whether provider supports streaming") supports_function_calling: bool = Field(False, description="Whether provider supports function calling") max_context_window: Optional[int] = Field(None, description="Maximum context window size") max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens") # Resilience configuration resilience: ResilienceConfig = Field(default_factory=ResilienceConfig, description="Resilience settings") @validator('priority') def validate_priority(cls, v): if v < 1: raise ValueError("Priority must be >= 1") return v class LLMServiceConfig(BaseModel): """Main LLM service configuration""" # Global settings default_provider: str = Field("privatemode", description="Default provider to use") enable_detailed_logging: bool = Field(False, description="Enable detailed request/response logging") enable_security_checks: bool = Field(True, description="Enable security validation") enable_metrics_collection: bool = Field(True, description="Enable metrics collection") # Security settings security_risk_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Risk threshold for blocking") security_warning_threshold: float = Field(0.6, ge=0.0, le=1.0, description="Risk threshold for warnings") max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length") max_response_length: int = Field(32000, ge=1000, description="Maximum response length") # Performance settings default_timeout_ms: int = Field(30000, ge=1000, le=300000, description="Default request timeout") max_concurrent_requests: int = Field(100, ge=1, le=1000, description="Maximum concurrent requests") # Provider configurations providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations") # Model routing (model_name -> provider_name) model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing") @validator('security_risk_threshold') def validate_risk_threshold(cls, v, values): warning_threshold = values.get('security_warning_threshold', 0.6) if v <= warning_threshold: raise ValueError("Risk threshold must be greater than warning threshold") return v def create_default_config() -> LLMServiceConfig: """Create default LLM service configuration""" # PrivateMode.ai configuration (via proxy) # Models will be fetched dynamically from proxy /models endpoint privatemode_config = ProviderConfig( name="privatemode", enabled=True, base_url=settings.PRIVATEMODE_PROXY_URL, api_key_env_var="PRIVATEMODE_API_KEY", default_model="privatemode-latest", supported_models=[], # Will be populated dynamically from proxy capabilities=["chat", "embeddings", "tee"], priority=1, max_requests_per_minute=100, max_requests_per_hour=2000, supports_streaming=True, supports_function_calling=True, max_context_window=128000, max_output_tokens=8192, resilience=ResilienceConfig( max_retries=3, retry_delay_ms=1000, timeout_ms=60000, # PrivateMode may be slower due to TEE circuit_breaker_threshold=5, circuit_breaker_reset_timeout_ms=120000 ) ) # Create main configuration config = LLMServiceConfig( default_provider="privatemode", enable_detailed_logging=settings.LOG_LLM_PROMPTS, enable_security_checks=settings.API_SECURITY_ENABLED, security_risk_threshold=settings.API_SECURITY_RISK_THRESHOLD, security_warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD, providers={ "privatemode": privatemode_config }, model_routing={} # Will be populated dynamically from provider models ) return config @dataclass class EnvironmentVariables: """Environment variables used by LLM service""" # Provider API keys PRIVATEMODE_API_KEY: Optional[str] = None OPENAI_API_KEY: Optional[str] = None ANTHROPIC_API_KEY: Optional[str] = None GOOGLE_API_KEY: Optional[str] = None # Service settings LOG_LLM_PROMPTS: bool = False def __post_init__(self): """Load values from environment""" self.PRIVATEMODE_API_KEY = os.getenv("PRIVATEMODE_API_KEY") self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") self.ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") self.GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") self.LOG_LLM_PROMPTS = os.getenv("LOG_LLM_PROMPTS", "false").lower() == "true" def get_api_key(self, provider_name: str) -> Optional[str]: """Get API key for a specific provider""" key_mapping = { "privatemode": self.PRIVATEMODE_API_KEY, "openai": self.OPENAI_API_KEY, "anthropic": self.ANTHROPIC_API_KEY, "google": self.GOOGLE_API_KEY } return key_mapping.get(provider_name.lower()) def validate_required_keys(self, enabled_providers: List[str]) -> List[str]: """Validate that required API keys are present""" missing_keys = [] for provider in enabled_providers: if not self.get_api_key(provider): missing_keys.append(f"{provider.upper()}_API_KEY") return missing_keys class ConfigurationManager: """Manages LLM service configuration""" def __init__(self): self._config: Optional[LLMServiceConfig] = None self._env_vars = EnvironmentVariables() def get_config(self) -> LLMServiceConfig: """Get current configuration""" if self._config is None: self._config = create_default_config() self._validate_configuration() return self._config def update_config(self, config: LLMServiceConfig): """Update configuration""" self._config = config self._validate_configuration() def get_provider_config(self, provider_name: str) -> Optional[ProviderConfig]: """Get configuration for a specific provider""" config = self.get_config() return config.providers.get(provider_name) def get_provider_for_model(self, model_name: str) -> Optional[str]: """Get provider name for a specific model""" config = self.get_config() return config.model_routing.get(model_name) def get_enabled_providers(self) -> List[str]: """Get list of enabled providers""" config = self.get_config() return [name for name, provider in config.providers.items() if provider.enabled] def get_api_key(self, provider_name: str) -> Optional[str]: """Get API key for provider""" return self._env_vars.get_api_key(provider_name) def _validate_configuration(self): """Validate current configuration""" if not self._config: return # Check for enabled providers without API keys enabled_providers = self.get_enabled_providers() missing_keys = self._env_vars.validate_required_keys(enabled_providers) if missing_keys: import logging logger = logging.getLogger(__name__) logger.warning(f"Missing API keys for enabled providers: {', '.join(missing_keys)}") # Validate default provider is enabled default_provider = self._config.default_provider if default_provider not in enabled_providers: raise ValueError(f"Default provider '{default_provider}' is not enabled") # Validate model routing points to enabled providers invalid_routes = [] for model, provider in self._config.model_routing.items(): if provider not in enabled_providers: invalid_routes.append(f"{model} -> {provider}") if invalid_routes: import logging logger = logging.getLogger(__name__) logger.warning(f"Model routes point to disabled providers: {', '.join(invalid_routes)}") async def refresh_provider_models(self, provider_name: str, models: List[str]): """Update supported models for a provider dynamically""" if not self._config: return provider_config = self._config.providers.get(provider_name) if not provider_config: return # Update supported models provider_config.supported_models = models # Update model routing - map all models to this provider for model in models: self._config.model_routing[model] = provider_name import logging logger = logging.getLogger(__name__) logger.info(f"Updated {provider_name} with {len(models)} models: {models}") async def get_all_available_models(self) -> Dict[str, List[str]]: """Get all available models grouped by provider""" config = self.get_config() models_by_provider = {} for provider_name, provider_config in config.providers.items(): if provider_config.enabled: models_by_provider[provider_name] = provider_config.supported_models return models_by_provider def get_model_provider_mapping(self) -> Dict[str, str]: """Get current model to provider mapping""" config = self.get_config() return config.model_routing.copy() # Global configuration manager config_manager = ConfigurationManager()