diff --git a/packages/exchange/src/exchange/providers/anthropic.py b/packages/exchange/src/exchange/providers/anthropic.py index c98c6d43..9f4b72d7 100644 --- a/packages/exchange/src/exchange/providers/anthropic.py +++ b/packages/exchange/src/exchange/providers/anthropic.py @@ -157,6 +157,11 @@ class AnthropicProvider(Provider): return message, usage + @staticmethod + def recommended_models() -> tuple[str, str]: + """Return the recommended model and processor for this provider""" + return "claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022" + @retry_procedure def _post(self, payload: dict) -> httpx.Response: response = self.client.post(ANTHROPIC_HOST, json=payload) diff --git a/packages/exchange/src/exchange/providers/base.py b/packages/exchange/src/exchange/providers/base.py index 3d48135d..3ff17591 100644 --- a/packages/exchange/src/exchange/providers/base.py +++ b/packages/exchange/src/exchange/providers/base.py @@ -50,6 +50,11 @@ class Provider(ABC): """Generate the next message using the specified model""" pass + @staticmethod + def recommended_models() -> tuple[str, str]: + """Return the recommended model and processor for this provider""" + return "gpt-4o", "gpt-4o-mini" + class MissingProviderEnvVariableError(Exception): def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None: diff --git a/packages/exchange/src/exchange/providers/databricks.py b/packages/exchange/src/exchange/providers/databricks.py index b8f92dca..517ccee6 100644 --- a/packages/exchange/src/exchange/providers/databricks.py +++ b/packages/exchange/src/exchange/providers/databricks.py @@ -90,6 +90,11 @@ class DatabricksProvider(Provider): usage = self.get_usage(response) return message, usage + @staticmethod + def recommended_models() -> tuple[str, str]: + """Return the recommended model and processor for this provider""" + return "databricks-meta-llama-3-1-70b-instruct", "databricks-meta-llama-3-1-70b-instruct" + @retry_procedure def _post(self, model: str, payload: dict) -> httpx.Response: response = self.client.post( diff --git a/packages/exchange/src/exchange/providers/google.py b/packages/exchange/src/exchange/providers/google.py index 76ccd7a9..bfb1faf0 100644 --- a/packages/exchange/src/exchange/providers/google.py +++ b/packages/exchange/src/exchange/providers/google.py @@ -165,3 +165,8 @@ class GoogleProvider(Provider): def _post(self, payload: dict, model: str) -> httpx.Response: response = self.client.post("models/" + model + ":generateContent", json=payload) return raise_for_status(response).json() + + @staticmethod + def recommended_models() -> tuple[str, str]: + """Return the recommended model and processor for this provider""" + return "gemini-1.5-flash", "gemini-1.5-flash" diff --git a/packages/exchange/src/exchange/providers/ollama.py b/packages/exchange/src/exchange/providers/ollama.py index 1f8e8fe5..23079315 100644 --- a/packages/exchange/src/exchange/providers/ollama.py +++ b/packages/exchange/src/exchange/providers/ollama.py @@ -43,3 +43,8 @@ ollama: # When served by Ollama, the OpenAI API is available at the path "v1/". client = httpx.Client(base_url=ollama_url + "v1/", timeout=timeout) return cls(client) + + @staticmethod + def recommended_models() -> tuple[str, str]: + """Return the recommended model and processor for this provider""" + return OLLAMA_MODEL, OLLAMA_MODEL diff --git a/src/goose/cli/config.py b/src/goose/cli/config.py index f0cd55e5..8b9e4fe9 100644 --- a/src/goose/cli/config.py +++ b/src/goose/cli/config.py @@ -6,7 +6,6 @@ from rich import print from rich.panel import Panel from ruamel.yaml import YAML -from exchange.providers.ollama import OLLAMA_MODEL from goose.profile import Profile from goose.utils import load_plugins @@ -90,19 +89,5 @@ def default_model_configuration() -> tuple[str, str, str]: pass else: provider = RECOMMENDED_DEFAULT_PROVIDER - recommended = { - "ollama": (OLLAMA_MODEL, OLLAMA_MODEL), - "openai": ("gpt-4o", "gpt-4o-mini"), - "anthropic": ( - "claude-3-5-sonnet-20241022", - "claude-3-5-sonnet-20241022", - ), - "databricks": ( - # TODO when function calling is first rec should be: "databricks-meta-llama-3-1-405b-instruct" - "databricks-meta-llama-3-1-70b-instruct", - "databricks-meta-llama-3-1-70b-instruct", - ), - "google": ("gemini-1.5-flash", "gemini-1.5-flash"), - } - processor, accelerator = recommended.get(provider, ("gpt-4o", "gpt-4o-mini")) + processor, accelerator = providers.get(provider).recommended_models() return provider, processor, accelerator