fix: session resume with arg handled incorrectly (#145)

This commit is contained in:
Lam Chau
2024-10-13 15:34:45 -07:00
committed by GitHub
parent fe7f27c6e5
commit 6ea6d1448d
5 changed files with 82 additions and 16 deletions

View File

@@ -14,6 +14,9 @@ from goose.utils import load_plugins
from goose.utils.autocomplete import SUPPORTED_SHELLS, setup_autocomplete
from goose.utils.session_file import list_sorted_session_files
LOG_LEVELS = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
LOG_CHOICE = click.Choice(LOG_LEVELS)
@click.group()
def goose_cli() -> None:
@@ -135,7 +138,7 @@ def get_session_files() -> dict[str, Path]:
@click.argument("name", required=False, shell_complete=autocomplete_session_files)
@click.option("--profile")
@click.option("--plan", type=click.Path(exists=True))
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
@click.option("--log-level", type=LOG_CHOICE, default="INFO")
def session_start(name: Optional[str], profile: str, log_level: str, plan: Optional[str] = None) -> None:
"""Start a new goose session"""
if plan:
@@ -161,7 +164,7 @@ def parse_args(ctx: click.Context, param: click.Parameter, value: str) -> dict[s
@session.command(name="planned")
@click.option("--plan", type=click.Path(exists=True))
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
@click.option("--log-level", type=LOG_CHOICE, default="INFO")
@click.option("-a", "--args", callback=parse_args, help="Args in the format arg1:value1,arg2:value2")
def session_planned(plan: str, log_level: str, args: Optional[dict[str, str]]) -> None:
plan_templated = render_template(Path(plan), context=args)
@@ -173,7 +176,7 @@ def session_planned(plan: str, log_level: str, args: Optional[dict[str, str]]) -
@session.command(name="resume")
@click.argument("name", required=False, shell_complete=autocomplete_session_files)
@click.option("--profile")
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
@click.option("--log-level", type=LOG_CHOICE, default="INFO")
def session_resume(name: Optional[str], profile: str, log_level: str) -> None:
"""Resume an existing goose session"""
session_files = get_session_files()
@@ -190,13 +193,13 @@ def session_resume(name: Optional[str], profile: str, log_level: str) -> None:
else:
print(f"Creating new session: {name}")
session = Session(name=name, profile=profile, log_level=log_level)
session.run()
session.run(new_session=False)
@goose_cli.command(name="run")
@click.argument("message_file", required=False, type=click.Path(exists=True))
@click.option("--profile")
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
@click.option("--log-level", type=LOG_CHOICE, default="INFO")
def run(message_file: Optional[str], profile: str, log_level: str) -> None:
"""Run a single-pass session with a message from a markdown input file"""
if message_file:

View File

@@ -14,9 +14,11 @@ class OverwriteSessionPrompt(Prompt):
self.default = "resume"
def check_choice(self, choice: str) -> bool:
normalized_choice = choice.lower()
for key in self.choices:
normalized_choice = choice.lower()
if normalized_choice == key or normalized_choice[0] == key[0]:
is_key = normalized_choice == key
is_first_letter = normalized_choice and normalized_choice[0] == key[0]
if is_key or is_first_letter:
return True
return False

View File

@@ -148,12 +148,15 @@ class Session:
print(f"[dim]ended run | name:[cyan]{self.name}[/] profile:[cyan]{profile}[/]")
print(f"[dim]to resume: [magenta]goose session resume {self.name} --profile {profile}[/][/]")
def run(self) -> None:
def run(self, new_session: bool = True) -> None:
"""
Runs the main loop to handle user inputs and responses.
Continues until an empty string is returned from the prompt.
Args:
new_session (bool): True when starting a new session, False when resuming.
"""
if is_existing_session(self.session_file_path):
if is_existing_session(self.session_file_path) and new_session:
self._prompt_overwrite_session()
profile_name = self.profile_name or "default"

View File

@@ -0,0 +1,49 @@
import pytest
from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt
@pytest.fixture
def prompt():
return OverwriteSessionPrompt()
def test_init(prompt):
assert prompt.choices == {
"yes": "Overwrite the existing session",
"no": "Pick a new session name",
"resume": "Resume the existing session",
}
assert prompt.default == "resume"
@pytest.mark.parametrize(
"choice, expected",
[
("", False),
("invalid", False),
("n", True),
("N", True),
("no", True),
("NO", True),
("r", True),
("R", True),
("resume", True),
("RESUME", True),
("y", True),
("Y", True),
("yes", True),
("YES", True),
],
)
def test_check_choice(prompt, choice, expected):
assert prompt.check_choice(choice) == expected
def test_instantiation():
prompt = OverwriteSessionPrompt()
assert prompt.choices == {
"yes": "Overwrite the existing session",
"no": "Pick a new session name",
"resume": "Resume the existing session",
}
assert prompt.default == "resume"

View File

@@ -192,11 +192,20 @@ def test_set_generated_session_name(mock_droid, create_session_with_mock_configs
def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_with_mock_configs):
session = create_session_with_mock_configs({"name": SESSION_NAME})
mock_is_existing.return_value = True
session.run()
mock_prompt.assert_called_once()
def check_prompt_behavior(is_existing, new_session, should_prompt):
mock_is_existing.return_value = is_existing
if new_session is None:
session.run()
else:
session.run(new_session=new_session)
mock_prompt.reset_mock()
mock_is_existing.return_value = False
session.run()
mock_prompt.assert_not_called()
if should_prompt:
mock_prompt.assert_called_once()
else:
mock_prompt.assert_not_called()
mock_prompt.reset_mock()
check_prompt_behavior(is_existing=True, new_session=None, should_prompt=True)
check_prompt_behavior(is_existing=False, new_session=None, should_prompt=False)
check_prompt_behavior(is_existing=True, new_session=True, should_prompt=True)
check_prompt_behavior(is_existing=False, new_session=False, should_prompt=False)