diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 3678484c..68bdc4ce 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -3,9 +3,9 @@ import traceback from pathlib import Path from typing import Optional -from langfuse.decorators import langfuse_context from exchange import Message, Text, ToolResult, ToolUse -from exchange.langfuse_wrapper import observe_wrapper, auth_check +from exchange.langfuse_wrapper import auth_check, observe_wrapper +from langfuse.decorators import langfuse_context from rich import print from rich.markdown import Markdown from rich.panel import Panel @@ -21,7 +21,7 @@ from goose.profile import Profile from goose.utils import droid, load_plugins from goose.utils._cost_calculator import get_total_cost_message from goose.utils._create_exchange import create_exchange -from goose.utils.session_file import is_empty_session, is_existing_session, read_or_create_file, log_messages +from goose.utils.session_file import is_empty_session, is_existing_session, log_messages, read_or_create_file RESUME_MESSAGE = "I see we were interrupted. How can I help you?" @@ -286,9 +286,15 @@ class Session: print(f"[yellow]Session already exists at {self.session_file_path}.[/]") choice = OverwriteSessionPrompt.ask("Enter your choice", show_choices=False) + # during __init__ we load the previous context, so we need to + # explicitly clear it + self.exchange.messages.clear() + match choice: case "y" | "yes": print("Overwriting existing session") + with open(self.session_file_path, "w") as f: + f.write("") case "n" | "no": while True: @@ -299,7 +305,7 @@ class Session: print(f"[yellow]Session '{new_session_name}' already exists[/]") case "r" | "resume": - self.exchange.messages.extend(self.load_session()) + self.exchange.messages.extend(self._get_initial_messages()) def _remove_empty_session(self) -> bool: """ diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index c605b524..75ad3376 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -1,8 +1,11 @@ -from unittest.mock import MagicMock, patch +import os +from typing import Union +from unittest.mock import MagicMock, mock_open, patch import pytest from exchange import Message, ToolResult, ToolUse from goose.cli.prompt.goose_prompt_session import GoosePromptSession +from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session from prompt_toolkit import PromptSession @@ -11,27 +14,70 @@ SPECIFIED_SESSION_NAME = "mySession" SESSION_NAME = "test" -@pytest.fixture -def mock_specified_session_name(): - with patch.object(PromptSession, "prompt", return_value=SPECIFIED_SESSION_NAME) as specified_session_name: - yield specified_session_name +@pytest.fixture(scope="module", autouse=True) +def set_openai_api_key(): + key = "OPENAI_API_KEY" + value = "test_api_key" + + original_api_key = os.environ.get(key) + os.environ[key] = value + + yield + + if original_api_key is None: + os.environ.pop(key, None) + else: + os.environ[key] = original_api_key @pytest.fixture -def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory): - with ( - patch("goose.cli.session.create_exchange") as mock_exchange, - patch("goose.cli.session.load_profile", return_value=profile_factory()), - patch("goose.cli.session.SessionNotifier") as mock_session_notifier, - patch("goose.cli.session.load_provider", return_value="provider"), +@patch.object(PromptSession, "prompt", return_value=SPECIFIED_SESSION_NAME) +def mock_specified_session_name(specified_session_name): + yield specified_session_name + + +@pytest.fixture +@patch("goose.cli.session.create_exchange", name="mock_exchange") +@patch("goose.cli.session.load_profile", name="mock_load_profile") +@patch("goose.cli.session.SessionNotifier", name="mock_session_notifier") +@patch("goose.cli.session.load_provider", name="mock_load_provider") +def create_session_with_mock_configs( + mock_load_provider, + mock_session_notifier, + mock_load_profile, + mock_exchange, + mock_sessions_path, + exchange_factory, + profile_factory, +): + mock_load_provider.return_value = "provider" + mock_session_notifier.return_value = MagicMock() + mock_load_profile.return_value = profile_factory() + mock_exchange.return_value = exchange_factory() + + def create_session(session_attributes: dict = {}): + return Session(**session_attributes) + + return create_session + + +@pytest.fixture +def session_factory(create_session_with_mock_configs): + def factory( + name=SESSION_NAME, + overwrite_prompt=None, + is_existing_session=None, + get_initial_messages=None, + file_opener=open, ): - mock_session_notifier.return_value = MagicMock() - mock_exchange.return_value = exchange_factory() + session = create_session_with_mock_configs({"name": name}) + session.overwrite_prompt = overwrite_prompt or OverwriteSessionPrompt() + session.is_existing_session = is_existing_session or (lambda _: False) + session._get_initial_messages = get_initial_messages or (lambda: []) + session.file_opener = file_opener + return session - def create_session(session_attributes: dict = {}): - return Session(**session_attributes) - - yield create_session + return factory def test_session_does_not_extend_last_user_text_message_on_init( @@ -123,18 +169,33 @@ def test_log_log_cost(create_session_with_mock_configs): mock_logger.info.assert_called_once_with(cost_message) -@patch("goose.cli.session.droid", return_value="generated_session_name", name="mock_droid") -def test_set_generated_session_name(mock_droid, create_session_with_mock_configs, mock_sessions_path): +@patch("goose.cli.session.droid", return_value="generated_session_name") +@patch("goose.cli.session.load_provider") +def test_set_generated_session_name( + mock_load_provider, mock_droid, create_session_with_mock_configs, mock_sessions_path +): + mock_provider = MagicMock() + mock_load_provider.return_value = mock_provider + session = create_session_with_mock_configs({"name": None}) + assert session.name == "generated_session_name" @patch("goose.cli.session.is_existing_session", name="mock_is_existing") @patch("goose.cli.session.Session._prompt_overwrite_session", name="mock_prompt") -def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_with_mock_configs): +def test_existing_session_prompt( + mock_prompt, + mock_is_existing, + create_session_with_mock_configs, +): session = create_session_with_mock_configs({"name": SESSION_NAME}) - def check_prompt_behavior(is_existing, new_session, should_prompt): + def check_prompt_behavior( + is_existing: bool, + new_session: Union[bool, None], + should_prompt: bool, + ) -> None: mock_is_existing.return_value = is_existing if new_session is None: session.run() @@ -151,3 +212,48 @@ def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_w check_prompt_behavior(is_existing=False, new_session=None, should_prompt=False) check_prompt_behavior(is_existing=True, new_session=True, should_prompt=True) check_prompt_behavior(is_existing=False, new_session=False, should_prompt=False) + + +def test_prompt_overwrite_session(session_factory): + def check_overwrite_behavior(choice: str, expected_messages: list[Message]) -> None: + session = session_factory() + + with ( + patch.object(OverwriteSessionPrompt, "ask", return_value=choice), + patch.object(session, "is_existing_session", return_value=True), + patch.object( + session, + "_get_initial_messages", + return_value=[Message.user(text="duck duck"), Message.user(text="goose")], + ), + patch("rich.prompt.Prompt.ask", return_value="new_session_name"), + patch("builtins.open", mock_open()) as mock_file, + ): + session._prompt_overwrite_session() + + if choice in ["y", "yes"]: + mock_file.assert_called_once_with(session.session_file_path, "w") + mock_file().write.assert_called_once_with("") + elif choice in ["n", "no"]: + assert session.name == "new_session_name" + elif choice in ["r", "resume"]: + # this is tested comparing the contents of the array + pass + + # because the messages are created with an id and creation date, we only want to check the text + actual_messages = [message.text for message in session.exchange.messages] + expected_messages = [message.text for message in expected_messages] + assert actual_messages == expected_messages + + check_overwrite_behavior(choice="yes", expected_messages=[]) + check_overwrite_behavior(choice="y", expected_messages=[]) + check_overwrite_behavior(choice="no", expected_messages=[]) + check_overwrite_behavior(choice="n", expected_messages=[]) + check_overwrite_behavior( + choice="resume", + expected_messages=[Message.user(text="duck duck"), Message.user(text="goose")], + ) + check_overwrite_behavior( + choice="r", + expected_messages=[Message.user(text="duck duck"), Message.user(text="goose")], + ) #