From 6ea6d1448df1905b4f7eefbf9e3222acc838d7d3 Mon Sep 17 00:00:00 2001 From: Lam Chau Date: Sun, 13 Oct 2024 15:34:45 -0700 Subject: [PATCH] fix: session resume with arg handled incorrectly (#145) --- src/goose/cli/main.py | 13 +++-- .../cli/prompt/overwrite_session_prompt.py | 6 ++- src/goose/cli/session.py | 7 ++- .../prompt/test_overwrite_session_prompt.py | 49 +++++++++++++++++++ tests/cli/test_session.py | 23 ++++++--- 5 files changed, 82 insertions(+), 16 deletions(-) create mode 100644 tests/cli/prompt/test_overwrite_session_prompt.py diff --git a/src/goose/cli/main.py b/src/goose/cli/main.py index 7d135988..aa345b60 100644 --- a/src/goose/cli/main.py +++ b/src/goose/cli/main.py @@ -14,6 +14,9 @@ from goose.utils import load_plugins from goose.utils.autocomplete import SUPPORTED_SHELLS, setup_autocomplete from goose.utils.session_file import list_sorted_session_files +LOG_LEVELS = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] +LOG_CHOICE = click.Choice(LOG_LEVELS) + @click.group() def goose_cli() -> None: @@ -135,7 +138,7 @@ def get_session_files() -> dict[str, Path]: @click.argument("name", required=False, shell_complete=autocomplete_session_files) @click.option("--profile") @click.option("--plan", type=click.Path(exists=True)) -@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") +@click.option("--log-level", type=LOG_CHOICE, default="INFO") def session_start(name: Optional[str], profile: str, log_level: str, plan: Optional[str] = None) -> None: """Start a new goose session""" if plan: @@ -161,7 +164,7 @@ def parse_args(ctx: click.Context, param: click.Parameter, value: str) -> dict[s @session.command(name="planned") @click.option("--plan", type=click.Path(exists=True)) -@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") +@click.option("--log-level", type=LOG_CHOICE, default="INFO") @click.option("-a", "--args", callback=parse_args, help="Args in the format arg1:value1,arg2:value2") def session_planned(plan: str, log_level: str, args: Optional[dict[str, str]]) -> None: plan_templated = render_template(Path(plan), context=args) @@ -173,7 +176,7 @@ def session_planned(plan: str, log_level: str, args: Optional[dict[str, str]]) - @session.command(name="resume") @click.argument("name", required=False, shell_complete=autocomplete_session_files) @click.option("--profile") -@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") +@click.option("--log-level", type=LOG_CHOICE, default="INFO") def session_resume(name: Optional[str], profile: str, log_level: str) -> None: """Resume an existing goose session""" session_files = get_session_files() @@ -190,13 +193,13 @@ def session_resume(name: Optional[str], profile: str, log_level: str) -> None: else: print(f"Creating new session: {name}") session = Session(name=name, profile=profile, log_level=log_level) - session.run() + session.run(new_session=False) @goose_cli.command(name="run") @click.argument("message_file", required=False, type=click.Path(exists=True)) @click.option("--profile") -@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") +@click.option("--log-level", type=LOG_CHOICE, default="INFO") def run(message_file: Optional[str], profile: str, log_level: str) -> None: """Run a single-pass session with a message from a markdown input file""" if message_file: diff --git a/src/goose/cli/prompt/overwrite_session_prompt.py b/src/goose/cli/prompt/overwrite_session_prompt.py index 0f4db693..1d90cbb1 100644 --- a/src/goose/cli/prompt/overwrite_session_prompt.py +++ b/src/goose/cli/prompt/overwrite_session_prompt.py @@ -14,9 +14,11 @@ class OverwriteSessionPrompt(Prompt): self.default = "resume" def check_choice(self, choice: str) -> bool: + normalized_choice = choice.lower() for key in self.choices: - normalized_choice = choice.lower() - if normalized_choice == key or normalized_choice[0] == key[0]: + is_key = normalized_choice == key + is_first_letter = normalized_choice and normalized_choice[0] == key[0] + if is_key or is_first_letter: return True return False diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index cf9dab70..d39e66ee 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -148,12 +148,15 @@ class Session: print(f"[dim]ended run | name:[cyan]{self.name}[/] profile:[cyan]{profile}[/]") print(f"[dim]to resume: [magenta]goose session resume {self.name} --profile {profile}[/][/]") - def run(self) -> None: + def run(self, new_session: bool = True) -> None: """ Runs the main loop to handle user inputs and responses. Continues until an empty string is returned from the prompt. + + Args: + new_session (bool): True when starting a new session, False when resuming. """ - if is_existing_session(self.session_file_path): + if is_existing_session(self.session_file_path) and new_session: self._prompt_overwrite_session() profile_name = self.profile_name or "default" diff --git a/tests/cli/prompt/test_overwrite_session_prompt.py b/tests/cli/prompt/test_overwrite_session_prompt.py new file mode 100644 index 00000000..95cf825b --- /dev/null +++ b/tests/cli/prompt/test_overwrite_session_prompt.py @@ -0,0 +1,49 @@ +import pytest +from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt + + +@pytest.fixture +def prompt(): + return OverwriteSessionPrompt() + + +def test_init(prompt): + assert prompt.choices == { + "yes": "Overwrite the existing session", + "no": "Pick a new session name", + "resume": "Resume the existing session", + } + assert prompt.default == "resume" + + +@pytest.mark.parametrize( + "choice, expected", + [ + ("", False), + ("invalid", False), + ("n", True), + ("N", True), + ("no", True), + ("NO", True), + ("r", True), + ("R", True), + ("resume", True), + ("RESUME", True), + ("y", True), + ("Y", True), + ("yes", True), + ("YES", True), + ], +) +def test_check_choice(prompt, choice, expected): + assert prompt.check_choice(choice) == expected + + +def test_instantiation(): + prompt = OverwriteSessionPrompt() + assert prompt.choices == { + "yes": "Overwrite the existing session", + "no": "Pick a new session name", + "resume": "Resume the existing session", + } + assert prompt.default == "resume" diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 899baa8d..b2eafea6 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -192,11 +192,20 @@ def test_set_generated_session_name(mock_droid, create_session_with_mock_configs def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_with_mock_configs): session = create_session_with_mock_configs({"name": SESSION_NAME}) - mock_is_existing.return_value = True - session.run() - mock_prompt.assert_called_once() + def check_prompt_behavior(is_existing, new_session, should_prompt): + mock_is_existing.return_value = is_existing + if new_session is None: + session.run() + else: + session.run(new_session=new_session) - mock_prompt.reset_mock() - mock_is_existing.return_value = False - session.run() - mock_prompt.assert_not_called() + if should_prompt: + mock_prompt.assert_called_once() + else: + mock_prompt.assert_not_called() + mock_prompt.reset_mock() + + check_prompt_behavior(is_existing=True, new_session=None, should_prompt=True) + check_prompt_behavior(is_existing=False, new_session=None, should_prompt=False) + check_prompt_behavior(is_existing=True, new_session=True, should_prompt=True) + check_prompt_behavior(is_existing=False, new_session=False, should_prompt=False)