mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-27 11:04:26 +01:00
feat: give commands the ability to execute logic (#63)
This commit is contained in:
@@ -6,10 +6,16 @@ from prompt_toolkit.styles import Style
|
||||
|
||||
from goose.cli.prompt.completer import GoosePromptCompleter
|
||||
from goose.cli.prompt.lexer import PromptLexer
|
||||
from goose.command import get_commands
|
||||
from goose.command.base import Command
|
||||
|
||||
|
||||
def create_prompt() -> PromptSession:
|
||||
def create_prompt(commands: dict[str, Command]) -> PromptSession:
|
||||
"""
|
||||
Create a prompt session with the given commands.
|
||||
|
||||
Args:
|
||||
commands (dict[str, Command]): A dictionary of command names, and instances of Command classes.
|
||||
"""
|
||||
# Define custom style
|
||||
style = Style.from_dict(
|
||||
{
|
||||
@@ -52,12 +58,6 @@ def create_prompt() -> PromptSession:
|
||||
# accept completion
|
||||
buffer.complete_state = None
|
||||
|
||||
# instantiate the commands available in the prompt
|
||||
commands = dict()
|
||||
command_plugins = get_commands()
|
||||
for command, command_cls in command_plugins.items():
|
||||
commands[command] = command_cls()
|
||||
|
||||
return PromptSession(
|
||||
completer=GoosePromptCompleter(commands=commands),
|
||||
lexer=PromptLexer(list(commands.keys())),
|
||||
|
||||
@@ -1,34 +1,88 @@
|
||||
from typing import Optional
|
||||
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.formatted_text import FormattedText
|
||||
from prompt_toolkit.validation import DummyValidator
|
||||
|
||||
from goose.cli.prompt.create import create_prompt
|
||||
from goose.cli.prompt.lexer import PromptLexer
|
||||
from goose.cli.prompt.prompt_validator import PromptValidator
|
||||
from goose.cli.prompt.user_input import PromptAction, UserInput
|
||||
from goose.command import get_commands
|
||||
|
||||
|
||||
class GoosePromptSession:
|
||||
def __init__(self, prompt_session: PromptSession) -> None:
|
||||
self.prompt_session = prompt_session
|
||||
def __init__(self) -> None:
|
||||
# instantiate the commands available in the prompt
|
||||
self.commands = dict()
|
||||
command_plugins = get_commands()
|
||||
for command, command_cls in command_plugins.items():
|
||||
self.commands[command] = command_cls()
|
||||
|
||||
@staticmethod
|
||||
def create_prompt_session() -> "GoosePromptSession":
|
||||
return GoosePromptSession(create_prompt())
|
||||
# the main prompt session that is used to interact with the llm
|
||||
self.main_prompt_session = create_prompt(self.commands)
|
||||
|
||||
# a text-only prompt session that doesn't contain any commands
|
||||
self.text_prompt_session = PromptSession()
|
||||
|
||||
def get_message_after_commands(self, message: str) -> str:
|
||||
lexer = PromptLexer(command_names=list(self.commands.keys()))
|
||||
doc = Document(message)
|
||||
lines = []
|
||||
# iterate through each line of the document
|
||||
for line_num in range(len(doc.lines)):
|
||||
classes_in_line = lexer.lex_document(doc)(line_num)
|
||||
line_result = []
|
||||
i = 0
|
||||
while i < len(classes_in_line):
|
||||
# if a command is found and it is not the last part of the line
|
||||
if classes_in_line[i][0] == "class:command" and i + 1 < len(classes_in_line):
|
||||
# extract the command name
|
||||
command_name = classes_in_line[i][1].strip("/").strip(":")
|
||||
# get the value following the command
|
||||
if classes_in_line[i + 1][0] == "class:parameter":
|
||||
command_value = classes_in_line[i + 1][1]
|
||||
else:
|
||||
command_value = ""
|
||||
|
||||
# execute the command with the given argument, expecting a return value
|
||||
value_after_execution = self.commands[command_name].execute(command_value, message)
|
||||
|
||||
# if the command returns None, raise an error - this should never happen
|
||||
# since the command should always return a string
|
||||
if value_after_execution is None:
|
||||
raise ValueError(f"Command {command_name} returned None")
|
||||
|
||||
# append the result of the command execution to the line results
|
||||
line_result.append(value_after_execution)
|
||||
i += 1
|
||||
|
||||
# if the part is plain text, just append it to the line results
|
||||
elif classes_in_line[i][0] == "class:text":
|
||||
line_result.append(classes_in_line[i][1])
|
||||
i += 1
|
||||
|
||||
# join all processed parts of the current line and add it to the lines list
|
||||
lines.append("".join(line_result))
|
||||
|
||||
# join all processed lines into a single string with newline characters and return
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_user_input(self) -> "UserInput":
|
||||
try:
|
||||
text = FormattedText([("#00AEAE", "G❯ ")]) # Define the prompt style and text.
|
||||
message = self.prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False)
|
||||
message = self.main_prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False)
|
||||
if message.strip() in ("exit", ":q"):
|
||||
return UserInput(PromptAction.EXIT)
|
||||
|
||||
message = self.get_message_after_commands(message)
|
||||
return UserInput(PromptAction.CONTINUE, message)
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return UserInput(PromptAction.EXIT)
|
||||
|
||||
def get_save_session_name(self) -> Optional[str]:
|
||||
return self.prompt_session.prompt(
|
||||
return self.text_prompt_session.prompt(
|
||||
"Enter a name to save this session under. A name will be generated for you if empty: ",
|
||||
validator=DummyValidator(),
|
||||
)
|
||||
).strip(" ")
|
||||
|
||||
@@ -5,6 +5,11 @@ from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.lexers import Lexer
|
||||
|
||||
|
||||
# These are lexers for the commands in the prompt. This is how we
|
||||
# are extracting the different parts of a command (here, used for styling),
|
||||
# but likely will be used to parse the command as well in the future.
|
||||
|
||||
|
||||
def completion_for_command(target_string: str) -> re.Pattern[str]:
|
||||
escaped_string = re.escape(target_string)
|
||||
vals = [f"(?:{escaped_string[:i]}(?=$))" for i in range(len(escaped_string), 0, -1)]
|
||||
@@ -13,22 +18,21 @@ def completion_for_command(target_string: str) -> re.Pattern[str]:
|
||||
|
||||
def command_itself(target_string: str) -> re.Pattern[str]:
|
||||
escaped_string = re.escape(target_string)
|
||||
return re.compile(rf"(?<!\S)(\/{escaped_string})")
|
||||
return re.compile(rf"(?<!\S)(\/{escaped_string}:?)")
|
||||
|
||||
|
||||
def value_for_command(command_string: str) -> re.Pattern[str]:
|
||||
escaped_string = re.escape(command_string)
|
||||
return re.compile(rf"(?<=(?<!\S)\/{escaped_string})([^\s]*)")
|
||||
escaped_string = re.escape(command_string + ":")
|
||||
return re.compile(rf"(?<=(?<!\S)\/{escaped_string})(?:(?:\"(.*?)(\"|$))|([^\s]*))")
|
||||
|
||||
|
||||
class PromptLexer(Lexer):
|
||||
def __init__(self, command_names: List[str]) -> None:
|
||||
self.patterns = []
|
||||
for command_name in command_names:
|
||||
full_command = command_name + ":"
|
||||
self.patterns.append((completion_for_command(full_command), "class:command"))
|
||||
self.patterns.append((value_for_command(full_command), "class:parameter"))
|
||||
self.patterns.append((command_itself(full_command), "class:command"))
|
||||
self.patterns.append((completion_for_command(command_name), "class:command"))
|
||||
self.patterns.append((value_for_command(command_name), "class:parameter"))
|
||||
self.patterns.append((command_itself(command_name), "class:command"))
|
||||
|
||||
def lex_document(self, document: Document) -> Callable[[int], list]:
|
||||
def get_line_tokens(line_number: int) -> Tuple[str, str]:
|
||||
|
||||
@@ -121,7 +121,7 @@ class Session:
|
||||
if len(self.exchange.messages) == 0 and plan:
|
||||
self.setup_plan(plan=plan)
|
||||
|
||||
self.prompt_session = GoosePromptSession.create_prompt_session()
|
||||
self.prompt_session = GoosePromptSession()
|
||||
|
||||
def setup_plan(self, plan: dict) -> None:
|
||||
if len(self.exchange.messages):
|
||||
|
||||
@@ -8,9 +8,19 @@ class Command(ABC):
|
||||
"""A command that can be executed by the CLI."""
|
||||
|
||||
def get_completions(self, query: str) -> List[Completion]:
|
||||
"""Get completions for the command."""
|
||||
"""
|
||||
Get completions for the command.
|
||||
|
||||
Args:
|
||||
query (str): The current query.
|
||||
"""
|
||||
return []
|
||||
|
||||
def execute(self, query: str) -> Optional[str]:
|
||||
"""Execute's the command and replaces it with the output."""
|
||||
"""
|
||||
Execute's the command and replaces it with the output.
|
||||
|
||||
Args:
|
||||
query (str): The query to execute.
|
||||
"""
|
||||
return ""
|
||||
|
||||
@@ -57,5 +57,4 @@ class FileCommand(Command):
|
||||
return completions
|
||||
|
||||
def execute(self, query: str) -> str | None:
|
||||
# GOOSE-TODO: return the query
|
||||
pass
|
||||
return query
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from prompt_toolkit import PromptSession
|
||||
import pytest
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.cli.prompt.user_input import PromptAction, UserInput
|
||||
@@ -7,41 +8,48 @@ from goose.cli.prompt.user_input import PromptAction, UserInput
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_session():
|
||||
with patch("prompt_toolkit.PromptSession") as mock_prompt_session:
|
||||
with patch("goose.cli.prompt.goose_prompt_session.PromptSession") as mock_prompt_session:
|
||||
yield mock_prompt_session
|
||||
|
||||
|
||||
def test_get_save_session_name(mock_prompt_session):
|
||||
mock_prompt_session.prompt.return_value = "my_session"
|
||||
goose_prompt_session = GoosePromptSession(mock_prompt_session)
|
||||
mock_prompt_session.return_value.prompt.return_value = "my_session"
|
||||
goose_prompt_session = GoosePromptSession()
|
||||
|
||||
assert goose_prompt_session.get_save_session_name() == "my_session"
|
||||
|
||||
|
||||
def test_get_user_input_to_continue(mock_prompt_session):
|
||||
mock_prompt_session.prompt.return_value = "input_value"
|
||||
goose_prompt_session = GoosePromptSession(mock_prompt_session)
|
||||
def test_get_save_session_name_with_space(mock_prompt_session):
|
||||
mock_prompt_session.return_value.prompt.return_value = "my_session "
|
||||
goose_prompt_session = GoosePromptSession()
|
||||
|
||||
user_input = goose_prompt_session.get_user_input()
|
||||
assert goose_prompt_session.get_save_session_name() == "my_session"
|
||||
|
||||
assert user_input == UserInput(PromptAction.CONTINUE, "input_value")
|
||||
|
||||
def test_get_user_input_to_continue():
|
||||
with patch.object(PromptSession, "prompt", return_value="input_value"):
|
||||
goose_prompt_session = GoosePromptSession()
|
||||
|
||||
user_input = goose_prompt_session.get_user_input()
|
||||
|
||||
assert user_input == UserInput(PromptAction.CONTINUE, "input_value")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exit_input", ["exit", ":q"])
|
||||
def test_get_user_input_to_exit(exit_input, mock_prompt_session):
|
||||
mock_prompt_session.prompt.return_value = exit_input
|
||||
goose_prompt_session = GoosePromptSession(mock_prompt_session)
|
||||
with patch.object(PromptSession, "prompt", return_value=exit_input):
|
||||
goose_prompt_session = GoosePromptSession()
|
||||
|
||||
user_input = goose_prompt_session.get_user_input()
|
||||
user_input = goose_prompt_session.get_user_input()
|
||||
|
||||
assert user_input == UserInput(PromptAction.EXIT)
|
||||
assert user_input == UserInput(PromptAction.EXIT)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("error", [EOFError, KeyboardInterrupt])
|
||||
def test_get_user_input_to_exit_when_error_occurs(error, mock_prompt_session):
|
||||
mock_prompt_session.prompt.side_effect = error
|
||||
goose_prompt_session = GoosePromptSession(mock_prompt_session)
|
||||
with patch.object(PromptSession, "prompt", side_effect=error):
|
||||
goose_prompt_session = GoosePromptSession()
|
||||
|
||||
user_input = goose_prompt_session.get_user_input()
|
||||
user_input = goose_prompt_session.get_user_input()
|
||||
|
||||
assert user_input == UserInput(PromptAction.EXIT)
|
||||
assert user_input == UserInput(PromptAction.EXIT)
|
||||
|
||||
@@ -232,22 +232,45 @@ def test_lex_document_ending_char_of_parameter_is_symbol():
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
|
||||
def test_command_itself():
|
||||
pattern = command_itself("file:")
|
||||
matches = pattern.match("/file:example.txt")
|
||||
def assert_pattern_matches(pattern, text, expected_group):
|
||||
matches = pattern.search(text)
|
||||
assert matches is not None
|
||||
assert matches.group(1) == "/file:"
|
||||
assert matches.group() == expected_group
|
||||
|
||||
|
||||
def test_command_itself():
|
||||
pattern = command_itself("file")
|
||||
assert_pattern_matches(pattern, "/file:example.txt", "/file:")
|
||||
assert_pattern_matches(pattern, "/file asdf", "/file")
|
||||
assert_pattern_matches(pattern, "some /file", "/file")
|
||||
assert_pattern_matches(pattern, "some /file:", "/file:")
|
||||
assert_pattern_matches(pattern, "/file /file", "/file")
|
||||
|
||||
assert pattern.search("file") is None
|
||||
assert pattern.search("/anothercommand") is None
|
||||
|
||||
|
||||
def test_value_for_command():
|
||||
pattern = value_for_command("file:")
|
||||
matches = pattern.search("/file:example.txt")
|
||||
assert matches is not None
|
||||
assert matches.group(1) == "example.txt"
|
||||
pattern = value_for_command("file")
|
||||
assert_pattern_matches(pattern, "/file:example.txt", "example.txt")
|
||||
assert_pattern_matches(pattern, '/file:"example space.txt"', '"example space.txt"')
|
||||
assert_pattern_matches(pattern, '/file:"example.txt" some other string', '"example.txt"')
|
||||
assert_pattern_matches(pattern, "something before /file:example.txt", "example.txt")
|
||||
|
||||
# assert no pattern matches when there is no value
|
||||
assert pattern.search("/file:").group() == ""
|
||||
assert pattern.search("/file: other").group() == ""
|
||||
assert pattern.search("/file: ").group() == ""
|
||||
assert pattern.search("/file other") is None
|
||||
|
||||
|
||||
def test_completion_for_command():
|
||||
pattern = completion_for_command("file:")
|
||||
matches = pattern.search("/file:")
|
||||
assert matches is not None
|
||||
assert matches.group(1) == "file:"
|
||||
pattern = completion_for_command("file")
|
||||
assert_pattern_matches(pattern, "/file", "/file")
|
||||
assert_pattern_matches(pattern, "/fi", "/fi")
|
||||
assert_pattern_matches(pattern, "before /fi", "/fi")
|
||||
assert_pattern_matches(pattern, "some /f", "/f")
|
||||
|
||||
assert pattern.search("/file after") is None
|
||||
assert pattern.search("/ file") is None
|
||||
assert pattern.search("/file:") is None
|
||||
|
||||
Reference in New Issue
Block a user