embedding models picker improvement

This commit is contained in:
2025-09-13 15:37:42 +02:00
parent e97d197d8b
commit 9c64177925
7 changed files with 177 additions and 91 deletions

View File

@@ -57,15 +57,24 @@ async def get_cached_models() -> List[Dict[str, Any]]:
# Convert ModelInfo objects to dict format for compatibility
models = []
for model_info in model_infos:
models.append({
model_dict = {
"id": model_info.id,
"object": model_info.object,
"created": model_info.created or int(time.time()),
"owned_by": model_info.owned_by,
# Add frontend-expected fields
"name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id
"provider": getattr(model_info, 'provider', model_info.owned_by) # Use provider if available, fallback to owned_by
})
"provider": getattr(model_info, 'provider', model_info.owned_by), # Use provider if available, fallback to owned_by
"capabilities": model_info.capabilities,
"context_window": model_info.context_window,
"max_output_tokens": model_info.max_output_tokens,
"supports_streaming": model_info.supports_streaming,
"supports_function_calling": model_info.supports_function_calling
}
# Include tasks field if present
if model_info.tasks:
model_dict["tasks"] = model_info.tasks
models.append(model_dict)
# Update cache
_models_cache["data"] = models

View File

@@ -138,6 +138,7 @@ class ModelInfo(BaseModel):
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
supports_streaming: bool = Field(False, description="Whether model supports streaming")
supports_function_calling: bool = Field(False, description="Whether model supports function calling")
tasks: Optional[List[str]] = Field(None, description="Model tasks (e.g., generate, embed, vision)")
class ProviderStatus(BaseModel):

View File

@@ -134,18 +134,39 @@ class PrivateModeProvider(BaseLLMProvider):
if not model_id:
continue
# Extract model capabilities from PrivateMode response
# Extract all information directly from API response
# Determine capabilities based on tasks field
tasks = model_data.get("tasks", [])
capabilities = []
# All PrivateMode models have TEE capability
capabilities.append("tee")
# Add capabilities based on tasks
if "generate" in tasks:
capabilities.append("chat")
if "embed" in tasks or "embedding" in tasks:
capabilities.append("embeddings")
if "vision" in tasks:
capabilities.append("vision")
# Check for function calling support in the API response
supports_function_calling = model_data.get("supports_function_calling", False)
if supports_function_calling:
capabilities.append("function_calling")
model_info = ModelInfo(
id=model_id,
object="model",
created=model_data.get("created", int(time.time())),
owned_by="privatemode",
owned_by=model_data.get("owned_by", "privatemode"),
provider=self.provider_name,
capabilities=self._get_model_capabilities(model_id),
context_window=self._get_model_context_window(model_id),
max_output_tokens=self._get_model_max_output(model_id),
supports_streaming=True, # PrivateMode supports streaming
supports_function_calling=self._supports_function_calling(model_id)
capabilities=capabilities,
context_window=model_data.get("context_window"),
max_output_tokens=model_data.get("max_output_tokens"),
supports_streaming=model_data.get("supports_streaming", True),
supports_function_calling=supports_function_calling,
tasks=tasks # Pass through tasks field from PrivateMode API
)
models.append(model_info)
@@ -453,68 +474,6 @@ class PrivateModeProvider(BaseLLMProvider):
details={"error": str(e)}
)
def _get_model_capabilities(self, model_id: str) -> List[str]:
"""Get capabilities for a specific model"""
capabilities = ["chat"]
# PrivateMode supports embeddings for most models
if "embed" in model_id.lower() or model_id in [
"privatemode-llama-3.1-405b", "privatemode-llama-3.1-70b",
"privatemode-claude-3.5-sonnet", "privatemode-gpt-4o"
]:
capabilities.append("embeddings")
# TEE protection is available for all PrivateMode models
capabilities.append("tee")
return capabilities
def _get_model_context_window(self, model_id: str) -> Optional[int]:
"""Get context window size for a specific model"""
context_windows = {
"privatemode-llama-3.1-405b": 128000,
"privatemode-llama-3.1-70b": 128000,
"privatemode-llama-3.1-8b": 128000,
"privatemode-llama-3-70b": 8192,
"privatemode-llama-3-8b": 8192,
"privatemode-claude-3.5-sonnet": 200000,
"privatemode-claude-3-haiku": 200000,
"privatemode-gpt-4o": 128000,
"privatemode-gpt-4o-mini": 128000,
"privatemode-gemini-1.5-pro": 2000000,
"privatemode-gemini-1.5-flash": 1000000
}
return context_windows.get(model_id, 8192) # Default to 8K
def _get_model_max_output(self, model_id: str) -> Optional[int]:
"""Get max output tokens for a specific model"""
max_outputs = {
"privatemode-llama-3.1-405b": 8192,
"privatemode-llama-3.1-70b": 8192,
"privatemode-llama-3.1-8b": 8192,
"privatemode-llama-3-70b": 4096,
"privatemode-llama-3-8b": 4096,
"privatemode-claude-3.5-sonnet": 8192,
"privatemode-claude-3-haiku": 4096,
"privatemode-gpt-4o": 16384,
"privatemode-gpt-4o-mini": 16384,
"privatemode-gemini-1.5-pro": 8192,
"privatemode-gemini-1.5-flash": 8192
}
return max_outputs.get(model_id, 4096) # Default to 4K
def _supports_function_calling(self, model_id: str) -> bool:
"""Check if model supports function calling"""
function_calling_models = [
"privatemode-gpt-4o", "privatemode-gpt-4o-mini",
"privatemode-claude-3.5-sonnet", "privatemode-claude-3-haiku",
"privatemode-gemini-1.5-pro", "privatemode-gemini-1.5-flash"
]
return model_id in function_calling_models
async def cleanup(self):
"""Cleanup PrivateMode provider resources"""
await super().cleanup()

