diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 713bd1f8..bb705373 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -63,6 +63,7 @@ def load_profile(name: Optional[str]) -> Profile: class SessionNotifier(Notifier): def __init__(self, status_indicator: Status) -> None: self.status_indicator = status_indicator + self.live = Live(self.status_indicator, refresh_per_second=8, transient=True) def log(self, content: RenderableType) -> None: print(content) @@ -70,6 +71,12 @@ class SessionNotifier(Notifier): def status(self, status: str) -> None: self.status_indicator.update(status) + def start(self) -> None: + self.live.start() + + def stop(self) -> None: + self.live.stop() + class Session: """A session handler for managing interactions between a user and the Goose exchange @@ -87,9 +94,9 @@ class Session: ) -> None: self.name = name self.status_indicator = Status("", spinner="dots") - notifier = SessionNotifier(self.status_indicator) + self.notifier = SessionNotifier(self.status_indicator) - self.exchange = build_exchange(profile=load_profile(profile), notifier=notifier) + self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier) if name is not None and self.session_file_path.exists(): messages = self.load_session() @@ -143,22 +150,23 @@ class Session: """ message = self.process_first_message() while message: # Loop until no input (empty string). - with Live(self.status_indicator, refresh_per_second=8, transient=True): - try: - self.exchange.add(message) - self.reply() # Process the user message. - except KeyboardInterrupt: - self.interrupt_reply() - except Exception: - print(traceback.format_exc()) - if self.exchange.messages: - self.exchange.messages.pop() - print( - "\n[red]The error above was an exception we were not able to handle.\n\n[/]" - + "These errors are often related to connection or authentication\n" - + "We've removed your most recent input" - + " - [yellow]depending on the error you may be able to continue[/]" - ) + self.notifier.start() + try: + self.exchange.add(message) + self.reply() # Process the user message. + except KeyboardInterrupt: + self.interrupt_reply() + except Exception: + print(traceback.format_exc()) + if self.exchange.messages: + self.exchange.messages.pop() + print( + "\n[red]The error above was an exception we were not able to handle.\n\n[/]" + + "These errors are often related to connection or authentication\n" + + "We've removed your most recent input" + + " - [yellow]depending on the error you may be able to continue[/]" + ) + self.notifier.stop() print() # Print a newline for separation. user_input = self.prompt_session.get_user_input() diff --git a/src/goose/notifier.py b/src/goose/notifier.py index 358256e1..f140c043 100644 --- a/src/goose/notifier.py +++ b/src/goose/notifier.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from typing import Optional from rich.console import RenderableType @@ -19,10 +20,20 @@ class Notifier(ABC): pass @abstractmethod - def status(self, status: str) -> None: + def status(self, status: Optional[str]) -> None: """Log a status to ephemeral display Args: status (str): The status to display """ pass + + @abstractmethod + def start(self) -> None: + """Start the display for the notifier""" + pass + + @abstractmethod + def stop(self) -> None: + """Stop the display for the notifier""" + pass diff --git a/src/goose/toolkit/developer.py b/src/goose/toolkit/developer.py index eeb084e1..062b1b40 100644 --- a/src/goose/toolkit/developer.py +++ b/src/goose/toolkit/developer.py @@ -7,7 +7,7 @@ from exchange import Message from rich import box from rich.markdown import Markdown from rich.panel import Panel -from rich.prompt import Confirm, PromptType +from rich.prompt import Confirm from rich.table import Table from rich.text import Text @@ -15,12 +15,10 @@ from goose.toolkit.base import Toolkit, tool from goose.toolkit.utils import get_language, render_template -def keep_unsafe_command_prompt(command: str) -> PromptType: +def keep_unsafe_command_prompt(command: str) -> bool: command_text = Text(command, style="bold red") message = ( - Text("\nWe flagged the command: ") - + command_text - + Text(" as potentially unsafe, do you want to proceed? (yes/no)") + Text("\nWe flagged the command: ") + command_text + Text(" as potentially unsafe, do you want to proceed?") ) return Confirm.ask(message, default=True) @@ -148,26 +146,15 @@ class Developer(Toolkit): # logging and integrates with the overall UI logging system self.notifier.log(Panel.fit(Markdown(f"```bash\n{command}\n```"), title="shell")) - safety_rails_exchange = self.exchange_view.processor.replace( - system=Message.load("prompts/safety_rails.jinja").text - ) - # remove the previous message which was a tool_use Assistant message - safety_rails_exchange.messages.pop() - - safety_rails_exchange.add(Message.assistant(f"Here is the command I'd like to run: `{command}`")) - safety_rails_exchange.add(Message.user("Please provide the danger rating of that command")) - rating = safety_rails_exchange.reply().text - - try: - rating = int(rating) - except ValueError: - rating = 5 # if we can't interpret we default to unsafe - if is_dangerous_command(command) or int(rating) > 3: + if is_dangerous_command(command): + # Stop the notifications so we can prompt + self.notifier.stop() if not keep_unsafe_command_prompt(command): raise RuntimeError( f"The command {command} was rejected as dangerous by the user." + " Do not proceed further, instead ask for instructions." ) + self.notifier.start() self.notifier.status("running shell command") result: CompletedProcess = run(command, shell=True, text=True, capture_output=True, check=False) if result.returncode == 0: diff --git a/src/goose/toolkit/prompts/safety_rails.jinja b/src/goose/toolkit/prompts/safety_rails.jinja deleted file mode 100644 index a470716b..00000000 --- a/src/goose/toolkit/prompts/safety_rails.jinja +++ /dev/null @@ -1,39 +0,0 @@ -You are a third party auditor, with the goal of ensuring safety for commands -that will run on the user's system. You will be provided a conversation, -along with the next command that is queued to run. Evaluate if that command -is safe. - -Here are examples of harmful commands: -1. Commands that delete files that weren't created in the session, like `rm -rf /` -2. Commands that edit dotfiles or generally anything at the top level of the home `~/.` directory -3. Commands that globally modify packages or the environment, such as `brew install` - -However, it's safe to allow edits scoped to a project: -1. Commands that create new directories or new files, or scaffold new projects. `mkdir example` -2. Commands that manage the dependencies of the current project, such as virtualenvs, cargo, npm etc. `poetry add requests` -3. Commands that edit dotfiles in a project directory, like a .gitignore - -You should rate the potential harm of a command on a scale of 1 to 5, with 1 being totally safe and 5 being very dangerous. -Reply with a single character, the numerical rating. - -Here are some examples. - -Assistant: Here is the command I'd like to run: `rm -rf ~` -User: Please provide the danger rating of that command -Assistant: 5 - -Assistant: Here is the command I'd like to run: `brew install ripgrep` -User: Please provide the danger rating of that command -Assistant: 4 - -Assistant: Here is the command I'd like to run: `rm tests/test_exchange.py` -User: Please provide the danger rating of that command -Assistant: 3 - -Assistant: Here is the command I'd like to run: `cat demo.py` -User: Please provide the danger rating of that command -Assistant: 1 - -Assistant: Here is the command I'd like to run: `echo "export PATH=$HOME/.local/bin/:$PATH" >> ~/.zprofile` -User: Please provide the danger rating of that command -Assistant: 5 diff --git a/src/goose/utils/check_shell_command.py b/src/goose/utils/check_shell_command.py index 6081de30..a0292261 100644 --- a/src/goose/utils/check_shell_command.py +++ b/src/goose/utils/check_shell_command.py @@ -12,10 +12,10 @@ def is_dangerous_command(command: str) -> bool: bool: True if the command is dangerous, False otherwise. """ dangerous_patterns = [ + # Commands that are generally unsafe r"\brm\b", # rm command r"\bgit\s+push\b", # git push command r"\bsudo\b", # sudo command - # Add more dangerous command patterns here r"\bmv\b", # mv command r"\bchmod\b", # chmod command r"\bchown\b", # chown command @@ -23,9 +23,8 @@ def is_dangerous_command(command: str) -> bool: r"\bsystemctl\b", # systemctl command r"\breboot\b", # reboot command r"\bshutdown\b", # shutdown command - # Manipulating files in ~/ directly or dot files - r"^~/[^/]+$", # Files directly in home directory - r"/\.[^/]+$", # Dot files + # Target files that are unsafe + r"\b~\/\.|\/\.\w+", # commands that point to files or dirs in home that start with a dot (dotfiles) ] for pattern in dangerous_patterns: if re.search(pattern, command): diff --git a/tests/utils/test_check_shell_command.py b/tests/utils/test_check_shell_command.py index 7d94a46d..f267d8e8 100644 --- a/tests/utils/test_check_shell_command.py +++ b/tests/utils/test_check_shell_command.py @@ -15,13 +15,26 @@ from goose.utils.check_shell_command import is_dangerous_command "systemctl stop nginx", "reboot", "shutdown now", - "echo hello > ~/.bashrc", + "cat ~/.hello.txt", + "cat ~/.config/example.txt", ], ) def test_dangerous_commands(command): assert is_dangerous_command(command) -@pytest.mark.parametrize("command", ["ls -la", 'echo "Hello World"', "cp ~/folder/file.txt /tmp/"]) +@pytest.mark.parametrize( + "command", + [ + "ls -la", + 'echo "Hello World"', + "cp ~/folder/file.txt /tmp/", + "echo hello > ~/toplevel/sublevel.txt", + "cat hello.txt", + "cat ~/config/example.txt", + "ls -la path/to/visible/file", + "echo 'file.with.dot.txt'", + ], +) def test_safe_commands(command): assert not is_dangerous_command(command)