diff --git a/backend/app/api/v1/llm.py b/backend/app/api/v1/llm.py index 799b165..c30d797 100644 --- a/backend/app/api/v1/llm.py +++ b/backend/app/api/v1/llm.py @@ -57,15 +57,24 @@ async def get_cached_models() -> List[Dict[str, Any]]: # Convert ModelInfo objects to dict format for compatibility models = [] for model_info in model_infos: - models.append({ + model_dict = { "id": model_info.id, "object": model_info.object, "created": model_info.created or int(time.time()), "owned_by": model_info.owned_by, # Add frontend-expected fields "name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id - "provider": getattr(model_info, 'provider', model_info.owned_by) # Use provider if available, fallback to owned_by - }) + "provider": getattr(model_info, 'provider', model_info.owned_by), # Use provider if available, fallback to owned_by + "capabilities": model_info.capabilities, + "context_window": model_info.context_window, + "max_output_tokens": model_info.max_output_tokens, + "supports_streaming": model_info.supports_streaming, + "supports_function_calling": model_info.supports_function_calling + } + # Include tasks field if present + if model_info.tasks: + model_dict["tasks"] = model_info.tasks + models.append(model_dict) # Update cache _models_cache["data"] = models diff --git a/backend/app/services/llm/models.py b/backend/app/services/llm/models.py index 9cd4ace..903451d 100644 --- a/backend/app/services/llm/models.py +++ b/backend/app/services/llm/models.py @@ -138,6 +138,7 @@ class ModelInfo(BaseModel): max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens") supports_streaming: bool = Field(False, description="Whether model supports streaming") supports_function_calling: bool = Field(False, description="Whether model supports function calling") + tasks: Optional[List[str]] = Field(None, description="Model tasks (e.g., generate, embed, vision)") class ProviderStatus(BaseModel): diff --git a/backend/app/services/llm/providers/privatemode.py b/backend/app/services/llm/providers/privatemode.py index 71e0d30..63f18ad 100644 --- a/backend/app/services/llm/providers/privatemode.py +++ b/backend/app/services/llm/providers/privatemode.py @@ -134,18 +134,39 @@ class PrivateModeProvider(BaseLLMProvider): if not model_id: continue - # Extract model capabilities from PrivateMode response + # Extract all information directly from API response + # Determine capabilities based on tasks field + tasks = model_data.get("tasks", []) + capabilities = [] + + # All PrivateMode models have TEE capability + capabilities.append("tee") + + # Add capabilities based on tasks + if "generate" in tasks: + capabilities.append("chat") + if "embed" in tasks or "embedding" in tasks: + capabilities.append("embeddings") + if "vision" in tasks: + capabilities.append("vision") + + # Check for function calling support in the API response + supports_function_calling = model_data.get("supports_function_calling", False) + if supports_function_calling: + capabilities.append("function_calling") + model_info = ModelInfo( id=model_id, object="model", created=model_data.get("created", int(time.time())), - owned_by="privatemode", + owned_by=model_data.get("owned_by", "privatemode"), provider=self.provider_name, - capabilities=self._get_model_capabilities(model_id), - context_window=self._get_model_context_window(model_id), - max_output_tokens=self._get_model_max_output(model_id), - supports_streaming=True, # PrivateMode supports streaming - supports_function_calling=self._supports_function_calling(model_id) + capabilities=capabilities, + context_window=model_data.get("context_window"), + max_output_tokens=model_data.get("max_output_tokens"), + supports_streaming=model_data.get("supports_streaming", True), + supports_function_calling=supports_function_calling, + tasks=tasks # Pass through tasks field from PrivateMode API ) models.append(model_info) @@ -453,68 +474,6 @@ class PrivateModeProvider(BaseLLMProvider): details={"error": str(e)} ) - def _get_model_capabilities(self, model_id: str) -> List[str]: - """Get capabilities for a specific model""" - capabilities = ["chat"] - - # PrivateMode supports embeddings for most models - if "embed" in model_id.lower() or model_id in [ - "privatemode-llama-3.1-405b", "privatemode-llama-3.1-70b", - "privatemode-claude-3.5-sonnet", "privatemode-gpt-4o" - ]: - capabilities.append("embeddings") - - # TEE protection is available for all PrivateMode models - capabilities.append("tee") - - return capabilities - - def _get_model_context_window(self, model_id: str) -> Optional[int]: - """Get context window size for a specific model""" - context_windows = { - "privatemode-llama-3.1-405b": 128000, - "privatemode-llama-3.1-70b": 128000, - "privatemode-llama-3.1-8b": 128000, - "privatemode-llama-3-70b": 8192, - "privatemode-llama-3-8b": 8192, - "privatemode-claude-3.5-sonnet": 200000, - "privatemode-claude-3-haiku": 200000, - "privatemode-gpt-4o": 128000, - "privatemode-gpt-4o-mini": 128000, - "privatemode-gemini-1.5-pro": 2000000, - "privatemode-gemini-1.5-flash": 1000000 - } - - return context_windows.get(model_id, 8192) # Default to 8K - - def _get_model_max_output(self, model_id: str) -> Optional[int]: - """Get max output tokens for a specific model""" - max_outputs = { - "privatemode-llama-3.1-405b": 8192, - "privatemode-llama-3.1-70b": 8192, - "privatemode-llama-3.1-8b": 8192, - "privatemode-llama-3-70b": 4096, - "privatemode-llama-3-8b": 4096, - "privatemode-claude-3.5-sonnet": 8192, - "privatemode-claude-3-haiku": 4096, - "privatemode-gpt-4o": 16384, - "privatemode-gpt-4o-mini": 16384, - "privatemode-gemini-1.5-pro": 8192, - "privatemode-gemini-1.5-flash": 8192 - } - - return max_outputs.get(model_id, 4096) # Default to 4K - - def _supports_function_calling(self, model_id: str) -> bool: - """Check if model supports function calling""" - function_calling_models = [ - "privatemode-gpt-4o", "privatemode-gpt-4o-mini", - "privatemode-claude-3.5-sonnet", "privatemode-claude-3-haiku", - "privatemode-gemini-1.5-pro", "privatemode-gemini-1.5-flash" - ] - - return model_id in function_calling_models - async def cleanup(self): """Cleanup PrivateMode provider resources""" await super().cleanup() diff --git a/frontend/src/app/api-keys/page.tsx b/frontend/src/app/api-keys/page.tsx index 61b6045..e62e2c7 100644 --- a/frontend/src/app/api-keys/page.tsx +++ b/frontend/src/app/api-keys/page.tsx @@ -73,6 +73,7 @@ interface Model { max_output_tokens?: number; supports_streaming?: boolean; supports_function_calling?: boolean; + tasks?: string[]; // Added tasks field from PrivateMode API } interface NewApiKeyData { diff --git a/frontend/src/components/chatbot/ChatbotManager.tsx b/frontend/src/components/chatbot/ChatbotManager.tsx index b99e1d5..66340e0 100644 --- a/frontend/src/components/chatbot/ChatbotManager.tsx +++ b/frontend/src/components/chatbot/ChatbotManager.tsx @@ -623,7 +623,7 @@ export function ChatbotManager() { {ragCollections.map((collection) => ( -
+
{collection.name}
{collection.document_count} documents @@ -889,7 +889,7 @@ export function ChatbotManager() { {ragCollections.map((collection) => ( -
+
{collection.name}
{collection.document_count} documents diff --git a/frontend/src/components/playground/EmbeddingPlayground.tsx b/frontend/src/components/playground/EmbeddingPlayground.tsx index 9425291..045d80a 100644 --- a/frontend/src/components/playground/EmbeddingPlayground.tsx +++ b/frontend/src/components/playground/EmbeddingPlayground.tsx @@ -1,6 +1,6 @@ "use client" -import { useState } from 'react' +import { useState, useEffect } from 'react' import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' import { Button } from '@/components/ui/button' import { Textarea } from '@/components/ui/textarea' @@ -14,6 +14,23 @@ import { Download, Zap, Calculator, BarChart3, AlertCircle } from 'lucide-react' import { useToast } from '@/hooks/use-toast' import { apiClient } from '@/lib/api-client' +interface Model { + id: string + object: string + created?: number + owned_by?: string + permission?: any[] + root?: string + parent?: string + provider?: string + capabilities?: string[] + context_window?: number + max_output_tokens?: number + supports_streaming?: boolean + supports_function_calling?: boolean + tasks?: string[] +} + interface EmbeddingResult { text: string embedding: number[] @@ -31,7 +48,7 @@ interface SessionStats { export default function EmbeddingPlayground() { const [text, setText] = useState('') - const [model, setModel] = useState('text-embedding-ada-002') + const [model, setModel] = useState('') const [encodingFormat, setEncodingFormat] = useState('float') const [isLoading, setIsLoading] = useState(false) const [results, setResults] = useState([]) @@ -43,8 +60,51 @@ export default function EmbeddingPlayground() { }) const [selectedResult, setSelectedResult] = useState(null) const [comparisonMode, setComparisonMode] = useState(false) + const [embeddingModels, setEmbeddingModels] = useState([]) + const [loadingModels, setLoadingModels] = useState(true) const { toast } = useToast() + // Fetch available embedding models + useEffect(() => { + const fetchModels = async () => { + try { + setLoadingModels(true) + const response = await apiClient.get('/api-internal/v1/llm/models') + + if (response.data) { + // Filter models that support embeddings based on tasks field + const models = response.data.filter((model: Model) => { + // Check if model has embed or embedding in tasks + if (model.tasks && Array.isArray(model.tasks)) { + return model.tasks.includes('embed') || model.tasks.includes('embedding') + } + // Fallback: check if model ID contains embedding patterns + const modelId = model.id.toLowerCase() + return modelId.includes('embed') || modelId.includes('text-embedding') + }) + + setEmbeddingModels(models) + + // Set default model if available + if (models.length > 0 && !model) { + setModel(models[0].id) + } + } + } catch (error) { + console.error('Failed to fetch models:', error) + toast({ + title: "Error", + description: "Failed to load embedding models", + variant: "destructive" + }) + } finally { + setLoadingModels(false) + } + } + + fetchModels() + }, []) + const handleGenerateEmbedding = async () => { if (!text.trim()) { toast({ @@ -55,6 +115,15 @@ export default function EmbeddingPlayground() { return } + if (!model) { + toast({ + title: "Error", + description: "Please select an embedding model", + variant: "destructive" + }) + return + } + setIsLoading(true) try { const data = await apiClient.post('/api-internal/v1/llm/embeddings', { @@ -93,13 +162,34 @@ export default function EmbeddingPlayground() { } } - const calculateCost = (tokens: number, model: string): number => { + const calculateCost = (tokens: number, modelId: string): number => { + // Known rates for common embedding models const rates: { [key: string]: number } = { 'text-embedding-ada-002': 0.0001, 'text-embedding-3-small': 0.00002, - 'text-embedding-3-large': 0.00013 + 'text-embedding-3-large': 0.00013, + 'privatemode-text-embedding-ada-002': 0.0001, + 'privatemode-text-embedding-3-small': 0.00002, + 'privatemode-text-embedding-3-large': 0.00013 } - return (tokens / 1000) * (rates[model] || 0.0001) + + // Check for exact match first + if (rates[modelId]) { + return (tokens / 1000) * rates[modelId] + } + + // Check for pattern matches (e.g., if model contains these patterns) + const modelLower = modelId.toLowerCase() + if (modelLower.includes('ada-002')) { + return (tokens / 1000) * 0.0001 + } else if (modelLower.includes('3-small')) { + return (tokens / 1000) * 0.00002 + } else if (modelLower.includes('3-large')) { + return (tokens / 1000) * 0.00013 + } + + // Default rate for unknown models + return (tokens / 1000) * 0.0001 } const updateSessionStats = (result: EmbeddingResult) => { @@ -157,14 +247,27 @@ export default function EmbeddingPlayground() {
- - + - text-embedding-ada-002 - text-embedding-3-small - text-embedding-3-large + {embeddingModels.length === 0 && !loadingModels ? ( + + No embedding models available + + ) : ( + embeddingModels.map((embModel) => ( + + {embModel.id} + {embModel.owned_by && embModel.owned_by !== 'unknown' && ( + + ({embModel.owned_by}) + + )} + + )) + )}
diff --git a/frontend/src/components/playground/ModelSelector.tsx b/frontend/src/components/playground/ModelSelector.tsx index 8bcfc4e..37fcab4 100644 --- a/frontend/src/components/playground/ModelSelector.tsx +++ b/frontend/src/components/playground/ModelSelector.tsx @@ -23,6 +23,7 @@ interface Model { max_output_tokens?: number supports_streaming?: boolean supports_function_calling?: boolean + tasks?: string[] // Added tasks field from PrivateMode API } interface ProviderStatus { @@ -97,11 +98,21 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl return 'Unknown' } - const getModelType = (modelId: string): 'chat' | 'embedding' | 'other' => { + const getModelType = (model: Model): 'chat' | 'embedding' | 'other' => { + // Check if model has tasks field from PrivateMode or other providers + if (model.tasks && Array.isArray(model.tasks)) { + // Models with "generate" task are chat models + if (model.tasks.includes('generate')) return 'chat' + // Models with "embed" task are embedding models + if (model.tasks.includes('embed') || model.tasks.includes('embedding')) return 'embedding' + } + + // Fallback to ID-based detection for models without tasks field + const modelId = model.id if (modelId.includes('embedding') || modelId.includes('embed')) return 'embedding' if (modelId.includes('whisper') || modelId.includes('speech')) return 'other' // Audio models - // PrivateMode and other chat models + // PrivateMode and other chat models by ID pattern if ( modelId.startsWith('privatemode-llama') || modelId.startsWith('privatemode-claude') || @@ -114,14 +125,16 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl modelId.includes('llama') || modelId.includes('gemma') || modelId.includes('qwen') || + modelId.includes('mistral') || + modelId.includes('command') || modelId.includes('latest') ) return 'chat' return 'other' } - const getModelCategory = (modelId: string): string => { - const type = getModelType(modelId) + const getModelCategory = (model: Model): string => { + const type = getModelType(model) switch (type) { case 'chat': return 'Chat Completion' case 'embedding': return 'Text Embedding' @@ -132,7 +145,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl const filteredModels = models.filter(model => { if (filter === 'all') return true - return getModelType(model.id) === filter + return getModelType(model) === filter }) const groupedModels = filteredModels.reduce((acc, model) => { @@ -255,7 +268,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl {model.id}
- {getModelCategory(model.id)} + {getModelCategory(model)} {model.supports_streaming && ( @@ -299,7 +312,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
Type: -
{getModelCategory(selectedModel.id)}
+
{getModelCategory(selectedModel)}
Object: