mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-23 15:34:27 +01:00
chore: Minor changes in providers envs logic (#161)
Co-authored-by: Lam Chau <lam@tbd.email>
This commit is contained in:
@@ -14,19 +14,29 @@ class Usage:
|
||||
total_tokens: int = field(default=None)
|
||||
|
||||
|
||||
class EmptyProviderNameError(Exception):
|
||||
def __init__(self, provider_cls: str) -> None:
|
||||
self.message = f"The provider class '{provider_cls}' has an empty PROVIDER_NAME."
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
PROVIDER_NAME: str
|
||||
REQUIRED_ENV_VARS: list[str] = []
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: type["Provider"]) -> "Provider":
|
||||
if not cls.PROVIDER_NAME:
|
||||
raise EmptyProviderNameError(cls.__name__)
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def check_env_vars(cls: type["Provider"], instructions_url: Optional[str] = None) -> None:
|
||||
for env_var in cls.REQUIRED_ENV_VARS:
|
||||
if env_var not in os.environ:
|
||||
raise MissingProviderEnvVariableError(env_var, cls.PROVIDER_NAME, instructions_url)
|
||||
missing_vars = [x for x in cls.REQUIRED_ENV_VARS if x not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
env_vars = ", ".join(missing_vars)
|
||||
raise MissingProviderEnvVariableError(env_vars, cls.PROVIDER_NAME, instructions_url)
|
||||
|
||||
@abstractmethod
|
||||
def complete(
|
||||
@@ -46,7 +56,7 @@ class MissingProviderEnvVariableError(Exception):
|
||||
self.env_variable = env_variable
|
||||
self.provider = provider
|
||||
self.instructions_url = instructions_url
|
||||
self.message = f"Missing environment variable: {env_variable} for provider {provider}."
|
||||
self.message = f"Missing environment variables: {env_variable} for provider {provider}."
|
||||
if instructions_url:
|
||||
self.message += f"\nPlease see {instructions_url} for instructions"
|
||||
super().__init__(self.message)
|
||||
|
||||
@@ -24,6 +24,7 @@ ollama:
|
||||
- name: developer
|
||||
requires: {{}}
|
||||
"""
|
||||
PROVIDER_NAME = "ollama"
|
||||
|
||||
def __init__(self, client: httpx.Client) -> None:
|
||||
print("PLEASE NOTE: the ollama provider is experimental, use with care")
|
||||
|
||||
@@ -32,7 +32,7 @@ def test_from_env_throw_error_when_missing_api_key():
|
||||
AnthropicProvider.from_env()
|
||||
assert context.value.provider == "anthropic"
|
||||
assert context.value.env_variable == "ANTHROPIC_API_KEY"
|
||||
assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic."
|
||||
assert context.value.message == "Missing environment variables: ANTHROPIC_API_KEY for provider anthropic."
|
||||
|
||||
|
||||
def test_anthropic_response_to_text_message() -> None:
|
||||
|
||||
@@ -2,10 +2,10 @@ import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from exchange import Text, ToolUse
|
||||
from exchange.providers.azure import AzureProvider
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
|
||||
from .conftest import complete, tools
|
||||
|
||||
AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini")
|
||||
@@ -36,7 +36,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name):
|
||||
AzureProvider.from_env()
|
||||
assert context.value.provider == "azure"
|
||||
assert context.value.env_variable == env_var_name
|
||||
assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure."
|
||||
assert context.value.message == f"Missing environment variables: {env_var_name} for provider azure."
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
|
||||
@@ -35,7 +35,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name):
|
||||
BedrockProvider.from_env()
|
||||
assert context.value.provider == "bedrock"
|
||||
assert context.value.env_variable == env_var_name
|
||||
assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock."
|
||||
assert context.value.message == f"Missing environment variables: {env_var_name} for provider bedrock."
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -28,7 +28,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name):
|
||||
DatabricksProvider.from_env()
|
||||
assert context.value.provider == "databricks"
|
||||
assert context.value.env_variable == env_var_name
|
||||
assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message
|
||||
assert f"Missing environment variables: {env_var_name} for provider databricks" in context.value.message
|
||||
assert "https://docs.databricks.com" in context.value.message
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ def test_from_env_throw_error_when_missing_api_key():
|
||||
GoogleProvider.from_env()
|
||||
assert context.value.provider == "google"
|
||||
assert context.value.env_variable == "GOOGLE_API_KEY"
|
||||
assert "Missing environment variable: GOOGLE_API_KEY for provider google" in context.value.message
|
||||
assert "Missing environment variables: GOOGLE_API_KEY for provider google" in context.value.message
|
||||
assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from exchange import Text, ToolUse
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
from exchange.providers.openai import OpenAiProvider
|
||||
@@ -17,7 +16,7 @@ def test_from_env_throw_error_when_missing_api_key():
|
||||
OpenAiProvider.from_env()
|
||||
assert context.value.provider == "openai"
|
||||
assert context.value.env_variable == "OPENAI_API_KEY"
|
||||
assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message
|
||||
assert "Missing environment variables: OPENAI_API_KEY for provider openai" in context.value.message
|
||||
assert "https://platform.openai.com" in context.value.message
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ def test_missing_provider_env_variable_error_without_instructions_url():
|
||||
assert error.env_variable == env_variable
|
||||
assert error.provider == provider
|
||||
assert error.instructions_url is None
|
||||
assert error.message == "Missing environment variable: API_KEY for provider TestProvider."
|
||||
assert error.message == "Missing environment variables: API_KEY for provider TestProvider."
|
||||
|
||||
|
||||
def test_missing_provider_env_variable_error_with_instructions_url():
|
||||
@@ -22,6 +22,6 @@ def test_missing_provider_env_variable_error_with_instructions_url():
|
||||
assert error.provider == provider
|
||||
assert error.instructions_url == instructions_url
|
||||
assert error.message == (
|
||||
"Missing environment variable: API_KEY for provider TestProvider.\n"
|
||||
"Missing environment variables: API_KEY for provider TestProvider.\n"
|
||||
"Please see http://example.com/instructions for instructions"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user