feat: give commands the ability to execute logic (#63)

This commit is contained in:
Luke Alvoeiro
2024-09-17 09:50:17 -07:00
committed by GitHub
parent 005f745a00
commit dde3366bd4
8 changed files with 154 additions and 56 deletions

View File

@@ -6,10 +6,16 @@ from prompt_toolkit.styles import Style
from goose.cli.prompt.completer import GoosePromptCompleter
from goose.cli.prompt.lexer import PromptLexer
from goose.command import get_commands
from goose.command.base import Command
def create_prompt() -> PromptSession:
def create_prompt(commands: dict[str, Command]) -> PromptSession:
"""
Create a prompt session with the given commands.
Args:
commands (dict[str, Command]): A dictionary of command names, and instances of Command classes.
"""
# Define custom style
style = Style.from_dict(
{
@@ -52,12 +58,6 @@ def create_prompt() -> PromptSession:
# accept completion
buffer.complete_state = None
# instantiate the commands available in the prompt
commands = dict()
command_plugins = get_commands()
for command, command_cls in command_plugins.items():
commands[command] = command_cls()
return PromptSession(
completer=GoosePromptCompleter(commands=commands),
lexer=PromptLexer(list(commands.keys())),

View File

@@ -1,34 +1,88 @@
from typing import Optional
from prompt_toolkit import PromptSession
from prompt_toolkit.document import Document
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.validation import DummyValidator
from goose.cli.prompt.create import create_prompt
from goose.cli.prompt.lexer import PromptLexer
from goose.cli.prompt.prompt_validator import PromptValidator
from goose.cli.prompt.user_input import PromptAction, UserInput
from goose.command import get_commands
class GoosePromptSession:
def __init__(self, prompt_session: PromptSession) -> None:
self.prompt_session = prompt_session
def __init__(self) -> None:
# instantiate the commands available in the prompt
self.commands = dict()
command_plugins = get_commands()
for command, command_cls in command_plugins.items():
self.commands[command] = command_cls()
@staticmethod
def create_prompt_session() -> "GoosePromptSession":
return GoosePromptSession(create_prompt())
# the main prompt session that is used to interact with the llm
self.main_prompt_session = create_prompt(self.commands)
# a text-only prompt session that doesn't contain any commands
self.text_prompt_session = PromptSession()
def get_message_after_commands(self, message: str) -> str:
lexer = PromptLexer(command_names=list(self.commands.keys()))
doc = Document(message)
lines = []
# iterate through each line of the document
for line_num in range(len(doc.lines)):
classes_in_line = lexer.lex_document(doc)(line_num)
line_result = []
i = 0
while i < len(classes_in_line):
# if a command is found and it is not the last part of the line
if classes_in_line[i][0] == "class:command" and i + 1 < len(classes_in_line):
# extract the command name
command_name = classes_in_line[i][1].strip("/").strip(":")
# get the value following the command
if classes_in_line[i + 1][0] == "class:parameter":
command_value = classes_in_line[i + 1][1]
else:
command_value = ""
# execute the command with the given argument, expecting a return value
value_after_execution = self.commands[command_name].execute(command_value, message)
# if the command returns None, raise an error - this should never happen
# since the command should always return a string
if value_after_execution is None:
raise ValueError(f"Command {command_name} returned None")
# append the result of the command execution to the line results
line_result.append(value_after_execution)
i += 1
# if the part is plain text, just append it to the line results
elif classes_in_line[i][0] == "class:text":
line_result.append(classes_in_line[i][1])
i += 1
# join all processed parts of the current line and add it to the lines list
lines.append("".join(line_result))
# join all processed lines into a single string with newline characters and return
return "\n".join(lines)
def get_user_input(self) -> "UserInput":
try:
text = FormattedText([("#00AEAE", "G ")]) # Define the prompt style and text.
message = self.prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False)
message = self.main_prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False)
if message.strip() in ("exit", ":q"):
return UserInput(PromptAction.EXIT)
message = self.get_message_after_commands(message)
return UserInput(PromptAction.CONTINUE, message)
except (EOFError, KeyboardInterrupt):
return UserInput(PromptAction.EXIT)
def get_save_session_name(self) -> Optional[str]:
return self.prompt_session.prompt(
return self.text_prompt_session.prompt(
"Enter a name to save this session under. A name will be generated for you if empty: ",
validator=DummyValidator(),
)
).strip(" ")

View File

@@ -5,6 +5,11 @@ from prompt_toolkit.document import Document
from prompt_toolkit.lexers import Lexer
# These are lexers for the commands in the prompt. This is how we
# are extracting the different parts of a command (here, used for styling),
# but likely will be used to parse the command as well in the future.
def completion_for_command(target_string: str) -> re.Pattern[str]:
escaped_string = re.escape(target_string)
vals = [f"(?:{escaped_string[:i]}(?=$))" for i in range(len(escaped_string), 0, -1)]
@@ -13,22 +18,21 @@ def completion_for_command(target_string: str) -> re.Pattern[str]:
def command_itself(target_string: str) -> re.Pattern[str]:
escaped_string = re.escape(target_string)
return re.compile(rf"(?<!\S)(\/{escaped_string})")
return re.compile(rf"(?<!\S)(\/{escaped_string}:?)")
def value_for_command(command_string: str) -> re.Pattern[str]:
escaped_string = re.escape(command_string)
return re.compile(rf"(?<=(?<!\S)\/{escaped_string})([^\s]*)")
escaped_string = re.escape(command_string + ":")
return re.compile(rf"(?<=(?<!\S)\/{escaped_string})(?:(?:\"(.*?)(\"|$))|([^\s]*))")
class PromptLexer(Lexer):
def __init__(self, command_names: List[str]) -> None:
self.patterns = []
for command_name in command_names:
full_command = command_name + ":"
self.patterns.append((completion_for_command(full_command), "class:command"))
self.patterns.append((value_for_command(full_command), "class:parameter"))
self.patterns.append((command_itself(full_command), "class:command"))
self.patterns.append((completion_for_command(command_name), "class:command"))
self.patterns.append((value_for_command(command_name), "class:parameter"))
self.patterns.append((command_itself(command_name), "class:command"))
def lex_document(self, document: Document) -> Callable[[int], list]:
def get_line_tokens(line_number: int) -> Tuple[str, str]:

View File

@@ -121,7 +121,7 @@ class Session:
if len(self.exchange.messages) == 0 and plan:
self.setup_plan(plan=plan)
self.prompt_session = GoosePromptSession.create_prompt_session()
self.prompt_session = GoosePromptSession()
def setup_plan(self, plan: dict) -> None:
if len(self.exchange.messages):

View File

@@ -8,9 +8,19 @@ class Command(ABC):
"""A command that can be executed by the CLI."""
def get_completions(self, query: str) -> List[Completion]:
"""Get completions for the command."""
"""
Get completions for the command.
Args:
query (str): The current query.
"""
return []
def execute(self, query: str) -> Optional[str]:
"""Execute's the command and replaces it with the output."""
"""
Execute's the command and replaces it with the output.
Args:
query (str): The query to execute.
"""
return ""

View File

@@ -57,5 +57,4 @@ class FileCommand(Command):
return completions
def execute(self, query: str) -> str | None:
# GOOSE-TODO: return the query
pass
return query

View File

@@ -1,5 +1,6 @@
from unittest.mock import patch
from prompt_toolkit import PromptSession
import pytest
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.prompt.user_input import PromptAction, UserInput
@@ -7,41 +8,48 @@ from goose.cli.prompt.user_input import PromptAction, UserInput
@pytest.fixture
def mock_prompt_session():
with patch("prompt_toolkit.PromptSession") as mock_prompt_session:
with patch("goose.cli.prompt.goose_prompt_session.PromptSession") as mock_prompt_session:
yield mock_prompt_session
def test_get_save_session_name(mock_prompt_session):
mock_prompt_session.prompt.return_value = "my_session"
goose_prompt_session = GoosePromptSession(mock_prompt_session)
mock_prompt_session.return_value.prompt.return_value = "my_session"
goose_prompt_session = GoosePromptSession()
assert goose_prompt_session.get_save_session_name() == "my_session"
def test_get_user_input_to_continue(mock_prompt_session):
mock_prompt_session.prompt.return_value = "input_value"
goose_prompt_session = GoosePromptSession(mock_prompt_session)
def test_get_save_session_name_with_space(mock_prompt_session):
mock_prompt_session.return_value.prompt.return_value = "my_session "
goose_prompt_session = GoosePromptSession()
user_input = goose_prompt_session.get_user_input()
assert goose_prompt_session.get_save_session_name() == "my_session"
assert user_input == UserInput(PromptAction.CONTINUE, "input_value")
def test_get_user_input_to_continue():
with patch.object(PromptSession, "prompt", return_value="input_value"):
goose_prompt_session = GoosePromptSession()
user_input = goose_prompt_session.get_user_input()
assert user_input == UserInput(PromptAction.CONTINUE, "input_value")
@pytest.mark.parametrize("exit_input", ["exit", ":q"])
def test_get_user_input_to_exit(exit_input, mock_prompt_session):
mock_prompt_session.prompt.return_value = exit_input
goose_prompt_session = GoosePromptSession(mock_prompt_session)
with patch.object(PromptSession, "prompt", return_value=exit_input):
goose_prompt_session = GoosePromptSession()
user_input = goose_prompt_session.get_user_input()
user_input = goose_prompt_session.get_user_input()
assert user_input == UserInput(PromptAction.EXIT)
assert user_input == UserInput(PromptAction.EXIT)
@pytest.mark.parametrize("error", [EOFError, KeyboardInterrupt])
def test_get_user_input_to_exit_when_error_occurs(error, mock_prompt_session):
mock_prompt_session.prompt.side_effect = error
goose_prompt_session = GoosePromptSession(mock_prompt_session)
with patch.object(PromptSession, "prompt", side_effect=error):
goose_prompt_session = GoosePromptSession()
user_input = goose_prompt_session.get_user_input()
user_input = goose_prompt_session.get_user_input()
assert user_input == UserInput(PromptAction.EXIT)
assert user_input == UserInput(PromptAction.EXIT)

View File

@@ -232,22 +232,45 @@ def test_lex_document_ending_char_of_parameter_is_symbol():
assert actual_tokens == expected_tokens
def test_command_itself():
pattern = command_itself("file:")
matches = pattern.match("/file:example.txt")
def assert_pattern_matches(pattern, text, expected_group):
matches = pattern.search(text)
assert matches is not None
assert matches.group(1) == "/file:"
assert matches.group() == expected_group
def test_command_itself():
pattern = command_itself("file")
assert_pattern_matches(pattern, "/file:example.txt", "/file:")
assert_pattern_matches(pattern, "/file asdf", "/file")
assert_pattern_matches(pattern, "some /file", "/file")
assert_pattern_matches(pattern, "some /file:", "/file:")
assert_pattern_matches(pattern, "/file /file", "/file")
assert pattern.search("file") is None
assert pattern.search("/anothercommand") is None
def test_value_for_command():
pattern = value_for_command("file:")
matches = pattern.search("/file:example.txt")
assert matches is not None
assert matches.group(1) == "example.txt"
pattern = value_for_command("file")
assert_pattern_matches(pattern, "/file:example.txt", "example.txt")
assert_pattern_matches(pattern, '/file:"example space.txt"', '"example space.txt"')
assert_pattern_matches(pattern, '/file:"example.txt" some other string', '"example.txt"')
assert_pattern_matches(pattern, "something before /file:example.txt", "example.txt")
# assert no pattern matches when there is no value
assert pattern.search("/file:").group() == ""
assert pattern.search("/file: other").group() == ""
assert pattern.search("/file: ").group() == ""
assert pattern.search("/file other") is None
def test_completion_for_command():
pattern = completion_for_command("file:")
matches = pattern.search("/file:")
assert matches is not None
assert matches.group(1) == "file:"
pattern = completion_for_command("file")
assert_pattern_matches(pattern, "/file", "/file")
assert_pattern_matches(pattern, "/fi", "/fi")
assert_pattern_matches(pattern, "before /fi", "/fi")
assert_pattern_matches(pattern, "some /f", "/f")
assert pattern.search("/file after") is None
assert pattern.search("/ file") is None
assert pattern.search("/file:") is None