mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
258 lines
9.8 KiB
Python
258 lines
9.8 KiB
Python
"""
|
|
LLM Service Data Models
|
|
|
|
Pydantic models for LLM requests and responses.
|
|
"""
|
|
|
|
from typing import Dict, List, Optional, Any, Union
|
|
from pydantic import BaseModel, Field, validator
|
|
from datetime import datetime
|
|
|
|
|
|
class ToolCall(BaseModel):
|
|
"""Tool call in a message"""
|
|
|
|
id: str = Field(..., description="Tool call identifier")
|
|
type: str = Field("function", description="Tool call type")
|
|
function: Dict[str, Any] = Field(..., description="Function call details")
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
"""Individual chat message"""
|
|
|
|
role: str = Field(..., description="Message role (system, user, assistant)")
|
|
content: Optional[str] = Field(None, description="Message content")
|
|
name: Optional[str] = Field(None, description="Optional message name")
|
|
tool_calls: Optional[List[ToolCall]] = Field(
|
|
None, description="Tool calls in this message"
|
|
)
|
|
tool_call_id: Optional[str] = Field(
|
|
None, description="Tool call ID for tool responses"
|
|
)
|
|
|
|
@validator("role")
|
|
def validate_role(cls, v):
|
|
allowed_roles = {"system", "user", "assistant", "function", "tool"}
|
|
if v not in allowed_roles:
|
|
raise ValueError(f"Role must be one of {allowed_roles}")
|
|
return v
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
"""Chat completion request"""
|
|
|
|
model: str = Field(..., description="Model identifier")
|
|
messages: List[ChatMessage] = Field(..., description="Chat messages")
|
|
temperature: Optional[float] = Field(
|
|
0.7, ge=0.0, le=2.0, description="Sampling temperature"
|
|
)
|
|
max_tokens: Optional[int] = Field(
|
|
None, ge=1, le=32000, description="Maximum tokens to generate"
|
|
)
|
|
top_p: Optional[float] = Field(
|
|
1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter"
|
|
)
|
|
top_k: Optional[int] = Field(None, ge=1, description="Top-k sampling parameter")
|
|
frequency_penalty: Optional[float] = Field(
|
|
0.0, ge=-2.0, le=2.0, description="Frequency penalty"
|
|
)
|
|
presence_penalty: Optional[float] = Field(
|
|
0.0, ge=-2.0, le=2.0, description="Presence penalty"
|
|
)
|
|
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences")
|
|
stream: Optional[bool] = Field(False, description="Stream response")
|
|
tools: Optional[List[Dict[str, Any]]] = Field(
|
|
None, description="Available tools for function calling"
|
|
)
|
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
|
None, description="Tool choice preference"
|
|
)
|
|
user_id: str = Field(..., description="User identifier")
|
|
api_key_id: int = Field(..., description="API key identifier")
|
|
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
|
|
|
@validator("messages")
|
|
def validate_messages(cls, v):
|
|
if not v:
|
|
raise ValueError("Messages cannot be empty")
|
|
return v
|
|
|
|
|
|
class TokenUsage(BaseModel):
|
|
"""Token usage information"""
|
|
|
|
prompt_tokens: int = Field(..., description="Tokens in the prompt")
|
|
completion_tokens: int = Field(..., description="Tokens in the completion")
|
|
total_tokens: int = Field(..., description="Total tokens used")
|
|
|
|
|
|
class ChatChoice(BaseModel):
|
|
"""Chat completion choice"""
|
|
|
|
index: int = Field(..., description="Choice index")
|
|
message: ChatMessage = Field(..., description="Generated message")
|
|
finish_reason: Optional[str] = Field(
|
|
None, description="Reason for completion finish"
|
|
)
|
|
|
|
|
|
class ChatResponse(BaseModel):
|
|
"""Chat completion response"""
|
|
|
|
id: str = Field(..., description="Response identifier")
|
|
object: str = Field("chat.completion", description="Object type")
|
|
created: int = Field(..., description="Creation timestamp")
|
|
model: str = Field(..., description="Model used")
|
|
provider: str = Field(..., description="Provider used")
|
|
choices: List[ChatChoice] = Field(..., description="Generated choices")
|
|
usage: Optional[TokenUsage] = Field(None, description="Token usage")
|
|
system_fingerprint: Optional[str] = Field(None, description="System fingerprint")
|
|
|
|
# Security fields maintained for backward compatibility
|
|
security_check: Optional[bool] = Field(
|
|
None, description="Whether security check passed"
|
|
)
|
|
risk_score: Optional[float] = Field(None, description="Security risk score")
|
|
detected_patterns: Optional[List[str]] = Field(
|
|
None, description="Detected security patterns"
|
|
)
|
|
|
|
# Performance metrics
|
|
latency_ms: Optional[float] = Field(
|
|
None, description="Response latency in milliseconds"
|
|
)
|
|
provider_latency_ms: Optional[float] = Field(
|
|
None, description="Provider-specific latency"
|
|
)
|
|
|
|
|
|
class EmbeddingRequest(BaseModel):
|
|
"""Embedding generation request"""
|
|
|
|
model: str = Field(..., description="Embedding model identifier")
|
|
input: Union[str, List[str]] = Field(..., description="Text to embed")
|
|
encoding_format: Optional[str] = Field("float", description="Encoding format")
|
|
dimensions: Optional[int] = Field(None, ge=1, description="Number of dimensions")
|
|
user_id: str = Field(..., description="User identifier")
|
|
api_key_id: int = Field(..., description="API key identifier")
|
|
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
|
|
|
@validator("input")
|
|
def validate_input(cls, v):
|
|
if isinstance(v, str):
|
|
if not v.strip():
|
|
raise ValueError("Input text cannot be empty")
|
|
elif isinstance(v, list):
|
|
if not v or not all(isinstance(item, str) and item.strip() for item in v):
|
|
raise ValueError(
|
|
"Input list cannot be empty and must contain non-empty strings"
|
|
)
|
|
return v
|
|
|
|
|
|
class EmbeddingData(BaseModel):
|
|
"""Single embedding data"""
|
|
|
|
object: str = Field("embedding", description="Object type")
|
|
index: int = Field(..., description="Embedding index")
|
|
embedding: List[float] = Field(..., description="Embedding vector")
|
|
|
|
|
|
class EmbeddingResponse(BaseModel):
|
|
"""Embedding generation response"""
|
|
|
|
object: str = Field("list", description="Object type")
|
|
data: List[EmbeddingData] = Field(..., description="Embedding data")
|
|
model: str = Field(..., description="Model used")
|
|
provider: str = Field(..., description="Provider used")
|
|
usage: Optional[TokenUsage] = Field(None, description="Token usage")
|
|
|
|
# Security fields maintained for backward compatibility
|
|
security_check: Optional[bool] = Field(
|
|
None, description="Whether security check passed"
|
|
)
|
|
risk_score: Optional[float] = Field(None, description="Security risk score")
|
|
detected_patterns: Optional[List[str]] = Field(
|
|
None, description="Detected security patterns"
|
|
)
|
|
|
|
# Performance metrics
|
|
latency_ms: Optional[float] = Field(
|
|
None, description="Response latency in milliseconds"
|
|
)
|
|
provider_latency_ms: Optional[float] = Field(
|
|
None, description="Provider-specific latency"
|
|
)
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""Model information"""
|
|
|
|
id: str = Field(..., description="Model identifier")
|
|
object: str = Field("model", description="Object type")
|
|
created: Optional[int] = Field(None, description="Creation timestamp")
|
|
owned_by: str = Field(..., description="Model owner")
|
|
provider: str = Field(..., description="Provider name")
|
|
capabilities: List[str] = Field(
|
|
default_factory=list, description="Model capabilities"
|
|
)
|
|
context_window: Optional[int] = Field(None, description="Context window size")
|
|
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
|
|
supports_streaming: bool = Field(
|
|
False, description="Whether model supports streaming"
|
|
)
|
|
supports_function_calling: bool = Field(
|
|
False, description="Whether model supports function calling"
|
|
)
|
|
tasks: Optional[List[str]] = Field(
|
|
None, description="Model tasks (e.g., generate, embed, vision)"
|
|
)
|
|
|
|
|
|
class ProviderStatus(BaseModel):
|
|
"""Provider health status"""
|
|
|
|
provider: str = Field(..., description="Provider name")
|
|
status: str = Field(..., description="Status (healthy, degraded, unavailable)")
|
|
latency_ms: Optional[float] = Field(None, description="Average latency")
|
|
success_rate: Optional[float] = Field(None, description="Success rate (0.0 to 1.0)")
|
|
last_check: datetime = Field(..., description="Last health check timestamp")
|
|
error_message: Optional[str] = Field(None, description="Error message if unhealthy")
|
|
models_available: List[str] = Field(
|
|
default_factory=list, description="Available models"
|
|
)
|
|
|
|
|
|
class LLMMetrics(BaseModel):
|
|
"""LLM service metrics"""
|
|
|
|
total_requests: int = Field(0, description="Total requests processed")
|
|
successful_requests: int = Field(0, description="Successful requests")
|
|
failed_requests: int = Field(0, description="Failed requests")
|
|
average_latency_ms: float = Field(0.0, description="Average response latency")
|
|
provider_metrics: Dict[str, Dict[str, Any]] = Field(
|
|
default_factory=dict, description="Per-provider metrics"
|
|
)
|
|
last_updated: datetime = Field(
|
|
default_factory=datetime.utcnow, description="Last metrics update"
|
|
)
|
|
|
|
|
|
class ResilienceConfig(BaseModel):
|
|
"""Configuration for resilience patterns"""
|
|
|
|
max_retries: int = Field(3, ge=0, le=10, description="Maximum retry attempts")
|
|
retry_delay_ms: int = Field(
|
|
1000, ge=100, le=30000, description="Initial retry delay"
|
|
)
|
|
retry_exponential_base: float = Field(
|
|
2.0, ge=1.1, le=5.0, description="Exponential backoff base"
|
|
)
|
|
timeout_ms: int = Field(30000, ge=1000, le=300000, description="Request timeout")
|
|
circuit_breaker_threshold: int = Field(
|
|
5, ge=1, le=50, description="Circuit breaker failure threshold"
|
|
)
|
|
circuit_breaker_reset_timeout_ms: int = Field(
|
|
60000, ge=10000, le=600000, description="Circuit breaker reset timeout"
|
|
)
|