From 908af7f1577d569facefd6be2b1d8ddc97e67cc0 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 4 Oct 2024 11:17:29 +1000 Subject: [PATCH] fix: exit the goose and show the error message when provider environment variable is not set (#103) --- .gitignore | 5 +- .../src/exchange/invalid_choice_error.py | 13 +++++ .../src/exchange/moderators/__init__.py | 6 +- .../src/exchange/providers/__init__.py | 6 +- .../src/exchange/providers/anthropic.py | 8 +-- .../exchange/src/exchange/providers/azure.py | 26 +++------ .../exchange/src/exchange/providers/base.py | 13 ++++- .../src/exchange/providers/bedrock.py | 15 +++-- .../src/exchange/providers/databricks.py | 22 +++---- .../exchange/src/exchange/providers/google.py | 11 +--- .../exchange/src/exchange/providers/openai.py | 9 +-- .../exchange/src/exchange/providers/utils.py | 9 +++ .../tests/providers/test_anthropic.py | 10 ++++ .../exchange/tests/providers/test_azure.py | 30 ++++++++++ .../exchange/tests/providers/test_bedrock.py | 27 +++++++++ .../tests/providers/test_databricks.py | 26 +++++++++ .../exchange/tests/providers/test_google.py | 11 ++++ .../exchange/tests/providers/test_openai.py | 12 ++++ .../exchange/tests/providers/test_provider.py | 18 ++++++ packages/exchange/tests/test_base.py | 27 +++++++++ .../tests/test_invalid_choice_error.py | 13 +++++ packages/exchange/tests/test_moderators.py | 17 ++++++ src/goose/cli/config.py | 7 +-- src/goose/cli/session.py | 25 ++++++-- src/goose/toolkit/__init__.py | 7 ++- tests/cli/test_session.py | 58 +++++++++++++++---- 26 files changed, 345 insertions(+), 86 deletions(-) create mode 100644 packages/exchange/src/exchange/invalid_choice_error.py create mode 100644 packages/exchange/tests/providers/test_provider.py create mode 100644 packages/exchange/tests/test_base.py create mode 100644 packages/exchange/tests/test_invalid_choice_error.py create mode 100644 packages/exchange/tests/test_moderators.py diff --git a/.gitignore b/.gitignore index c9142a70..f799b722 100644 --- a/.gitignore +++ b/.gitignore @@ -121,4 +121,7 @@ celerybeat.pid .vscode # Autogenerated docs files -docs/docs/reference \ No newline at end of file +docs/docs/reference + +# uv lock file +uv.lock diff --git a/packages/exchange/src/exchange/invalid_choice_error.py b/packages/exchange/src/exchange/invalid_choice_error.py new file mode 100644 index 00000000..ffbb9899 --- /dev/null +++ b/packages/exchange/src/exchange/invalid_choice_error.py @@ -0,0 +1,13 @@ +from typing import List + + +class InvalidChoiceError(Exception): + def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None: + self.attribute_name = attribute_name + self.attribute_value = attribute_value + self.available_values = available_values + self.message = ( + f"Unknown {attribute_name}: {attribute_value}." + + f" Available {attribute_name}s: {', '.join(available_values)}" + ) + super().__init__(self.message) diff --git a/packages/exchange/src/exchange/moderators/__init__.py b/packages/exchange/src/exchange/moderators/__init__.py index 56b198a7..82d032e4 100644 --- a/packages/exchange/src/exchange/moderators/__init__.py +++ b/packages/exchange/src/exchange/moderators/__init__.py @@ -1,6 +1,7 @@ from functools import cache from typing import Type +from exchange.invalid_choice_error import InvalidChoiceError from exchange.moderators.base import Moderator from exchange.utils import load_plugins from exchange.moderators.passive import PassiveModerator # noqa @@ -10,4 +11,7 @@ from exchange.moderators.summarizer import ContextSummarizer # noqa @cache def get_moderator(name: str) -> Type[Moderator]: - return load_plugins(group="exchange.moderator")[name] + moderators = load_plugins(group="exchange.moderator") + if name not in moderators: + raise InvalidChoiceError("moderator", name, moderators.keys()) + return moderators[name] diff --git a/packages/exchange/src/exchange/providers/__init__.py b/packages/exchange/src/exchange/providers/__init__.py index ac7ed07a..f92d4f76 100644 --- a/packages/exchange/src/exchange/providers/__init__.py +++ b/packages/exchange/src/exchange/providers/__init__.py @@ -1,6 +1,7 @@ from functools import cache from typing import Type +from exchange.invalid_choice_error import InvalidChoiceError from exchange.providers.anthropic import AnthropicProvider # noqa from exchange.providers.base import Provider, Usage # noqa from exchange.providers.databricks import DatabricksProvider # noqa @@ -14,4 +15,7 @@ from exchange.utils import load_plugins @cache def get_provider(name: str) -> Type[Provider]: - return load_plugins(group="exchange.provider")[name] + providers = load_plugins(group="exchange.provider") + if name not in providers: + raise InvalidChoiceError("provider", name, providers.keys()) + return providers[name] diff --git a/packages/exchange/src/exchange/providers/anthropic.py b/packages/exchange/src/exchange/providers/anthropic.py index 154ec5f7..bf052b20 100644 --- a/packages/exchange/src/exchange/providers/anthropic.py +++ b/packages/exchange/src/exchange/providers/anthropic.py @@ -7,8 +7,7 @@ from exchange import Message, Tool from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status -from exchange.providers.utils import raise_for_status +from exchange.providers.utils import get_provider_env_value, retry_if_status, raise_for_status ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" @@ -27,10 +26,7 @@ class AnthropicProvider(Provider): @classmethod def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider": url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST) - try: - key = os.environ["ANTHROPIC_API_KEY"] - except KeyError: - raise RuntimeError("Failed to get ANTHROPIC_API_KEY from the environment") + key = get_provider_env_value("ANTHROPIC_API_KEY", "anthropic") client = httpx.Client( base_url=url, headers={ diff --git a/packages/exchange/src/exchange/providers/azure.py b/packages/exchange/src/exchange/providers/azure.py index 7bacb9dd..a06a557d 100644 --- a/packages/exchange/src/exchange/providers/azure.py +++ b/packages/exchange/src/exchange/providers/azure.py @@ -1,9 +1,11 @@ -import os from typing import Type import httpx from exchange.providers import OpenAiProvider +from exchange.providers.utils import get_provider_env_value + +PROVIDER_NAME = "azure" class AzureProvider(OpenAiProvider): @@ -14,25 +16,11 @@ class AzureProvider(OpenAiProvider): @classmethod def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": - try: - url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.") + url = get_provider_env_value("AZURE_CHAT_COMPLETIONS_HOST_NAME", PROVIDER_NAME) + deployment_name = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", PROVIDER_NAME) - try: - deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.") - - try: - api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.") - - try: - key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.") + api_version = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", PROVIDER_NAME) + key = get_provider_env_value("AZURE_CHAT_COMPLETIONS_KEY", PROVIDER_NAME) # format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version url = f"{url}/openai/deployments/{deployment_name}/" diff --git a/packages/exchange/src/exchange/providers/base.py b/packages/exchange/src/exchange/providers/base.py index 7b7ff88b..78c267e7 100644 --- a/packages/exchange/src/exchange/providers/base.py +++ b/packages/exchange/src/exchange/providers/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from attrs import define, field -from typing import List, Tuple, Type +from typing import List, Optional, Tuple, Type from exchange.message import Message from exchange.tool import Tool @@ -28,3 +28,14 @@ class Provider(ABC): ) -> Tuple[Message, Usage]: """Generate the next message using the specified model""" pass + + +class MissingProviderEnvVariableError(Exception): + def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None: + self.env_variable = env_variable + self.provider = provider + self.instructions_url = instructions_url + self.message = f"Missing environment variable: {env_variable} for provider {provider}." + if instructions_url: + self.message += f"\nPlease see {instructions_url} for instructions" + super().__init__(self.message) diff --git a/packages/exchange/src/exchange/providers/bedrock.py b/packages/exchange/src/exchange/providers/bedrock.py index 2a5f53dc..c8c1d681 100644 --- a/packages/exchange/src/exchange/providers/bedrock.py +++ b/packages/exchange/src/exchange/providers/bedrock.py @@ -13,8 +13,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message from exchange.providers import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status -from exchange.providers.utils import raise_for_status +from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status from exchange.tool import Tool SERVICE = "bedrock-runtime" @@ -147,6 +146,9 @@ class AwsClient(httpx.Client): return headers +PROVIDER_NAME = "bedrock" + + class BedrockProvider(Provider): def __init__(self, client: AwsClient) -> None: self.client = client @@ -154,12 +156,9 @@ class BedrockProvider(Provider): @classmethod def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider": aws_region = os.environ.get("AWS_REGION", "us-east-1") - try: - aws_access_key = os.environ["AWS_ACCESS_KEY_ID"] - aws_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"] - aws_session_token = os.environ.get("AWS_SESSION_TOKEN") - except KeyError: - raise RuntimeError("Failed to get AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY from the environment") + aws_access_key = get_provider_env_value("AWS_ACCESS_KEY_ID", PROVIDER_NAME) + aws_secret_key = get_provider_env_value("AWS_SECRET_ACCESS_KEY", PROVIDER_NAME) + aws_session_token = get_provider_env_value("AWS_SESSION_TOKEN", PROVIDER_NAME) client = AwsClient( aws_region=aws_region, diff --git a/packages/exchange/src/exchange/providers/databricks.py b/packages/exchange/src/exchange/providers/databricks.py index 84dc7515..77d392e8 100644 --- a/packages/exchange/src/exchange/providers/databricks.py +++ b/packages/exchange/src/exchange/providers/databricks.py @@ -1,4 +1,3 @@ -import os from typing import Any, Dict, List, Tuple, Type import httpx @@ -6,7 +5,7 @@ import httpx from exchange.message import Message from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import raise_for_status, retry_if_status +from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -37,18 +36,8 @@ class DatabricksProvider(Provider): @classmethod def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider": - try: - url = os.environ["DATABRICKS_HOST"] - except KeyError: - raise RuntimeError( - "Failed to get DATABRICKS_HOST from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" - ) - try: - key = os.environ["DATABRICKS_TOKEN"] - except KeyError: - raise RuntimeError( - "Failed to get DATABRICKS_TOKEN from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" - ) + url = cls._get_env_variable("DATABRICKS_HOST") + key = cls._get_env_variable("DATABRICKS_TOKEN") client = httpx.Client( base_url=url, auth=("token", key), @@ -100,3 +89,8 @@ class DatabricksProvider(Provider): json=payload, ) return raise_for_status(response).json() + + @classmethod + def _get_env_variable(cls: Type["DatabricksProvider"], key: str) -> str: + instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" + return get_provider_env_value(key, "databricks", instruction) diff --git a/packages/exchange/src/exchange/providers/google.py b/packages/exchange/src/exchange/providers/google.py index 426aa79d..4fc020f3 100644 --- a/packages/exchange/src/exchange/providers/google.py +++ b/packages/exchange/src/exchange/providers/google.py @@ -7,8 +7,7 @@ from exchange import Message, Tool from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status -from exchange.providers.utils import raise_for_status +from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" @@ -27,12 +26,8 @@ class GoogleProvider(Provider): @classmethod def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) - try: - key = os.environ["GOOGLE_API_KEY"] - except KeyError: - raise RuntimeError( - "Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key" - ) + api_key_instructions_url = "https://ai.google.dev/gemini-api/docs/api-key" + key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url) client = httpx.Client( base_url=url, diff --git a/packages/exchange/src/exchange/providers/openai.py b/packages/exchange/src/exchange/providers/openai.py index dbd293b4..c30558b8 100644 --- a/packages/exchange/src/exchange/providers/openai.py +++ b/packages/exchange/src/exchange/providers/openai.py @@ -6,6 +6,7 @@ import httpx from exchange.message import Message from exchange.providers.base import Provider, Usage from exchange.providers.utils import ( + get_provider_env_value, messages_to_openai_spec, openai_response_to_message, openai_single_message_context_length_exceeded, @@ -36,12 +37,8 @@ class OpenAiProvider(Provider): @classmethod def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": url = os.environ.get("OPENAI_HOST", OPENAI_HOST) - try: - key = os.environ["OPENAI_API_KEY"] - except KeyError: - raise RuntimeError( - "Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys" - ) + api_key_instructions_url = "https://platform.openai.com/docs/api-reference/api-keys" + key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url) client = httpx.Client( base_url=url + "v1/", auth=("Bearer", key), diff --git a/packages/exchange/src/exchange/providers/utils.py b/packages/exchange/src/exchange/providers/utils.py index 4be7ac31..01504305 100644 --- a/packages/exchange/src/exchange/providers/utils.py +++ b/packages/exchange/src/exchange/providers/utils.py @@ -1,11 +1,13 @@ import base64 import json +import os import re from typing import Any, Callable, Dict, List, Optional, Tuple import httpx from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message +from exchange.providers.base import MissingProviderEnvVariableError from exchange.tool import Tool from tenacity import retry_if_exception @@ -179,6 +181,13 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None: raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}") +def get_provider_env_value(env_variable: str, provider: str, instructions_url: Optional[str] = None) -> str: + try: + return os.environ[env_variable] + except KeyError: + raise MissingProviderEnvVariableError(env_variable, provider, instructions_url) + + class InitialMessageTooLargeError(Exception): """Custom error raised when the first input message in an exchange is too large.""" diff --git a/packages/exchange/tests/providers/test_anthropic.py b/packages/exchange/tests/providers/test_anthropic.py index a6f5bc68..272ebcb0 100644 --- a/packages/exchange/tests/providers/test_anthropic.py +++ b/packages/exchange/tests/providers/test_anthropic.py @@ -6,6 +6,7 @@ import pytest from exchange import Message, Text from exchange.content import ToolResult, ToolUse from exchange.providers.anthropic import AnthropicProvider +from exchange.providers.base import MissingProviderEnvVariableError from exchange.tool import Tool @@ -25,6 +26,15 @@ def anthropic_provider(): return AnthropicProvider.from_env() +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + AnthropicProvider.from_env() + assert context.value.provider == "anthropic" + assert context.value.env_variable == "ANTHROPIC_API_KEY" + assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic." + + def test_anthropic_response_to_text_message() -> None: response = { "content": [{"type": "text", "text": "Hello from Claude!"}], diff --git a/packages/exchange/tests/providers/test_azure.py b/packages/exchange/tests/providers/test_azure.py index adafabed..b46be30b 100644 --- a/packages/exchange/tests/providers/test_azure.py +++ b/packages/exchange/tests/providers/test_azure.py @@ -1,14 +1,44 @@ import os +from unittest.mock import patch import pytest from exchange import Text, ToolUse from exchange.providers.azure import AzureProvider +from exchange.providers.base import MissingProviderEnvVariableError from .conftest import complete, tools AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini") +@pytest.mark.parametrize( + "env_var_name", + [ + ("AZURE_CHAT_COMPLETIONS_HOST_NAME"), + ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"), + ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"), + ("AZURE_CHAT_COMPLETIONS_KEY"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version", + "AZURE_CHAT_COMPLETIONS_KEY": "test_api_key", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + AzureProvider.from_env() + assert context.value.provider == "azure" + assert context.value.env_variable == env_var_name + assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure." + + @pytest.mark.vcr() def test_azure_complete(default_azure_env): reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL) diff --git a/packages/exchange/tests/providers/test_bedrock.py b/packages/exchange/tests/providers/test_bedrock.py index 2525f650..f8fcaa4b 100644 --- a/packages/exchange/tests/providers/test_bedrock.py +++ b/packages/exchange/tests/providers/test_bedrock.py @@ -5,12 +5,39 @@ from unittest.mock import patch import pytest from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.bedrock import BedrockProvider from exchange.tool import Tool logger = logging.getLogger(__name__) +@pytest.mark.parametrize( + "env_var_name", + [ + ("AWS_ACCESS_KEY_ID"), + ("AWS_SECRET_ACCESS_KEY"), + ("AWS_SESSION_TOKEN"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "test_access_key_id", + "AWS_SECRET_ACCESS_KEY": "test_secret_access_key", + "AWS_SESSION_TOKEN": "test_session_token", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + BedrockProvider.from_env() + assert context.value.provider == "bedrock" + assert context.value.env_variable == env_var_name + assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock." + + @pytest.fixture @patch.dict( os.environ, diff --git a/packages/exchange/tests/providers/test_databricks.py b/packages/exchange/tests/providers/test_databricks.py index 3c142114..4b6793ab 100644 --- a/packages/exchange/tests/providers/test_databricks.py +++ b/packages/exchange/tests/providers/test_databricks.py @@ -3,9 +3,35 @@ from unittest.mock import patch import pytest from exchange import Message, Text +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.databricks import DatabricksProvider +@pytest.mark.parametrize( + "env_var_name", + [ + ("DATABRICKS_HOST"), + ("DATABRICKS_TOKEN"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "DATABRICKS_HOST": "test_host", + "DATABRICKS_TOKEN": "test_token", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + DatabricksProvider.from_env() + assert context.value.provider == "databricks" + assert context.value.env_variable == env_var_name + assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message + assert "https://docs.databricks.com" in context.value.message + + @pytest.fixture @patch.dict( os.environ, diff --git a/packages/exchange/tests/providers/test_google.py b/packages/exchange/tests/providers/test_google.py index 47ad46b4..76ae4c8d 100644 --- a/packages/exchange/tests/providers/test_google.py +++ b/packages/exchange/tests/providers/test_google.py @@ -5,6 +5,7 @@ import httpx import pytest from exchange import Message, Text from exchange.content import ToolResult, ToolUse +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.google import GoogleProvider from exchange.tool import Tool @@ -19,6 +20,16 @@ def example_fn(param: str) -> None: pass +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + GoogleProvider.from_env() + assert context.value.provider == "google" + assert context.value.env_variable == "GOOGLE_API_KEY" + assert "Missing environment variable: GOOGLE_API_KEY for provider google" in context.value.message + assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message + + @pytest.fixture @patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"}) def google_provider(): diff --git a/packages/exchange/tests/providers/test_openai.py b/packages/exchange/tests/providers/test_openai.py index 45bc6205..ea979abe 100644 --- a/packages/exchange/tests/providers/test_openai.py +++ b/packages/exchange/tests/providers/test_openai.py @@ -1,14 +1,26 @@ import os +from unittest.mock import patch import pytest from exchange import Text, ToolUse +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.openai import OpenAiProvider from .conftest import complete, vision, tools OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + OpenAiProvider.from_env() + assert context.value.provider == "openai" + assert context.value.env_variable == "OPENAI_API_KEY" + assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message + assert "https://platform.openai.com" in context.value.message + + @pytest.mark.vcr() def test_openai_complete(default_openai_env): reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL) diff --git a/packages/exchange/tests/providers/test_provider.py b/packages/exchange/tests/providers/test_provider.py new file mode 100644 index 00000000..fb7d15ce --- /dev/null +++ b/packages/exchange/tests/providers/test_provider.py @@ -0,0 +1,18 @@ +import pytest +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.providers import get_provider + + +def test_get_provider_valid(): + provider_name = "openai" + provider = get_provider(provider_name) + assert provider.__name__ == "OpenAiProvider" + + +def test_get_provider_throw_error_for_unknown_provider(): + with pytest.raises(InvalidChoiceError) as error: + get_provider("nonexistent") + assert error.value.attribute_name == "provider" + assert error.value.attribute_value == "nonexistent" + assert "openai" in error.value.available_values + assert "openai" in error.value.message diff --git a/packages/exchange/tests/test_base.py b/packages/exchange/tests/test_base.py new file mode 100644 index 00000000..4aae8bde --- /dev/null +++ b/packages/exchange/tests/test_base.py @@ -0,0 +1,27 @@ +from exchange.providers.base import MissingProviderEnvVariableError + + +def test_missing_provider_env_variable_error_without_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + error = MissingProviderEnvVariableError(env_variable, provider) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url is None + assert error.message == "Missing environment variable: API_KEY for provider TestProvider." + + +def test_missing_provider_env_variable_error_with_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + instructions_url = "http://example.com/instructions" + error = MissingProviderEnvVariableError(env_variable, provider, instructions_url) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url == instructions_url + assert error.message == ( + "Missing environment variable: API_KEY for provider TestProvider.\n" + "Please see http://example.com/instructions for instructions" + ) diff --git a/packages/exchange/tests/test_invalid_choice_error.py b/packages/exchange/tests/test_invalid_choice_error.py new file mode 100644 index 00000000..9fad8b12 --- /dev/null +++ b/packages/exchange/tests/test_invalid_choice_error.py @@ -0,0 +1,13 @@ +from exchange.invalid_choice_error import InvalidChoiceError + + +def test_load_invalid_choice_error(): + attribute_name = "moderator" + attribute_value = "not_exist" + available_values = ["truncate", "summarizer"] + error = InvalidChoiceError(attribute_name, attribute_value, available_values) + + assert error.attribute_name == attribute_name + assert error.attribute_value == attribute_value + assert error.attribute_value == attribute_value + assert error.message == "Unknown moderator: not_exist. Available moderators: truncate, summarizer" diff --git a/packages/exchange/tests/test_moderators.py b/packages/exchange/tests/test_moderators.py new file mode 100644 index 00000000..16bcaa13 --- /dev/null +++ b/packages/exchange/tests/test_moderators.py @@ -0,0 +1,17 @@ +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.moderators import get_moderator +import pytest + + +def test_get_moderator(): + moderator = get_moderator("truncate") + assert moderator.__name__ == "ContextTruncate" + + +def test_get_moderator_raise_error_for_unknown_moderator(): + with pytest.raises(InvalidChoiceError) as error: + get_moderator("nonexistent") + assert error.value.attribute_name == "moderator" + assert error.value.attribute_value == "nonexistent" + assert "truncate" in error.value.available_values + assert "truncate" in error.value.message diff --git a/src/goose/cli/config.py b/src/goose/cli/config.py index 2005c468..7bede0be 100644 --- a/src/goose/cli/config.py +++ b/src/goose/cli/config.py @@ -16,6 +16,7 @@ PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml") SESSIONS_PATH = GOOSE_GLOBAL_PATH.joinpath("sessions") SESSION_FILE_SUFFIX = ".jsonl" LOG_PATH = GOOSE_GLOBAL_PATH.joinpath("logs") +RECOMMENDED_DEFAULT_PROVIDER = "openai" @cache @@ -88,11 +89,7 @@ def default_model_configuration() -> Tuple[str, str, str]: except Exception: pass else: - raise ValueError( - "Could not detect an available provider," - + " make sure to plugin a provider or set an env var such as OPENAI_API_KEY" - ) - + provider = RECOMMENDED_DEFAULT_PROVIDER recommended = { "ollama": (OLLAMA_MODEL, OLLAMA_MODEL), "openai": ("gpt-4o", "gpt-4o-mini"), diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index bfa869a3..f9fe13b9 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -1,8 +1,11 @@ +import sys import traceback from pathlib import Path from typing import Any, Dict, List, Optional -from exchange import Message, ToolResult, ToolUse, Text +from exchange import Message, ToolResult, ToolUse, Text, Exchange +from exchange.providers.base import MissingProviderEnvVariableError +from exchange.invalid_choice_error import InvalidChoiceError from rich import print from rich.console import RenderableType from rich.live import Live @@ -11,7 +14,7 @@ from rich.panel import Panel from rich.status import Status from goose.build import build_exchange -from goose.cli.config import ensure_config, session_path, LOG_PATH +from goose.cli.config import PROFILES_CONFIG_PATH, ensure_config, session_path, LOG_PATH from goose._logger import get_logger, setup_logging from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.notifier import Notifier @@ -89,8 +92,7 @@ class Session: self.profile = profile self.status_indicator = Status("", spinner="dots") self.notifier = SessionNotifier(self.status_indicator) - - self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier) + self.exchange = self._create_exchange() setup_logging(log_file_directory=LOG_PATH, log_level=log_level) self.exchange.messages.extend(self._get_initial_messages()) @@ -100,6 +102,21 @@ class Session: self.prompt_session = GoosePromptSession() + def _create_exchange(self) -> Exchange: + try: + return build_exchange(profile=load_profile(self.profile), notifier=self.notifier) + except MissingProviderEnvVariableError as e: + error_message = f"{e.message}. Please set the required environment variable to continue." + print(Panel(error_message, style="red")) + sys.exit(1) + except InvalidChoiceError as e: + error_message = ( + f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n" + + "Configuration doc: https://block-open-source.github.io/goose/configuration.html" + ) + print(error_message) + sys.exit(1) + def _get_initial_messages(self) -> List[Message]: messages = self.load_session() diff --git a/src/goose/toolkit/__init__.py b/src/goose/toolkit/__init__.py index a3a97d41..fc561ee6 100644 --- a/src/goose/toolkit/__init__.py +++ b/src/goose/toolkit/__init__.py @@ -1,9 +1,12 @@ from functools import cache - +from exchange.invalid_choice_error import InvalidChoiceError from goose.toolkit.base import Toolkit from goose.utils import load_plugins @cache def get_toolkit(name: str) -> type[Toolkit]: - return load_plugins(group="goose.toolkit")[name] + toolkits = load_plugins(group="goose.toolkit") + if name not in toolkits: + raise InvalidChoiceError("toolkit", name, toolkits.keys()) + return toolkits[name] diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 6d9086bc..81519512 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch import pytest from exchange import Exchange, Message, ToolUse, ToolResult +from exchange.providers.base import MissingProviderEnvVariableError +from exchange.invalid_choice_error import InvalidChoiceError from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session @@ -19,10 +21,11 @@ def mock_specified_session_name(): @pytest.fixture def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory): - with patch("goose.cli.session.build_exchange") as mock_exchange, patch( - "goose.cli.session.load_profile", return_value=profile_factory() - ), patch("goose.cli.session.SessionNotifier") as mock_session_notifier, patch( - "goose.cli.session.load_provider", return_value="provider" + with ( + patch("goose.cli.session.build_exchange") as mock_exchange, + patch("goose.cli.session.load_profile", return_value=profile_factory()), + patch("goose.cli.session.SessionNotifier") as mock_session_notifier, + patch("goose.cli.session.load_provider", return_value="provider"), ): mock_session_notifier.return_value = MagicMock() mock_exchange.return_value = exchange_factory() @@ -113,9 +116,11 @@ def test_log_log_cost(create_session_with_mock_configs): session = create_session_with_mock_configs() mock_logger = MagicMock() cost_message = "You have used 100 tokens" - with patch("exchange.Exchange.get_token_usage", return_value={}), patch( - "goose.cli.session.get_total_cost_message", return_value=cost_message - ), patch("goose.cli.session.get_logger", return_value=mock_logger): + with ( + patch("exchange.Exchange.get_token_usage", return_value={}), + patch("goose.cli.session.get_total_cost_message", return_value=cost_message), + patch("goose.cli.session.get_logger", return_value=mock_logger), + ): session._log_cost() mock_logger.info.assert_called_once_with(cost_message) @@ -133,9 +138,11 @@ def test_run_should_auto_save_session(create_session_with_mock_configs, mock_ses ] session = create_session_with_mock_configs({"name": SESSION_NAME}) - with patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs), patch.object( - Exchange, "generate" - ) as mock_generate, patch("goose.cli.session.save_latest_session") as mock_save_latest_session: + with ( + patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs), + patch.object(Exchange, "generate") as mock_generate, + patch("goose.cli.session.save_latest_session") as mock_save_latest_session, + ): mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs) session.run() @@ -151,3 +158,34 @@ def test_set_generated_session_name(create_session_with_mock_configs, mock_sessi with patch("goose.cli.session.droid", return_value=generated_session_name): session = create_session_with_mock_configs({"name": None}) assert session.name == generated_session_name + + +def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mock_configs, mock_sessions_path): + session = create_session_with_mock_configs() + expected_error = MissingProviderEnvVariableError(env_variable="OPENAI_API_KEY", provider="openai") + with ( + patch("goose.cli.session.build_exchange", side_effect=expected_error), + patch("goose.cli.session.print") as mock_print, + patch("sys.exit") as mock_exit, + ): + session._create_exchange() + mock_print.call_args_list[0][0][0].renderable == ( + "Missing environment variable OPENAI_API_KEY for provider openai. ", + "Please set the required environment variable to continue.", + ) + mock_exit.assert_called_once_with(1) + + +def test_create_exchange_exit_when_configuration_is_incorrect(create_session_with_mock_configs, mock_sessions_path): + session = create_session_with_mock_configs() + expected_error = InvalidChoiceError( + attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"] + ) + with ( + patch("goose.cli.session.build_exchange", side_effect=expected_error), + patch("goose.cli.session.print") as mock_print, + patch("sys.exit") as mock_exit, + ): + session._create_exchange() + assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0] + mock_exit.assert_called_once_with(1)