View File

@@ -73,6 +73,7 @@ interface Model {
max_output_tokens?: number;
supports_streaming?: boolean;
supports_function_calling?: boolean;
tasks?: string[]; // Added tasks field from PrivateMode API
}
interface NewApiKeyData {

View File

@@ -623,7 +623,7 @@ export function ChatbotManager() {
<SelectContent>
{ragCollections.map((collection) => (
<SelectItem key={collection.id} value={collection.id}>
<div>
<div className="text-foreground">
<div className="font-medium">{collection.name}</div>
<div className="text-sm text-muted-foreground">
{collection.document_count} documents
@@ -889,7 +889,7 @@ export function ChatbotManager() {
<SelectContent>
{ragCollections.map((collection) => (
<SelectItem key={collection.id} value={collection.id}>
<div>
<div className="text-foreground">
<div className="font-medium">{collection.name}</div>
<div className="text-sm text-muted-foreground">
{collection.document_count} documents

View File

@@ -1,6 +1,6 @@
"use client"
import { useState } from 'react'
import { useState, useEffect } from 'react'
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
import { Button } from '@/components/ui/button'
import { Textarea } from '@/components/ui/textarea'
@@ -14,6 +14,23 @@ import { Download, Zap, Calculator, BarChart3, AlertCircle } from 'lucide-react'
import { useToast } from '@/hooks/use-toast'
import { apiClient } from '@/lib/api-client'
interface Model {
id: string
object: string
created?: number
owned_by?: string
permission?: any[]
root?: string
parent?: string
provider?: string
capabilities?: string[]
context_window?: number
max_output_tokens?: number
supports_streaming?: boolean
supports_function_calling?: boolean
tasks?: string[]
}
interface EmbeddingResult {
text: string
embedding: number[]
@@ -31,7 +48,7 @@ interface SessionStats {
export default function EmbeddingPlayground() {
const [text, setText] = useState('')
const [model, setModel] = useState('text-embedding-ada-002')
const [model, setModel] = useState('')
const [encodingFormat, setEncodingFormat] = useState('float')
const [isLoading, setIsLoading] = useState(false)
const [results, setResults] = useState<EmbeddingResult[]>([])
@@ -43,8 +60,51 @@ export default function EmbeddingPlayground() {
})
const [selectedResult, setSelectedResult] = useState<EmbeddingResult | null>(null)
const [comparisonMode, setComparisonMode] = useState(false)
const [embeddingModels, setEmbeddingModels] = useState<Model[]>([])
const [loadingModels, setLoadingModels] = useState(true)
const { toast } = useToast()
// Fetch available embedding models
useEffect(() => {
const fetchModels = async () => {
try {
setLoadingModels(true)
const response = await apiClient.get('/api-internal/v1/llm/models')
if (response.data) {
// Filter models that support embeddings based on tasks field
const models = response.data.filter((model: Model) => {
// Check if model has embed or embedding in tasks
if (model.tasks && Array.isArray(model.tasks)) {
return model.tasks.includes('embed') || model.tasks.includes('embedding')
}
// Fallback: check if model ID contains embedding patterns
const modelId = model.id.toLowerCase()
return modelId.includes('embed') || modelId.includes('text-embedding')
})
setEmbeddingModels(models)
// Set default model if available
if (models.length > 0 && !model) {
setModel(models[0].id)
}
}
} catch (error) {
console.error('Failed to fetch models:', error)
toast({
title: "Error",
description: "Failed to load embedding models",
variant: "destructive"
})
} finally {
setLoadingModels(false)
}
}
fetchModels()
}, [])
const handleGenerateEmbedding = async () => {
if (!text.trim()) {
toast({
@@ -55,6 +115,15 @@ export default function EmbeddingPlayground() {
return
}
if (!model) {
toast({
title: "Error",
description: "Please select an embedding model",
variant: "destructive"
})
return
}
setIsLoading(true)
try {
const data = await apiClient.post('/api-internal/v1/llm/embeddings', {
@@ -93,13 +162,34 @@ export default function EmbeddingPlayground() {
}
}
const calculateCost = (tokens: number, model: string): number => {
const calculateCost = (tokens: number, modelId: string): number => {
// Known rates for common embedding models
const rates: { [key: string]: number } = {
'text-embedding-ada-002': 0.0001,
'text-embedding-3-small': 0.00002,
'text-embedding-3-large': 0.00013
'text-embedding-3-large': 0.00013,
'privatemode-text-embedding-ada-002': 0.0001,
'privatemode-text-embedding-3-small': 0.00002,
'privatemode-text-embedding-3-large': 0.00013
}
return (tokens / 1000) * (rates[model] || 0.0001)
// Check for exact match first
if (rates[modelId]) {
return (tokens / 1000) * rates[modelId]
}
// Check for pattern matches (e.g., if model contains these patterns)
const modelLower = modelId.toLowerCase()
if (modelLower.includes('ada-002')) {
return (tokens / 1000) * 0.0001
} else if (modelLower.includes('3-small')) {
return (tokens / 1000) * 0.00002
} else if (modelLower.includes('3-large')) {
return (tokens / 1000) * 0.00013
}
// Default rate for unknown models
return (tokens / 1000) * 0.0001
}
const updateSessionStats = (result: EmbeddingResult) => {
@@ -157,14 +247,27 @@ export default function EmbeddingPlayground() {
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<div>
<label className="text-sm font-medium mb-2 block">Model</label>
<Select value={model} onValueChange={setModel}>
<Select value={model} onValueChange={setModel} disabled={loadingModels}>
<SelectTrigger>
<SelectValue />
<SelectValue placeholder={loadingModels ? "Loading models..." : "Select a model"} />
</SelectTrigger>
<SelectContent>
<SelectItem value="text-embedding-ada-002">text-embedding-ada-002</SelectItem>
<SelectItem value="text-embedding-3-small">text-embedding-3-small</SelectItem>
<SelectItem value="text-embedding-3-large">text-embedding-3-large</SelectItem>
{embeddingModels.length === 0 && !loadingModels ? (
<SelectItem value="no-models" disabled>
No embedding models available
</SelectItem>
) : (
embeddingModels.map((embModel) => (
<SelectItem key={embModel.id} value={embModel.id}>
{embModel.id}
{embModel.owned_by && embModel.owned_by !== 'unknown' && (
<span className="text-muted-foreground ml-2">
({embModel.owned_by})
</span>
)}
</SelectItem>
))
)}
</SelectContent>
</Select>
</div>

View File

@@ -23,6 +23,7 @@ interface Model {
max_output_tokens?: number
supports_streaming?: boolean
supports_function_calling?: boolean
tasks?: string[] // Added tasks field from PrivateMode API
}
interface ProviderStatus {
@@ -97,11 +98,21 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
return 'Unknown'
}
const getModelType = (modelId: string): 'chat' | 'embedding' | 'other' => {
const getModelType = (model: Model): 'chat' | 'embedding' | 'other' => {
// Check if model has tasks field from PrivateMode or other providers
if (model.tasks && Array.isArray(model.tasks)) {
// Models with "generate" task are chat models
if (model.tasks.includes('generate')) return 'chat'
// Models with "embed" task are embedding models
if (model.tasks.includes('embed') || model.tasks.includes('embedding')) return 'embedding'
}
// Fallback to ID-based detection for models without tasks field
const modelId = model.id
if (modelId.includes('embedding') || modelId.includes('embed')) return 'embedding'
if (modelId.includes('whisper') || modelId.includes('speech')) return 'other' // Audio models
// PrivateMode and other chat models
// PrivateMode and other chat models by ID pattern
if (
modelId.startsWith('privatemode-llama') ||
modelId.startsWith('privatemode-claude') ||
@@ -114,14 +125,16 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
modelId.includes('llama') ||
modelId.includes('gemma') ||
modelId.includes('qwen') ||
modelId.includes('mistral') ||
modelId.includes('command') ||
modelId.includes('latest')
) return 'chat'
return 'other'
}
const getModelCategory = (modelId: string): string => {
const type = getModelType(modelId)
const getModelCategory = (model: Model): string => {
const type = getModelType(model)
switch (type) {
case 'chat': return 'Chat Completion'
case 'embedding': return 'Text Embedding'
@@ -132,7 +145,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
const filteredModels = models.filter(model => {
if (filter === 'all') return true
return getModelType(model.id) === filter
return getModelType(model) === filter
})
const groupedModels = filteredModels.reduce((acc, model) => {
@@ -255,7 +268,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
<span>{model.id}</span>
<div className="flex gap-1">
<Badge variant="outline" className="text-xs">
{getModelCategory(model.id)}
{getModelCategory(model)}
</Badge>
{model.supports_streaming && (
<Badge variant="secondary" className="text-xs">
@@ -299,7 +312,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
</div>
<div>
<span className="font-medium">Type:</span>
<div className="text-muted-foreground">{getModelCategory(selectedModel.id)}</div>
<div className="text-muted-foreground">{getModelCategory(selectedModel)}</div>
</div>
<div>
<span className="font-medium">Object:</span>