refactor: improve safety rails speed and prompt (#45)

This commit is contained in:
Bradley Axen
2024-09-04 22:05:53 -07:00
committed by GitHub
parent 72d927f84f
commit a579e21037
6 changed files with 63 additions and 84 deletions

View File

@@ -63,6 +63,7 @@ def load_profile(name: Optional[str]) -> Profile:
class SessionNotifier(Notifier):
def __init__(self, status_indicator: Status) -> None:
self.status_indicator = status_indicator
self.live = Live(self.status_indicator, refresh_per_second=8, transient=True)
def log(self, content: RenderableType) -> None:
print(content)
@@ -70,6 +71,12 @@ class SessionNotifier(Notifier):
def status(self, status: str) -> None:
self.status_indicator.update(status)
def start(self) -> None:
self.live.start()
def stop(self) -> None:
self.live.stop()
class Session:
"""A session handler for managing interactions between a user and the Goose exchange
@@ -87,9 +94,9 @@ class Session:
) -> None:
self.name = name
self.status_indicator = Status("", spinner="dots")
notifier = SessionNotifier(self.status_indicator)
self.notifier = SessionNotifier(self.status_indicator)
self.exchange = build_exchange(profile=load_profile(profile), notifier=notifier)
self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier)
if name is not None and self.session_file_path.exists():
messages = self.load_session()
@@ -143,22 +150,23 @@ class Session:
"""
message = self.process_first_message()
while message: # Loop until no input (empty string).
with Live(self.status_indicator, refresh_per_second=8, transient=True):
try:
self.exchange.add(message)
self.reply() # Process the user message.
except KeyboardInterrupt:
self.interrupt_reply()
except Exception:
print(traceback.format_exc())
if self.exchange.messages:
self.exchange.messages.pop()
print(
"\n[red]The error above was an exception we were not able to handle.\n\n[/]"
+ "These errors are often related to connection or authentication\n"
+ "We've removed your most recent input"
+ " - [yellow]depending on the error you may be able to continue[/]"
)
self.notifier.start()
try:
self.exchange.add(message)
self.reply() # Process the user message.
except KeyboardInterrupt:
self.interrupt_reply()
except Exception:
print(traceback.format_exc())
if self.exchange.messages:
self.exchange.messages.pop()
print(
"\n[red]The error above was an exception we were not able to handle.\n\n[/]"
+ "These errors are often related to connection or authentication\n"
+ "We've removed your most recent input"
+ " - [yellow]depending on the error you may be able to continue[/]"
)
self.notifier.stop()
print() # Print a newline for separation.
user_input = self.prompt_session.get_user_input()

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional
from rich.console import RenderableType
@@ -19,10 +20,20 @@ class Notifier(ABC):
pass
@abstractmethod
def status(self, status: str) -> None:
def status(self, status: Optional[str]) -> None:
"""Log a status to ephemeral display
Args:
status (str): The status to display
"""
pass
@abstractmethod
def start(self) -> None:
"""Start the display for the notifier"""
pass
@abstractmethod
def stop(self) -> None:
"""Stop the display for the notifier"""
pass

View File

@@ -7,7 +7,7 @@ from exchange import Message
from rich import box
from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Confirm, PromptType
from rich.prompt import Confirm
from rich.table import Table
from rich.text import Text
@@ -15,12 +15,10 @@ from goose.toolkit.base import Toolkit, tool
from goose.toolkit.utils import get_language, render_template
def keep_unsafe_command_prompt(command: str) -> PromptType:
def keep_unsafe_command_prompt(command: str) -> bool:
command_text = Text(command, style="bold red")
message = (
Text("\nWe flagged the command: ")
+ command_text
+ Text(" as potentially unsafe, do you want to proceed? (yes/no)")
Text("\nWe flagged the command: ") + command_text + Text(" as potentially unsafe, do you want to proceed?")
)
return Confirm.ask(message, default=True)
@@ -148,26 +146,15 @@ class Developer(Toolkit):
# logging and integrates with the overall UI logging system
self.notifier.log(Panel.fit(Markdown(f"```bash\n{command}\n```"), title="shell"))
safety_rails_exchange = self.exchange_view.processor.replace(
system=Message.load("prompts/safety_rails.jinja").text
)
# remove the previous message which was a tool_use Assistant message
safety_rails_exchange.messages.pop()
safety_rails_exchange.add(Message.assistant(f"Here is the command I'd like to run: `{command}`"))
safety_rails_exchange.add(Message.user("Please provide the danger rating of that command"))
rating = safety_rails_exchange.reply().text
try:
rating = int(rating)
except ValueError:
rating = 5 # if we can't interpret we default to unsafe
if is_dangerous_command(command) or int(rating) > 3:
if is_dangerous_command(command):
# Stop the notifications so we can prompt
self.notifier.stop()
if not keep_unsafe_command_prompt(command):
raise RuntimeError(
f"The command {command} was rejected as dangerous by the user."
+ " Do not proceed further, instead ask for instructions."
)
self.notifier.start()
self.notifier.status("running shell command")
result: CompletedProcess = run(command, shell=True, text=True, capture_output=True, check=False)
if result.returncode == 0:

View File

@@ -1,39 +0,0 @@
You are a third party auditor, with the goal of ensuring safety for commands
that will run on the user's system. You will be provided a conversation,
along with the next command that is queued to run. Evaluate if that command
is safe.
Here are examples of harmful commands:
1. Commands that delete files that weren't created in the session, like `rm -rf /`
2. Commands that edit dotfiles or generally anything at the top level of the home `~/.` directory
3. Commands that globally modify packages or the environment, such as `brew install`
However, it's safe to allow edits scoped to a project:
1. Commands that create new directories or new files, or scaffold new projects. `mkdir example`
2. Commands that manage the dependencies of the current project, such as virtualenvs, cargo, npm etc. `poetry add requests`
3. Commands that edit dotfiles in a project directory, like a .gitignore
You should rate the potential harm of a command on a scale of 1 to 5, with 1 being totally safe and 5 being very dangerous.
Reply with a single character, the numerical rating.
Here are some examples.
Assistant: Here is the command I'd like to run: `rm -rf ~`
User: Please provide the danger rating of that command
Assistant: 5
Assistant: Here is the command I'd like to run: `brew install ripgrep`
User: Please provide the danger rating of that command
Assistant: 4
Assistant: Here is the command I'd like to run: `rm tests/test_exchange.py`
User: Please provide the danger rating of that command
Assistant: 3
Assistant: Here is the command I'd like to run: `cat demo.py`
User: Please provide the danger rating of that command
Assistant: 1
Assistant: Here is the command I'd like to run: `echo "export PATH=$HOME/.local/bin/:$PATH" >> ~/.zprofile`
User: Please provide the danger rating of that command
Assistant: 5

View File

@@ -12,10 +12,10 @@ def is_dangerous_command(command: str) -> bool:
bool: True if the command is dangerous, False otherwise.
"""
dangerous_patterns = [
# Commands that are generally unsafe
r"\brm\b", # rm command
r"\bgit\s+push\b", # git push command
r"\bsudo\b", # sudo command
# Add more dangerous command patterns here
r"\bmv\b", # mv command
r"\bchmod\b", # chmod command
r"\bchown\b", # chown command
@@ -23,9 +23,8 @@ def is_dangerous_command(command: str) -> bool:
r"\bsystemctl\b", # systemctl command
r"\breboot\b", # reboot command
r"\bshutdown\b", # shutdown command
# Manipulating files in ~/ directly or dot files
r"^~/[^/]+$", # Files directly in home directory
r"/\.[^/]+$", # Dot files
# Target files that are unsafe
r"\b~\/\.|\/\.\w+", # commands that point to files or dirs in home that start with a dot (dotfiles)
]
for pattern in dangerous_patterns:
if re.search(pattern, command):

View File

@@ -15,13 +15,26 @@ from goose.utils.check_shell_command import is_dangerous_command
"systemctl stop nginx",
"reboot",
"shutdown now",
"echo hello > ~/.bashrc",
"cat ~/.hello.txt",
"cat ~/.config/example.txt",
],
)
def test_dangerous_commands(command):
assert is_dangerous_command(command)
@pytest.mark.parametrize("command", ["ls -la", 'echo "Hello World"', "cp ~/folder/file.txt /tmp/"])
@pytest.mark.parametrize(
"command",
[
"ls -la",
'echo "Hello World"',
"cp ~/folder/file.txt /tmp/",
"echo hello > ~/toplevel/sublevel.txt",
"cat hello.txt",
"cat ~/config/example.txt",
"ls -la path/to/visible/file",
"echo 'file.with.dot.txt'",
],
)
def test_safe_commands(command):
assert not is_dangerous_command(command)