embedding models picker improvement

This commit is contained in:
2025-09-13 15:37:42 +02:00
parent e97d197d8b
commit 9c64177925
7 changed files with 177 additions and 91 deletions

View File

@@ -57,15 +57,24 @@ async def get_cached_models() -> List[Dict[str, Any]]:
# Convert ModelInfo objects to dict format for compatibility # Convert ModelInfo objects to dict format for compatibility
models = [] models = []
for model_info in model_infos: for model_info in model_infos:
models.append({ model_dict = {
"id": model_info.id, "id": model_info.id,
"object": model_info.object, "object": model_info.object,
"created": model_info.created or int(time.time()), "created": model_info.created or int(time.time()),
"owned_by": model_info.owned_by, "owned_by": model_info.owned_by,
# Add frontend-expected fields # Add frontend-expected fields
"name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id "name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id
"provider": getattr(model_info, 'provider', model_info.owned_by) # Use provider if available, fallback to owned_by "provider": getattr(model_info, 'provider', model_info.owned_by), # Use provider if available, fallback to owned_by
}) "capabilities": model_info.capabilities,
"context_window": model_info.context_window,
"max_output_tokens": model_info.max_output_tokens,
"supports_streaming": model_info.supports_streaming,
"supports_function_calling": model_info.supports_function_calling
}
# Include tasks field if present
if model_info.tasks:
model_dict["tasks"] = model_info.tasks
models.append(model_dict)
# Update cache # Update cache
_models_cache["data"] = models _models_cache["data"] = models

View File

@@ -138,6 +138,7 @@ class ModelInfo(BaseModel):
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens") max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
supports_streaming: bool = Field(False, description="Whether model supports streaming") supports_streaming: bool = Field(False, description="Whether model supports streaming")
supports_function_calling: bool = Field(False, description="Whether model supports function calling") supports_function_calling: bool = Field(False, description="Whether model supports function calling")
tasks: Optional[List[str]] = Field(None, description="Model tasks (e.g., generate, embed, vision)")
class ProviderStatus(BaseModel): class ProviderStatus(BaseModel):

View File

@@ -134,18 +134,39 @@ class PrivateModeProvider(BaseLLMProvider):
if not model_id: if not model_id:
continue continue
# Extract model capabilities from PrivateMode response # Extract all information directly from API response
# Determine capabilities based on tasks field
tasks = model_data.get("tasks", [])
capabilities = []
# All PrivateMode models have TEE capability
capabilities.append("tee")
# Add capabilities based on tasks
if "generate" in tasks:
capabilities.append("chat")
if "embed" in tasks or "embedding" in tasks:
capabilities.append("embeddings")
if "vision" in tasks:
capabilities.append("vision")
# Check for function calling support in the API response
supports_function_calling = model_data.get("supports_function_calling", False)
if supports_function_calling:
capabilities.append("function_calling")
model_info = ModelInfo( model_info = ModelInfo(
id=model_id, id=model_id,
object="model", object="model",
created=model_data.get("created", int(time.time())), created=model_data.get("created", int(time.time())),
owned_by="privatemode", owned_by=model_data.get("owned_by", "privatemode"),
provider=self.provider_name, provider=self.provider_name,
capabilities=self._get_model_capabilities(model_id), capabilities=capabilities,
context_window=self._get_model_context_window(model_id), context_window=model_data.get("context_window"),
max_output_tokens=self._get_model_max_output(model_id), max_output_tokens=model_data.get("max_output_tokens"),
supports_streaming=True, # PrivateMode supports streaming supports_streaming=model_data.get("supports_streaming", True),
supports_function_calling=self._supports_function_calling(model_id) supports_function_calling=supports_function_calling,
tasks=tasks # Pass through tasks field from PrivateMode API
) )
models.append(model_info) models.append(model_info)
@@ -453,68 +474,6 @@ class PrivateModeProvider(BaseLLMProvider):
details={"error": str(e)} details={"error": str(e)}
) )
def _get_model_capabilities(self, model_id: str) -> List[str]:
"""Get capabilities for a specific model"""
capabilities = ["chat"]
# PrivateMode supports embeddings for most models
if "embed" in model_id.lower() or model_id in [
"privatemode-llama-3.1-405b", "privatemode-llama-3.1-70b",
"privatemode-claude-3.5-sonnet", "privatemode-gpt-4o"
]:
capabilities.append("embeddings")
# TEE protection is available for all PrivateMode models
capabilities.append("tee")
return capabilities
def _get_model_context_window(self, model_id: str) -> Optional[int]:
"""Get context window size for a specific model"""
context_windows = {
"privatemode-llama-3.1-405b": 128000,
"privatemode-llama-3.1-70b": 128000,
"privatemode-llama-3.1-8b": 128000,
"privatemode-llama-3-70b": 8192,
"privatemode-llama-3-8b": 8192,
"privatemode-claude-3.5-sonnet": 200000,
"privatemode-claude-3-haiku": 200000,
"privatemode-gpt-4o": 128000,
"privatemode-gpt-4o-mini": 128000,
"privatemode-gemini-1.5-pro": 2000000,
"privatemode-gemini-1.5-flash": 1000000
}
return context_windows.get(model_id, 8192) # Default to 8K
def _get_model_max_output(self, model_id: str) -> Optional[int]:
"""Get max output tokens for a specific model"""
max_outputs = {
"privatemode-llama-3.1-405b": 8192,
"privatemode-llama-3.1-70b": 8192,
"privatemode-llama-3.1-8b": 8192,
"privatemode-llama-3-70b": 4096,
"privatemode-llama-3-8b": 4096,
"privatemode-claude-3.5-sonnet": 8192,
"privatemode-claude-3-haiku": 4096,
"privatemode-gpt-4o": 16384,
"privatemode-gpt-4o-mini": 16384,
"privatemode-gemini-1.5-pro": 8192,
"privatemode-gemini-1.5-flash": 8192
}
return max_outputs.get(model_id, 4096) # Default to 4K
def _supports_function_calling(self, model_id: str) -> bool:
"""Check if model supports function calling"""
function_calling_models = [
"privatemode-gpt-4o", "privatemode-gpt-4o-mini",
"privatemode-claude-3.5-sonnet", "privatemode-claude-3-haiku",
"privatemode-gemini-1.5-pro", "privatemode-gemini-1.5-flash"
]
return model_id in function_calling_models
async def cleanup(self): async def cleanup(self):
"""Cleanup PrivateMode provider resources""" """Cleanup PrivateMode provider resources"""
await super().cleanup() await super().cleanup()

