feat: add guards to session management (#101)

This commit is contained in:
Lam Chau
2024-10-10 05:01:04 -07:00
committed by GitHub
parent 798f346c5a
commit 4375e2fe5e
5 changed files with 184 additions and 34 deletions

View File

@@ -0,0 +1,30 @@
from typing import Any
from rich.prompt import Prompt
class OverwriteSessionPrompt(Prompt):
def __init__(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
super().__init__(*args, **kwargs)
self.choices = {
"yes": "Overwrite the existing session",
"no": "Pick a new session name",
"resume": "Resume the existing session",
}
self.default = "resume"
def check_choice(self, choice: str) -> bool:
for key in self.choices:
normalized_choice = choice.lower()
if normalized_choice == key or normalized_choice[0] == key[0]:
return True
return False
def pre_prompt(self) -> str:
print("Would you like to overwrite it?")
print()
for key, value in self.choices.items():
first_letter, remaining = key[0], key[1:]
rendered_key = rf"[{first_letter}]{remaining}"
print(f" {rendered_key:10} {value}")
print()

View File

@@ -1,22 +1,24 @@
import traceback
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from exchange import Message, ToolResult, ToolUse, Text
from exchange import Message, Text, ToolResult, ToolUse
from rich import print
from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Prompt
from rich.status import Status
from goose.cli.config import ensure_config, session_path, LOG_PATH
from goose._logger import get_logger, setup_logging
from goose.cli.config import LOG_PATH, ensure_config, session_path
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.session_notifier import SessionNotifier
from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt
from goose.notifier import Notifier
from goose.profile import Profile
from goose.utils import droid, load_plugins
from goose.utils._cost_calculator import get_total_cost_message
from goose.utils._create_exchange import create_exchange
from goose.utils.session_file import read_or_create_file, save_latest_session
from goose.utils.session_file import is_empty_session, is_existing_session, read_or_create_file, save_latest_session
RESUME_MESSAGE = "I see we were interrupted. How can I help you?"
@@ -60,7 +62,7 @@ class Session:
profile: Optional[str] = None,
plan: Optional[dict] = None,
log_level: Optional[str] = "INFO",
**kwargs: Dict[str, Any],
**kwargs: dict[str, Any],
) -> None:
if name is None:
self.name = droid()
@@ -69,7 +71,7 @@ class Session:
self.profile_name = profile
self.prompt_session = GoosePromptSession()
self.status_indicator = Status("", spinner="dots")
self.notifier = SessionNotifier(self.status_indicator)
self.notifier = Notifier(self.status_indicator)
self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier)
setup_logging(log_file_directory=LOG_PATH, log_level=log_level)
@@ -81,7 +83,7 @@ class Session:
self.prompt_session = GoosePromptSession()
def _get_initial_messages(self) -> List[Message]:
def _get_initial_messages(self) -> list[Message]:
messages = self.load_session()
if messages and messages[-1].role == "user":
@@ -151,8 +153,11 @@ class Session:
Runs the main loop to handle user inputs and responses.
Continues until an empty string is returned from the prompt.
"""
print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile_name or 'default'}[/]")
print(f"[dim]saving to {self.session_file_path}")
if is_existing_session(self.session_file_path):
self._prompt_overwrite_session()
profile_name = self.profile_name or "default"
print(f"[dim]starting session | name: [cyan]{self.name}[/cyan] profile: [cyan]{profile_name}[/cyan][/dim]")
print()
message = self.process_first_message()
while message: # Loop until no input (empty string).
@@ -178,6 +183,7 @@ class Session:
user_input = self.prompt_session.get_user_input()
message = Message.user(text=user_input.text) if user_input.to_continue() else None
self._remove_empty_session()
self._log_cost()
def reply(self) -> None:
@@ -234,12 +240,53 @@ class Session:
def session_file_path(self) -> Path:
return session_path(self.name)
def load_session(self) -> List[Message]:
def load_session(self) -> list[Message]:
return read_or_create_file(self.session_file_path)
def _log_cost(self) -> None:
get_logger().info(get_total_cost_message(self.exchange.get_token_usage()))
print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}")
print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}[/]")
def _prompt_overwrite_session(self) -> None:
print(f"[yellow]Session already exists at {self.session_file_path}.[/]")
choice = OverwriteSessionPrompt.ask("Enter your choice", show_choices=False)
match choice:
case "y" | "yes":
print("Overwriting existing session")
case "n" | "no":
while True:
new_session_name = Prompt.ask("Enter a new session name")
if not is_existing_session(session_path(new_session_name)):
self.name = new_session_name
break
print(f"[yellow]Session '{new_session_name}' already exists[/]")
case "r" | "resume":
self.exchange.messages.extend(self.load_session())
def _remove_empty_session(self) -> bool:
"""
Removes the session file only when it's empty.
Note: This is because a session file is created at the start of the run
loop. When a user aborts before their first message empty session files
will be created, causing confusion when resuming sessions (which
depends on most recent mtime and is non-empty).
Returns:
bool: True if the session file was removed, False otherwise.
"""
logger = get_logger()
try:
if is_empty_session(self.session_file_path):
logger.debug(f"deleting empty session file: {self.session_file_path}")
self.session_file_path.unlink()
return True
except Exception as e:
logger.error(f"error deleting empty session file: {e}")
return False
if __name__ == "__main__":

View File

@@ -1,7 +1,7 @@
import json
import os
from pathlib import Path
import tempfile
from pathlib import Path
from typing import Dict, Iterator, List
from exchange import Message
@@ -9,6 +9,14 @@ from exchange import Message
from goose.cli.config import SESSION_FILE_SUFFIX
def is_existing_session(path: Path) -> bool:
return path.is_file() and path.stat().st_size > 0
def is_empty_session(path: Path) -> bool:
return path.is_file() and path.stat().st_size == 0
def write_to_file(file_path: Path, messages: List[Message]) -> None:
with open(file_path, "w") as f:
_write_messages_to_file(f, messages)

View File

@@ -1,7 +1,8 @@
import json
from unittest.mock import MagicMock, patch
import pytest
from exchange import Exchange, Message, ToolUse, ToolResult
from exchange import Exchange, Message, ToolResult, ToolUse
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.prompt.user_input import PromptAction, UserInput
from goose.cli.session import Session
@@ -22,7 +23,7 @@ def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profi
with (
patch("goose.cli.session.create_exchange") as mock_exchange,
patch("goose.cli.session.load_profile", return_value=profile_factory()),
patch("goose.cli.session.SessionNotifier") as mock_session_notifier,
patch("goose.cli.session.Notifier") as mock_session_notifier,
patch("goose.cli.session.load_provider", return_value="provider"),
):
mock_session_notifier.return_value = MagicMock()
@@ -123,36 +124,79 @@ def test_log_log_cost(create_session_with_mock_configs):
mock_logger.info.assert_called_once_with(cost_message)
def test_run_should_auto_save_session(create_session_with_mock_configs, mock_sessions_path):
@patch.object(GoosePromptSession, "get_user_input", name="get_user_input")
@patch.object(Exchange, "generate", name="mock_generate")
@patch("goose.cli.session.save_latest_session", name="mock_save_latest_session")
def test_run_should_auto_save_session(
mock_save_latest_session,
mock_generate,
mock_get_user_input,
create_session_with_mock_configs,
mock_sessions_path,
):
def custom_exchange_generate(self, *args, **kwargs):
message = Message.assistant("Response")
self.add(message)
return message
def mock_generate_side_effect(*args, **kwargs):
return custom_exchange_generate(session.exchange, *args, **kwargs)
def save_latest_session(file, messages):
file.write_text("\n".join(json.dumps(m.to_dict()) for m in messages))
user_inputs = [
UserInput(action=PromptAction.CONTINUE, text="Question1"),
UserInput(action=PromptAction.CONTINUE, text="Question2"),
UserInput(action=PromptAction.EXIT),
]
session = create_session_with_mock_configs({"name": SESSION_NAME})
with (
patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs),
patch.object(Exchange, "generate") as mock_generate,
patch("goose.cli.session.save_latest_session") as mock_save_latest_session,
):
mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs)
session.run()
session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"
assert session.exchange.generate.call_count == 2
assert mock_save_latest_session.call_count == 2
assert mock_save_latest_session.call_args_list[0][0][0] == session_file
assert session_file.exists()
mock_get_user_input.side_effect = user_inputs
mock_generate.side_effect = mock_generate_side_effect
mock_save_latest_session.side_effect = save_latest_session
session.run()
session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"
assert mock_generate.call_count == 2
assert mock_save_latest_session.call_count == 2
assert mock_save_latest_session.call_args_list[0][0][0] == session_file
assert session_file.exists()
with open(session_file, "r") as f:
saved_messages = [json.loads(line) for line in f]
expected_messages = [
Message.user("Question1"),
Message.assistant("Response"),
Message.user("Question2"),
Message.assistant("Response"),
]
assert len(saved_messages) == len(expected_messages)
for saved, expected in zip(saved_messages, expected_messages):
assert saved["role"] == expected.role
assert saved["content"][0]["text"] == expected.text
def test_set_generated_session_name(create_session_with_mock_configs, mock_sessions_path):
generated_session_name = "generated_session_name"
with patch("goose.cli.session.droid", return_value=generated_session_name):
session = create_session_with_mock_configs({"name": None})
assert session.name == generated_session_name
@patch("goose.cli.session.droid", return_value="generated_session_name", name="mock_droid")
def test_set_generated_session_name(mock_droid, create_session_with_mock_configs, mock_sessions_path):
session = create_session_with_mock_configs({"name": None})
assert session.name == "generated_session_name"
@patch("goose.cli.session.is_existing_session", name="mock_is_existing")
@patch("goose.cli.session.Session._prompt_overwrite_session", name="mock_prompt")
def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_with_mock_configs):
session = create_session_with_mock_configs({"name": SESSION_NAME})
mock_is_existing.return_value = True
session.run()
mock_prompt.assert_called_once()
mock_prompt.reset_mock()
mock_is_existing.return_value = False
session.run()
mock_prompt.assert_not_called()

View File

@@ -1,9 +1,11 @@
import os
from pathlib import Path
from unittest.mock import patch
import pytest
from exchange import Message
from goose.utils.session_file import (
is_empty_session,
list_sorted_session_files,
read_from_file,
read_or_create_file,
@@ -115,3 +117,22 @@ def create_session_file(file_path, file_name) -> Path:
file = file_path / f"{file_name}.jsonl"
file.touch()
return file
@patch("pathlib.Path.is_file", return_value=True, name="mock_is_file")
@patch("pathlib.Path.stat", name="mock_stat")
def test_is_empty_session(mock_stat, mock_is_file):
mock_stat.return_value.st_size = 0
assert is_empty_session(Path("empty_file.json"))
@patch("pathlib.Path.is_file", return_value=True, name="mock_is_file")
@patch("pathlib.Path.stat", name="mock_stat")
def test_is_not_empty_session(mock_stat, mock_is_file):
mock_stat.return_value.st_size = 100
assert not is_empty_session(Path("non_empty_file.json"))
@patch("pathlib.Path.is_file", return_value=False, name="mock_is_file")
def test_is_not_empty_session_file_not_found(mock_is_file):
assert not is_empty_session(Path("file_not_found.json"))