embedding models picker improvement

This commit is contained in:
2025-09-13 15:37:42 +02:00
parent e97d197d8b
commit 9c64177925
7 changed files with 177 additions and 91 deletions

View File

@@ -134,18 +134,39 @@ class PrivateModeProvider(BaseLLMProvider):
if not model_id:
continue
# Extract model capabilities from PrivateMode response
# Extract all information directly from API response
# Determine capabilities based on tasks field
tasks = model_data.get("tasks", [])
capabilities = []
# All PrivateMode models have TEE capability
capabilities.append("tee")
# Add capabilities based on tasks
if "generate" in tasks:
capabilities.append("chat")
if "embed" in tasks or "embedding" in tasks:
capabilities.append("embeddings")
if "vision" in tasks:
capabilities.append("vision")
# Check for function calling support in the API response
supports_function_calling = model_data.get("supports_function_calling", False)
if supports_function_calling:
capabilities.append("function_calling")
model_info = ModelInfo(
id=model_id,
object="model",
created=model_data.get("created", int(time.time())),
owned_by="privatemode",
owned_by=model_data.get("owned_by", "privatemode"),
provider=self.provider_name,
capabilities=self._get_model_capabilities(model_id),
context_window=self._get_model_context_window(model_id),
max_output_tokens=self._get_model_max_output(model_id),
supports_streaming=True, # PrivateMode supports streaming
supports_function_calling=self._supports_function_calling(model_id)
capabilities=capabilities,
context_window=model_data.get("context_window"),
max_output_tokens=model_data.get("max_output_tokens"),
supports_streaming=model_data.get("supports_streaming", True),
supports_function_calling=supports_function_calling,
tasks=tasks # Pass through tasks field from PrivateMode API
)
models.append(model_info)
@@ -453,68 +474,6 @@ class PrivateModeProvider(BaseLLMProvider):
details={"error": str(e)}
)
def _get_model_capabilities(self, model_id: str) -> List[str]:
"""Get capabilities for a specific model"""
capabilities = ["chat"]
# PrivateMode supports embeddings for most models
if "embed" in model_id.lower() or model_id in [
"privatemode-llama-3.1-405b", "privatemode-llama-3.1-70b",
"privatemode-claude-3.5-sonnet", "privatemode-gpt-4o"
]:
capabilities.append("embeddings")
# TEE protection is available for all PrivateMode models
capabilities.append("tee")
return capabilities
def _get_model_context_window(self, model_id: str) -> Optional[int]:
"""Get context window size for a specific model"""
context_windows = {
"privatemode-llama-3.1-405b": 128000,
"privatemode-llama-3.1-70b": 128000,
"privatemode-llama-3.1-8b": 128000,
"privatemode-llama-3-70b": 8192,
"privatemode-llama-3-8b": 8192,
"privatemode-claude-3.5-sonnet": 200000,
"privatemode-claude-3-haiku": 200000,
"privatemode-gpt-4o": 128000,
"privatemode-gpt-4o-mini": 128000,
"privatemode-gemini-1.5-pro": 2000000,
"privatemode-gemini-1.5-flash": 1000000
}
return context_windows.get(model_id, 8192) # Default to 8K
def _get_model_max_output(self, model_id: str) -> Optional[int]:
"""Get max output tokens for a specific model"""
max_outputs = {
"privatemode-llama-3.1-405b": 8192,
"privatemode-llama-3.1-70b": 8192,
"privatemode-llama-3.1-8b": 8192,
"privatemode-llama-3-70b": 4096,
"privatemode-llama-3-8b": 4096,
"privatemode-claude-3.5-sonnet": 8192,
"privatemode-claude-3-haiku": 4096,
"privatemode-gpt-4o": 16384,
"privatemode-gpt-4o-mini": 16384,
"privatemode-gemini-1.5-pro": 8192,
"privatemode-gemini-1.5-flash": 8192
}
return max_outputs.get(model_id, 4096) # Default to 4K
def _supports_function_calling(self, model_id: str) -> bool:
"""Check if model supports function calling"""
function_calling_models = [
"privatemode-gpt-4o", "privatemode-gpt-4o-mini",
"privatemode-claude-3.5-sonnet", "privatemode-claude-3-haiku",
"privatemode-gemini-1.5-pro", "privatemode-gemini-1.5-flash"
]
return model_id in function_calling_models
async def cleanup(self):
"""Cleanup PrivateMode provider resources"""
await super().cleanup()