mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
embedding models picker improvement
This commit is contained in:
@@ -57,15 +57,24 @@ async def get_cached_models() -> List[Dict[str, Any]]:
|
|||||||
# Convert ModelInfo objects to dict format for compatibility
|
# Convert ModelInfo objects to dict format for compatibility
|
||||||
models = []
|
models = []
|
||||||
for model_info in model_infos:
|
for model_info in model_infos:
|
||||||
models.append({
|
model_dict = {
|
||||||
"id": model_info.id,
|
"id": model_info.id,
|
||||||
"object": model_info.object,
|
"object": model_info.object,
|
||||||
"created": model_info.created or int(time.time()),
|
"created": model_info.created or int(time.time()),
|
||||||
"owned_by": model_info.owned_by,
|
"owned_by": model_info.owned_by,
|
||||||
# Add frontend-expected fields
|
# Add frontend-expected fields
|
||||||
"name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id
|
"name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id
|
||||||
"provider": getattr(model_info, 'provider', model_info.owned_by) # Use provider if available, fallback to owned_by
|
"provider": getattr(model_info, 'provider', model_info.owned_by), # Use provider if available, fallback to owned_by
|
||||||
})
|
"capabilities": model_info.capabilities,
|
||||||
|
"context_window": model_info.context_window,
|
||||||
|
"max_output_tokens": model_info.max_output_tokens,
|
||||||
|
"supports_streaming": model_info.supports_streaming,
|
||||||
|
"supports_function_calling": model_info.supports_function_calling
|
||||||
|
}
|
||||||
|
# Include tasks field if present
|
||||||
|
if model_info.tasks:
|
||||||
|
model_dict["tasks"] = model_info.tasks
|
||||||
|
models.append(model_dict)
|
||||||
|
|
||||||
# Update cache
|
# Update cache
|
||||||
_models_cache["data"] = models
|
_models_cache["data"] = models
|
||||||
|
|||||||
@@ -138,6 +138,7 @@ class ModelInfo(BaseModel):
|
|||||||
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
|
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
|
||||||
supports_streaming: bool = Field(False, description="Whether model supports streaming")
|
supports_streaming: bool = Field(False, description="Whether model supports streaming")
|
||||||
supports_function_calling: bool = Field(False, description="Whether model supports function calling")
|
supports_function_calling: bool = Field(False, description="Whether model supports function calling")
|
||||||
|
tasks: Optional[List[str]] = Field(None, description="Model tasks (e.g., generate, embed, vision)")
|
||||||
|
|
||||||
|
|
||||||
class ProviderStatus(BaseModel):
|
class ProviderStatus(BaseModel):
|
||||||
|
|||||||
@@ -134,18 +134,39 @@ class PrivateModeProvider(BaseLLMProvider):
|
|||||||
if not model_id:
|
if not model_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Extract model capabilities from PrivateMode response
|
# Extract all information directly from API response
|
||||||
|
# Determine capabilities based on tasks field
|
||||||
|
tasks = model_data.get("tasks", [])
|
||||||
|
capabilities = []
|
||||||
|
|
||||||
|
# All PrivateMode models have TEE capability
|
||||||
|
capabilities.append("tee")
|
||||||
|
|
||||||
|
# Add capabilities based on tasks
|
||||||
|
if "generate" in tasks:
|
||||||
|
capabilities.append("chat")
|
||||||
|
if "embed" in tasks or "embedding" in tasks:
|
||||||
|
capabilities.append("embeddings")
|
||||||
|
if "vision" in tasks:
|
||||||
|
capabilities.append("vision")
|
||||||
|
|
||||||
|
# Check for function calling support in the API response
|
||||||
|
supports_function_calling = model_data.get("supports_function_calling", False)
|
||||||
|
if supports_function_calling:
|
||||||
|
capabilities.append("function_calling")
|
||||||
|
|
||||||
model_info = ModelInfo(
|
model_info = ModelInfo(
|
||||||
id=model_id,
|
id=model_id,
|
||||||
object="model",
|
object="model",
|
||||||
created=model_data.get("created", int(time.time())),
|
created=model_data.get("created", int(time.time())),
|
||||||
owned_by="privatemode",
|
owned_by=model_data.get("owned_by", "privatemode"),
|
||||||
provider=self.provider_name,
|
provider=self.provider_name,
|
||||||
capabilities=self._get_model_capabilities(model_id),
|
capabilities=capabilities,
|
||||||
context_window=self._get_model_context_window(model_id),
|
context_window=model_data.get("context_window"),
|
||||||
max_output_tokens=self._get_model_max_output(model_id),
|
max_output_tokens=model_data.get("max_output_tokens"),
|
||||||
supports_streaming=True, # PrivateMode supports streaming
|
supports_streaming=model_data.get("supports_streaming", True),
|
||||||
supports_function_calling=self._supports_function_calling(model_id)
|
supports_function_calling=supports_function_calling,
|
||||||
|
tasks=tasks # Pass through tasks field from PrivateMode API
|
||||||
)
|
)
|
||||||
models.append(model_info)
|
models.append(model_info)
|
||||||
|
|
||||||
@@ -453,68 +474,6 @@ class PrivateModeProvider(BaseLLMProvider):
|
|||||||
details={"error": str(e)}
|
details={"error": str(e)}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_model_capabilities(self, model_id: str) -> List[str]:
|
|
||||||
"""Get capabilities for a specific model"""
|
|
||||||
capabilities = ["chat"]
|
|
||||||
|
|
||||||
# PrivateMode supports embeddings for most models
|
|
||||||
if "embed" in model_id.lower() or model_id in [
|
|
||||||
"privatemode-llama-3.1-405b", "privatemode-llama-3.1-70b",
|
|
||||||
"privatemode-claude-3.5-sonnet", "privatemode-gpt-4o"
|
|
||||||
]:
|
|
||||||
capabilities.append("embeddings")
|
|
||||||
|
|
||||||
# TEE protection is available for all PrivateMode models
|
|
||||||
capabilities.append("tee")
|
|
||||||
|
|
||||||
return capabilities
|
|
||||||
|
|
||||||
def _get_model_context_window(self, model_id: str) -> Optional[int]:
|
|
||||||
"""Get context window size for a specific model"""
|
|
||||||
context_windows = {
|
|
||||||
"privatemode-llama-3.1-405b": 128000,
|
|
||||||
"privatemode-llama-3.1-70b": 128000,
|
|
||||||
"privatemode-llama-3.1-8b": 128000,
|
|
||||||
"privatemode-llama-3-70b": 8192,
|
|
||||||
"privatemode-llama-3-8b": 8192,
|
|
||||||
"privatemode-claude-3.5-sonnet": 200000,
|
|
||||||
"privatemode-claude-3-haiku": 200000,
|
|
||||||
"privatemode-gpt-4o": 128000,
|
|
||||||
"privatemode-gpt-4o-mini": 128000,
|
|
||||||
"privatemode-gemini-1.5-pro": 2000000,
|
|
||||||
"privatemode-gemini-1.5-flash": 1000000
|
|
||||||
}
|
|
||||||
|
|
||||||
return context_windows.get(model_id, 8192) # Default to 8K
|
|
||||||
|
|
||||||
def _get_model_max_output(self, model_id: str) -> Optional[int]:
|
|
||||||
"""Get max output tokens for a specific model"""
|
|
||||||
max_outputs = {
|
|
||||||
"privatemode-llama-3.1-405b": 8192,
|
|
||||||
"privatemode-llama-3.1-70b": 8192,
|
|
||||||
"privatemode-llama-3.1-8b": 8192,
|
|
||||||
"privatemode-llama-3-70b": 4096,
|
|
||||||
"privatemode-llama-3-8b": 4096,
|
|
||||||
"privatemode-claude-3.5-sonnet": 8192,
|
|
||||||
"privatemode-claude-3-haiku": 4096,
|
|
||||||
"privatemode-gpt-4o": 16384,
|
|
||||||
"privatemode-gpt-4o-mini": 16384,
|
|
||||||
"privatemode-gemini-1.5-pro": 8192,
|
|
||||||
"privatemode-gemini-1.5-flash": 8192
|
|
||||||
}
|
|
||||||
|
|
||||||
return max_outputs.get(model_id, 4096) # Default to 4K
|
|
||||||
|
|
||||||
def _supports_function_calling(self, model_id: str) -> bool:
|
|
||||||
"""Check if model supports function calling"""
|
|
||||||
function_calling_models = [
|
|
||||||
"privatemode-gpt-4o", "privatemode-gpt-4o-mini",
|
|
||||||
"privatemode-claude-3.5-sonnet", "privatemode-claude-3-haiku",
|
|
||||||
"privatemode-gemini-1.5-pro", "privatemode-gemini-1.5-flash"
|
|
||||||
]
|
|
||||||
|
|
||||||
return model_id in function_calling_models
|
|
||||||
|
|
||||||
async def cleanup(self):
|
async def cleanup(self):
|
||||||
"""Cleanup PrivateMode provider resources"""
|
"""Cleanup PrivateMode provider resources"""
|
||||||
await super().cleanup()
|
await super().cleanup()
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ interface Model {
|
|||||||
max_output_tokens?: number;
|
max_output_tokens?: number;
|
||||||
supports_streaming?: boolean;
|
supports_streaming?: boolean;
|
||||||
supports_function_calling?: boolean;
|
supports_function_calling?: boolean;
|
||||||
|
tasks?: string[]; // Added tasks field from PrivateMode API
|
||||||
}
|
}
|
||||||
|
|
||||||
interface NewApiKeyData {
|
interface NewApiKeyData {
|
||||||
|
|||||||
@@ -623,7 +623,7 @@ export function ChatbotManager() {
|
|||||||
<SelectContent>
|
<SelectContent>
|
||||||
{ragCollections.map((collection) => (
|
{ragCollections.map((collection) => (
|
||||||
<SelectItem key={collection.id} value={collection.id}>
|
<SelectItem key={collection.id} value={collection.id}>
|
||||||
<div>
|
<div className="text-foreground">
|
||||||
<div className="font-medium">{collection.name}</div>
|
<div className="font-medium">{collection.name}</div>
|
||||||
<div className="text-sm text-muted-foreground">
|
<div className="text-sm text-muted-foreground">
|
||||||
{collection.document_count} documents
|
{collection.document_count} documents
|
||||||
@@ -889,7 +889,7 @@ export function ChatbotManager() {
|
|||||||
<SelectContent>
|
<SelectContent>
|
||||||
{ragCollections.map((collection) => (
|
{ragCollections.map((collection) => (
|
||||||
<SelectItem key={collection.id} value={collection.id}>
|
<SelectItem key={collection.id} value={collection.id}>
|
||||||
<div>
|
<div className="text-foreground">
|
||||||
<div className="font-medium">{collection.name}</div>
|
<div className="font-medium">{collection.name}</div>
|
||||||
<div className="text-sm text-muted-foreground">
|
<div className="text-sm text-muted-foreground">
|
||||||
{collection.document_count} documents
|
{collection.document_count} documents
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"use client"
|
"use client"
|
||||||
|
|
||||||
import { useState } from 'react'
|
import { useState, useEffect } from 'react'
|
||||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
||||||
import { Button } from '@/components/ui/button'
|
import { Button } from '@/components/ui/button'
|
||||||
import { Textarea } from '@/components/ui/textarea'
|
import { Textarea } from '@/components/ui/textarea'
|
||||||
@@ -14,6 +14,23 @@ import { Download, Zap, Calculator, BarChart3, AlertCircle } from 'lucide-react'
|
|||||||
import { useToast } from '@/hooks/use-toast'
|
import { useToast } from '@/hooks/use-toast'
|
||||||
import { apiClient } from '@/lib/api-client'
|
import { apiClient } from '@/lib/api-client'
|
||||||
|
|
||||||
|
interface Model {
|
||||||
|
id: string
|
||||||
|
object: string
|
||||||
|
created?: number
|
||||||
|
owned_by?: string
|
||||||
|
permission?: any[]
|
||||||
|
root?: string
|
||||||
|
parent?: string
|
||||||
|
provider?: string
|
||||||
|
capabilities?: string[]
|
||||||
|
context_window?: number
|
||||||
|
max_output_tokens?: number
|
||||||
|
supports_streaming?: boolean
|
||||||
|
supports_function_calling?: boolean
|
||||||
|
tasks?: string[]
|
||||||
|
}
|
||||||
|
|
||||||
interface EmbeddingResult {
|
interface EmbeddingResult {
|
||||||
text: string
|
text: string
|
||||||
embedding: number[]
|
embedding: number[]
|
||||||
@@ -31,7 +48,7 @@ interface SessionStats {
|
|||||||
|
|
||||||
export default function EmbeddingPlayground() {
|
export default function EmbeddingPlayground() {
|
||||||
const [text, setText] = useState('')
|
const [text, setText] = useState('')
|
||||||
const [model, setModel] = useState('text-embedding-ada-002')
|
const [model, setModel] = useState('')
|
||||||
const [encodingFormat, setEncodingFormat] = useState('float')
|
const [encodingFormat, setEncodingFormat] = useState('float')
|
||||||
const [isLoading, setIsLoading] = useState(false)
|
const [isLoading, setIsLoading] = useState(false)
|
||||||
const [results, setResults] = useState<EmbeddingResult[]>([])
|
const [results, setResults] = useState<EmbeddingResult[]>([])
|
||||||
@@ -43,8 +60,51 @@ export default function EmbeddingPlayground() {
|
|||||||
})
|
})
|
||||||
const [selectedResult, setSelectedResult] = useState<EmbeddingResult | null>(null)
|
const [selectedResult, setSelectedResult] = useState<EmbeddingResult | null>(null)
|
||||||
const [comparisonMode, setComparisonMode] = useState(false)
|
const [comparisonMode, setComparisonMode] = useState(false)
|
||||||
|
const [embeddingModels, setEmbeddingModels] = useState<Model[]>([])
|
||||||
|
const [loadingModels, setLoadingModels] = useState(true)
|
||||||
const { toast } = useToast()
|
const { toast } = useToast()
|
||||||
|
|
||||||
|
// Fetch available embedding models
|
||||||
|
useEffect(() => {
|
||||||
|
const fetchModels = async () => {
|
||||||
|
try {
|
||||||
|
setLoadingModels(true)
|
||||||
|
const response = await apiClient.get('/api-internal/v1/llm/models')
|
||||||
|
|
||||||
|
if (response.data) {
|
||||||
|
// Filter models that support embeddings based on tasks field
|
||||||
|
const models = response.data.filter((model: Model) => {
|
||||||
|
// Check if model has embed or embedding in tasks
|
||||||
|
if (model.tasks && Array.isArray(model.tasks)) {
|
||||||
|
return model.tasks.includes('embed') || model.tasks.includes('embedding')
|
||||||
|
}
|
||||||
|
// Fallback: check if model ID contains embedding patterns
|
||||||
|
const modelId = model.id.toLowerCase()
|
||||||
|
return modelId.includes('embed') || modelId.includes('text-embedding')
|
||||||
|
})
|
||||||
|
|
||||||
|
setEmbeddingModels(models)
|
||||||
|
|
||||||
|
// Set default model if available
|
||||||
|
if (models.length > 0 && !model) {
|
||||||
|
setModel(models[0].id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to fetch models:', error)
|
||||||
|
toast({
|
||||||
|
title: "Error",
|
||||||
|
description: "Failed to load embedding models",
|
||||||
|
variant: "destructive"
|
||||||
|
})
|
||||||
|
} finally {
|
||||||
|
setLoadingModels(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fetchModels()
|
||||||
|
}, [])
|
||||||
|
|
||||||
const handleGenerateEmbedding = async () => {
|
const handleGenerateEmbedding = async () => {
|
||||||
if (!text.trim()) {
|
if (!text.trim()) {
|
||||||
toast({
|
toast({
|
||||||
@@ -55,6 +115,15 @@ export default function EmbeddingPlayground() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!model) {
|
||||||
|
toast({
|
||||||
|
title: "Error",
|
||||||
|
description: "Please select an embedding model",
|
||||||
|
variant: "destructive"
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
setIsLoading(true)
|
setIsLoading(true)
|
||||||
try {
|
try {
|
||||||
const data = await apiClient.post('/api-internal/v1/llm/embeddings', {
|
const data = await apiClient.post('/api-internal/v1/llm/embeddings', {
|
||||||
@@ -93,13 +162,34 @@ export default function EmbeddingPlayground() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const calculateCost = (tokens: number, model: string): number => {
|
const calculateCost = (tokens: number, modelId: string): number => {
|
||||||
|
// Known rates for common embedding models
|
||||||
const rates: { [key: string]: number } = {
|
const rates: { [key: string]: number } = {
|
||||||
'text-embedding-ada-002': 0.0001,
|
'text-embedding-ada-002': 0.0001,
|
||||||
'text-embedding-3-small': 0.00002,
|
'text-embedding-3-small': 0.00002,
|
||||||
'text-embedding-3-large': 0.00013
|
'text-embedding-3-large': 0.00013,
|
||||||
|
'privatemode-text-embedding-ada-002': 0.0001,
|
||||||
|
'privatemode-text-embedding-3-small': 0.00002,
|
||||||
|
'privatemode-text-embedding-3-large': 0.00013
|
||||||
}
|
}
|
||||||
return (tokens / 1000) * (rates[model] || 0.0001)
|
|
||||||
|
// Check for exact match first
|
||||||
|
if (rates[modelId]) {
|
||||||
|
return (tokens / 1000) * rates[modelId]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for pattern matches (e.g., if model contains these patterns)
|
||||||
|
const modelLower = modelId.toLowerCase()
|
||||||
|
if (modelLower.includes('ada-002')) {
|
||||||
|
return (tokens / 1000) * 0.0001
|
||||||
|
} else if (modelLower.includes('3-small')) {
|
||||||
|
return (tokens / 1000) * 0.00002
|
||||||
|
} else if (modelLower.includes('3-large')) {
|
||||||
|
return (tokens / 1000) * 0.00013
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default rate for unknown models
|
||||||
|
return (tokens / 1000) * 0.0001
|
||||||
}
|
}
|
||||||
|
|
||||||
const updateSessionStats = (result: EmbeddingResult) => {
|
const updateSessionStats = (result: EmbeddingResult) => {
|
||||||
@@ -157,14 +247,27 @@ export default function EmbeddingPlayground() {
|
|||||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||||
<div>
|
<div>
|
||||||
<label className="text-sm font-medium mb-2 block">Model</label>
|
<label className="text-sm font-medium mb-2 block">Model</label>
|
||||||
<Select value={model} onValueChange={setModel}>
|
<Select value={model} onValueChange={setModel} disabled={loadingModels}>
|
||||||
<SelectTrigger>
|
<SelectTrigger>
|
||||||
<SelectValue />
|
<SelectValue placeholder={loadingModels ? "Loading models..." : "Select a model"} />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent>
|
<SelectContent>
|
||||||
<SelectItem value="text-embedding-ada-002">text-embedding-ada-002</SelectItem>
|
{embeddingModels.length === 0 && !loadingModels ? (
|
||||||
<SelectItem value="text-embedding-3-small">text-embedding-3-small</SelectItem>
|
<SelectItem value="no-models" disabled>
|
||||||
<SelectItem value="text-embedding-3-large">text-embedding-3-large</SelectItem>
|
No embedding models available
|
||||||
|
</SelectItem>
|
||||||
|
) : (
|
||||||
|
embeddingModels.map((embModel) => (
|
||||||
|
<SelectItem key={embModel.id} value={embModel.id}>
|
||||||
|
{embModel.id}
|
||||||
|
{embModel.owned_by && embModel.owned_by !== 'unknown' && (
|
||||||
|
<span className="text-muted-foreground ml-2">
|
||||||
|
({embModel.owned_by})
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</SelectItem>
|
||||||
|
))
|
||||||
|
)}
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ interface Model {
|
|||||||
max_output_tokens?: number
|
max_output_tokens?: number
|
||||||
supports_streaming?: boolean
|
supports_streaming?: boolean
|
||||||
supports_function_calling?: boolean
|
supports_function_calling?: boolean
|
||||||
|
tasks?: string[] // Added tasks field from PrivateMode API
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ProviderStatus {
|
interface ProviderStatus {
|
||||||
@@ -97,11 +98,21 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
return 'Unknown'
|
return 'Unknown'
|
||||||
}
|
}
|
||||||
|
|
||||||
const getModelType = (modelId: string): 'chat' | 'embedding' | 'other' => {
|
const getModelType = (model: Model): 'chat' | 'embedding' | 'other' => {
|
||||||
|
// Check if model has tasks field from PrivateMode or other providers
|
||||||
|
if (model.tasks && Array.isArray(model.tasks)) {
|
||||||
|
// Models with "generate" task are chat models
|
||||||
|
if (model.tasks.includes('generate')) return 'chat'
|
||||||
|
// Models with "embed" task are embedding models
|
||||||
|
if (model.tasks.includes('embed') || model.tasks.includes('embedding')) return 'embedding'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to ID-based detection for models without tasks field
|
||||||
|
const modelId = model.id
|
||||||
if (modelId.includes('embedding') || modelId.includes('embed')) return 'embedding'
|
if (modelId.includes('embedding') || modelId.includes('embed')) return 'embedding'
|
||||||
if (modelId.includes('whisper') || modelId.includes('speech')) return 'other' // Audio models
|
if (modelId.includes('whisper') || modelId.includes('speech')) return 'other' // Audio models
|
||||||
|
|
||||||
// PrivateMode and other chat models
|
// PrivateMode and other chat models by ID pattern
|
||||||
if (
|
if (
|
||||||
modelId.startsWith('privatemode-llama') ||
|
modelId.startsWith('privatemode-llama') ||
|
||||||
modelId.startsWith('privatemode-claude') ||
|
modelId.startsWith('privatemode-claude') ||
|
||||||
@@ -114,14 +125,16 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
modelId.includes('llama') ||
|
modelId.includes('llama') ||
|
||||||
modelId.includes('gemma') ||
|
modelId.includes('gemma') ||
|
||||||
modelId.includes('qwen') ||
|
modelId.includes('qwen') ||
|
||||||
|
modelId.includes('mistral') ||
|
||||||
|
modelId.includes('command') ||
|
||||||
modelId.includes('latest')
|
modelId.includes('latest')
|
||||||
) return 'chat'
|
) return 'chat'
|
||||||
|
|
||||||
return 'other'
|
return 'other'
|
||||||
}
|
}
|
||||||
|
|
||||||
const getModelCategory = (modelId: string): string => {
|
const getModelCategory = (model: Model): string => {
|
||||||
const type = getModelType(modelId)
|
const type = getModelType(model)
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case 'chat': return 'Chat Completion'
|
case 'chat': return 'Chat Completion'
|
||||||
case 'embedding': return 'Text Embedding'
|
case 'embedding': return 'Text Embedding'
|
||||||
@@ -132,7 +145,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
|
|
||||||
const filteredModels = models.filter(model => {
|
const filteredModels = models.filter(model => {
|
||||||
if (filter === 'all') return true
|
if (filter === 'all') return true
|
||||||
return getModelType(model.id) === filter
|
return getModelType(model) === filter
|
||||||
})
|
})
|
||||||
|
|
||||||
const groupedModels = filteredModels.reduce((acc, model) => {
|
const groupedModels = filteredModels.reduce((acc, model) => {
|
||||||
@@ -255,7 +268,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
<span>{model.id}</span>
|
<span>{model.id}</span>
|
||||||
<div className="flex gap-1">
|
<div className="flex gap-1">
|
||||||
<Badge variant="outline" className="text-xs">
|
<Badge variant="outline" className="text-xs">
|
||||||
{getModelCategory(model.id)}
|
{getModelCategory(model)}
|
||||||
</Badge>
|
</Badge>
|
||||||
{model.supports_streaming && (
|
{model.supports_streaming && (
|
||||||
<Badge variant="secondary" className="text-xs">
|
<Badge variant="secondary" className="text-xs">
|
||||||
@@ -299,7 +312,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium">Type:</span>
|
<span className="font-medium">Type:</span>
|
||||||
<div className="text-muted-foreground">{getModelCategory(selectedModel.id)}</div>
|
<div className="text-muted-foreground">{getModelCategory(selectedModel)}</div>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium">Object:</span>
|
<span className="font-medium">Object:</span>
|
||||||
|
|||||||
Reference in New Issue
Block a user