mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-18 16:04:28 +01:00
embedding models picker improvement
This commit is contained in:
@@ -134,18 +134,39 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# Extract model capabilities from PrivateMode response
|
||||
# Extract all information directly from API response
|
||||
# Determine capabilities based on tasks field
|
||||
tasks = model_data.get("tasks", [])
|
||||
capabilities = []
|
||||
|
||||
# All PrivateMode models have TEE capability
|
||||
capabilities.append("tee")
|
||||
|
||||
# Add capabilities based on tasks
|
||||
if "generate" in tasks:
|
||||
capabilities.append("chat")
|
||||
if "embed" in tasks or "embedding" in tasks:
|
||||
capabilities.append("embeddings")
|
||||
if "vision" in tasks:
|
||||
capabilities.append("vision")
|
||||
|
||||
# Check for function calling support in the API response
|
||||
supports_function_calling = model_data.get("supports_function_calling", False)
|
||||
if supports_function_calling:
|
||||
capabilities.append("function_calling")
|
||||
|
||||
model_info = ModelInfo(
|
||||
id=model_id,
|
||||
object="model",
|
||||
created=model_data.get("created", int(time.time())),
|
||||
owned_by="privatemode",
|
||||
owned_by=model_data.get("owned_by", "privatemode"),
|
||||
provider=self.provider_name,
|
||||
capabilities=self._get_model_capabilities(model_id),
|
||||
context_window=self._get_model_context_window(model_id),
|
||||
max_output_tokens=self._get_model_max_output(model_id),
|
||||
supports_streaming=True, # PrivateMode supports streaming
|
||||
supports_function_calling=self._supports_function_calling(model_id)
|
||||
capabilities=capabilities,
|
||||
context_window=model_data.get("context_window"),
|
||||
max_output_tokens=model_data.get("max_output_tokens"),
|
||||
supports_streaming=model_data.get("supports_streaming", True),
|
||||
supports_function_calling=supports_function_calling,
|
||||
tasks=tasks # Pass through tasks field from PrivateMode API
|
||||
)
|
||||
models.append(model_info)
|
||||
|
||||
@@ -453,68 +474,6 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
def _get_model_capabilities(self, model_id: str) -> List[str]:
|
||||
"""Get capabilities for a specific model"""
|
||||
capabilities = ["chat"]
|
||||
|
||||
# PrivateMode supports embeddings for most models
|
||||
if "embed" in model_id.lower() or model_id in [
|
||||
"privatemode-llama-3.1-405b", "privatemode-llama-3.1-70b",
|
||||
"privatemode-claude-3.5-sonnet", "privatemode-gpt-4o"
|
||||
]:
|
||||
capabilities.append("embeddings")
|
||||
|
||||
# TEE protection is available for all PrivateMode models
|
||||
capabilities.append("tee")
|
||||
|
||||
return capabilities
|
||||
|
||||
def _get_model_context_window(self, model_id: str) -> Optional[int]:
|
||||
"""Get context window size for a specific model"""
|
||||
context_windows = {
|
||||
"privatemode-llama-3.1-405b": 128000,
|
||||
"privatemode-llama-3.1-70b": 128000,
|
||||
"privatemode-llama-3.1-8b": 128000,
|
||||
"privatemode-llama-3-70b": 8192,
|
||||
"privatemode-llama-3-8b": 8192,
|
||||
"privatemode-claude-3.5-sonnet": 200000,
|
||||
"privatemode-claude-3-haiku": 200000,
|
||||
"privatemode-gpt-4o": 128000,
|
||||
"privatemode-gpt-4o-mini": 128000,
|
||||
"privatemode-gemini-1.5-pro": 2000000,
|
||||
"privatemode-gemini-1.5-flash": 1000000
|
||||
}
|
||||
|
||||
return context_windows.get(model_id, 8192) # Default to 8K
|
||||
|
||||
def _get_model_max_output(self, model_id: str) -> Optional[int]:
|
||||
"""Get max output tokens for a specific model"""
|
||||
max_outputs = {
|
||||
"privatemode-llama-3.1-405b": 8192,
|
||||
"privatemode-llama-3.1-70b": 8192,
|
||||
"privatemode-llama-3.1-8b": 8192,
|
||||
"privatemode-llama-3-70b": 4096,
|
||||
"privatemode-llama-3-8b": 4096,
|
||||
"privatemode-claude-3.5-sonnet": 8192,
|
||||
"privatemode-claude-3-haiku": 4096,
|
||||
"privatemode-gpt-4o": 16384,
|
||||
"privatemode-gpt-4o-mini": 16384,
|
||||
"privatemode-gemini-1.5-pro": 8192,
|
||||
"privatemode-gemini-1.5-flash": 8192
|
||||
}
|
||||
|
||||
return max_outputs.get(model_id, 4096) # Default to 4K
|
||||
|
||||
def _supports_function_calling(self, model_id: str) -> bool:
|
||||
"""Check if model supports function calling"""
|
||||
function_calling_models = [
|
||||
"privatemode-gpt-4o", "privatemode-gpt-4o-mini",
|
||||
"privatemode-claude-3.5-sonnet", "privatemode-claude-3-haiku",
|
||||
"privatemode-gemini-1.5-pro", "privatemode-gemini-1.5-flash"
|
||||
]
|
||||
|
||||
return model_id in function_calling_models
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup PrivateMode provider resources"""
|
||||
await super().cleanup()
|
||||
|
||||
Reference in New Issue
Block a user