mirror of
https://github.com/aljazceru/goose.git
synced 2026-01-01 05:24:24 +01:00
fix: exit the goose and show the error message when provider environment variable is not set (#103)
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -121,4 +121,7 @@ celerybeat.pid
|
||||
.vscode
|
||||
|
||||
# Autogenerated docs files
|
||||
docs/docs/reference
|
||||
docs/docs/reference
|
||||
|
||||
# uv lock file
|
||||
uv.lock
|
||||
|
||||
13
packages/exchange/src/exchange/invalid_choice_error.py
Normal file
13
packages/exchange/src/exchange/invalid_choice_error.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import List
|
||||
|
||||
|
||||
class InvalidChoiceError(Exception):
|
||||
def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None:
|
||||
self.attribute_name = attribute_name
|
||||
self.attribute_value = attribute_value
|
||||
self.available_values = available_values
|
||||
self.message = (
|
||||
f"Unknown {attribute_name}: {attribute_value}."
|
||||
+ f" Available {attribute_name}s: {', '.join(available_values)}"
|
||||
)
|
||||
super().__init__(self.message)
|
||||
@@ -1,6 +1,7 @@
|
||||
from functools import cache
|
||||
from typing import Type
|
||||
|
||||
from exchange.invalid_choice_error import InvalidChoiceError
|
||||
from exchange.moderators.base import Moderator
|
||||
from exchange.utils import load_plugins
|
||||
from exchange.moderators.passive import PassiveModerator # noqa
|
||||
@@ -10,4 +11,7 @@ from exchange.moderators.summarizer import ContextSummarizer # noqa
|
||||
|
||||
@cache
|
||||
def get_moderator(name: str) -> Type[Moderator]:
|
||||
return load_plugins(group="exchange.moderator")[name]
|
||||
moderators = load_plugins(group="exchange.moderator")
|
||||
if name not in moderators:
|
||||
raise InvalidChoiceError("moderator", name, moderators.keys())
|
||||
return moderators[name]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from functools import cache
|
||||
from typing import Type
|
||||
|
||||
from exchange.invalid_choice_error import InvalidChoiceError
|
||||
from exchange.providers.anthropic import AnthropicProvider # noqa
|
||||
from exchange.providers.base import Provider, Usage # noqa
|
||||
from exchange.providers.databricks import DatabricksProvider # noqa
|
||||
@@ -14,4 +15,7 @@ from exchange.utils import load_plugins
|
||||
|
||||
@cache
|
||||
def get_provider(name: str) -> Type[Provider]:
|
||||
return load_plugins(group="exchange.provider")[name]
|
||||
providers = load_plugins(group="exchange.provider")
|
||||
if name not in providers:
|
||||
raise InvalidChoiceError("provider", name, providers.keys())
|
||||
return providers[name]
|
||||
|
||||
@@ -7,8 +7,7 @@ from exchange import Message, Tool
|
||||
from exchange.content import Text, ToolResult, ToolUse
|
||||
from exchange.providers.base import Provider, Usage
|
||||
from tenacity import retry, wait_fixed, stop_after_attempt
|
||||
from exchange.providers.utils import retry_if_status
|
||||
from exchange.providers.utils import raise_for_status
|
||||
from exchange.providers.utils import get_provider_env_value, retry_if_status, raise_for_status
|
||||
|
||||
ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"
|
||||
|
||||
@@ -27,10 +26,7 @@ class AnthropicProvider(Provider):
|
||||
@classmethod
|
||||
def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider":
|
||||
url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST)
|
||||
try:
|
||||
key = os.environ["ANTHROPIC_API_KEY"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Failed to get ANTHROPIC_API_KEY from the environment")
|
||||
key = get_provider_env_value("ANTHROPIC_API_KEY", "anthropic")
|
||||
client = httpx.Client(
|
||||
base_url=url,
|
||||
headers={
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
from typing import Type
|
||||
|
||||
import httpx
|
||||
|
||||
from exchange.providers import OpenAiProvider
|
||||
from exchange.providers.utils import get_provider_env_value
|
||||
|
||||
PROVIDER_NAME = "azure"
|
||||
|
||||
|
||||
class AzureProvider(OpenAiProvider):
|
||||
@@ -14,25 +16,11 @@ class AzureProvider(OpenAiProvider):
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
|
||||
try:
|
||||
url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.")
|
||||
url = get_provider_env_value("AZURE_CHAT_COMPLETIONS_HOST_NAME", PROVIDER_NAME)
|
||||
deployment_name = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", PROVIDER_NAME)
|
||||
|
||||
try:
|
||||
deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.")
|
||||
|
||||
try:
|
||||
api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.")
|
||||
|
||||
try:
|
||||
key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.")
|
||||
api_version = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", PROVIDER_NAME)
|
||||
key = get_provider_env_value("AZURE_CHAT_COMPLETIONS_KEY", PROVIDER_NAME)
|
||||
|
||||
# format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version
|
||||
url = f"{url}/openai/deployments/{deployment_name}/"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from attrs import define, field
|
||||
from typing import List, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
from exchange.message import Message
|
||||
from exchange.tool import Tool
|
||||
@@ -28,3 +28,14 @@ class Provider(ABC):
|
||||
) -> Tuple[Message, Usage]:
|
||||
"""Generate the next message using the specified model"""
|
||||
pass
|
||||
|
||||
|
||||
class MissingProviderEnvVariableError(Exception):
|
||||
def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None:
|
||||
self.env_variable = env_variable
|
||||
self.provider = provider
|
||||
self.instructions_url = instructions_url
|
||||
self.message = f"Missing environment variable: {env_variable} for provider {provider}."
|
||||
if instructions_url:
|
||||
self.message += f"\nPlease see {instructions_url} for instructions"
|
||||
super().__init__(self.message)
|
||||
|
||||
@@ -13,8 +13,7 @@ from exchange.content import Text, ToolResult, ToolUse
|
||||
from exchange.message import Message
|
||||
from exchange.providers import Provider, Usage
|
||||
from tenacity import retry, wait_fixed, stop_after_attempt
|
||||
from exchange.providers.utils import retry_if_status
|
||||
from exchange.providers.utils import raise_for_status
|
||||
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
|
||||
from exchange.tool import Tool
|
||||
|
||||
SERVICE = "bedrock-runtime"
|
||||
@@ -147,6 +146,9 @@ class AwsClient(httpx.Client):
|
||||
return headers
|
||||
|
||||
|
||||
PROVIDER_NAME = "bedrock"
|
||||
|
||||
|
||||
class BedrockProvider(Provider):
|
||||
def __init__(self, client: AwsClient) -> None:
|
||||
self.client = client
|
||||
@@ -154,12 +156,9 @@ class BedrockProvider(Provider):
|
||||
@classmethod
|
||||
def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider":
|
||||
aws_region = os.environ.get("AWS_REGION", "us-east-1")
|
||||
try:
|
||||
aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
|
||||
aws_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
|
||||
aws_session_token = os.environ.get("AWS_SESSION_TOKEN")
|
||||
except KeyError:
|
||||
raise RuntimeError("Failed to get AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY from the environment")
|
||||
aws_access_key = get_provider_env_value("AWS_ACCESS_KEY_ID", PROVIDER_NAME)
|
||||
aws_secret_key = get_provider_env_value("AWS_SECRET_ACCESS_KEY", PROVIDER_NAME)
|
||||
aws_session_token = get_provider_env_value("AWS_SESSION_TOKEN", PROVIDER_NAME)
|
||||
|
||||
client = AwsClient(
|
||||
aws_region=aws_region,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
import httpx
|
||||
@@ -6,7 +5,7 @@ import httpx
|
||||
from exchange.message import Message
|
||||
from exchange.providers.base import Provider, Usage
|
||||
from tenacity import retry, wait_fixed, stop_after_attempt
|
||||
from exchange.providers.utils import raise_for_status, retry_if_status
|
||||
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
|
||||
from exchange.providers.utils import (
|
||||
messages_to_openai_spec,
|
||||
openai_response_to_message,
|
||||
@@ -37,18 +36,8 @@ class DatabricksProvider(Provider):
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider":
|
||||
try:
|
||||
url = os.environ["DATABRICKS_HOST"]
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
"Failed to get DATABRICKS_HOST from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
|
||||
)
|
||||
try:
|
||||
key = os.environ["DATABRICKS_TOKEN"]
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
"Failed to get DATABRICKS_TOKEN from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
|
||||
)
|
||||
url = cls._get_env_variable("DATABRICKS_HOST")
|
||||
key = cls._get_env_variable("DATABRICKS_TOKEN")
|
||||
client = httpx.Client(
|
||||
base_url=url,
|
||||
auth=("token", key),
|
||||
@@ -100,3 +89,8 @@ class DatabricksProvider(Provider):
|
||||
json=payload,
|
||||
)
|
||||
return raise_for_status(response).json()
|
||||
|
||||
@classmethod
|
||||
def _get_env_variable(cls: Type["DatabricksProvider"], key: str) -> str:
|
||||
instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
|
||||
return get_provider_env_value(key, "databricks", instruction)
|
||||
|
||||
@@ -7,8 +7,7 @@ from exchange import Message, Tool
|
||||
from exchange.content import Text, ToolResult, ToolUse
|
||||
from exchange.providers.base import Provider, Usage
|
||||
from tenacity import retry, wait_fixed, stop_after_attempt
|
||||
from exchange.providers.utils import retry_if_status
|
||||
from exchange.providers.utils import raise_for_status
|
||||
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
|
||||
|
||||
GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
@@ -27,12 +26,8 @@ class GoogleProvider(Provider):
|
||||
@classmethod
|
||||
def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider":
|
||||
url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST)
|
||||
try:
|
||||
key = os.environ["GOOGLE_API_KEY"]
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
"Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key"
|
||||
)
|
||||
api_key_instructions_url = "https://ai.google.dev/gemini-api/docs/api-key"
|
||||
key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url)
|
||||
|
||||
client = httpx.Client(
|
||||
base_url=url,
|
||||
|
||||
@@ -6,6 +6,7 @@ import httpx
|
||||
from exchange.message import Message
|
||||
from exchange.providers.base import Provider, Usage
|
||||
from exchange.providers.utils import (
|
||||
get_provider_env_value,
|
||||
messages_to_openai_spec,
|
||||
openai_response_to_message,
|
||||
openai_single_message_context_length_exceeded,
|
||||
@@ -36,12 +37,8 @@ class OpenAiProvider(Provider):
|
||||
@classmethod
|
||||
def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider":
|
||||
url = os.environ.get("OPENAI_HOST", OPENAI_HOST)
|
||||
try:
|
||||
key = os.environ["OPENAI_API_KEY"]
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
"Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys"
|
||||
)
|
||||
api_key_instructions_url = "https://platform.openai.com/docs/api-reference/api-keys"
|
||||
key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url)
|
||||
client = httpx.Client(
|
||||
base_url=url + "v1/",
|
||||
auth=("Bearer", key),
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from exchange.content import Text, ToolResult, ToolUse
|
||||
from exchange.message import Message
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
from exchange.tool import Tool
|
||||
from tenacity import retry_if_exception
|
||||
|
||||
@@ -179,6 +181,13 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None:
|
||||
raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}")
|
||||
|
||||
|
||||
def get_provider_env_value(env_variable: str, provider: str, instructions_url: Optional[str] = None) -> str:
|
||||
try:
|
||||
return os.environ[env_variable]
|
||||
except KeyError:
|
||||
raise MissingProviderEnvVariableError(env_variable, provider, instructions_url)
|
||||
|
||||
|
||||
class InitialMessageTooLargeError(Exception):
|
||||
"""Custom error raised when the first input message in an exchange is too large."""
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
from exchange import Message, Text
|
||||
from exchange.content import ToolResult, ToolUse
|
||||
from exchange.providers.anthropic import AnthropicProvider
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
from exchange.tool import Tool
|
||||
|
||||
|
||||
@@ -25,6 +26,15 @@ def anthropic_provider():
|
||||
return AnthropicProvider.from_env()
|
||||
|
||||
|
||||
def test_from_env_throw_error_when_missing_api_key():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(MissingProviderEnvVariableError) as context:
|
||||
AnthropicProvider.from_env()
|
||||
assert context.value.provider == "anthropic"
|
||||
assert context.value.env_variable == "ANTHROPIC_API_KEY"
|
||||
assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic."
|
||||
|
||||
|
||||
def test_anthropic_response_to_text_message() -> None:
|
||||
response = {
|
||||
"content": [{"type": "text", "text": "Hello from Claude!"}],
|
||||
|
||||
@@ -1,14 +1,44 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from exchange import Text, ToolUse
|
||||
from exchange.providers.azure import AzureProvider
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
from .conftest import complete, tools
|
||||
|
||||
AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_var_name",
|
||||
[
|
||||
("AZURE_CHAT_COMPLETIONS_HOST_NAME"),
|
||||
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"),
|
||||
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"),
|
||||
("AZURE_CHAT_COMPLETIONS_KEY"),
|
||||
],
|
||||
)
|
||||
def test_from_env_throw_error_when_missing_env_var(env_var_name):
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name",
|
||||
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name",
|
||||
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version",
|
||||
"AZURE_CHAT_COMPLETIONS_KEY": "test_api_key",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
os.environ.pop(env_var_name)
|
||||
with pytest.raises(MissingProviderEnvVariableError) as context:
|
||||
AzureProvider.from_env()
|
||||
assert context.value.provider == "azure"
|
||||
assert context.value.env_variable == env_var_name
|
||||
assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure."
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_azure_complete(default_azure_env):
|
||||
reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL)
|
||||
|
||||
@@ -5,12 +5,39 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from exchange.content import Text, ToolResult, ToolUse
|
||||
from exchange.message import Message
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
from exchange.providers.bedrock import BedrockProvider
|
||||
from exchange.tool import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_var_name",
|
||||
[
|
||||
("AWS_ACCESS_KEY_ID"),
|
||||
("AWS_SECRET_ACCESS_KEY"),
|
||||
("AWS_SESSION_TOKEN"),
|
||||
],
|
||||
)
|
||||
def test_from_env_throw_error_when_missing_env_var(env_var_name):
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AWS_ACCESS_KEY_ID": "test_access_key_id",
|
||||
"AWS_SECRET_ACCESS_KEY": "test_secret_access_key",
|
||||
"AWS_SESSION_TOKEN": "test_session_token",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
os.environ.pop(env_var_name)
|
||||
with pytest.raises(MissingProviderEnvVariableError) as context:
|
||||
BedrockProvider.from_env()
|
||||
assert context.value.provider == "bedrock"
|
||||
assert context.value.env_variable == env_var_name
|
||||
assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock."
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
|
||||
@@ -3,9 +3,35 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from exchange import Message, Text
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
from exchange.providers.databricks import DatabricksProvider
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_var_name",
|
||||
[
|
||||
("DATABRICKS_HOST"),
|
||||
("DATABRICKS_TOKEN"),
|
||||
],
|
||||
)
|
||||
def test_from_env_throw_error_when_missing_env_var(env_var_name):
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DATABRICKS_HOST": "test_host",
|
||||
"DATABRICKS_TOKEN": "test_token",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
os.environ.pop(env_var_name)
|
||||
with pytest.raises(MissingProviderEnvVariableError) as context:
|
||||
DatabricksProvider.from_env()
|
||||
assert context.value.provider == "databricks"
|
||||
assert context.value.env_variable == env_var_name
|
||||
assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message
|
||||
assert "https://docs.databricks.com" in context.value.message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
|
||||
@@ -5,6 +5,7 @@ import httpx
|
||||
import pytest
|
||||
from exchange import Message, Text
|
||||
from exchange.content import ToolResult, ToolUse
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
from exchange.providers.google import GoogleProvider
|
||||
from exchange.tool import Tool
|
||||
|
||||
@@ -19,6 +20,16 @@ def example_fn(param: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def test_from_env_throw_error_when_missing_api_key():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(MissingProviderEnvVariableError) as context:
|
||||
GoogleProvider.from_env()
|
||||
assert context.value.provider == "google"
|
||||
assert context.value.env_variable == "GOOGLE_API_KEY"
|
||||
assert "Missing environment variable: GOOGLE_API_KEY for provider google" in context.value.message
|
||||
assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"})
|
||||
def google_provider():
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from exchange import Text, ToolUse
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
from exchange.providers.openai import OpenAiProvider
|
||||
from .conftest import complete, vision, tools
|
||||
|
||||
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
|
||||
|
||||
|
||||
def test_from_env_throw_error_when_missing_api_key():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(MissingProviderEnvVariableError) as context:
|
||||
OpenAiProvider.from_env()
|
||||
assert context.value.provider == "openai"
|
||||
assert context.value.env_variable == "OPENAI_API_KEY"
|
||||
assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message
|
||||
assert "https://platform.openai.com" in context.value.message
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_openai_complete(default_openai_env):
|
||||
reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL)
|
||||
|
||||
18
packages/exchange/tests/providers/test_provider.py
Normal file
18
packages/exchange/tests/providers/test_provider.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import pytest
|
||||
from exchange.invalid_choice_error import InvalidChoiceError
|
||||
from exchange.providers import get_provider
|
||||
|
||||
|
||||
def test_get_provider_valid():
|
||||
provider_name = "openai"
|
||||
provider = get_provider(provider_name)
|
||||
assert provider.__name__ == "OpenAiProvider"
|
||||
|
||||
|
||||
def test_get_provider_throw_error_for_unknown_provider():
|
||||
with pytest.raises(InvalidChoiceError) as error:
|
||||
get_provider("nonexistent")
|
||||
assert error.value.attribute_name == "provider"
|
||||
assert error.value.attribute_value == "nonexistent"
|
||||
assert "openai" in error.value.available_values
|
||||
assert "openai" in error.value.message
|
||||
27
packages/exchange/tests/test_base.py
Normal file
27
packages/exchange/tests/test_base.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
|
||||
|
||||
def test_missing_provider_env_variable_error_without_instructions_url():
|
||||
env_variable = "API_KEY"
|
||||
provider = "TestProvider"
|
||||
error = MissingProviderEnvVariableError(env_variable, provider)
|
||||
|
||||
assert error.env_variable == env_variable
|
||||
assert error.provider == provider
|
||||
assert error.instructions_url is None
|
||||
assert error.message == "Missing environment variable: API_KEY for provider TestProvider."
|
||||
|
||||
|
||||
def test_missing_provider_env_variable_error_with_instructions_url():
|
||||
env_variable = "API_KEY"
|
||||
provider = "TestProvider"
|
||||
instructions_url = "http://example.com/instructions"
|
||||
error = MissingProviderEnvVariableError(env_variable, provider, instructions_url)
|
||||
|
||||
assert error.env_variable == env_variable
|
||||
assert error.provider == provider
|
||||
assert error.instructions_url == instructions_url
|
||||
assert error.message == (
|
||||
"Missing environment variable: API_KEY for provider TestProvider.\n"
|
||||
"Please see http://example.com/instructions for instructions"
|
||||
)
|
||||
13
packages/exchange/tests/test_invalid_choice_error.py
Normal file
13
packages/exchange/tests/test_invalid_choice_error.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from exchange.invalid_choice_error import InvalidChoiceError
|
||||
|
||||
|
||||
def test_load_invalid_choice_error():
|
||||
attribute_name = "moderator"
|
||||
attribute_value = "not_exist"
|
||||
available_values = ["truncate", "summarizer"]
|
||||
error = InvalidChoiceError(attribute_name, attribute_value, available_values)
|
||||
|
||||
assert error.attribute_name == attribute_name
|
||||
assert error.attribute_value == attribute_value
|
||||
assert error.attribute_value == attribute_value
|
||||
assert error.message == "Unknown moderator: not_exist. Available moderators: truncate, summarizer"
|
||||
17
packages/exchange/tests/test_moderators.py
Normal file
17
packages/exchange/tests/test_moderators.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from exchange.invalid_choice_error import InvalidChoiceError
|
||||
from exchange.moderators import get_moderator
|
||||
import pytest
|
||||
|
||||
|
||||
def test_get_moderator():
|
||||
moderator = get_moderator("truncate")
|
||||
assert moderator.__name__ == "ContextTruncate"
|
||||
|
||||
|
||||
def test_get_moderator_raise_error_for_unknown_moderator():
|
||||
with pytest.raises(InvalidChoiceError) as error:
|
||||
get_moderator("nonexistent")
|
||||
assert error.value.attribute_name == "moderator"
|
||||
assert error.value.attribute_value == "nonexistent"
|
||||
assert "truncate" in error.value.available_values
|
||||
assert "truncate" in error.value.message
|
||||
@@ -16,6 +16,7 @@ PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml")
|
||||
SESSIONS_PATH = GOOSE_GLOBAL_PATH.joinpath("sessions")
|
||||
SESSION_FILE_SUFFIX = ".jsonl"
|
||||
LOG_PATH = GOOSE_GLOBAL_PATH.joinpath("logs")
|
||||
RECOMMENDED_DEFAULT_PROVIDER = "openai"
|
||||
|
||||
|
||||
@cache
|
||||
@@ -88,11 +89,7 @@ def default_model_configuration() -> Tuple[str, str, str]:
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not detect an available provider,"
|
||||
+ " make sure to plugin a provider or set an env var such as OPENAI_API_KEY"
|
||||
)
|
||||
|
||||
provider = RECOMMENDED_DEFAULT_PROVIDER
|
||||
recommended = {
|
||||
"ollama": (OLLAMA_MODEL, OLLAMA_MODEL),
|
||||
"openai": ("gpt-4o", "gpt-4o-mini"),
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from exchange import Message, ToolResult, ToolUse, Text
|
||||
from exchange import Message, ToolResult, ToolUse, Text, Exchange
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
from exchange.invalid_choice_error import InvalidChoiceError
|
||||
from rich import print
|
||||
from rich.console import RenderableType
|
||||
from rich.live import Live
|
||||
@@ -11,7 +14,7 @@ from rich.panel import Panel
|
||||
from rich.status import Status
|
||||
|
||||
from goose.build import build_exchange
|
||||
from goose.cli.config import ensure_config, session_path, LOG_PATH
|
||||
from goose.cli.config import PROFILES_CONFIG_PATH, ensure_config, session_path, LOG_PATH
|
||||
from goose._logger import get_logger, setup_logging
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.notifier import Notifier
|
||||
@@ -89,8 +92,7 @@ class Session:
|
||||
self.profile = profile
|
||||
self.status_indicator = Status("", spinner="dots")
|
||||
self.notifier = SessionNotifier(self.status_indicator)
|
||||
|
||||
self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier)
|
||||
self.exchange = self._create_exchange()
|
||||
setup_logging(log_file_directory=LOG_PATH, log_level=log_level)
|
||||
|
||||
self.exchange.messages.extend(self._get_initial_messages())
|
||||
@@ -100,6 +102,21 @@ class Session:
|
||||
|
||||
self.prompt_session = GoosePromptSession()
|
||||
|
||||
def _create_exchange(self) -> Exchange:
|
||||
try:
|
||||
return build_exchange(profile=load_profile(self.profile), notifier=self.notifier)
|
||||
except MissingProviderEnvVariableError as e:
|
||||
error_message = f"{e.message}. Please set the required environment variable to continue."
|
||||
print(Panel(error_message, style="red"))
|
||||
sys.exit(1)
|
||||
except InvalidChoiceError as e:
|
||||
error_message = (
|
||||
f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n"
|
||||
+ "Configuration doc: https://block-open-source.github.io/goose/configuration.html"
|
||||
)
|
||||
print(error_message)
|
||||
sys.exit(1)
|
||||
|
||||
def _get_initial_messages(self) -> List[Message]:
|
||||
messages = self.load_session()
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from functools import cache
|
||||
|
||||
from exchange.invalid_choice_error import InvalidChoiceError
|
||||
from goose.toolkit.base import Toolkit
|
||||
from goose.utils import load_plugins
|
||||
|
||||
|
||||
@cache
|
||||
def get_toolkit(name: str) -> type[Toolkit]:
|
||||
return load_plugins(group="goose.toolkit")[name]
|
||||
toolkits = load_plugins(group="goose.toolkit")
|
||||
if name not in toolkits:
|
||||
raise InvalidChoiceError("toolkit", name, toolkits.keys())
|
||||
return toolkits[name]
|
||||
|
||||
@@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from exchange import Exchange, Message, ToolUse, ToolResult
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
from exchange.invalid_choice_error import InvalidChoiceError
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.cli.prompt.user_input import PromptAction, UserInput
|
||||
from goose.cli.session import Session
|
||||
@@ -19,10 +21,11 @@ def mock_specified_session_name():
|
||||
|
||||
@pytest.fixture
|
||||
def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory):
|
||||
with patch("goose.cli.session.build_exchange") as mock_exchange, patch(
|
||||
"goose.cli.session.load_profile", return_value=profile_factory()
|
||||
), patch("goose.cli.session.SessionNotifier") as mock_session_notifier, patch(
|
||||
"goose.cli.session.load_provider", return_value="provider"
|
||||
with (
|
||||
patch("goose.cli.session.build_exchange") as mock_exchange,
|
||||
patch("goose.cli.session.load_profile", return_value=profile_factory()),
|
||||
patch("goose.cli.session.SessionNotifier") as mock_session_notifier,
|
||||
patch("goose.cli.session.load_provider", return_value="provider"),
|
||||
):
|
||||
mock_session_notifier.return_value = MagicMock()
|
||||
mock_exchange.return_value = exchange_factory()
|
||||
@@ -113,9 +116,11 @@ def test_log_log_cost(create_session_with_mock_configs):
|
||||
session = create_session_with_mock_configs()
|
||||
mock_logger = MagicMock()
|
||||
cost_message = "You have used 100 tokens"
|
||||
with patch("exchange.Exchange.get_token_usage", return_value={}), patch(
|
||||
"goose.cli.session.get_total_cost_message", return_value=cost_message
|
||||
), patch("goose.cli.session.get_logger", return_value=mock_logger):
|
||||
with (
|
||||
patch("exchange.Exchange.get_token_usage", return_value={}),
|
||||
patch("goose.cli.session.get_total_cost_message", return_value=cost_message),
|
||||
patch("goose.cli.session.get_logger", return_value=mock_logger),
|
||||
):
|
||||
session._log_cost()
|
||||
mock_logger.info.assert_called_once_with(cost_message)
|
||||
|
||||
@@ -133,9 +138,11 @@ def test_run_should_auto_save_session(create_session_with_mock_configs, mock_ses
|
||||
]
|
||||
|
||||
session = create_session_with_mock_configs({"name": SESSION_NAME})
|
||||
with patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs), patch.object(
|
||||
Exchange, "generate"
|
||||
) as mock_generate, patch("goose.cli.session.save_latest_session") as mock_save_latest_session:
|
||||
with (
|
||||
patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs),
|
||||
patch.object(Exchange, "generate") as mock_generate,
|
||||
patch("goose.cli.session.save_latest_session") as mock_save_latest_session,
|
||||
):
|
||||
mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs)
|
||||
session.run()
|
||||
|
||||
@@ -151,3 +158,34 @@ def test_set_generated_session_name(create_session_with_mock_configs, mock_sessi
|
||||
with patch("goose.cli.session.droid", return_value=generated_session_name):
|
||||
session = create_session_with_mock_configs({"name": None})
|
||||
assert session.name == generated_session_name
|
||||
|
||||
|
||||
def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mock_configs, mock_sessions_path):
|
||||
session = create_session_with_mock_configs()
|
||||
expected_error = MissingProviderEnvVariableError(env_variable="OPENAI_API_KEY", provider="openai")
|
||||
with (
|
||||
patch("goose.cli.session.build_exchange", side_effect=expected_error),
|
||||
patch("goose.cli.session.print") as mock_print,
|
||||
patch("sys.exit") as mock_exit,
|
||||
):
|
||||
session._create_exchange()
|
||||
mock_print.call_args_list[0][0][0].renderable == (
|
||||
"Missing environment variable OPENAI_API_KEY for provider openai. ",
|
||||
"Please set the required environment variable to continue.",
|
||||
)
|
||||
mock_exit.assert_called_once_with(1)
|
||||
|
||||
|
||||
def test_create_exchange_exit_when_configuration_is_incorrect(create_session_with_mock_configs, mock_sessions_path):
|
||||
session = create_session_with_mock_configs()
|
||||
expected_error = InvalidChoiceError(
|
||||
attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"]
|
||||
)
|
||||
with (
|
||||
patch("goose.cli.session.build_exchange", side_effect=expected_error),
|
||||
patch("goose.cli.session.print") as mock_print,
|
||||
patch("sys.exit") as mock_exit,
|
||||
):
|
||||
session._create_exchange()
|
||||
assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0]
|
||||
mock_exit.assert_called_once_with(1)
|
||||
|
||||
Reference in New Issue
Block a user