mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-28 03:24:21 +01:00
refactor: used recommended_models function as default model (#209)
This commit is contained in:
@@ -157,6 +157,11 @@ class AnthropicProvider(Provider):
|
||||
|
||||
return message, usage
|
||||
|
||||
@staticmethod
|
||||
def recommended_models() -> tuple[str, str]:
|
||||
"""Return the recommended model and processor for this provider"""
|
||||
return "claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022"
|
||||
|
||||
@retry_procedure
|
||||
def _post(self, payload: dict) -> httpx.Response:
|
||||
response = self.client.post(ANTHROPIC_HOST, json=payload)
|
||||
|
||||
@@ -50,6 +50,11 @@ class Provider(ABC):
|
||||
"""Generate the next message using the specified model"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def recommended_models() -> tuple[str, str]:
|
||||
"""Return the recommended model and processor for this provider"""
|
||||
return "gpt-4o", "gpt-4o-mini"
|
||||
|
||||
|
||||
class MissingProviderEnvVariableError(Exception):
|
||||
def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None:
|
||||
|
||||
@@ -90,6 +90,11 @@ class DatabricksProvider(Provider):
|
||||
usage = self.get_usage(response)
|
||||
return message, usage
|
||||
|
||||
@staticmethod
|
||||
def recommended_models() -> tuple[str, str]:
|
||||
"""Return the recommended model and processor for this provider"""
|
||||
return "databricks-meta-llama-3-1-70b-instruct", "databricks-meta-llama-3-1-70b-instruct"
|
||||
|
||||
@retry_procedure
|
||||
def _post(self, model: str, payload: dict) -> httpx.Response:
|
||||
response = self.client.post(
|
||||
|
||||
@@ -165,3 +165,8 @@ class GoogleProvider(Provider):
|
||||
def _post(self, payload: dict, model: str) -> httpx.Response:
|
||||
response = self.client.post("models/" + model + ":generateContent", json=payload)
|
||||
return raise_for_status(response).json()
|
||||
|
||||
@staticmethod
|
||||
def recommended_models() -> tuple[str, str]:
|
||||
"""Return the recommended model and processor for this provider"""
|
||||
return "gemini-1.5-flash", "gemini-1.5-flash"
|
||||
|
||||
@@ -43,3 +43,8 @@ ollama:
|
||||
# When served by Ollama, the OpenAI API is available at the path "v1/".
|
||||
client = httpx.Client(base_url=ollama_url + "v1/", timeout=timeout)
|
||||
return cls(client)
|
||||
|
||||
@staticmethod
|
||||
def recommended_models() -> tuple[str, str]:
|
||||
"""Return the recommended model and processor for this provider"""
|
||||
return OLLAMA_MODEL, OLLAMA_MODEL
|
||||
|
||||
@@ -6,7 +6,6 @@ from rich import print
|
||||
from rich.panel import Panel
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from exchange.providers.ollama import OLLAMA_MODEL
|
||||
|
||||
from goose.profile import Profile
|
||||
from goose.utils import load_plugins
|
||||
@@ -90,19 +89,5 @@ def default_model_configuration() -> tuple[str, str, str]:
|
||||
pass
|
||||
else:
|
||||
provider = RECOMMENDED_DEFAULT_PROVIDER
|
||||
recommended = {
|
||||
"ollama": (OLLAMA_MODEL, OLLAMA_MODEL),
|
||||
"openai": ("gpt-4o", "gpt-4o-mini"),
|
||||
"anthropic": (
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
),
|
||||
"databricks": (
|
||||
# TODO when function calling is first rec should be: "databricks-meta-llama-3-1-405b-instruct"
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
),
|
||||
"google": ("gemini-1.5-flash", "gemini-1.5-flash"),
|
||||
}
|
||||
processor, accelerator = recommended.get(provider, ("gpt-4o", "gpt-4o-mini"))
|
||||
processor, accelerator = providers.get(provider).recommended_models()
|
||||
return provider, processor, accelerator
|
||||
|
||||
Reference in New Issue
Block a user