refactor: used recommended_models function as default model (#209)

This commit is contained in:
Lifei Zhou
2024-10-31 13:57:16 +11:00
committed by GitHub
parent aca1a6872d
commit 0ab1966e93
6 changed files with 26 additions and 16 deletions

View File

@@ -157,6 +157,11 @@ class AnthropicProvider(Provider):
return message, usage
@staticmethod
def recommended_models() -> tuple[str, str]:
"""Return the recommended model and processor for this provider"""
return "claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022"
@retry_procedure
def _post(self, payload: dict) -> httpx.Response:
response = self.client.post(ANTHROPIC_HOST, json=payload)

View File

@@ -50,6 +50,11 @@ class Provider(ABC):
"""Generate the next message using the specified model"""
pass
@staticmethod
def recommended_models() -> tuple[str, str]:
"""Return the recommended model and processor for this provider"""
return "gpt-4o", "gpt-4o-mini"
class MissingProviderEnvVariableError(Exception):
def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None:

View File

@@ -90,6 +90,11 @@ class DatabricksProvider(Provider):
usage = self.get_usage(response)
return message, usage
@staticmethod
def recommended_models() -> tuple[str, str]:
"""Return the recommended model and processor for this provider"""
return "databricks-meta-llama-3-1-70b-instruct", "databricks-meta-llama-3-1-70b-instruct"
@retry_procedure
def _post(self, model: str, payload: dict) -> httpx.Response:
response = self.client.post(

View File

@@ -165,3 +165,8 @@ class GoogleProvider(Provider):
def _post(self, payload: dict, model: str) -> httpx.Response:
response = self.client.post("models/" + model + ":generateContent", json=payload)
return raise_for_status(response).json()
@staticmethod
def recommended_models() -> tuple[str, str]:
"""Return the recommended model and processor for this provider"""
return "gemini-1.5-flash", "gemini-1.5-flash"

View File

@@ -43,3 +43,8 @@ ollama:
# When served by Ollama, the OpenAI API is available at the path "v1/".
client = httpx.Client(base_url=ollama_url + "v1/", timeout=timeout)
return cls(client)
@staticmethod
def recommended_models() -> tuple[str, str]:
"""Return the recommended model and processor for this provider"""
return OLLAMA_MODEL, OLLAMA_MODEL

View File

@@ -6,7 +6,6 @@ from rich import print
from rich.panel import Panel
from ruamel.yaml import YAML
from exchange.providers.ollama import OLLAMA_MODEL
from goose.profile import Profile
from goose.utils import load_plugins
@@ -90,19 +89,5 @@ def default_model_configuration() -> tuple[str, str, str]:
pass
else:
provider = RECOMMENDED_DEFAULT_PROVIDER
recommended = {
"ollama": (OLLAMA_MODEL, OLLAMA_MODEL),
"openai": ("gpt-4o", "gpt-4o-mini"),
"anthropic": (
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-20241022",
),
"databricks": (
# TODO when function calling is first rec should be: "databricks-meta-llama-3-1-405b-instruct"
"databricks-meta-llama-3-1-70b-instruct",
"databricks-meta-llama-3-1-70b-instruct",
),
"google": ("gemini-1.5-flash", "gemini-1.5-flash"),
}
processor, accelerator = recommended.get(provider, ("gpt-4o", "gpt-4o-mini"))
processor, accelerator = providers.get(provider).recommended_models()
return provider, processor, accelerator