From 4375e2fe5ea5647dc59809b2f829150ebd9c7301 Mon Sep 17 00:00:00 2001 From: Lam Chau Date: Thu, 10 Oct 2024 05:01:04 -0700 Subject: [PATCH] feat: add guards to session management (#101) --- .../cli/prompt/overwrite_session_prompt.py | 30 +++++++ src/goose/cli/session.py | 71 ++++++++++++--- src/goose/utils/session_file.py | 10 ++- tests/cli/test_session.py | 86 ++++++++++++++----- tests/utils/test_session_file.py | 21 +++++ 5 files changed, 184 insertions(+), 34 deletions(-) create mode 100644 src/goose/cli/prompt/overwrite_session_prompt.py diff --git a/src/goose/cli/prompt/overwrite_session_prompt.py b/src/goose/cli/prompt/overwrite_session_prompt.py new file mode 100644 index 00000000..0f4db693 --- /dev/null +++ b/src/goose/cli/prompt/overwrite_session_prompt.py @@ -0,0 +1,30 @@ +from typing import Any + +from rich.prompt import Prompt + + +class OverwriteSessionPrompt(Prompt): + def __init__(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None: + super().__init__(*args, **kwargs) + self.choices = { + "yes": "Overwrite the existing session", + "no": "Pick a new session name", + "resume": "Resume the existing session", + } + self.default = "resume" + + def check_choice(self, choice: str) -> bool: + for key in self.choices: + normalized_choice = choice.lower() + if normalized_choice == key or normalized_choice[0] == key[0]: + return True + return False + + def pre_prompt(self) -> str: + print("Would you like to overwrite it?") + print() + for key, value in self.choices.items(): + first_letter, remaining = key[0], key[1:] + rendered_key = rf"[{first_letter}]{remaining}" + print(f" {rendered_key:10} {value}") + print() diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 16b3cab6..bd249e8f 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -1,22 +1,24 @@ import traceback from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Optional -from exchange import Message, ToolResult, ToolUse, Text +from exchange import Message, Text, ToolResult, ToolUse from rich import print from rich.markdown import Markdown from rich.panel import Panel +from rich.prompt import Prompt from rich.status import Status -from goose.cli.config import ensure_config, session_path, LOG_PATH from goose._logger import get_logger, setup_logging +from goose.cli.config import LOG_PATH, ensure_config, session_path from goose.cli.prompt.goose_prompt_session import GoosePromptSession -from goose.cli.session_notifier import SessionNotifier +from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt +from goose.notifier import Notifier from goose.profile import Profile from goose.utils import droid, load_plugins from goose.utils._cost_calculator import get_total_cost_message from goose.utils._create_exchange import create_exchange -from goose.utils.session_file import read_or_create_file, save_latest_session +from goose.utils.session_file import is_empty_session, is_existing_session, read_or_create_file, save_latest_session RESUME_MESSAGE = "I see we were interrupted. How can I help you?" @@ -60,7 +62,7 @@ class Session: profile: Optional[str] = None, plan: Optional[dict] = None, log_level: Optional[str] = "INFO", - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: if name is None: self.name = droid() @@ -69,7 +71,7 @@ class Session: self.profile_name = profile self.prompt_session = GoosePromptSession() self.status_indicator = Status("", spinner="dots") - self.notifier = SessionNotifier(self.status_indicator) + self.notifier = Notifier(self.status_indicator) self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier) setup_logging(log_file_directory=LOG_PATH, log_level=log_level) @@ -81,7 +83,7 @@ class Session: self.prompt_session = GoosePromptSession() - def _get_initial_messages(self) -> List[Message]: + def _get_initial_messages(self) -> list[Message]: messages = self.load_session() if messages and messages[-1].role == "user": @@ -151,8 +153,11 @@ class Session: Runs the main loop to handle user inputs and responses. Continues until an empty string is returned from the prompt. """ - print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile_name or 'default'}[/]") - print(f"[dim]saving to {self.session_file_path}") + if is_existing_session(self.session_file_path): + self._prompt_overwrite_session() + + profile_name = self.profile_name or "default" + print(f"[dim]starting session | name: [cyan]{self.name}[/cyan] profile: [cyan]{profile_name}[/cyan][/dim]") print() message = self.process_first_message() while message: # Loop until no input (empty string). @@ -178,6 +183,7 @@ class Session: user_input = self.prompt_session.get_user_input() message = Message.user(text=user_input.text) if user_input.to_continue() else None + self._remove_empty_session() self._log_cost() def reply(self) -> None: @@ -234,12 +240,53 @@ class Session: def session_file_path(self) -> Path: return session_path(self.name) - def load_session(self) -> List[Message]: + def load_session(self) -> list[Message]: return read_or_create_file(self.session_file_path) def _log_cost(self) -> None: get_logger().info(get_total_cost_message(self.exchange.get_token_usage())) - print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}") + print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}[/]") + + def _prompt_overwrite_session(self) -> None: + print(f"[yellow]Session already exists at {self.session_file_path}.[/]") + + choice = OverwriteSessionPrompt.ask("Enter your choice", show_choices=False) + match choice: + case "y" | "yes": + print("Overwriting existing session") + + case "n" | "no": + while True: + new_session_name = Prompt.ask("Enter a new session name") + if not is_existing_session(session_path(new_session_name)): + self.name = new_session_name + break + print(f"[yellow]Session '{new_session_name}' already exists[/]") + + case "r" | "resume": + self.exchange.messages.extend(self.load_session()) + + def _remove_empty_session(self) -> bool: + """ + Removes the session file only when it's empty. + + Note: This is because a session file is created at the start of the run + loop. When a user aborts before their first message empty session files + will be created, causing confusion when resuming sessions (which + depends on most recent mtime and is non-empty). + + Returns: + bool: True if the session file was removed, False otherwise. + """ + logger = get_logger() + try: + if is_empty_session(self.session_file_path): + logger.debug(f"deleting empty session file: {self.session_file_path}") + self.session_file_path.unlink() + return True + except Exception as e: + logger.error(f"error deleting empty session file: {e}") + return False if __name__ == "__main__": diff --git a/src/goose/utils/session_file.py b/src/goose/utils/session_file.py index 435186ce..e367dcf1 100644 --- a/src/goose/utils/session_file.py +++ b/src/goose/utils/session_file.py @@ -1,7 +1,7 @@ import json import os -from pathlib import Path import tempfile +from pathlib import Path from typing import Dict, Iterator, List from exchange import Message @@ -9,6 +9,14 @@ from exchange import Message from goose.cli.config import SESSION_FILE_SUFFIX +def is_existing_session(path: Path) -> bool: + return path.is_file() and path.stat().st_size > 0 + + +def is_empty_session(path: Path) -> bool: + return path.is_file() and path.stat().st_size == 0 + + def write_to_file(file_path: Path, messages: List[Message]) -> None: with open(file_path, "w") as f: _write_messages_to_file(f, messages) diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index f2437462..8def0442 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -1,7 +1,8 @@ +import json from unittest.mock import MagicMock, patch import pytest -from exchange import Exchange, Message, ToolUse, ToolResult +from exchange import Exchange, Message, ToolResult, ToolUse from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session @@ -22,7 +23,7 @@ def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profi with ( patch("goose.cli.session.create_exchange") as mock_exchange, patch("goose.cli.session.load_profile", return_value=profile_factory()), - patch("goose.cli.session.SessionNotifier") as mock_session_notifier, + patch("goose.cli.session.Notifier") as mock_session_notifier, patch("goose.cli.session.load_provider", return_value="provider"), ): mock_session_notifier.return_value = MagicMock() @@ -123,36 +124,79 @@ def test_log_log_cost(create_session_with_mock_configs): mock_logger.info.assert_called_once_with(cost_message) -def test_run_should_auto_save_session(create_session_with_mock_configs, mock_sessions_path): +@patch.object(GoosePromptSession, "get_user_input", name="get_user_input") +@patch.object(Exchange, "generate", name="mock_generate") +@patch("goose.cli.session.save_latest_session", name="mock_save_latest_session") +def test_run_should_auto_save_session( + mock_save_latest_session, + mock_generate, + mock_get_user_input, + create_session_with_mock_configs, + mock_sessions_path, +): def custom_exchange_generate(self, *args, **kwargs): message = Message.assistant("Response") self.add(message) return message + def mock_generate_side_effect(*args, **kwargs): + return custom_exchange_generate(session.exchange, *args, **kwargs) + + def save_latest_session(file, messages): + file.write_text("\n".join(json.dumps(m.to_dict()) for m in messages)) + user_inputs = [ UserInput(action=PromptAction.CONTINUE, text="Question1"), UserInput(action=PromptAction.CONTINUE, text="Question2"), UserInput(action=PromptAction.EXIT), ] - session = create_session_with_mock_configs({"name": SESSION_NAME}) - with ( - patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs), - patch.object(Exchange, "generate") as mock_generate, - patch("goose.cli.session.save_latest_session") as mock_save_latest_session, - ): - mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs) - session.run() - session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl" - assert session.exchange.generate.call_count == 2 - assert mock_save_latest_session.call_count == 2 - assert mock_save_latest_session.call_args_list[0][0][0] == session_file - assert session_file.exists() + mock_get_user_input.side_effect = user_inputs + mock_generate.side_effect = mock_generate_side_effect + mock_save_latest_session.side_effect = save_latest_session + + session.run() + + session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl" + + assert mock_generate.call_count == 2 + assert mock_save_latest_session.call_count == 2 + assert mock_save_latest_session.call_args_list[0][0][0] == session_file + assert session_file.exists() + + with open(session_file, "r") as f: + saved_messages = [json.loads(line) for line in f] + + expected_messages = [ + Message.user("Question1"), + Message.assistant("Response"), + Message.user("Question2"), + Message.assistant("Response"), + ] + + assert len(saved_messages) == len(expected_messages) + for saved, expected in zip(saved_messages, expected_messages): + assert saved["role"] == expected.role + assert saved["content"][0]["text"] == expected.text -def test_set_generated_session_name(create_session_with_mock_configs, mock_sessions_path): - generated_session_name = "generated_session_name" - with patch("goose.cli.session.droid", return_value=generated_session_name): - session = create_session_with_mock_configs({"name": None}) - assert session.name == generated_session_name +@patch("goose.cli.session.droid", return_value="generated_session_name", name="mock_droid") +def test_set_generated_session_name(mock_droid, create_session_with_mock_configs, mock_sessions_path): + session = create_session_with_mock_configs({"name": None}) + assert session.name == "generated_session_name" + + +@patch("goose.cli.session.is_existing_session", name="mock_is_existing") +@patch("goose.cli.session.Session._prompt_overwrite_session", name="mock_prompt") +def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_with_mock_configs): + session = create_session_with_mock_configs({"name": SESSION_NAME}) + + mock_is_existing.return_value = True + session.run() + mock_prompt.assert_called_once() + + mock_prompt.reset_mock() + mock_is_existing.return_value = False + session.run() + mock_prompt.assert_not_called() diff --git a/tests/utils/test_session_file.py b/tests/utils/test_session_file.py index 6a2a6498..29056456 100644 --- a/tests/utils/test_session_file.py +++ b/tests/utils/test_session_file.py @@ -1,9 +1,11 @@ import os from pathlib import Path +from unittest.mock import patch import pytest from exchange import Message from goose.utils.session_file import ( + is_empty_session, list_sorted_session_files, read_from_file, read_or_create_file, @@ -115,3 +117,22 @@ def create_session_file(file_path, file_name) -> Path: file = file_path / f"{file_name}.jsonl" file.touch() return file + + +@patch("pathlib.Path.is_file", return_value=True, name="mock_is_file") +@patch("pathlib.Path.stat", name="mock_stat") +def test_is_empty_session(mock_stat, mock_is_file): + mock_stat.return_value.st_size = 0 + assert is_empty_session(Path("empty_file.json")) + + +@patch("pathlib.Path.is_file", return_value=True, name="mock_is_file") +@patch("pathlib.Path.stat", name="mock_stat") +def test_is_not_empty_session(mock_stat, mock_is_file): + mock_stat.return_value.st_size = 100 + assert not is_empty_session(Path("non_empty_file.json")) + + +@patch("pathlib.Path.is_file", return_value=False, name="mock_is_file") +def test_is_not_empty_session_file_not_found(mock_is_file): + assert not is_empty_session(Path("file_not_found.json"))