mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-27 02:54:23 +01:00
feat: add guards to session management (#101)
This commit is contained in:
30
src/goose/cli/prompt/overwrite_session_prompt.py
Normal file
30
src/goose/cli/prompt/overwrite_session_prompt.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Any
|
||||
|
||||
from rich.prompt import Prompt
|
||||
|
||||
|
||||
class OverwriteSessionPrompt(Prompt):
|
||||
def __init__(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.choices = {
|
||||
"yes": "Overwrite the existing session",
|
||||
"no": "Pick a new session name",
|
||||
"resume": "Resume the existing session",
|
||||
}
|
||||
self.default = "resume"
|
||||
|
||||
def check_choice(self, choice: str) -> bool:
|
||||
for key in self.choices:
|
||||
normalized_choice = choice.lower()
|
||||
if normalized_choice == key or normalized_choice[0] == key[0]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def pre_prompt(self) -> str:
|
||||
print("Would you like to overwrite it?")
|
||||
print()
|
||||
for key, value in self.choices.items():
|
||||
first_letter, remaining = key[0], key[1:]
|
||||
rendered_key = rf"[{first_letter}]{remaining}"
|
||||
print(f" {rendered_key:10} {value}")
|
||||
print()
|
||||
@@ -1,22 +1,24 @@
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from exchange import Message, ToolResult, ToolUse, Text
|
||||
from exchange import Message, Text, ToolResult, ToolUse
|
||||
from rich import print
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt
|
||||
from rich.status import Status
|
||||
|
||||
from goose.cli.config import ensure_config, session_path, LOG_PATH
|
||||
from goose._logger import get_logger, setup_logging
|
||||
from goose.cli.config import LOG_PATH, ensure_config, session_path
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.cli.session_notifier import SessionNotifier
|
||||
from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt
|
||||
from goose.notifier import Notifier
|
||||
from goose.profile import Profile
|
||||
from goose.utils import droid, load_plugins
|
||||
from goose.utils._cost_calculator import get_total_cost_message
|
||||
from goose.utils._create_exchange import create_exchange
|
||||
from goose.utils.session_file import read_or_create_file, save_latest_session
|
||||
from goose.utils.session_file import is_empty_session, is_existing_session, read_or_create_file, save_latest_session
|
||||
|
||||
RESUME_MESSAGE = "I see we were interrupted. How can I help you?"
|
||||
|
||||
@@ -60,7 +62,7 @@ class Session:
|
||||
profile: Optional[str] = None,
|
||||
plan: Optional[dict] = None,
|
||||
log_level: Optional[str] = "INFO",
|
||||
**kwargs: Dict[str, Any],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
if name is None:
|
||||
self.name = droid()
|
||||
@@ -69,7 +71,7 @@ class Session:
|
||||
self.profile_name = profile
|
||||
self.prompt_session = GoosePromptSession()
|
||||
self.status_indicator = Status("", spinner="dots")
|
||||
self.notifier = SessionNotifier(self.status_indicator)
|
||||
self.notifier = Notifier(self.status_indicator)
|
||||
|
||||
self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier)
|
||||
setup_logging(log_file_directory=LOG_PATH, log_level=log_level)
|
||||
@@ -81,7 +83,7 @@ class Session:
|
||||
|
||||
self.prompt_session = GoosePromptSession()
|
||||
|
||||
def _get_initial_messages(self) -> List[Message]:
|
||||
def _get_initial_messages(self) -> list[Message]:
|
||||
messages = self.load_session()
|
||||
|
||||
if messages and messages[-1].role == "user":
|
||||
@@ -151,8 +153,11 @@ class Session:
|
||||
Runs the main loop to handle user inputs and responses.
|
||||
Continues until an empty string is returned from the prompt.
|
||||
"""
|
||||
print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile_name or 'default'}[/]")
|
||||
print(f"[dim]saving to {self.session_file_path}")
|
||||
if is_existing_session(self.session_file_path):
|
||||
self._prompt_overwrite_session()
|
||||
|
||||
profile_name = self.profile_name or "default"
|
||||
print(f"[dim]starting session | name: [cyan]{self.name}[/cyan] profile: [cyan]{profile_name}[/cyan][/dim]")
|
||||
print()
|
||||
message = self.process_first_message()
|
||||
while message: # Loop until no input (empty string).
|
||||
@@ -178,6 +183,7 @@ class Session:
|
||||
user_input = self.prompt_session.get_user_input()
|
||||
message = Message.user(text=user_input.text) if user_input.to_continue() else None
|
||||
|
||||
self._remove_empty_session()
|
||||
self._log_cost()
|
||||
|
||||
def reply(self) -> None:
|
||||
@@ -234,12 +240,53 @@ class Session:
|
||||
def session_file_path(self) -> Path:
|
||||
return session_path(self.name)
|
||||
|
||||
def load_session(self) -> List[Message]:
|
||||
def load_session(self) -> list[Message]:
|
||||
return read_or_create_file(self.session_file_path)
|
||||
|
||||
def _log_cost(self) -> None:
|
||||
get_logger().info(get_total_cost_message(self.exchange.get_token_usage()))
|
||||
print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}")
|
||||
print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}[/]")
|
||||
|
||||
def _prompt_overwrite_session(self) -> None:
|
||||
print(f"[yellow]Session already exists at {self.session_file_path}.[/]")
|
||||
|
||||
choice = OverwriteSessionPrompt.ask("Enter your choice", show_choices=False)
|
||||
match choice:
|
||||
case "y" | "yes":
|
||||
print("Overwriting existing session")
|
||||
|
||||
case "n" | "no":
|
||||
while True:
|
||||
new_session_name = Prompt.ask("Enter a new session name")
|
||||
if not is_existing_session(session_path(new_session_name)):
|
||||
self.name = new_session_name
|
||||
break
|
||||
print(f"[yellow]Session '{new_session_name}' already exists[/]")
|
||||
|
||||
case "r" | "resume":
|
||||
self.exchange.messages.extend(self.load_session())
|
||||
|
||||
def _remove_empty_session(self) -> bool:
|
||||
"""
|
||||
Removes the session file only when it's empty.
|
||||
|
||||
Note: This is because a session file is created at the start of the run
|
||||
loop. When a user aborts before their first message empty session files
|
||||
will be created, causing confusion when resuming sessions (which
|
||||
depends on most recent mtime and is non-empty).
|
||||
|
||||
Returns:
|
||||
bool: True if the session file was removed, False otherwise.
|
||||
"""
|
||||
logger = get_logger()
|
||||
try:
|
||||
if is_empty_session(self.session_file_path):
|
||||
logger.debug(f"deleting empty session file: {self.session_file_path}")
|
||||
self.session_file_path.unlink()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"error deleting empty session file: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List
|
||||
|
||||
from exchange import Message
|
||||
@@ -9,6 +9,14 @@ from exchange import Message
|
||||
from goose.cli.config import SESSION_FILE_SUFFIX
|
||||
|
||||
|
||||
def is_existing_session(path: Path) -> bool:
|
||||
return path.is_file() and path.stat().st_size > 0
|
||||
|
||||
|
||||
def is_empty_session(path: Path) -> bool:
|
||||
return path.is_file() and path.stat().st_size == 0
|
||||
|
||||
|
||||
def write_to_file(file_path: Path, messages: List[Message]) -> None:
|
||||
with open(file_path, "w") as f:
|
||||
_write_messages_to_file(f, messages)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from exchange import Exchange, Message, ToolUse, ToolResult
|
||||
from exchange import Exchange, Message, ToolResult, ToolUse
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.cli.prompt.user_input import PromptAction, UserInput
|
||||
from goose.cli.session import Session
|
||||
@@ -22,7 +23,7 @@ def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profi
|
||||
with (
|
||||
patch("goose.cli.session.create_exchange") as mock_exchange,
|
||||
patch("goose.cli.session.load_profile", return_value=profile_factory()),
|
||||
patch("goose.cli.session.SessionNotifier") as mock_session_notifier,
|
||||
patch("goose.cli.session.Notifier") as mock_session_notifier,
|
||||
patch("goose.cli.session.load_provider", return_value="provider"),
|
||||
):
|
||||
mock_session_notifier.return_value = MagicMock()
|
||||
@@ -123,36 +124,79 @@ def test_log_log_cost(create_session_with_mock_configs):
|
||||
mock_logger.info.assert_called_once_with(cost_message)
|
||||
|
||||
|
||||
def test_run_should_auto_save_session(create_session_with_mock_configs, mock_sessions_path):
|
||||
@patch.object(GoosePromptSession, "get_user_input", name="get_user_input")
|
||||
@patch.object(Exchange, "generate", name="mock_generate")
|
||||
@patch("goose.cli.session.save_latest_session", name="mock_save_latest_session")
|
||||
def test_run_should_auto_save_session(
|
||||
mock_save_latest_session,
|
||||
mock_generate,
|
||||
mock_get_user_input,
|
||||
create_session_with_mock_configs,
|
||||
mock_sessions_path,
|
||||
):
|
||||
def custom_exchange_generate(self, *args, **kwargs):
|
||||
message = Message.assistant("Response")
|
||||
self.add(message)
|
||||
return message
|
||||
|
||||
def mock_generate_side_effect(*args, **kwargs):
|
||||
return custom_exchange_generate(session.exchange, *args, **kwargs)
|
||||
|
||||
def save_latest_session(file, messages):
|
||||
file.write_text("\n".join(json.dumps(m.to_dict()) for m in messages))
|
||||
|
||||
user_inputs = [
|
||||
UserInput(action=PromptAction.CONTINUE, text="Question1"),
|
||||
UserInput(action=PromptAction.CONTINUE, text="Question2"),
|
||||
UserInput(action=PromptAction.EXIT),
|
||||
]
|
||||
|
||||
session = create_session_with_mock_configs({"name": SESSION_NAME})
|
||||
with (
|
||||
patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs),
|
||||
patch.object(Exchange, "generate") as mock_generate,
|
||||
patch("goose.cli.session.save_latest_session") as mock_save_latest_session,
|
||||
):
|
||||
mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs)
|
||||
session.run()
|
||||
|
||||
session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"
|
||||
assert session.exchange.generate.call_count == 2
|
||||
assert mock_save_latest_session.call_count == 2
|
||||
assert mock_save_latest_session.call_args_list[0][0][0] == session_file
|
||||
assert session_file.exists()
|
||||
mock_get_user_input.side_effect = user_inputs
|
||||
mock_generate.side_effect = mock_generate_side_effect
|
||||
mock_save_latest_session.side_effect = save_latest_session
|
||||
|
||||
session.run()
|
||||
|
||||
session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"
|
||||
|
||||
assert mock_generate.call_count == 2
|
||||
assert mock_save_latest_session.call_count == 2
|
||||
assert mock_save_latest_session.call_args_list[0][0][0] == session_file
|
||||
assert session_file.exists()
|
||||
|
||||
with open(session_file, "r") as f:
|
||||
saved_messages = [json.loads(line) for line in f]
|
||||
|
||||
expected_messages = [
|
||||
Message.user("Question1"),
|
||||
Message.assistant("Response"),
|
||||
Message.user("Question2"),
|
||||
Message.assistant("Response"),
|
||||
]
|
||||
|
||||
assert len(saved_messages) == len(expected_messages)
|
||||
for saved, expected in zip(saved_messages, expected_messages):
|
||||
assert saved["role"] == expected.role
|
||||
assert saved["content"][0]["text"] == expected.text
|
||||
|
||||
|
||||
def test_set_generated_session_name(create_session_with_mock_configs, mock_sessions_path):
|
||||
generated_session_name = "generated_session_name"
|
||||
with patch("goose.cli.session.droid", return_value=generated_session_name):
|
||||
session = create_session_with_mock_configs({"name": None})
|
||||
assert session.name == generated_session_name
|
||||
@patch("goose.cli.session.droid", return_value="generated_session_name", name="mock_droid")
|
||||
def test_set_generated_session_name(mock_droid, create_session_with_mock_configs, mock_sessions_path):
|
||||
session = create_session_with_mock_configs({"name": None})
|
||||
assert session.name == "generated_session_name"
|
||||
|
||||
|
||||
@patch("goose.cli.session.is_existing_session", name="mock_is_existing")
|
||||
@patch("goose.cli.session.Session._prompt_overwrite_session", name="mock_prompt")
|
||||
def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_with_mock_configs):
|
||||
session = create_session_with_mock_configs({"name": SESSION_NAME})
|
||||
|
||||
mock_is_existing.return_value = True
|
||||
session.run()
|
||||
mock_prompt.assert_called_once()
|
||||
|
||||
mock_prompt.reset_mock()
|
||||
mock_is_existing.return_value = False
|
||||
session.run()
|
||||
mock_prompt.assert_not_called()
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from exchange import Message
|
||||
from goose.utils.session_file import (
|
||||
is_empty_session,
|
||||
list_sorted_session_files,
|
||||
read_from_file,
|
||||
read_or_create_file,
|
||||
@@ -115,3 +117,22 @@ def create_session_file(file_path, file_name) -> Path:
|
||||
file = file_path / f"{file_name}.jsonl"
|
||||
file.touch()
|
||||
return file
|
||||
|
||||
|
||||
@patch("pathlib.Path.is_file", return_value=True, name="mock_is_file")
|
||||
@patch("pathlib.Path.stat", name="mock_stat")
|
||||
def test_is_empty_session(mock_stat, mock_is_file):
|
||||
mock_stat.return_value.st_size = 0
|
||||
assert is_empty_session(Path("empty_file.json"))
|
||||
|
||||
|
||||
@patch("pathlib.Path.is_file", return_value=True, name="mock_is_file")
|
||||
@patch("pathlib.Path.stat", name="mock_stat")
|
||||
def test_is_not_empty_session(mock_stat, mock_is_file):
|
||||
mock_stat.return_value.st_size = 100
|
||||
assert not is_empty_session(Path("non_empty_file.json"))
|
||||
|
||||
|
||||
@patch("pathlib.Path.is_file", return_value=False, name="mock_is_file")
|
||||
def test_is_not_empty_session_file_not_found(mock_is_file):
|
||||
assert not is_empty_session(Path("file_not_found.json"))
|
||||
|
||||
Reference in New Issue
Block a user