mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
embedding models picker improvement
This commit is contained in:
@@ -57,15 +57,24 @@ async def get_cached_models() -> List[Dict[str, Any]]:
|
||||
# Convert ModelInfo objects to dict format for compatibility
|
||||
models = []
|
||||
for model_info in model_infos:
|
||||
models.append({
|
||||
model_dict = {
|
||||
"id": model_info.id,
|
||||
"object": model_info.object,
|
||||
"created": model_info.created or int(time.time()),
|
||||
"owned_by": model_info.owned_by,
|
||||
# Add frontend-expected fields
|
||||
"name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id
|
||||
"provider": getattr(model_info, 'provider', model_info.owned_by) # Use provider if available, fallback to owned_by
|
||||
})
|
||||
"provider": getattr(model_info, 'provider', model_info.owned_by), # Use provider if available, fallback to owned_by
|
||||
"capabilities": model_info.capabilities,
|
||||
"context_window": model_info.context_window,
|
||||
"max_output_tokens": model_info.max_output_tokens,
|
||||
"supports_streaming": model_info.supports_streaming,
|
||||
"supports_function_calling": model_info.supports_function_calling
|
||||
}
|
||||
# Include tasks field if present
|
||||
if model_info.tasks:
|
||||
model_dict["tasks"] = model_info.tasks
|
||||
models.append(model_dict)
|
||||
|
||||
# Update cache
|
||||
_models_cache["data"] = models
|
||||
|
||||
@@ -138,6 +138,7 @@ class ModelInfo(BaseModel):
|
||||
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
|
||||
supports_streaming: bool = Field(False, description="Whether model supports streaming")
|
||||
supports_function_calling: bool = Field(False, description="Whether model supports function calling")
|
||||
tasks: Optional[List[str]] = Field(None, description="Model tasks (e.g., generate, embed, vision)")
|
||||
|
||||
|
||||
class ProviderStatus(BaseModel):
|
||||
|
||||
@@ -134,18 +134,39 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# Extract model capabilities from PrivateMode response
|
||||
# Extract all information directly from API response
|
||||
# Determine capabilities based on tasks field
|
||||
tasks = model_data.get("tasks", [])
|
||||
capabilities = []
|
||||
|
||||
# All PrivateMode models have TEE capability
|
||||
capabilities.append("tee")
|
||||
|
||||
# Add capabilities based on tasks
|
||||
if "generate" in tasks:
|
||||
capabilities.append("chat")
|
||||
if "embed" in tasks or "embedding" in tasks:
|
||||
capabilities.append("embeddings")
|
||||
if "vision" in tasks:
|
||||
capabilities.append("vision")
|
||||
|
||||
# Check for function calling support in the API response
|
||||
supports_function_calling = model_data.get("supports_function_calling", False)
|
||||
if supports_function_calling:
|
||||
capabilities.append("function_calling")
|
||||
|
||||
model_info = ModelInfo(
|
||||
id=model_id,
|
||||
object="model",
|
||||
created=model_data.get("created", int(time.time())),
|
||||
owned_by="privatemode",
|
||||
owned_by=model_data.get("owned_by", "privatemode"),
|
||||
provider=self.provider_name,
|
||||
capabilities=self._get_model_capabilities(model_id),
|
||||
context_window=self._get_model_context_window(model_id),
|
||||
max_output_tokens=self._get_model_max_output(model_id),
|
||||
supports_streaming=True, # PrivateMode supports streaming
|
||||
supports_function_calling=self._supports_function_calling(model_id)
|
||||
capabilities=capabilities,
|
||||
context_window=model_data.get("context_window"),
|
||||
max_output_tokens=model_data.get("max_output_tokens"),
|
||||
supports_streaming=model_data.get("supports_streaming", True),
|
||||
supports_function_calling=supports_function_calling,
|
||||
tasks=tasks # Pass through tasks field from PrivateMode API
|
||||
)
|
||||
models.append(model_info)
|
||||
|
||||
@@ -453,68 +474,6 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
def _get_model_capabilities(self, model_id: str) -> List[str]:
|
||||
"""Get capabilities for a specific model"""
|
||||
capabilities = ["chat"]
|
||||
|
||||
# PrivateMode supports embeddings for most models
|
||||
if "embed" in model_id.lower() or model_id in [
|
||||
"privatemode-llama-3.1-405b", "privatemode-llama-3.1-70b",
|
||||
"privatemode-claude-3.5-sonnet", "privatemode-gpt-4o"
|
||||
]:
|
||||
capabilities.append("embeddings")
|
||||
|
||||
# TEE protection is available for all PrivateMode models
|
||||
capabilities.append("tee")
|
||||
|
||||
return capabilities
|
||||
|
||||
def _get_model_context_window(self, model_id: str) -> Optional[int]:
|
||||
"""Get context window size for a specific model"""
|
||||
context_windows = {
|
||||
"privatemode-llama-3.1-405b": 128000,
|
||||
"privatemode-llama-3.1-70b": 128000,
|
||||
"privatemode-llama-3.1-8b": 128000,
|
||||
"privatemode-llama-3-70b": 8192,
|
||||
"privatemode-llama-3-8b": 8192,
|
||||
"privatemode-claude-3.5-sonnet": 200000,
|
||||
"privatemode-claude-3-haiku": 200000,
|
||||
"privatemode-gpt-4o": 128000,
|
||||
"privatemode-gpt-4o-mini": 128000,
|
||||
"privatemode-gemini-1.5-pro": 2000000,
|
||||
"privatemode-gemini-1.5-flash": 1000000
|
||||
}
|
||||
|
||||
return context_windows.get(model_id, 8192) # Default to 8K
|
||||
|
||||
def _get_model_max_output(self, model_id: str) -> Optional[int]:
|
||||
"""Get max output tokens for a specific model"""
|
||||
max_outputs = {
|
||||
"privatemode-llama-3.1-405b": 8192,
|
||||
"privatemode-llama-3.1-70b": 8192,
|
||||
"privatemode-llama-3.1-8b": 8192,
|
||||
"privatemode-llama-3-70b": 4096,
|
||||
"privatemode-llama-3-8b": 4096,
|
||||
"privatemode-claude-3.5-sonnet": 8192,
|
||||
"privatemode-claude-3-haiku": 4096,
|
||||
"privatemode-gpt-4o": 16384,
|
||||
"privatemode-gpt-4o-mini": 16384,
|
||||
"privatemode-gemini-1.5-pro": 8192,
|
||||
"privatemode-gemini-1.5-flash": 8192
|
||||
}
|
||||
|
||||
return max_outputs.get(model_id, 4096) # Default to 4K
|
||||
|
||||
def _supports_function_calling(self, model_id: str) -> bool:
|
||||
"""Check if model supports function calling"""
|
||||
function_calling_models = [
|
||||
"privatemode-gpt-4o", "privatemode-gpt-4o-mini",
|
||||
"privatemode-claude-3.5-sonnet", "privatemode-claude-3-haiku",
|
||||
"privatemode-gemini-1.5-pro", "privatemode-gemini-1.5-flash"
|
||||
]
|
||||
|
||||
return model_id in function_calling_models
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup PrivateMode provider resources"""
|
||||
await super().cleanup()
|
||||
|
||||
@@ -73,6 +73,7 @@ interface Model {
|
||||
max_output_tokens?: number;
|
||||
supports_streaming?: boolean;
|
||||
supports_function_calling?: boolean;
|
||||
tasks?: string[]; // Added tasks field from PrivateMode API
|
||||
}
|
||||
|
||||
interface NewApiKeyData {
|
||||
|
||||
@@ -623,7 +623,7 @@ export function ChatbotManager() {
|
||||
<SelectContent>
|
||||
{ragCollections.map((collection) => (
|
||||
<SelectItem key={collection.id} value={collection.id}>
|
||||
<div>
|
||||
<div className="text-foreground">
|
||||
<div className="font-medium">{collection.name}</div>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{collection.document_count} documents
|
||||
@@ -889,7 +889,7 @@ export function ChatbotManager() {
|
||||
<SelectContent>
|
||||
{ragCollections.map((collection) => (
|
||||
<SelectItem key={collection.id} value={collection.id}>
|
||||
<div>
|
||||
<div className="text-foreground">
|
||||
<div className="font-medium">{collection.name}</div>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{collection.document_count} documents
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client"
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useState, useEffect } from 'react'
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Textarea } from '@/components/ui/textarea'
|
||||
@@ -14,6 +14,23 @@ import { Download, Zap, Calculator, BarChart3, AlertCircle } from 'lucide-react'
|
||||
import { useToast } from '@/hooks/use-toast'
|
||||
import { apiClient } from '@/lib/api-client'
|
||||
|
||||
interface Model {
|
||||
id: string
|
||||
object: string
|
||||
created?: number
|
||||
owned_by?: string
|
||||
permission?: any[]
|
||||
root?: string
|
||||
parent?: string
|
||||
provider?: string
|
||||
capabilities?: string[]
|
||||
context_window?: number
|
||||
max_output_tokens?: number
|
||||
supports_streaming?: boolean
|
||||
supports_function_calling?: boolean
|
||||
tasks?: string[]
|
||||
}
|
||||
|
||||
interface EmbeddingResult {
|
||||
text: string
|
||||
embedding: number[]
|
||||
@@ -31,7 +48,7 @@ interface SessionStats {
|
||||
|
||||
export default function EmbeddingPlayground() {
|
||||
const [text, setText] = useState('')
|
||||
const [model, setModel] = useState('text-embedding-ada-002')
|
||||
const [model, setModel] = useState('')
|
||||
const [encodingFormat, setEncodingFormat] = useState('float')
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const [results, setResults] = useState<EmbeddingResult[]>([])
|
||||
@@ -43,8 +60,51 @@ export default function EmbeddingPlayground() {
|
||||
})
|
||||
const [selectedResult, setSelectedResult] = useState<EmbeddingResult | null>(null)
|
||||
const [comparisonMode, setComparisonMode] = useState(false)
|
||||
const [embeddingModels, setEmbeddingModels] = useState<Model[]>([])
|
||||
const [loadingModels, setLoadingModels] = useState(true)
|
||||
const { toast } = useToast()
|
||||
|
||||
// Fetch available embedding models
|
||||
useEffect(() => {
|
||||
const fetchModels = async () => {
|
||||
try {
|
||||
setLoadingModels(true)
|
||||
const response = await apiClient.get('/api-internal/v1/llm/models')
|
||||
|
||||
if (response.data) {
|
||||
// Filter models that support embeddings based on tasks field
|
||||
const models = response.data.filter((model: Model) => {
|
||||
// Check if model has embed or embedding in tasks
|
||||
if (model.tasks && Array.isArray(model.tasks)) {
|
||||
return model.tasks.includes('embed') || model.tasks.includes('embedding')
|
||||
}
|
||||
// Fallback: check if model ID contains embedding patterns
|
||||
const modelId = model.id.toLowerCase()
|
||||
return modelId.includes('embed') || modelId.includes('text-embedding')
|
||||
})
|
||||
|
||||
setEmbeddingModels(models)
|
||||
|
||||
// Set default model if available
|
||||
if (models.length > 0 && !model) {
|
||||
setModel(models[0].id)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch models:', error)
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to load embedding models",
|
||||
variant: "destructive"
|
||||
})
|
||||
} finally {
|
||||
setLoadingModels(false)
|
||||
}
|
||||
}
|
||||
|
||||
fetchModels()
|
||||
}, [])
|
||||
|
||||
const handleGenerateEmbedding = async () => {
|
||||
if (!text.trim()) {
|
||||
toast({
|
||||
@@ -55,6 +115,15 @@ export default function EmbeddingPlayground() {
|
||||
return
|
||||
}
|
||||
|
||||
if (!model) {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Please select an embedding model",
|
||||
variant: "destructive"
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setIsLoading(true)
|
||||
try {
|
||||
const data = await apiClient.post('/api-internal/v1/llm/embeddings', {
|
||||
@@ -93,13 +162,34 @@ export default function EmbeddingPlayground() {
|
||||
}
|
||||
}
|
||||
|
||||
const calculateCost = (tokens: number, model: string): number => {
|
||||
const calculateCost = (tokens: number, modelId: string): number => {
|
||||
// Known rates for common embedding models
|
||||
const rates: { [key: string]: number } = {
|
||||
'text-embedding-ada-002': 0.0001,
|
||||
'text-embedding-3-small': 0.00002,
|
||||
'text-embedding-3-large': 0.00013
|
||||
'text-embedding-3-large': 0.00013,
|
||||
'privatemode-text-embedding-ada-002': 0.0001,
|
||||
'privatemode-text-embedding-3-small': 0.00002,
|
||||
'privatemode-text-embedding-3-large': 0.00013
|
||||
}
|
||||
return (tokens / 1000) * (rates[model] || 0.0001)
|
||||
|
||||
// Check for exact match first
|
||||
if (rates[modelId]) {
|
||||
return (tokens / 1000) * rates[modelId]
|
||||
}
|
||||
|
||||
// Check for pattern matches (e.g., if model contains these patterns)
|
||||
const modelLower = modelId.toLowerCase()
|
||||
if (modelLower.includes('ada-002')) {
|
||||
return (tokens / 1000) * 0.0001
|
||||
} else if (modelLower.includes('3-small')) {
|
||||
return (tokens / 1000) * 0.00002
|
||||
} else if (modelLower.includes('3-large')) {
|
||||
return (tokens / 1000) * 0.00013
|
||||
}
|
||||
|
||||
// Default rate for unknown models
|
||||
return (tokens / 1000) * 0.0001
|
||||
}
|
||||
|
||||
const updateSessionStats = (result: EmbeddingResult) => {
|
||||
@@ -157,14 +247,27 @@ export default function EmbeddingPlayground() {
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label className="text-sm font-medium mb-2 block">Model</label>
|
||||
<Select value={model} onValueChange={setModel}>
|
||||
<Select value={model} onValueChange={setModel} disabled={loadingModels}>
|
||||
<SelectTrigger>
|
||||
<SelectValue />
|
||||
<SelectValue placeholder={loadingModels ? "Loading models..." : "Select a model"} />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="text-embedding-ada-002">text-embedding-ada-002</SelectItem>
|
||||
<SelectItem value="text-embedding-3-small">text-embedding-3-small</SelectItem>
|
||||
<SelectItem value="text-embedding-3-large">text-embedding-3-large</SelectItem>
|
||||
{embeddingModels.length === 0 && !loadingModels ? (
|
||||
<SelectItem value="no-models" disabled>
|
||||
No embedding models available
|
||||
</SelectItem>
|
||||
) : (
|
||||
embeddingModels.map((embModel) => (
|
||||
<SelectItem key={embModel.id} value={embModel.id}>
|
||||
{embModel.id}
|
||||
{embModel.owned_by && embModel.owned_by !== 'unknown' && (
|
||||
<span className="text-muted-foreground ml-2">
|
||||
({embModel.owned_by})
|
||||
</span>
|
||||
)}
|
||||
</SelectItem>
|
||||
))
|
||||
)}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
@@ -23,6 +23,7 @@ interface Model {
|
||||
max_output_tokens?: number
|
||||
supports_streaming?: boolean
|
||||
supports_function_calling?: boolean
|
||||
tasks?: string[] // Added tasks field from PrivateMode API
|
||||
}
|
||||
|
||||
interface ProviderStatus {
|
||||
@@ -97,11 +98,21 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
return 'Unknown'
|
||||
}
|
||||
|
||||
const getModelType = (modelId: string): 'chat' | 'embedding' | 'other' => {
|
||||
const getModelType = (model: Model): 'chat' | 'embedding' | 'other' => {
|
||||
// Check if model has tasks field from PrivateMode or other providers
|
||||
if (model.tasks && Array.isArray(model.tasks)) {
|
||||
// Models with "generate" task are chat models
|
||||
if (model.tasks.includes('generate')) return 'chat'
|
||||
// Models with "embed" task are embedding models
|
||||
if (model.tasks.includes('embed') || model.tasks.includes('embedding')) return 'embedding'
|
||||
}
|
||||
|
||||
// Fallback to ID-based detection for models without tasks field
|
||||
const modelId = model.id
|
||||
if (modelId.includes('embedding') || modelId.includes('embed')) return 'embedding'
|
||||
if (modelId.includes('whisper') || modelId.includes('speech')) return 'other' // Audio models
|
||||
|
||||
// PrivateMode and other chat models
|
||||
// PrivateMode and other chat models by ID pattern
|
||||
if (
|
||||
modelId.startsWith('privatemode-llama') ||
|
||||
modelId.startsWith('privatemode-claude') ||
|
||||
@@ -114,14 +125,16 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
modelId.includes('llama') ||
|
||||
modelId.includes('gemma') ||
|
||||
modelId.includes('qwen') ||
|
||||
modelId.includes('mistral') ||
|
||||
modelId.includes('command') ||
|
||||
modelId.includes('latest')
|
||||
) return 'chat'
|
||||
|
||||
return 'other'
|
||||
}
|
||||
|
||||
const getModelCategory = (modelId: string): string => {
|
||||
const type = getModelType(modelId)
|
||||
const getModelCategory = (model: Model): string => {
|
||||
const type = getModelType(model)
|
||||
switch (type) {
|
||||
case 'chat': return 'Chat Completion'
|
||||
case 'embedding': return 'Text Embedding'
|
||||
@@ -132,7 +145,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
|
||||
const filteredModels = models.filter(model => {
|
||||
if (filter === 'all') return true
|
||||
return getModelType(model.id) === filter
|
||||
return getModelType(model) === filter
|
||||
})
|
||||
|
||||
const groupedModels = filteredModels.reduce((acc, model) => {
|
||||
@@ -255,7 +268,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
<span>{model.id}</span>
|
||||
<div className="flex gap-1">
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{getModelCategory(model.id)}
|
||||
{getModelCategory(model)}
|
||||
</Badge>
|
||||
{model.supports_streaming && (
|
||||
<Badge variant="secondary" className="text-xs">
|
||||
@@ -299,7 +312,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium">Type:</span>
|
||||
<div className="text-muted-foreground">{getModelCategory(selectedModel.id)}</div>
|
||||
<div className="text-muted-foreground">{getModelCategory(selectedModel)}</div>
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium">Object:</span>
|
||||
|
||||
Reference in New Issue
Block a user