fix: exit the goose and show the error message when provider environment variable is not set (#103)

This commit is contained in:
Lifei Zhou
2024-10-04 11:17:29 +10:00
committed by GitHub
parent 9e35c6370e
commit 908af7f157
26 changed files with 345 additions and 86 deletions

5
.gitignore vendored
View File

@@ -121,4 +121,7 @@ celerybeat.pid
.vscode
# Autogenerated docs files
docs/docs/reference
docs/docs/reference
# uv lock file
uv.lock

View File

@@ -0,0 +1,13 @@
from typing import List
class InvalidChoiceError(Exception):
def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None:
self.attribute_name = attribute_name
self.attribute_value = attribute_value
self.available_values = available_values
self.message = (
f"Unknown {attribute_name}: {attribute_value}."
+ f" Available {attribute_name}s: {', '.join(available_values)}"
)
super().__init__(self.message)

View File

@@ -1,6 +1,7 @@
from functools import cache
from typing import Type
from exchange.invalid_choice_error import InvalidChoiceError
from exchange.moderators.base import Moderator
from exchange.utils import load_plugins
from exchange.moderators.passive import PassiveModerator # noqa
@@ -10,4 +11,7 @@ from exchange.moderators.summarizer import ContextSummarizer # noqa
@cache
def get_moderator(name: str) -> Type[Moderator]:
return load_plugins(group="exchange.moderator")[name]
moderators = load_plugins(group="exchange.moderator")
if name not in moderators:
raise InvalidChoiceError("moderator", name, moderators.keys())
return moderators[name]

View File

@@ -1,6 +1,7 @@
from functools import cache
from typing import Type
from exchange.invalid_choice_error import InvalidChoiceError
from exchange.providers.anthropic import AnthropicProvider # noqa
from exchange.providers.base import Provider, Usage # noqa
from exchange.providers.databricks import DatabricksProvider # noqa
@@ -14,4 +15,7 @@ from exchange.utils import load_plugins
@cache
def get_provider(name: str) -> Type[Provider]:
return load_plugins(group="exchange.provider")[name]
providers = load_plugins(group="exchange.provider")
if name not in providers:
raise InvalidChoiceError("provider", name, providers.keys())
return providers[name]

View File

@@ -7,8 +7,7 @@ from exchange import Message, Tool
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status
from exchange.providers.utils import get_provider_env_value, retry_if_status, raise_for_status
ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"
@@ -27,10 +26,7 @@ class AnthropicProvider(Provider):
@classmethod
def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider":
url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST)
try:
key = os.environ["ANTHROPIC_API_KEY"]
except KeyError:
raise RuntimeError("Failed to get ANTHROPIC_API_KEY from the environment")
key = get_provider_env_value("ANTHROPIC_API_KEY", "anthropic")
client = httpx.Client(
base_url=url,
headers={

View File

@@ -1,9 +1,11 @@
import os
from typing import Type
import httpx
from exchange.providers import OpenAiProvider
from exchange.providers.utils import get_provider_env_value
PROVIDER_NAME = "azure"
class AzureProvider(OpenAiProvider):
@@ -14,25 +16,11 @@ class AzureProvider(OpenAiProvider):
@classmethod
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
try:
url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.")
url = get_provider_env_value("AZURE_CHAT_COMPLETIONS_HOST_NAME", PROVIDER_NAME)
deployment_name = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", PROVIDER_NAME)
try:
deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.")
try:
api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.")
try:
key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.")
api_version = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", PROVIDER_NAME)
key = get_provider_env_value("AZURE_CHAT_COMPLETIONS_KEY", PROVIDER_NAME)
# format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}/"

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from attrs import define, field
from typing import List, Tuple, Type
from typing import List, Optional, Tuple, Type
from exchange.message import Message
from exchange.tool import Tool
@@ -28,3 +28,14 @@ class Provider(ABC):
) -> Tuple[Message, Usage]:
"""Generate the next message using the specified model"""
pass
class MissingProviderEnvVariableError(Exception):
def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None:
self.env_variable = env_variable
self.provider = provider
self.instructions_url = instructions_url
self.message = f"Missing environment variable: {env_variable} for provider {provider}."
if instructions_url:
self.message += f"\nPlease see {instructions_url} for instructions"
super().__init__(self.message)

View File

@@ -13,8 +13,7 @@ from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
from exchange.tool import Tool
SERVICE = "bedrock-runtime"
@@ -147,6 +146,9 @@ class AwsClient(httpx.Client):
return headers
PROVIDER_NAME = "bedrock"
class BedrockProvider(Provider):
def __init__(self, client: AwsClient) -> None:
self.client = client
@@ -154,12 +156,9 @@ class BedrockProvider(Provider):
@classmethod
def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider":
aws_region = os.environ.get("AWS_REGION", "us-east-1")
try:
aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_session_token = os.environ.get("AWS_SESSION_TOKEN")
except KeyError:
raise RuntimeError("Failed to get AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY from the environment")
aws_access_key = get_provider_env_value("AWS_ACCESS_KEY_ID", PROVIDER_NAME)
aws_secret_key = get_provider_env_value("AWS_SECRET_ACCESS_KEY", PROVIDER_NAME)
aws_session_token = get_provider_env_value("AWS_SESSION_TOKEN", PROVIDER_NAME)
client = AwsClient(
aws_region=aws_region,

View File

@@ -1,4 +1,3 @@
import os
from typing import Any, Dict, List, Tuple, Type
import httpx
@@ -6,7 +5,7 @@ import httpx
from exchange.message import Message
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import raise_for_status, retry_if_status
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
@@ -37,18 +36,8 @@ class DatabricksProvider(Provider):
@classmethod
def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider":
try:
url = os.environ["DATABRICKS_HOST"]
except KeyError:
raise RuntimeError(
"Failed to get DATABRICKS_HOST from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
)
try:
key = os.environ["DATABRICKS_TOKEN"]
except KeyError:
raise RuntimeError(
"Failed to get DATABRICKS_TOKEN from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
)
url = cls._get_env_variable("DATABRICKS_HOST")
key = cls._get_env_variable("DATABRICKS_TOKEN")
client = httpx.Client(
base_url=url,
auth=("token", key),
@@ -100,3 +89,8 @@ class DatabricksProvider(Provider):
json=payload,
)
return raise_for_status(response).json()
@classmethod
def _get_env_variable(cls: Type["DatabricksProvider"], key: str) -> str:
instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
return get_provider_env_value(key, "databricks", instruction)

View File

@@ -7,8 +7,7 @@ from exchange import Message, Tool
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta"
@@ -27,12 +26,8 @@ class GoogleProvider(Provider):
@classmethod
def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider":
url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST)
try:
key = os.environ["GOOGLE_API_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key"
)
api_key_instructions_url = "https://ai.google.dev/gemini-api/docs/api-key"
key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url)
client = httpx.Client(
base_url=url,

View File

@@ -6,6 +6,7 @@ import httpx
from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.utils import (
get_provider_env_value,
messages_to_openai_spec,
openai_response_to_message,
openai_single_message_context_length_exceeded,
@@ -36,12 +37,8 @@ class OpenAiProvider(Provider):
@classmethod
def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider":
url = os.environ.get("OPENAI_HOST", OPENAI_HOST)
try:
key = os.environ["OPENAI_API_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys"
)
api_key_instructions_url = "https://platform.openai.com/docs/api-reference/api-keys"
key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url)
client = httpx.Client(
base_url=url + "v1/",
auth=("Bearer", key),

View File

@@ -1,11 +1,13 @@
import base64
import json
import os
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
import httpx
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.tool import Tool
from tenacity import retry_if_exception
@@ -179,6 +181,13 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None:
raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}")
def get_provider_env_value(env_variable: str, provider: str, instructions_url: Optional[str] = None) -> str:
try:
return os.environ[env_variable]
except KeyError:
raise MissingProviderEnvVariableError(env_variable, provider, instructions_url)
class InitialMessageTooLargeError(Exception):
"""Custom error raised when the first input message in an exchange is too large."""

View File

@@ -6,6 +6,7 @@ import pytest
from exchange import Message, Text
from exchange.content import ToolResult, ToolUse
from exchange.providers.anthropic import AnthropicProvider
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.tool import Tool
@@ -25,6 +26,15 @@ def anthropic_provider():
return AnthropicProvider.from_env()
def test_from_env_throw_error_when_missing_api_key():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(MissingProviderEnvVariableError) as context:
AnthropicProvider.from_env()
assert context.value.provider == "anthropic"
assert context.value.env_variable == "ANTHROPIC_API_KEY"
assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic."
def test_anthropic_response_to_text_message() -> None:
response = {
"content": [{"type": "text", "text": "Hello from Claude!"}],

View File

@@ -1,14 +1,44 @@
import os
from unittest.mock import patch
import pytest
from exchange import Text, ToolUse
from exchange.providers.azure import AzureProvider
from exchange.providers.base import MissingProviderEnvVariableError
from .conftest import complete, tools
AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini")
@pytest.mark.parametrize(
"env_var_name",
[
("AZURE_CHAT_COMPLETIONS_HOST_NAME"),
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"),
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"),
("AZURE_CHAT_COMPLETIONS_KEY"),
],
)
def test_from_env_throw_error_when_missing_env_var(env_var_name):
with patch.dict(
os.environ,
{
"AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version",
"AZURE_CHAT_COMPLETIONS_KEY": "test_api_key",
},
clear=True,
):
os.environ.pop(env_var_name)
with pytest.raises(MissingProviderEnvVariableError) as context:
AzureProvider.from_env()
assert context.value.provider == "azure"
assert context.value.env_variable == env_var_name
assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure."
@pytest.mark.vcr()
def test_azure_complete(default_azure_env):
reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL)

View File

@@ -5,12 +5,39 @@ from unittest.mock import patch
import pytest
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.providers.bedrock import BedrockProvider
from exchange.tool import Tool
logger = logging.getLogger(__name__)
@pytest.mark.parametrize(
"env_var_name",
[
("AWS_ACCESS_KEY_ID"),
("AWS_SECRET_ACCESS_KEY"),
("AWS_SESSION_TOKEN"),
],
)
def test_from_env_throw_error_when_missing_env_var(env_var_name):
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "test_access_key_id",
"AWS_SECRET_ACCESS_KEY": "test_secret_access_key",
"AWS_SESSION_TOKEN": "test_session_token",
},
clear=True,
):
os.environ.pop(env_var_name)
with pytest.raises(MissingProviderEnvVariableError) as context:
BedrockProvider.from_env()
assert context.value.provider == "bedrock"
assert context.value.env_variable == env_var_name
assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock."
@pytest.fixture
@patch.dict(
os.environ,

View File

@@ -3,9 +3,35 @@ from unittest.mock import patch
import pytest
from exchange import Message, Text
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.providers.databricks import DatabricksProvider
@pytest.mark.parametrize(
"env_var_name",
[
("DATABRICKS_HOST"),
("DATABRICKS_TOKEN"),
],
)
def test_from_env_throw_error_when_missing_env_var(env_var_name):
with patch.dict(
os.environ,
{
"DATABRICKS_HOST": "test_host",
"DATABRICKS_TOKEN": "test_token",
},
clear=True,
):
os.environ.pop(env_var_name)
with pytest.raises(MissingProviderEnvVariableError) as context:
DatabricksProvider.from_env()
assert context.value.provider == "databricks"
assert context.value.env_variable == env_var_name
assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message
assert "https://docs.databricks.com" in context.value.message
@pytest.fixture
@patch.dict(
os.environ,

View File

@@ -5,6 +5,7 @@ import httpx
import pytest
from exchange import Message, Text
from exchange.content import ToolResult, ToolUse
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.providers.google import GoogleProvider
from exchange.tool import Tool
@@ -19,6 +20,16 @@ def example_fn(param: str) -> None:
pass
def test_from_env_throw_error_when_missing_api_key():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(MissingProviderEnvVariableError) as context:
GoogleProvider.from_env()
assert context.value.provider == "google"
assert context.value.env_variable == "GOOGLE_API_KEY"
assert "Missing environment variable: GOOGLE_API_KEY for provider google" in context.value.message
assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message
@pytest.fixture
@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"})
def google_provider():

View File

@@ -1,14 +1,26 @@
import os
from unittest.mock import patch
import pytest
from exchange import Text, ToolUse
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.providers.openai import OpenAiProvider
from .conftest import complete, vision, tools
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
def test_from_env_throw_error_when_missing_api_key():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(MissingProviderEnvVariableError) as context:
OpenAiProvider.from_env()
assert context.value.provider == "openai"
assert context.value.env_variable == "OPENAI_API_KEY"
assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message
assert "https://platform.openai.com" in context.value.message
@pytest.mark.vcr()
def test_openai_complete(default_openai_env):
reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL)

View File

@@ -0,0 +1,18 @@
import pytest
from exchange.invalid_choice_error import InvalidChoiceError
from exchange.providers import get_provider
def test_get_provider_valid():
provider_name = "openai"
provider = get_provider(provider_name)
assert provider.__name__ == "OpenAiProvider"
def test_get_provider_throw_error_for_unknown_provider():
with pytest.raises(InvalidChoiceError) as error:
get_provider("nonexistent")
assert error.value.attribute_name == "provider"
assert error.value.attribute_value == "nonexistent"
assert "openai" in error.value.available_values
assert "openai" in error.value.message

View File

@@ -0,0 +1,27 @@
from exchange.providers.base import MissingProviderEnvVariableError
def test_missing_provider_env_variable_error_without_instructions_url():
env_variable = "API_KEY"
provider = "TestProvider"
error = MissingProviderEnvVariableError(env_variable, provider)
assert error.env_variable == env_variable
assert error.provider == provider
assert error.instructions_url is None
assert error.message == "Missing environment variable: API_KEY for provider TestProvider."
def test_missing_provider_env_variable_error_with_instructions_url():
env_variable = "API_KEY"
provider = "TestProvider"
instructions_url = "http://example.com/instructions"
error = MissingProviderEnvVariableError(env_variable, provider, instructions_url)
assert error.env_variable == env_variable
assert error.provider == provider
assert error.instructions_url == instructions_url
assert error.message == (
"Missing environment variable: API_KEY for provider TestProvider.\n"
"Please see http://example.com/instructions for instructions"
)

View File

@@ -0,0 +1,13 @@
from exchange.invalid_choice_error import InvalidChoiceError
def test_load_invalid_choice_error():
attribute_name = "moderator"
attribute_value = "not_exist"
available_values = ["truncate", "summarizer"]
error = InvalidChoiceError(attribute_name, attribute_value, available_values)
assert error.attribute_name == attribute_name
assert error.attribute_value == attribute_value
assert error.attribute_value == attribute_value
assert error.message == "Unknown moderator: not_exist. Available moderators: truncate, summarizer"

View File

@@ -0,0 +1,17 @@
from exchange.invalid_choice_error import InvalidChoiceError
from exchange.moderators import get_moderator
import pytest
def test_get_moderator():
moderator = get_moderator("truncate")
assert moderator.__name__ == "ContextTruncate"
def test_get_moderator_raise_error_for_unknown_moderator():
with pytest.raises(InvalidChoiceError) as error:
get_moderator("nonexistent")
assert error.value.attribute_name == "moderator"
assert error.value.attribute_value == "nonexistent"
assert "truncate" in error.value.available_values
assert "truncate" in error.value.message

View File

@@ -16,6 +16,7 @@ PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml")
SESSIONS_PATH = GOOSE_GLOBAL_PATH.joinpath("sessions")
SESSION_FILE_SUFFIX = ".jsonl"
LOG_PATH = GOOSE_GLOBAL_PATH.joinpath("logs")
RECOMMENDED_DEFAULT_PROVIDER = "openai"
@cache
@@ -88,11 +89,7 @@ def default_model_configuration() -> Tuple[str, str, str]:
except Exception:
pass
else:
raise ValueError(
"Could not detect an available provider,"
+ " make sure to plugin a provider or set an env var such as OPENAI_API_KEY"
)
provider = RECOMMENDED_DEFAULT_PROVIDER
recommended = {
"ollama": (OLLAMA_MODEL, OLLAMA_MODEL),
"openai": ("gpt-4o", "gpt-4o-mini"),

View File

@@ -1,8 +1,11 @@
import sys
import traceback
from pathlib import Path
from typing import Any, Dict, List, Optional
from exchange import Message, ToolResult, ToolUse, Text
from exchange import Message, ToolResult, ToolUse, Text, Exchange
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.invalid_choice_error import InvalidChoiceError
from rich import print
from rich.console import RenderableType
from rich.live import Live
@@ -11,7 +14,7 @@ from rich.panel import Panel
from rich.status import Status
from goose.build import build_exchange
from goose.cli.config import ensure_config, session_path, LOG_PATH
from goose.cli.config import PROFILES_CONFIG_PATH, ensure_config, session_path, LOG_PATH
from goose._logger import get_logger, setup_logging
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.notifier import Notifier
@@ -89,8 +92,7 @@ class Session:
self.profile = profile
self.status_indicator = Status("", spinner="dots")
self.notifier = SessionNotifier(self.status_indicator)
self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier)
self.exchange = self._create_exchange()
setup_logging(log_file_directory=LOG_PATH, log_level=log_level)
self.exchange.messages.extend(self._get_initial_messages())
@@ -100,6 +102,21 @@ class Session:
self.prompt_session = GoosePromptSession()
def _create_exchange(self) -> Exchange:
try:
return build_exchange(profile=load_profile(self.profile), notifier=self.notifier)
except MissingProviderEnvVariableError as e:
error_message = f"{e.message}. Please set the required environment variable to continue."
print(Panel(error_message, style="red"))
sys.exit(1)
except InvalidChoiceError as e:
error_message = (
f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n"
+ "Configuration doc: https://block-open-source.github.io/goose/configuration.html"
)
print(error_message)
sys.exit(1)
def _get_initial_messages(self) -> List[Message]:
messages = self.load_session()

View File

@@ -1,9 +1,12 @@
from functools import cache
from exchange.invalid_choice_error import InvalidChoiceError
from goose.toolkit.base import Toolkit
from goose.utils import load_plugins
@cache
def get_toolkit(name: str) -> type[Toolkit]:
return load_plugins(group="goose.toolkit")[name]
toolkits = load_plugins(group="goose.toolkit")
if name not in toolkits:
raise InvalidChoiceError("toolkit", name, toolkits.keys())
return toolkits[name]

View File

@@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch
import pytest
from exchange import Exchange, Message, ToolUse, ToolResult
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.invalid_choice_error import InvalidChoiceError
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.prompt.user_input import PromptAction, UserInput
from goose.cli.session import Session
@@ -19,10 +21,11 @@ def mock_specified_session_name():
@pytest.fixture
def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory):
with patch("goose.cli.session.build_exchange") as mock_exchange, patch(
"goose.cli.session.load_profile", return_value=profile_factory()
), patch("goose.cli.session.SessionNotifier") as mock_session_notifier, patch(
"goose.cli.session.load_provider", return_value="provider"
with (
patch("goose.cli.session.build_exchange") as mock_exchange,
patch("goose.cli.session.load_profile", return_value=profile_factory()),
patch("goose.cli.session.SessionNotifier") as mock_session_notifier,
patch("goose.cli.session.load_provider", return_value="provider"),
):
mock_session_notifier.return_value = MagicMock()
mock_exchange.return_value = exchange_factory()
@@ -113,9 +116,11 @@ def test_log_log_cost(create_session_with_mock_configs):
session = create_session_with_mock_configs()
mock_logger = MagicMock()
cost_message = "You have used 100 tokens"
with patch("exchange.Exchange.get_token_usage", return_value={}), patch(
"goose.cli.session.get_total_cost_message", return_value=cost_message
), patch("goose.cli.session.get_logger", return_value=mock_logger):
with (
patch("exchange.Exchange.get_token_usage", return_value={}),
patch("goose.cli.session.get_total_cost_message", return_value=cost_message),
patch("goose.cli.session.get_logger", return_value=mock_logger),
):
session._log_cost()
mock_logger.info.assert_called_once_with(cost_message)
@@ -133,9 +138,11 @@ def test_run_should_auto_save_session(create_session_with_mock_configs, mock_ses
]
session = create_session_with_mock_configs({"name": SESSION_NAME})
with patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs), patch.object(
Exchange, "generate"
) as mock_generate, patch("goose.cli.session.save_latest_session") as mock_save_latest_session:
with (
patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs),
patch.object(Exchange, "generate") as mock_generate,
patch("goose.cli.session.save_latest_session") as mock_save_latest_session,
):
mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs)
session.run()
@@ -151,3 +158,34 @@ def test_set_generated_session_name(create_session_with_mock_configs, mock_sessi
with patch("goose.cli.session.droid", return_value=generated_session_name):
session = create_session_with_mock_configs({"name": None})
assert session.name == generated_session_name
def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mock_configs, mock_sessions_path):
session = create_session_with_mock_configs()
expected_error = MissingProviderEnvVariableError(env_variable="OPENAI_API_KEY", provider="openai")
with (
patch("goose.cli.session.build_exchange", side_effect=expected_error),
patch("goose.cli.session.print") as mock_print,
patch("sys.exit") as mock_exit,
):
session._create_exchange()
mock_print.call_args_list[0][0][0].renderable == (
"Missing environment variable OPENAI_API_KEY for provider openai. ",
"Please set the required environment variable to continue.",
)
mock_exit.assert_called_once_with(1)
def test_create_exchange_exit_when_configuration_is_incorrect(create_session_with_mock_configs, mock_sessions_path):
session = create_session_with_mock_configs()
expected_error = InvalidChoiceError(
attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"]
)
with (
patch("goose.cli.session.build_exchange", side_effect=expected_error),
patch("goose.cli.session.print") as mock_print,
patch("sys.exit") as mock_exit,
):
session._create_exchange()
assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0]
mock_exit.assert_called_once_with(1)