mirror of
https://github.com/aljazceru/goose.git
synced 2026-01-22 15:54:29 +01:00
feat: add guards to session management (#101)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from exchange import Exchange, Message, ToolUse, ToolResult
|
||||
from exchange import Exchange, Message, ToolResult, ToolUse
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.cli.prompt.user_input import PromptAction, UserInput
|
||||
from goose.cli.session import Session
|
||||
@@ -22,7 +23,7 @@ def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profi
|
||||
with (
|
||||
patch("goose.cli.session.create_exchange") as mock_exchange,
|
||||
patch("goose.cli.session.load_profile", return_value=profile_factory()),
|
||||
patch("goose.cli.session.SessionNotifier") as mock_session_notifier,
|
||||
patch("goose.cli.session.Notifier") as mock_session_notifier,
|
||||
patch("goose.cli.session.load_provider", return_value="provider"),
|
||||
):
|
||||
mock_session_notifier.return_value = MagicMock()
|
||||
@@ -123,36 +124,79 @@ def test_log_log_cost(create_session_with_mock_configs):
|
||||
mock_logger.info.assert_called_once_with(cost_message)
|
||||
|
||||
|
||||
def test_run_should_auto_save_session(create_session_with_mock_configs, mock_sessions_path):
|
||||
@patch.object(GoosePromptSession, "get_user_input", name="get_user_input")
|
||||
@patch.object(Exchange, "generate", name="mock_generate")
|
||||
@patch("goose.cli.session.save_latest_session", name="mock_save_latest_session")
|
||||
def test_run_should_auto_save_session(
|
||||
mock_save_latest_session,
|
||||
mock_generate,
|
||||
mock_get_user_input,
|
||||
create_session_with_mock_configs,
|
||||
mock_sessions_path,
|
||||
):
|
||||
def custom_exchange_generate(self, *args, **kwargs):
|
||||
message = Message.assistant("Response")
|
||||
self.add(message)
|
||||
return message
|
||||
|
||||
def mock_generate_side_effect(*args, **kwargs):
|
||||
return custom_exchange_generate(session.exchange, *args, **kwargs)
|
||||
|
||||
def save_latest_session(file, messages):
|
||||
file.write_text("\n".join(json.dumps(m.to_dict()) for m in messages))
|
||||
|
||||
user_inputs = [
|
||||
UserInput(action=PromptAction.CONTINUE, text="Question1"),
|
||||
UserInput(action=PromptAction.CONTINUE, text="Question2"),
|
||||
UserInput(action=PromptAction.EXIT),
|
||||
]
|
||||
|
||||
session = create_session_with_mock_configs({"name": SESSION_NAME})
|
||||
with (
|
||||
patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs),
|
||||
patch.object(Exchange, "generate") as mock_generate,
|
||||
patch("goose.cli.session.save_latest_session") as mock_save_latest_session,
|
||||
):
|
||||
mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs)
|
||||
session.run()
|
||||
|
||||
session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"
|
||||
assert session.exchange.generate.call_count == 2
|
||||
assert mock_save_latest_session.call_count == 2
|
||||
assert mock_save_latest_session.call_args_list[0][0][0] == session_file
|
||||
assert session_file.exists()
|
||||
mock_get_user_input.side_effect = user_inputs
|
||||
mock_generate.side_effect = mock_generate_side_effect
|
||||
mock_save_latest_session.side_effect = save_latest_session
|
||||
|
||||
session.run()
|
||||
|
||||
session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"
|
||||
|
||||
assert mock_generate.call_count == 2
|
||||
assert mock_save_latest_session.call_count == 2
|
||||
assert mock_save_latest_session.call_args_list[0][0][0] == session_file
|
||||
assert session_file.exists()
|
||||
|
||||
with open(session_file, "r") as f:
|
||||
saved_messages = [json.loads(line) for line in f]
|
||||
|
||||
expected_messages = [
|
||||
Message.user("Question1"),
|
||||
Message.assistant("Response"),
|
||||
Message.user("Question2"),
|
||||
Message.assistant("Response"),
|
||||
]
|
||||
|
||||
assert len(saved_messages) == len(expected_messages)
|
||||
for saved, expected in zip(saved_messages, expected_messages):
|
||||
assert saved["role"] == expected.role
|
||||
assert saved["content"][0]["text"] == expected.text
|
||||
|
||||
|
||||
def test_set_generated_session_name(create_session_with_mock_configs, mock_sessions_path):
|
||||
generated_session_name = "generated_session_name"
|
||||
with patch("goose.cli.session.droid", return_value=generated_session_name):
|
||||
session = create_session_with_mock_configs({"name": None})
|
||||
assert session.name == generated_session_name
|
||||
@patch("goose.cli.session.droid", return_value="generated_session_name", name="mock_droid")
|
||||
def test_set_generated_session_name(mock_droid, create_session_with_mock_configs, mock_sessions_path):
|
||||
session = create_session_with_mock_configs({"name": None})
|
||||
assert session.name == "generated_session_name"
|
||||
|
||||
|
||||
@patch("goose.cli.session.is_existing_session", name="mock_is_existing")
|
||||
@patch("goose.cli.session.Session._prompt_overwrite_session", name="mock_prompt")
|
||||
def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_with_mock_configs):
|
||||
session = create_session_with_mock_configs({"name": SESSION_NAME})
|
||||
|
||||
mock_is_existing.return_value = True
|
||||
session.run()
|
||||
mock_prompt.assert_called_once()
|
||||
|
||||
mock_prompt.reset_mock()
|
||||
mock_is_existing.return_value = False
|
||||
session.run()
|
||||
mock_prompt.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user