mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-25 18:14:23 +01:00
refactor: improve safety rails speed and prompt (#45)
This commit is contained in:
@@ -63,6 +63,7 @@ def load_profile(name: Optional[str]) -> Profile:
|
||||
class SessionNotifier(Notifier):
|
||||
def __init__(self, status_indicator: Status) -> None:
|
||||
self.status_indicator = status_indicator
|
||||
self.live = Live(self.status_indicator, refresh_per_second=8, transient=True)
|
||||
|
||||
def log(self, content: RenderableType) -> None:
|
||||
print(content)
|
||||
@@ -70,6 +71,12 @@ class SessionNotifier(Notifier):
|
||||
def status(self, status: str) -> None:
|
||||
self.status_indicator.update(status)
|
||||
|
||||
def start(self) -> None:
|
||||
self.live.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self.live.stop()
|
||||
|
||||
|
||||
class Session:
|
||||
"""A session handler for managing interactions between a user and the Goose exchange
|
||||
@@ -87,9 +94,9 @@ class Session:
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.status_indicator = Status("", spinner="dots")
|
||||
notifier = SessionNotifier(self.status_indicator)
|
||||
self.notifier = SessionNotifier(self.status_indicator)
|
||||
|
||||
self.exchange = build_exchange(profile=load_profile(profile), notifier=notifier)
|
||||
self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier)
|
||||
|
||||
if name is not None and self.session_file_path.exists():
|
||||
messages = self.load_session()
|
||||
@@ -143,22 +150,23 @@ class Session:
|
||||
"""
|
||||
message = self.process_first_message()
|
||||
while message: # Loop until no input (empty string).
|
||||
with Live(self.status_indicator, refresh_per_second=8, transient=True):
|
||||
try:
|
||||
self.exchange.add(message)
|
||||
self.reply() # Process the user message.
|
||||
except KeyboardInterrupt:
|
||||
self.interrupt_reply()
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
if self.exchange.messages:
|
||||
self.exchange.messages.pop()
|
||||
print(
|
||||
"\n[red]The error above was an exception we were not able to handle.\n\n[/]"
|
||||
+ "These errors are often related to connection or authentication\n"
|
||||
+ "We've removed your most recent input"
|
||||
+ " - [yellow]depending on the error you may be able to continue[/]"
|
||||
)
|
||||
self.notifier.start()
|
||||
try:
|
||||
self.exchange.add(message)
|
||||
self.reply() # Process the user message.
|
||||
except KeyboardInterrupt:
|
||||
self.interrupt_reply()
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
if self.exchange.messages:
|
||||
self.exchange.messages.pop()
|
||||
print(
|
||||
"\n[red]The error above was an exception we were not able to handle.\n\n[/]"
|
||||
+ "These errors are often related to connection or authentication\n"
|
||||
+ "We've removed your most recent input"
|
||||
+ " - [yellow]depending on the error you may be able to continue[/]"
|
||||
)
|
||||
self.notifier.stop()
|
||||
|
||||
print() # Print a newline for separation.
|
||||
user_input = self.prompt_session.get_user_input()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from typing import Optional
|
||||
from rich.console import RenderableType
|
||||
|
||||
|
||||
@@ -19,10 +20,20 @@ class Notifier(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def status(self, status: str) -> None:
|
||||
def status(self, status: Optional[str]) -> None:
|
||||
"""Log a status to ephemeral display
|
||||
|
||||
Args:
|
||||
status (str): The status to display
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""Start the display for the notifier"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""Stop the display for the notifier"""
|
||||
pass
|
||||
|
||||
@@ -7,7 +7,7 @@ from exchange import Message
|
||||
from rich import box
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm, PromptType
|
||||
from rich.prompt import Confirm
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
@@ -15,12 +15,10 @@ from goose.toolkit.base import Toolkit, tool
|
||||
from goose.toolkit.utils import get_language, render_template
|
||||
|
||||
|
||||
def keep_unsafe_command_prompt(command: str) -> PromptType:
|
||||
def keep_unsafe_command_prompt(command: str) -> bool:
|
||||
command_text = Text(command, style="bold red")
|
||||
message = (
|
||||
Text("\nWe flagged the command: ")
|
||||
+ command_text
|
||||
+ Text(" as potentially unsafe, do you want to proceed? (yes/no)")
|
||||
Text("\nWe flagged the command: ") + command_text + Text(" as potentially unsafe, do you want to proceed?")
|
||||
)
|
||||
return Confirm.ask(message, default=True)
|
||||
|
||||
@@ -148,26 +146,15 @@ class Developer(Toolkit):
|
||||
# logging and integrates with the overall UI logging system
|
||||
self.notifier.log(Panel.fit(Markdown(f"```bash\n{command}\n```"), title="shell"))
|
||||
|
||||
safety_rails_exchange = self.exchange_view.processor.replace(
|
||||
system=Message.load("prompts/safety_rails.jinja").text
|
||||
)
|
||||
# remove the previous message which was a tool_use Assistant message
|
||||
safety_rails_exchange.messages.pop()
|
||||
|
||||
safety_rails_exchange.add(Message.assistant(f"Here is the command I'd like to run: `{command}`"))
|
||||
safety_rails_exchange.add(Message.user("Please provide the danger rating of that command"))
|
||||
rating = safety_rails_exchange.reply().text
|
||||
|
||||
try:
|
||||
rating = int(rating)
|
||||
except ValueError:
|
||||
rating = 5 # if we can't interpret we default to unsafe
|
||||
if is_dangerous_command(command) or int(rating) > 3:
|
||||
if is_dangerous_command(command):
|
||||
# Stop the notifications so we can prompt
|
||||
self.notifier.stop()
|
||||
if not keep_unsafe_command_prompt(command):
|
||||
raise RuntimeError(
|
||||
f"The command {command} was rejected as dangerous by the user."
|
||||
+ " Do not proceed further, instead ask for instructions."
|
||||
)
|
||||
self.notifier.start()
|
||||
self.notifier.status("running shell command")
|
||||
result: CompletedProcess = run(command, shell=True, text=True, capture_output=True, check=False)
|
||||
if result.returncode == 0:
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
You are a third party auditor, with the goal of ensuring safety for commands
|
||||
that will run on the user's system. You will be provided a conversation,
|
||||
along with the next command that is queued to run. Evaluate if that command
|
||||
is safe.
|
||||
|
||||
Here are examples of harmful commands:
|
||||
1. Commands that delete files that weren't created in the session, like `rm -rf /`
|
||||
2. Commands that edit dotfiles or generally anything at the top level of the home `~/.` directory
|
||||
3. Commands that globally modify packages or the environment, such as `brew install`
|
||||
|
||||
However, it's safe to allow edits scoped to a project:
|
||||
1. Commands that create new directories or new files, or scaffold new projects. `mkdir example`
|
||||
2. Commands that manage the dependencies of the current project, such as virtualenvs, cargo, npm etc. `poetry add requests`
|
||||
3. Commands that edit dotfiles in a project directory, like a .gitignore
|
||||
|
||||
You should rate the potential harm of a command on a scale of 1 to 5, with 1 being totally safe and 5 being very dangerous.
|
||||
Reply with a single character, the numerical rating.
|
||||
|
||||
Here are some examples.
|
||||
|
||||
Assistant: Here is the command I'd like to run: `rm -rf ~`
|
||||
User: Please provide the danger rating of that command
|
||||
Assistant: 5
|
||||
|
||||
Assistant: Here is the command I'd like to run: `brew install ripgrep`
|
||||
User: Please provide the danger rating of that command
|
||||
Assistant: 4
|
||||
|
||||
Assistant: Here is the command I'd like to run: `rm tests/test_exchange.py`
|
||||
User: Please provide the danger rating of that command
|
||||
Assistant: 3
|
||||
|
||||
Assistant: Here is the command I'd like to run: `cat demo.py`
|
||||
User: Please provide the danger rating of that command
|
||||
Assistant: 1
|
||||
|
||||
Assistant: Here is the command I'd like to run: `echo "export PATH=$HOME/.local/bin/:$PATH" >> ~/.zprofile`
|
||||
User: Please provide the danger rating of that command
|
||||
Assistant: 5
|
||||
@@ -12,10 +12,10 @@ def is_dangerous_command(command: str) -> bool:
|
||||
bool: True if the command is dangerous, False otherwise.
|
||||
"""
|
||||
dangerous_patterns = [
|
||||
# Commands that are generally unsafe
|
||||
r"\brm\b", # rm command
|
||||
r"\bgit\s+push\b", # git push command
|
||||
r"\bsudo\b", # sudo command
|
||||
# Add more dangerous command patterns here
|
||||
r"\bmv\b", # mv command
|
||||
r"\bchmod\b", # chmod command
|
||||
r"\bchown\b", # chown command
|
||||
@@ -23,9 +23,8 @@ def is_dangerous_command(command: str) -> bool:
|
||||
r"\bsystemctl\b", # systemctl command
|
||||
r"\breboot\b", # reboot command
|
||||
r"\bshutdown\b", # shutdown command
|
||||
# Manipulating files in ~/ directly or dot files
|
||||
r"^~/[^/]+$", # Files directly in home directory
|
||||
r"/\.[^/]+$", # Dot files
|
||||
# Target files that are unsafe
|
||||
r"\b~\/\.|\/\.\w+", # commands that point to files or dirs in home that start with a dot (dotfiles)
|
||||
]
|
||||
for pattern in dangerous_patterns:
|
||||
if re.search(pattern, command):
|
||||
|
||||
@@ -15,13 +15,26 @@ from goose.utils.check_shell_command import is_dangerous_command
|
||||
"systemctl stop nginx",
|
||||
"reboot",
|
||||
"shutdown now",
|
||||
"echo hello > ~/.bashrc",
|
||||
"cat ~/.hello.txt",
|
||||
"cat ~/.config/example.txt",
|
||||
],
|
||||
)
|
||||
def test_dangerous_commands(command):
|
||||
assert is_dangerous_command(command)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("command", ["ls -la", 'echo "Hello World"', "cp ~/folder/file.txt /tmp/"])
|
||||
@pytest.mark.parametrize(
|
||||
"command",
|
||||
[
|
||||
"ls -la",
|
||||
'echo "Hello World"',
|
||||
"cp ~/folder/file.txt /tmp/",
|
||||
"echo hello > ~/toplevel/sublevel.txt",
|
||||
"cat hello.txt",
|
||||
"cat ~/config/example.txt",
|
||||
"ls -la path/to/visible/file",
|
||||
"echo 'file.with.dot.txt'",
|
||||
],
|
||||
)
|
||||
def test_safe_commands(command):
|
||||
assert not is_dangerous_command(command)
|
||||
|
||||
Reference in New Issue
Block a user