mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-31 20:04:28 +01:00
fix(agent/file_operations): Fix path processing in file_operations.py and across workspace backend implementations
- Adjusted path processing and use of `agent.workspace` in the file_operations.py module to prevent double path resolution. - Updated the `is_duplicate_operation` and `log_operation` functions in file_operations.py to use the `make_relative` argument of the `sanitize_path_arg` decorator. - Refactored the `write_to_file`, `list_folder`, and `list_files` functions in file_operations.py to accept both string and Path objects as the path argument. - Modified the GCSFileWorkspace and S3FileWorkspace classes in the file_workspace module to ensure that the root path is always an absolute path. This commit addresses issues with path processing in the file_operations.py module and across different workspace backend implementations. The changes ensure that relative paths are correctly converted to absolute paths where necessary and that the file operations logic functions consistently handle path arguments as strings or Path objects. Additionally, the GCSFileWorkspace and S3FileWorkspace classes now enforce that the root path is always an absolute path.
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
@@ -84,7 +83,7 @@ def file_operations_state(log_path: str | Path) -> dict[str, str]:
|
||||
return state
|
||||
|
||||
|
||||
@sanitize_path_arg("file_path")
|
||||
@sanitize_path_arg("file_path", make_relative=True)
|
||||
def is_duplicate_operation(
|
||||
operation: Operation, file_path: Path, agent: Agent, checksum: str | None = None
|
||||
) -> bool:
|
||||
@@ -99,10 +98,6 @@ def is_duplicate_operation(
|
||||
Returns:
|
||||
True if the operation has already been performed on the file
|
||||
"""
|
||||
# Make the file path into a relative path if possible
|
||||
with contextlib.suppress(ValueError):
|
||||
file_path = file_path.relative_to(agent.workspace.root)
|
||||
|
||||
state = file_operations_state(agent.file_manager.file_ops_log_path)
|
||||
if operation == "delete" and str(file_path) not in state:
|
||||
return True
|
||||
@@ -111,9 +106,12 @@ def is_duplicate_operation(
|
||||
return False
|
||||
|
||||
|
||||
@sanitize_path_arg("file_path")
|
||||
@sanitize_path_arg("file_path", make_relative=True)
|
||||
def log_operation(
|
||||
operation: Operation, file_path: Path, agent: Agent, checksum: str | None = None
|
||||
operation: Operation,
|
||||
file_path: str | Path,
|
||||
agent: Agent,
|
||||
checksum: str | None = None,
|
||||
) -> None:
|
||||
"""Log the file operation to the file_logger.log
|
||||
|
||||
@@ -122,10 +120,6 @@ def log_operation(
|
||||
file_path: The name of the file the operation was performed on
|
||||
checksum: The checksum of the contents to be written
|
||||
"""
|
||||
# Make the file path into a relative path if possible
|
||||
with contextlib.suppress(ValueError):
|
||||
file_path = file_path.relative_to(agent.workspace.root)
|
||||
|
||||
log_entry = f"{operation}: {file_path}"
|
||||
if checksum is not None:
|
||||
log_entry += f" #{checksum}"
|
||||
@@ -210,8 +204,7 @@ def ingest_file(
|
||||
},
|
||||
aliases=["create_file"],
|
||||
)
|
||||
@sanitize_path_arg("filename")
|
||||
async def write_to_file(filename: Path, contents: str, agent: Agent) -> str:
|
||||
async def write_to_file(filename: str | Path, contents: str, agent: Agent) -> str:
|
||||
"""Write contents to a file
|
||||
|
||||
Args:
|
||||
@@ -222,14 +215,14 @@ async def write_to_file(filename: Path, contents: str, agent: Agent) -> str:
|
||||
str: A message indicating success or failure
|
||||
"""
|
||||
checksum = text_checksum(contents)
|
||||
if is_duplicate_operation("write", filename, agent, checksum):
|
||||
raise DuplicateOperationError(f"File {filename.name} has already been updated.")
|
||||
if is_duplicate_operation("write", Path(filename), agent, checksum):
|
||||
raise DuplicateOperationError(f"File {filename} has already been updated.")
|
||||
|
||||
directory = os.path.dirname(filename)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
if directory := os.path.dirname(filename):
|
||||
agent.workspace.get_path(directory).mkdir(exist_ok=True)
|
||||
await agent.workspace.write_file(filename, contents)
|
||||
log_operation("write", filename, agent, checksum)
|
||||
return f"File {filename.name} has been written successfully."
|
||||
return f"File {filename} has been written successfully."
|
||||
|
||||
|
||||
def append_to_file(
|
||||
@@ -264,8 +257,7 @@ def append_to_file(
|
||||
)
|
||||
},
|
||||
)
|
||||
@sanitize_path_arg("folder")
|
||||
def list_folder(folder: Path, agent: Agent) -> list[str]:
|
||||
def list_folder(folder: str | Path, agent: Agent) -> list[str]:
|
||||
"""Lists files in a folder recursively
|
||||
|
||||
Args:
|
||||
@@ -274,15 +266,4 @@ def list_folder(folder: Path, agent: Agent) -> list[str]:
|
||||
Returns:
|
||||
list[str]: A list of files found in the folder
|
||||
"""
|
||||
found_files = []
|
||||
|
||||
for root, _, files in os.walk(folder):
|
||||
for file in files:
|
||||
if file.startswith("."):
|
||||
continue
|
||||
relative_path = os.path.relpath(
|
||||
os.path.join(root, file), agent.workspace.root
|
||||
)
|
||||
found_files.append(relative_path)
|
||||
|
||||
return found_files
|
||||
return [str(p) for p in agent.workspace.list(folder)]
|
||||
|
||||
@@ -16,7 +16,7 @@ def get_workspace(
|
||||
) -> FileWorkspace:
|
||||
assert bool(root_path) != bool(id), "Specify root_path or id to get workspace"
|
||||
if root_path is None:
|
||||
root_path = Path(f"workspaces/{id}")
|
||||
root_path = Path(f"/workspaces/{id}")
|
||||
|
||||
match backend:
|
||||
case FileWorkspaceBackendName.LOCAL:
|
||||
|
||||
@@ -72,8 +72,8 @@ class FileWorkspace(ABC):
|
||||
"""Write to a file in the workspace."""
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files in a directory in the workspace."""
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
|
||||
@@ -29,6 +29,7 @@ class GCSFileWorkspace(FileWorkspace):
|
||||
def __init__(self, config: GCSFileWorkspaceConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
assert self._root.is_absolute()
|
||||
|
||||
self._gcs = storage.Client()
|
||||
super().__init__()
|
||||
@@ -47,7 +48,7 @@ class GCSFileWorkspace(FileWorkspace):
|
||||
self._bucket = self._gcs.get_bucket(self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
return super().get_path(relative_path).relative_to(Path("/"))
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def open_file(self, path: str | Path, mode: str = "r"):
|
||||
"""Open a file in the workspace."""
|
||||
@@ -78,11 +79,13 @@ class GCSFileWorkspace(FileWorkspace):
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files in a directory in the workspace."""
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
path = self.get_path(path)
|
||||
blobs = self._bucket.list_blobs(prefix=str(path))
|
||||
return [Path(blob.name) for blob in blobs if not blob.name.endswith("/")]
|
||||
blobs = self._bucket.list_blobs(
|
||||
prefix=f"{path}/" if path != Path(".") else None
|
||||
)
|
||||
return [Path(blob.name).relative_to(path) for blob in blobs]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the workspace."""
|
||||
|
||||
@@ -56,10 +56,10 @@ class LocalFileWorkspace(FileWorkspace):
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = "."):
|
||||
"""List all files in a directory in the workspace."""
|
||||
full_path = self.get_path(path)
|
||||
return [str(file) for file in full_path.glob("*") if file.is_file()]
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
path = self.get_path(path)
|
||||
return [file.relative_to(path) for file in path.rglob("*") if file.is_file()]
|
||||
|
||||
def delete_file(self, path: str | Path):
|
||||
"""Delete a file in the workspace."""
|
||||
|
||||
@@ -40,6 +40,7 @@ class S3FileWorkspace(FileWorkspace):
|
||||
def __init__(self, config: S3FileWorkspaceConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
assert self._root.is_absolute()
|
||||
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
|
||||
self._s3 = boto3.resource(
|
||||
@@ -71,7 +72,7 @@ class S3FileWorkspace(FileWorkspace):
|
||||
self._bucket = self._s3.create_bucket(Bucket=self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
return super().get_path(relative_path).relative_to(Path("/"))
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def open_file(self, path: str | Path, mode: str = "r"):
|
||||
"""Open a file in the workspace."""
|
||||
@@ -99,20 +100,14 @@ class S3FileWorkspace(FileWorkspace):
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files in a directory in the workspace."""
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
path = self.get_path(path)
|
||||
if path == Path("."):
|
||||
return [
|
||||
Path(obj.key)
|
||||
for obj in self._bucket.objects.all()
|
||||
if not obj.key.endswith("/")
|
||||
]
|
||||
return [Path(obj.key) for obj in self._bucket.objects.all()]
|
||||
else:
|
||||
return [
|
||||
Path(obj.key)
|
||||
for obj in self._bucket.objects.filter(Prefix=str(path))
|
||||
if not obj.key.endswith("/")
|
||||
Path(obj.key) for obj in self._bucket.objects.filter(Prefix=f"{path}/")
|
||||
]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
|
||||
@@ -60,8 +60,8 @@ def gcs_workspace(gcs_workspace_uninitialized: GCSFileWorkspace) -> GCSFileWorks
|
||||
|
||||
TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
("existing_test_file_1", "test content 1"),
|
||||
("existing_test_file_2.txt", "test content 2"),
|
||||
(Path("existing_test_file_3"), "test content 3"),
|
||||
("/existing_test_file_2.txt", "test content 2"),
|
||||
(Path("/existing_test_file_3"), "test content 3"),
|
||||
(Path("existing/test/file/4"), "test content 4"),
|
||||
]
|
||||
|
||||
@@ -69,7 +69,9 @@ TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
@pytest_asyncio.fixture
|
||||
async def gcs_workspace_with_files(gcs_workspace: GCSFileWorkspace) -> GCSFileWorkspace:
|
||||
for file_name, file_content in TEST_FILES:
|
||||
gcs_workspace._bucket.blob(str(file_name)).upload_from_string(file_content)
|
||||
gcs_workspace._bucket.blob(
|
||||
str(gcs_workspace.get_path(file_name))
|
||||
).upload_from_string(file_content)
|
||||
yield gcs_workspace # type: ignore
|
||||
|
||||
|
||||
@@ -84,8 +86,24 @@ async def test_read_file(gcs_workspace_with_files: GCSFileWorkspace):
|
||||
|
||||
|
||||
def test_list_files(gcs_workspace_with_files: GCSFileWorkspace):
|
||||
files = gcs_workspace_with_files.list_files()
|
||||
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
|
||||
# List at root level
|
||||
assert (files := gcs_workspace_with_files.list()) == gcs_workspace_with_files.list()
|
||||
assert len(files) > 0
|
||||
assert set(files) == set(
|
||||
p.relative_to("/") if (p := Path(file_name)).is_absolute() else p
|
||||
for file_name, _ in TEST_FILES
|
||||
)
|
||||
|
||||
# List at nested path
|
||||
assert (
|
||||
nested_files := gcs_workspace_with_files.list("existing")
|
||||
) == gcs_workspace_with_files.list("existing")
|
||||
assert len(nested_files) > 0
|
||||
assert set(nested_files) == set(
|
||||
p
|
||||
for file_name, _ in TEST_FILES
|
||||
if (p := Path(file_name)).is_relative_to("existing")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -58,8 +58,8 @@ def s3_workspace(s3_workspace_uninitialized: S3FileWorkspace) -> S3FileWorkspace
|
||||
|
||||
TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
("existing_test_file_1", "test content 1"),
|
||||
("existing_test_file_2.txt", "test content 2"),
|
||||
(Path("existing_test_file_3"), "test content 3"),
|
||||
("/existing_test_file_2.txt", "test content 2"),
|
||||
(Path("/existing_test_file_3"), "test content 3"),
|
||||
(Path("existing/test/file/4"), "test content 4"),
|
||||
]
|
||||
|
||||
@@ -67,7 +67,9 @@ TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
@pytest_asyncio.fixture
|
||||
async def s3_workspace_with_files(s3_workspace: S3FileWorkspace) -> S3FileWorkspace:
|
||||
for file_name, file_content in TEST_FILES:
|
||||
s3_workspace._bucket.Object(str(file_name)).put(Body=file_content)
|
||||
s3_workspace._bucket.Object(str(s3_workspace.get_path(file_name))).put(
|
||||
Body=file_content
|
||||
)
|
||||
yield s3_workspace # type: ignore
|
||||
|
||||
|
||||
@@ -82,8 +84,24 @@ async def test_read_file(s3_workspace_with_files: S3FileWorkspace):
|
||||
|
||||
|
||||
def test_list_files(s3_workspace_with_files: S3FileWorkspace):
|
||||
files = s3_workspace_with_files.list_files()
|
||||
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
|
||||
# List at root level
|
||||
assert (files := s3_workspace_with_files.list()) == s3_workspace_with_files.list()
|
||||
assert len(files) > 0
|
||||
assert set(files) == set(
|
||||
p.relative_to("/") if (p := Path(file_name)).is_absolute() else p
|
||||
for file_name, _ in TEST_FILES
|
||||
)
|
||||
|
||||
# List at nested path
|
||||
assert (
|
||||
nested_files := s3_workspace_with_files.list("existing")
|
||||
) == s3_workspace_with_files.list("existing")
|
||||
assert len(nested_files) > 0
|
||||
assert set(nested_files) == set(
|
||||
p
|
||||
for file_name, _ in TEST_FILES
|
||||
if (p := Path(file_name)).is_relative_to("existing")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user