From c75917cf371158a6b58a848d3baba2b4f15aa8ab Mon Sep 17 00:00:00 2001 From: Elena Zherdeva <107525751+elenazherdeva@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:20:17 -0700 Subject: [PATCH] chore: Minor changes in providers envs logic (#161) Co-authored-by: Lam Chau --- .../exchange/src/exchange/providers/base.py | 18 ++++++++++++++---- .../exchange/src/exchange/providers/ollama.py | 1 + .../exchange/tests/providers/test_anthropic.py | 2 +- .../exchange/tests/providers/test_azure.py | 4 ++-- .../exchange/tests/providers/test_bedrock.py | 2 +- .../tests/providers/test_databricks.py | 2 +- .../exchange/tests/providers/test_google.py | 2 +- .../exchange/tests/providers/test_openai.py | 3 +-- packages/exchange/tests/test_base.py | 4 ++-- 9 files changed, 24 insertions(+), 14 deletions(-) diff --git a/packages/exchange/src/exchange/providers/base.py b/packages/exchange/src/exchange/providers/base.py index 76b1c339..3d48135d 100644 --- a/packages/exchange/src/exchange/providers/base.py +++ b/packages/exchange/src/exchange/providers/base.py @@ -14,19 +14,29 @@ class Usage: total_tokens: int = field(default=None) +class EmptyProviderNameError(Exception): + def __init__(self, provider_cls: str) -> None: + self.message = f"The provider class '{provider_cls}' has an empty PROVIDER_NAME." + super().__init__(self.message) + + class Provider(ABC): PROVIDER_NAME: str REQUIRED_ENV_VARS: list[str] = [] @classmethod def from_env(cls: type["Provider"]) -> "Provider": + if not cls.PROVIDER_NAME: + raise EmptyProviderNameError(cls.__name__) return cls() @classmethod def check_env_vars(cls: type["Provider"], instructions_url: Optional[str] = None) -> None: - for env_var in cls.REQUIRED_ENV_VARS: - if env_var not in os.environ: - raise MissingProviderEnvVariableError(env_var, cls.PROVIDER_NAME, instructions_url) + missing_vars = [x for x in cls.REQUIRED_ENV_VARS if x not in os.environ] + + if missing_vars: + env_vars = ", ".join(missing_vars) + raise MissingProviderEnvVariableError(env_vars, cls.PROVIDER_NAME, instructions_url) @abstractmethod def complete( @@ -46,7 +56,7 @@ class MissingProviderEnvVariableError(Exception): self.env_variable = env_variable self.provider = provider self.instructions_url = instructions_url - self.message = f"Missing environment variable: {env_variable} for provider {provider}." + self.message = f"Missing environment variables: {env_variable} for provider {provider}." if instructions_url: self.message += f"\nPlease see {instructions_url} for instructions" super().__init__(self.message) diff --git a/packages/exchange/src/exchange/providers/ollama.py b/packages/exchange/src/exchange/providers/ollama.py index 51fef510..1f8e8fe5 100644 --- a/packages/exchange/src/exchange/providers/ollama.py +++ b/packages/exchange/src/exchange/providers/ollama.py @@ -24,6 +24,7 @@ ollama: - name: developer requires: {{}} """ + PROVIDER_NAME = "ollama" def __init__(self, client: httpx.Client) -> None: print("PLEASE NOTE: the ollama provider is experimental, use with care") diff --git a/packages/exchange/tests/providers/test_anthropic.py b/packages/exchange/tests/providers/test_anthropic.py index daf92527..c269d706 100644 --- a/packages/exchange/tests/providers/test_anthropic.py +++ b/packages/exchange/tests/providers/test_anthropic.py @@ -32,7 +32,7 @@ def test_from_env_throw_error_when_missing_api_key(): AnthropicProvider.from_env() assert context.value.provider == "anthropic" assert context.value.env_variable == "ANTHROPIC_API_KEY" - assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic." + assert context.value.message == "Missing environment variables: ANTHROPIC_API_KEY for provider anthropic." def test_anthropic_response_to_text_message() -> None: diff --git a/packages/exchange/tests/providers/test_azure.py b/packages/exchange/tests/providers/test_azure.py index 44b75d38..5a701473 100644 --- a/packages/exchange/tests/providers/test_azure.py +++ b/packages/exchange/tests/providers/test_azure.py @@ -2,10 +2,10 @@ import os from unittest.mock import patch import pytest - from exchange import Text, ToolUse from exchange.providers.azure import AzureProvider from exchange.providers.base import MissingProviderEnvVariableError + from .conftest import complete, tools AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini") @@ -36,7 +36,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): AzureProvider.from_env() assert context.value.provider == "azure" assert context.value.env_variable == env_var_name - assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure." + assert context.value.message == f"Missing environment variables: {env_var_name} for provider azure." @pytest.mark.vcr() diff --git a/packages/exchange/tests/providers/test_bedrock.py b/packages/exchange/tests/providers/test_bedrock.py index f7b68c03..2b3ea36e 100644 --- a/packages/exchange/tests/providers/test_bedrock.py +++ b/packages/exchange/tests/providers/test_bedrock.py @@ -35,7 +35,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): BedrockProvider.from_env() assert context.value.provider == "bedrock" assert context.value.env_variable == env_var_name - assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock." + assert context.value.message == f"Missing environment variables: {env_var_name} for provider bedrock." @pytest.fixture diff --git a/packages/exchange/tests/providers/test_databricks.py b/packages/exchange/tests/providers/test_databricks.py index cd01335a..0b2729dc 100644 --- a/packages/exchange/tests/providers/test_databricks.py +++ b/packages/exchange/tests/providers/test_databricks.py @@ -28,7 +28,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): DatabricksProvider.from_env() assert context.value.provider == "databricks" assert context.value.env_variable == env_var_name - assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message + assert f"Missing environment variables: {env_var_name} for provider databricks" in context.value.message assert "https://docs.databricks.com" in context.value.message diff --git a/packages/exchange/tests/providers/test_google.py b/packages/exchange/tests/providers/test_google.py index 401db5a6..3e1028a0 100644 --- a/packages/exchange/tests/providers/test_google.py +++ b/packages/exchange/tests/providers/test_google.py @@ -28,7 +28,7 @@ def test_from_env_throw_error_when_missing_api_key(): GoogleProvider.from_env() assert context.value.provider == "google" assert context.value.env_variable == "GOOGLE_API_KEY" - assert "Missing environment variable: GOOGLE_API_KEY for provider google" in context.value.message + assert "Missing environment variables: GOOGLE_API_KEY for provider google" in context.value.message assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message diff --git a/packages/exchange/tests/providers/test_openai.py b/packages/exchange/tests/providers/test_openai.py index db0a5261..e29dadc5 100644 --- a/packages/exchange/tests/providers/test_openai.py +++ b/packages/exchange/tests/providers/test_openai.py @@ -2,7 +2,6 @@ import os from unittest.mock import patch import pytest - from exchange import Text, ToolUse from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.openai import OpenAiProvider @@ -17,7 +16,7 @@ def test_from_env_throw_error_when_missing_api_key(): OpenAiProvider.from_env() assert context.value.provider == "openai" assert context.value.env_variable == "OPENAI_API_KEY" - assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message + assert "Missing environment variables: OPENAI_API_KEY for provider openai" in context.value.message assert "https://platform.openai.com" in context.value.message diff --git a/packages/exchange/tests/test_base.py b/packages/exchange/tests/test_base.py index 4aae8bde..46baaebb 100644 --- a/packages/exchange/tests/test_base.py +++ b/packages/exchange/tests/test_base.py @@ -9,7 +9,7 @@ def test_missing_provider_env_variable_error_without_instructions_url(): assert error.env_variable == env_variable assert error.provider == provider assert error.instructions_url is None - assert error.message == "Missing environment variable: API_KEY for provider TestProvider." + assert error.message == "Missing environment variables: API_KEY for provider TestProvider." def test_missing_provider_env_variable_error_with_instructions_url(): @@ -22,6 +22,6 @@ def test_missing_provider_env_variable_error_with_instructions_url(): assert error.provider == provider assert error.instructions_url == instructions_url assert error.message == ( - "Missing environment variable: API_KEY for provider TestProvider.\n" + "Missing environment variables: API_KEY for provider TestProvider.\n" "Please see http://example.com/instructions for instructions" )