Move path argument sanitization for commands to a decorator (#4918)

* Move path argument sanitization for commands to a decorator

* Fix tests

* Add `@functools.wraps` to `@sanitize_path_arg` decorator

Co-authored-by: James Collins <collijk@uw.edu>

---------

Co-authored-by: James Collins <collijk@uw.edu>
This commit is contained in:
Reinier van der Leer
2023-07-09 21:40:56 +02:00
committed by GitHub
parent 050c52a008
commit c562fbf4bc
8 changed files with 138 additions and 63 deletions

View File

@@ -169,8 +169,6 @@ class Agent:
if self.config.speak_mode:
say_text(f"I want to execute {command_name}", self.config)
arguments = self._resolve_pathlike_command_args(arguments)
except Exception as e:
logger.error("Error: \n", str(e))
self.log_cycle_handler.log_cycle(
@@ -309,14 +307,3 @@ class Agent:
logger.typewriter_log(
"SYSTEM: ", Fore.YELLOW, "Unable to execute command"
)
def _resolve_pathlike_command_args(self, command_args):
if "directory" in command_args and command_args["directory"] in {"", "/"}:
command_args["directory"] = str(self.workspace.root)
else:
for pathlike in ["filename", "directory", "clone_path"]:
if pathlike in command_args:
command_args[pathlike] = str(
self.workspace.get_path(command_args[pathlike])
)
return command_args

View File

@@ -0,0 +1,64 @@
import functools
from pathlib import Path
from typing import Callable
from autogpt.agent.agent import Agent
from autogpt.logs import logger
def sanitize_path_arg(arg_name: str):
def decorator(func: Callable):
# Get position of path parameter, in case it is passed as a positional argument
try:
arg_index = list(func.__annotations__.keys()).index(arg_name)
except ValueError:
raise TypeError(
f"Sanitized parameter '{arg_name}' absent or not annotated on function '{func.__name__}'"
)
# Get position of agent parameter, in case it is passed as a positional argument
try:
agent_arg_index = list(func.__annotations__.keys()).index("agent")
except ValueError:
raise TypeError(
f"Parameter 'agent' absent or not annotated on function '{func.__name__}'"
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
logger.debug(f"Sanitizing arg '{arg_name}' on function '{func.__name__}'")
logger.debug(f"Function annotations: {func.__annotations__}")
# Get Agent from the called function's arguments
agent = kwargs.get(
"agent", len(args) > agent_arg_index and args[agent_arg_index]
)
logger.debug(f"Args: {args}")
logger.debug(f"KWArgs: {kwargs}")
logger.debug(f"Agent argument lifted from function call: {agent}")
if not isinstance(agent, Agent):
raise RuntimeError("Could not get Agent from decorated command's args")
# Sanitize the specified path argument, if one is given
given_path: str | Path | None = kwargs.get(
arg_name, len(args) > arg_index and args[arg_index] or None
)
if given_path:
if given_path in {"", "/"}:
sanitized_path = str(agent.workspace.root)
else:
sanitized_path = str(agent.workspace.get_path(given_path))
if arg_name in kwargs:
kwargs[arg_name] = sanitized_path
else:
# args is an immutable tuple; must be converted to a list to update
arg_list = list(args)
arg_list[arg_index] = sanitized_path
args = tuple(arg_list)
return func(*args, **kwargs)
return wrapper
return decorator

View File

@@ -12,6 +12,8 @@ from autogpt.command_decorator import command
from autogpt.config import Config
from autogpt.logs import logger
from .decorators import sanitize_path_arg
ALLOWLIST_CONTROL = "allowlist"
DENYLIST_CONTROL = "denylist"
@@ -43,14 +45,14 @@ def execute_python_code(code: str, name: str, agent: Agent) -> str:
Returns:
str: The STDOUT captured from the code when it ran
"""
ai_name = agent.ai_name
ai_name = agent.ai_config.ai_name
code_dir = agent.workspace.get_path(Path(ai_name, "executed_code"))
os.makedirs(code_dir, exist_ok=True)
if not name.endswith(".py"):
name = name + ".py"
# The `name` arg is not covered by Agent._resolve_pathlike_command_args(),
# The `name` arg is not covered by @sanitize_path_arg,
# so sanitization must be done here to prevent path traversal.
file_path = agent.workspace.get_path(code_dir / name)
if not file_path.is_relative_to(code_dir):
@@ -76,6 +78,7 @@ def execute_python_code(code: str, name: str, agent: Agent) -> str:
},
},
)
@sanitize_path_arg("filename")
def execute_python_file(filename: str, agent: Agent) -> str:
"""Execute a Python file in a Docker container and return the output

View File

@@ -1,20 +1,21 @@
"""File operations for AutoGPT"""
from __future__ import annotations
import contextlib
import hashlib
import os
import os.path
from pathlib import Path
from typing import Generator, Literal
from confection import Config
from autogpt.agent.agent import Agent
from autogpt.command_decorator import command
from autogpt.commands.file_operations_utils import read_textual_file
from autogpt.config import Config
from autogpt.logs import logger
from autogpt.memory.vector import MemoryItem, VectorMemory
from .decorators import sanitize_path_arg
from .file_operations_utils import read_textual_file
Operation = Literal["write", "append", "delete"]
@@ -74,21 +75,26 @@ def file_operations_state(log_path: str) -> dict[str, str]:
return state
@sanitize_path_arg("filename")
def is_duplicate_operation(
operation: Operation, filename: str, config: Config, checksum: str | None = None
operation: Operation, filename: str, agent: Agent, checksum: str | None = None
) -> bool:
"""Check if the operation has already been performed
Args:
operation: The operation to check for
filename: The name of the file to check for
config: The agent config
agent: The agent
checksum: The checksum of the contents to be written
Returns:
True if the operation has already been performed on the file
"""
state = file_operations_state(config.file_logger_path)
# Make the filename into a relative path if possible
with contextlib.suppress(ValueError):
filename = str(Path(filename).relative_to(agent.workspace.root))
state = file_operations_state(agent.config.file_logger_path)
if operation == "delete" and filename not in state:
return True
if operation == "write" and state.get(filename) == checksum:
@@ -96,8 +102,9 @@ def is_duplicate_operation(
return False
@sanitize_path_arg("filename")
def log_operation(
operation: str, filename: str, agent: Agent, checksum: str | None = None
operation: Operation, filename: str, agent: Agent, checksum: str | None = None
) -> None:
"""Log the file operation to the file_logger.txt
@@ -106,6 +113,10 @@ def log_operation(
filename: The name of the file the operation was performed on
checksum: The checksum of the contents to be written
"""
# Make the filename into a relative path if possible
with contextlib.suppress(ValueError):
filename = str(Path(filename).relative_to(agent.workspace.root))
log_entry = f"{operation}: {filename}"
if checksum is not None:
log_entry += f" #{checksum}"
@@ -126,6 +137,7 @@ def log_operation(
}
},
)
@sanitize_path_arg("filename")
def read_file(filename: str, agent: Agent) -> str:
"""Read a file and return the contents
@@ -191,6 +203,7 @@ def ingest_file(
},
aliases=["write_file", "create_file"],
)
@sanitize_path_arg("filename")
def write_to_file(filename: str, text: str, agent: Agent) -> str:
"""Write text to a file
@@ -202,7 +215,7 @@ def write_to_file(filename: str, text: str, agent: Agent) -> str:
str: A message indicating success or failure
"""
checksum = text_checksum(text)
if is_duplicate_operation("write", filename, agent.config, checksum):
if is_duplicate_operation("write", filename, agent, checksum):
return "Error: File has already been updated."
try:
directory = os.path.dirname(filename)
@@ -231,6 +244,7 @@ def write_to_file(filename: str, text: str, agent: Agent) -> str:
},
},
)
@sanitize_path_arg("filename")
def append_to_file(
filename: str, text: str, agent: Agent, should_log: bool = True
) -> str:
@@ -271,6 +285,7 @@ def append_to_file(
}
},
)
@sanitize_path_arg("filename")
def delete_file(filename: str, agent: Agent) -> str:
"""Delete a file
@@ -280,7 +295,7 @@ def delete_file(filename: str, agent: Agent) -> str:
Returns:
str: A message indicating success or failure
"""
if is_duplicate_operation("delete", filename, agent.config):
if is_duplicate_operation("delete", filename, agent):
return "Error: File has already been deleted."
try:
os.remove(filename)
@@ -301,6 +316,7 @@ def delete_file(filename: str, agent: Agent) -> str:
}
},
)
@sanitize_path_arg("directory")
def list_files(directory: str, agent: Agent) -> list[str]:
"""lists files in a directory recursively

View File

@@ -6,6 +6,8 @@ from autogpt.agent.agent import Agent
from autogpt.command_decorator import command
from autogpt.url_utils.validators import validate_url
from .decorators import sanitize_path_arg
@command(
"clone_repository",
@@ -22,9 +24,10 @@ from autogpt.url_utils.validators import validate_url
"required": True,
},
},
lambda config: config.github_username and config.github_api_key,
lambda config: bool(config.github_username and config.github_api_key),
"Configure github_username and github_api_key.",
)
@sanitize_path_arg("clone_path")
@validate_url
def clone_repository(url: str, clone_path: str, agent: Agent) -> str:
"""Clone a GitHub repository locally.

View File

@@ -24,7 +24,7 @@ from autogpt.logs import logger
"required": True,
},
},
lambda config: config.image_provider,
lambda config: bool(config.image_provider),
"Requires a image provider to be set.",
)
def generate_image(prompt: str, agent: Agent, size: int = 256) -> str:

View File

@@ -1,5 +1,6 @@
import os
import random
import re
import string
import tempfile
@@ -88,13 +89,9 @@ def test_execute_python_file_invalid(agent: Agent):
def test_execute_python_file_not_found(agent: Agent):
assert all(
s in sut.execute_python_file("notexist.py", agent).lower()
for s in [
"python: can't open file 'notexist.py'",
"[errno 2] no such file or directory",
]
)
result = sut.execute_python_file("notexist.py", agent).lower()
assert re.match(r"python: can't open file '([A-Z]:)?[/\\\-\w]*notexist.py'", result)
assert "[errno 2] no such file or directory" in result
def test_execute_shell(random_string: str, agent: Agent):

View File

@@ -44,8 +44,13 @@ def mock_MemoryItem_from_text(
@pytest.fixture()
def test_file_path(workspace: Workspace):
return workspace.get_path("test_file.txt")
def test_file_name():
return Path("test_file.txt")
@pytest.fixture
def test_file_path(test_file_name: Path, workspace: Workspace):
return workspace.get_path(test_file_name)
@pytest.fixture()
@@ -130,42 +135,34 @@ def test_is_duplicate_operation(agent: Agent, mocker: MockerFixture):
# Test cases with write operations
assert (
file_ops.is_duplicate_operation(
"write", "path/to/file1.txt", agent.config, "checksum1"
"write", "path/to/file1.txt", agent, "checksum1"
)
is True
)
assert (
file_ops.is_duplicate_operation(
"write", "path/to/file1.txt", agent.config, "checksum2"
"write", "path/to/file1.txt", agent, "checksum2"
)
is False
)
assert (
file_ops.is_duplicate_operation(
"write", "path/to/file3.txt", agent.config, "checksum3"
"write", "path/to/file3.txt", agent, "checksum3"
)
is False
)
# Test cases with append operations
assert (
file_ops.is_duplicate_operation(
"append", "path/to/file1.txt", agent.config, "checksum1"
"append", "path/to/file1.txt", agent, "checksum1"
)
is False
)
# Test cases with delete operations
assert (
file_ops.is_duplicate_operation(
"delete", "path/to/file1.txt", config=agent.config
)
is False
)
assert (
file_ops.is_duplicate_operation(
"delete", "path/to/file3.txt", config=agent.config
)
is True
file_ops.is_duplicate_operation("delete", "path/to/file1.txt", agent) is False
)
assert file_ops.is_duplicate_operation("delete", "path/to/file3.txt", agent) is True
# Test logging a file operation
@@ -206,7 +203,15 @@ def test_read_file_not_found(agent: Agent):
assert "Error:" in content and filename in content and "no such file" in content
def test_write_to_file(test_file_path: Path, agent: Agent):
def test_write_to_file_relative_path(test_file_name: Path, agent: Agent):
new_content = "This is new content.\n"
file_ops.write_to_file(str(test_file_name), new_content, agent=agent)
with open(agent.workspace.get_path(test_file_name), "r", encoding="utf-8") as f:
content = f.read()
assert content == new_content
def test_write_to_file_absolute_path(test_file_path: Path, agent: Agent):
new_content = "This is new content.\n"
file_ops.write_to_file(str(test_file_path), new_content, agent=agent)
with open(test_file_path, "r", encoding="utf-8") as f:
@@ -214,24 +219,24 @@ def test_write_to_file(test_file_path: Path, agent: Agent):
assert content == new_content
def test_write_file_logs_checksum(test_file_path: Path, agent: Agent):
def test_write_file_logs_checksum(test_file_name: Path, agent: Agent):
new_content = "This is new content.\n"
new_checksum = file_ops.text_checksum(new_content)
file_ops.write_to_file(str(test_file_path), new_content, agent=agent)
file_ops.write_to_file(str(test_file_name), new_content, agent=agent)
with open(agent.config.file_logger_path, "r", encoding="utf-8") as f:
log_entry = f.read()
assert log_entry == f"write: {test_file_path} #{new_checksum}\n"
assert log_entry == f"write: {test_file_name} #{new_checksum}\n"
def test_write_file_fails_if_content_exists(test_file_path: Path, agent: Agent):
def test_write_file_fails_if_content_exists(test_file_name: Path, agent: Agent):
new_content = "This is new content.\n"
file_ops.log_operation(
"write",
str(test_file_path),
str(test_file_name),
agent=agent,
checksum=file_ops.text_checksum(new_content),
)
result = file_ops.write_to_file(str(test_file_path), new_content, agent=agent)
result = file_ops.write_to_file(str(test_file_name), new_content, agent=agent)
assert result == "Error: File has already been updated."
@@ -258,11 +263,11 @@ def test_append_to_file(test_nested_file: Path, agent: Agent):
def test_append_to_file_uses_checksum_from_appended_file(
test_file_path: Path, agent: Agent
test_file_name: Path, agent: Agent
):
append_text = "This is appended text.\n"
file_ops.append_to_file(test_file_path, append_text, agent=agent)
file_ops.append_to_file(test_file_path, append_text, agent=agent)
file_ops.append_to_file(test_file_name, append_text, agent=agent)
file_ops.append_to_file(test_file_name, append_text, agent=agent)
with open(agent.config.file_logger_path, "r", encoding="utf-8") as f:
log_contents = f.read()
@@ -272,8 +277,8 @@ def test_append_to_file_uses_checksum_from_appended_file(
digest.update(append_text.encode("utf-8"))
checksum2 = digest.hexdigest()
assert log_contents == (
f"append: {test_file_path} #{checksum1}\n"
f"append: {test_file_path} #{checksum2}\n"
f"append: {test_file_name} #{checksum1}\n"
f"append: {test_file_name} #{checksum2}\n"
)
@@ -288,7 +293,7 @@ def test_delete_missing_file(agent: Agent):
# confuse the log
file_ops.log_operation("write", filename, agent=agent, checksum="fake")
try:
os.remove(filename)
os.remove(agent.workspace.get_path(filename))
except FileNotFoundError as err:
assert str(err) in file_ops.delete_file(filename, agent=agent)
return