View File

@@ -73,6 +73,7 @@ interface Model {
max_output_tokens?: number; max_output_tokens?: number;
supports_streaming?: boolean; supports_streaming?: boolean;
supports_function_calling?: boolean; supports_function_calling?: boolean;
tasks?: string[]; // Added tasks field from PrivateMode API
} }
interface NewApiKeyData { interface NewApiKeyData {

View File

@@ -623,7 +623,7 @@ export function ChatbotManager() {
<SelectContent> <SelectContent>
{ragCollections.map((collection) => ( {ragCollections.map((collection) => (
<SelectItem key={collection.id} value={collection.id}> <SelectItem key={collection.id} value={collection.id}>
<div> <div className="text-foreground">
<div className="font-medium">{collection.name}</div> <div className="font-medium">{collection.name}</div>
<div className="text-sm text-muted-foreground"> <div className="text-sm text-muted-foreground">
{collection.document_count} documents {collection.document_count} documents
@@ -889,7 +889,7 @@ export function ChatbotManager() {
<SelectContent> <SelectContent>
{ragCollections.map((collection) => ( {ragCollections.map((collection) => (
<SelectItem key={collection.id} value={collection.id}> <SelectItem key={collection.id} value={collection.id}>
<div> <div className="text-foreground">
<div className="font-medium">{collection.name}</div> <div className="font-medium">{collection.name}</div>
<div className="text-sm text-muted-foreground"> <div className="text-sm text-muted-foreground">
{collection.document_count} documents {collection.document_count} documents

View File

@@ -1,6 +1,6 @@
"use client" "use client"
import { useState } from 'react' import { useState, useEffect } from 'react'
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
import { Button } from '@/components/ui/button' import { Button } from '@/components/ui/button'
import { Textarea } from '@/components/ui/textarea' import { Textarea } from '@/components/ui/textarea'
@@ -14,6 +14,23 @@ import { Download, Zap, Calculator, BarChart3, AlertCircle } from 'lucide-react'
import { useToast } from '@/hooks/use-toast' import { useToast } from '@/hooks/use-toast'
import { apiClient } from '@/lib/api-client' import { apiClient } from '@/lib/api-client'
interface Model {
id: string
object: string
created?: number
owned_by?: string
permission?: any[]
root?: string
parent?: string
provider?: string
capabilities?: string[]
context_window?: number
max_output_tokens?: number
supports_streaming?: boolean
supports_function_calling?: boolean
tasks?: string[]
}
interface EmbeddingResult { interface EmbeddingResult {
text: string text: string
embedding: number[] embedding: number[]
@@ -31,7 +48,7 @@ interface SessionStats {
export default function EmbeddingPlayground() { export default function EmbeddingPlayground() {
const [text, setText] = useState('') const [text, setText] = useState('')
const [model, setModel] = useState('text-embedding-ada-002') const [model, setModel] = useState('')
const [encodingFormat, setEncodingFormat] = useState('float') const [encodingFormat, setEncodingFormat] = useState('float')
const [isLoading, setIsLoading] = useState(false) const [isLoading, setIsLoading] = useState(false)
const [results, setResults] = useState<EmbeddingResult[]>([]) const [results, setResults] = useState<EmbeddingResult[]>([])
@@ -43,8 +60,51 @@ export default function EmbeddingPlayground() {
}) })
const [selectedResult, setSelectedResult] = useState<EmbeddingResult | null>(null) const [selectedResult, setSelectedResult] = useState<EmbeddingResult | null>(null)
const [comparisonMode, setComparisonMode] = useState(false) const [comparisonMode, setComparisonMode] = useState(false)
const [embeddingModels, setEmbeddingModels] = useState<Model[]>([])
const [loadingModels, setLoadingModels] = useState(true)
const { toast } = useToast() const { toast } = useToast()
// Fetch available embedding models
useEffect(() => {
const fetchModels = async () => {
try {
setLoadingModels(true)
const response = await apiClient.get('/api-internal/v1/llm/models')
if (response.data) {
// Filter models that support embeddings based on tasks field
const models = response.data.filter((model: Model) => {
// Check if model has embed or embedding in tasks
if (model.tasks && Array.isArray(model.tasks)) {
return model.tasks.includes('embed') || model.tasks.includes('embedding')
}
// Fallback: check if model ID contains embedding patterns
const modelId = model.id.toLowerCase()
return modelId.includes('embed') || modelId.includes('text-embedding')
})
setEmbeddingModels(models)
// Set default model if available
if (models.length > 0 && !model) {
setModel(models[0].id)
}
}
} catch (error) {
console.error('Failed to fetch models:', error)
toast({
title: "Error",
description: "Failed to load embedding models",
variant: "destructive"
})
} finally {
setLoadingModels(false)
}
}
fetchModels()
}, [])
const handleGenerateEmbedding = async () => { const handleGenerateEmbedding = async () => {
if (!text.trim()) { if (!text.trim()) {
toast({ toast({
@@ -55,6 +115,15 @@ export default function EmbeddingPlayground() {
return return
} }
if (!model) {
toast({
title: "Error",
description: "Please select an embedding model",
variant: "destructive"
})
return
}
setIsLoading(true) setIsLoading(true)
try { try {
const data = await apiClient.post('/api-internal/v1/llm/embeddings', { const data = await apiClient.post('/api-internal/v1/llm/embeddings', {
@@ -93,13 +162,34 @@ export default function EmbeddingPlayground() {
} }
} }
const calculateCost = (tokens: number, model: string): number => { const calculateCost = (tokens: number, modelId: string): number => {
// Known rates for common embedding models
const rates: { [key: string]: number } = { const rates: { [key: string]: number } = {
'text-embedding-ada-002': 0.0001, 'text-embedding-ada-002': 0.0001,
'text-embedding-3-small': 0.00002, 'text-embedding-3-small': 0.00002,
'text-embedding-3-large': 0.00013 'text-embedding-3-large': 0.00013,
'privatemode-text-embedding-ada-002': 0.0001,
'privatemode-text-embedding-3-small': 0.00002,
'privatemode-text-embedding-3-large': 0.00013
} }
return (tokens / 1000) * (rates[model] || 0.0001)
// Check for exact match first
if (rates[modelId]) {
return (tokens / 1000) * rates[modelId]
}
// Check for pattern matches (e.g., if model contains these patterns)
const modelLower = modelId.toLowerCase()
if (modelLower.includes('ada-002')) {
return (tokens / 1000) * 0.0001
} else if (modelLower.includes('3-small')) {
return (tokens / 1000) * 0.00002
} else if (modelLower.includes('3-large')) {
return (tokens / 1000) * 0.00013
}
// Default rate for unknown models
return (tokens / 1000) * 0.0001
} }
const updateSessionStats = (result: EmbeddingResult) => { const updateSessionStats = (result: EmbeddingResult) => {
@@ -157,14 +247,27 @@ export default function EmbeddingPlayground() {
<div className="grid grid-cols-1 md:grid-cols-2 gap-4"> <div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<div> <div>
<label className="text-sm font-medium mb-2 block">Model</label> <label className="text-sm font-medium mb-2 block">Model</label>
<Select value={model} onValueChange={setModel}> <Select value={model} onValueChange={setModel} disabled={loadingModels}>
<SelectTrigger> <SelectTrigger>
<SelectValue /> <SelectValue placeholder={loadingModels ? "Loading models..." : "Select a model"} />
</SelectTrigger> </SelectTrigger>
<SelectContent> <SelectContent>
<SelectItem value="text-embedding-ada-002">text-embedding-ada-002</SelectItem> {embeddingModels.length === 0 && !loadingModels ? (
<SelectItem value="text-embedding-3-small">text-embedding-3-small</SelectItem> <SelectItem value="no-models" disabled>
<SelectItem value="text-embedding-3-large">text-embedding-3-large</SelectItem> No embedding models available
</SelectItem>
) : (
embeddingModels.map((embModel) => (
<SelectItem key={embModel.id} value={embModel.id}>
{embModel.id}
{embModel.owned_by && embModel.owned_by !== 'unknown' && (
<span className="text-muted-foreground ml-2">
({embModel.owned_by})
</span>
)}
</SelectItem>
))
)}
</SelectContent> </SelectContent>
</Select> </Select>
</div> </div>

View File

@@ -23,6 +23,7 @@ interface Model {
max_output_tokens?: number max_output_tokens?: number
supports_streaming?: boolean supports_streaming?: boolean
supports_function_calling?: boolean supports_function_calling?: boolean
tasks?: string[] // Added tasks field from PrivateMode API
} }
interface ProviderStatus { interface ProviderStatus {
@@ -97,11 +98,21 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
return 'Unknown' return 'Unknown'
} }
const getModelType = (modelId: string): 'chat' | 'embedding' | 'other' => { const getModelType = (model: Model): 'chat' | 'embedding' | 'other' => {
// Check if model has tasks field from PrivateMode or other providers
if (model.tasks && Array.isArray(model.tasks)) {
// Models with "generate" task are chat models
if (model.tasks.includes('generate')) return 'chat'
// Models with "embed" task are embedding models
if (model.tasks.includes('embed') || model.tasks.includes('embedding')) return 'embedding'
}
// Fallback to ID-based detection for models without tasks field
const modelId = model.id
if (modelId.includes('embedding') || modelId.includes('embed')) return 'embedding' if (modelId.includes('embedding') || modelId.includes('embed')) return 'embedding'
if (modelId.includes('whisper') || modelId.includes('speech')) return 'other' // Audio models if (modelId.includes('whisper') || modelId.includes('speech')) return 'other' // Audio models
// PrivateMode and other chat models // PrivateMode and other chat models by ID pattern
if ( if (
modelId.startsWith('privatemode-llama') || modelId.startsWith('privatemode-llama') ||
modelId.startsWith('privatemode-claude') || modelId.startsWith('privatemode-claude') ||
@@ -114,14 +125,16 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
modelId.includes('llama') || modelId.includes('llama') ||
modelId.includes('gemma') || modelId.includes('gemma') ||
modelId.includes('qwen') || modelId.includes('qwen') ||
modelId.includes('mistral') ||
modelId.includes('command') ||
modelId.includes('latest') modelId.includes('latest')
) return 'chat' ) return 'chat'
return 'other' return 'other'
} }
const getModelCategory = (modelId: string): string => { const getModelCategory = (model: Model): string => {
const type = getModelType(modelId) const type = getModelType(model)
switch (type) { switch (type) {
case 'chat': return 'Chat Completion' case 'chat': return 'Chat Completion'
case 'embedding': return 'Text Embedding' case 'embedding': return 'Text Embedding'
@@ -132,7 +145,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
const filteredModels = models.filter(model => { const filteredModels = models.filter(model => {
if (filter === 'all') return true if (filter === 'all') return true
return getModelType(model.id) === filter return getModelType(model) === filter
}) })
const groupedModels = filteredModels.reduce((acc, model) => { const groupedModels = filteredModels.reduce((acc, model) => {
@@ -255,7 +268,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
<span>{model.id}</span> <span>{model.id}</span>
<div className="flex gap-1"> <div className="flex gap-1">
<Badge variant="outline" className="text-xs"> <Badge variant="outline" className="text-xs">
{getModelCategory(model.id)} {getModelCategory(model)}
</Badge> </Badge>
{model.supports_streaming && ( {model.supports_streaming && (
<Badge variant="secondary" className="text-xs"> <Badge variant="secondary" className="text-xs">
@@ -299,7 +312,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
</div> </div>
<div> <div>
<span className="font-medium">Type:</span> <span className="font-medium">Type:</span>
<div className="text-muted-foreground">{getModelCategory(selectedModel.id)}</div> <div className="text-muted-foreground">{getModelCategory(selectedModel)}</div>
</div> </div>
<div> <div>
<span className="font-medium">Object:</span> <span className="font-medium">Object:</span>