fix(agent/file_operations): Fix path processing in file_operations.py and across workspace backend implementations

- Adjusted path processing and use of `agent.workspace` in the file_operations.py module to prevent double path resolution.
- Updated the `is_duplicate_operation` and `log_operation` functions in file_operations.py to use the `make_relative` argument of the `sanitize_path_arg` decorator.
- Refactored the `write_to_file`, `list_folder`, and `list_files` functions in file_operations.py to accept both string and Path objects as the path argument.
- Modified the GCSFileWorkspace and S3FileWorkspace classes in the file_workspace module to ensure that the root path is always an absolute path.

This commit addresses issues with path processing in the file_operations.py module and across different workspace backend implementations. The changes ensure that relative paths are correctly converted to absolute paths where necessary and that the file operations logic functions consistently handle path arguments as strings or Path objects. Additionally, the GCSFileWorkspace and S3FileWorkspace classes now enforce that the root path is always an absolute path.
This commit is contained in:
Reinier van der Leer
2023-12-12 15:29:25 +01:00
parent 3e19da1258
commit 198a0ecad6
8 changed files with 81 additions and 66 deletions

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import contextlib
import hashlib
import logging
import os
@@ -84,7 +83,7 @@ def file_operations_state(log_path: str | Path) -> dict[str, str]:
return state
@sanitize_path_arg("file_path")
@sanitize_path_arg("file_path", make_relative=True)
def is_duplicate_operation(
operation: Operation, file_path: Path, agent: Agent, checksum: str | None = None
) -> bool:
@@ -99,10 +98,6 @@ def is_duplicate_operation(
Returns:
True if the operation has already been performed on the file
"""
# Make the file path into a relative path if possible
with contextlib.suppress(ValueError):
file_path = file_path.relative_to(agent.workspace.root)
state = file_operations_state(agent.file_manager.file_ops_log_path)
if operation == "delete" and str(file_path) not in state:
return True
@@ -111,9 +106,12 @@ def is_duplicate_operation(
return False
@sanitize_path_arg("file_path")
@sanitize_path_arg("file_path", make_relative=True)
def log_operation(
operation: Operation, file_path: Path, agent: Agent, checksum: str | None = None
operation: Operation,
file_path: str | Path,
agent: Agent,
checksum: str | None = None,
) -> None:
"""Log the file operation to the file_logger.log
@@ -122,10 +120,6 @@ def log_operation(
file_path: The name of the file the operation was performed on
checksum: The checksum of the contents to be written
"""
# Make the file path into a relative path if possible
with contextlib.suppress(ValueError):
file_path = file_path.relative_to(agent.workspace.root)
log_entry = f"{operation}: {file_path}"
if checksum is not None:
log_entry += f" #{checksum}"
@@ -210,8 +204,7 @@ def ingest_file(
},
aliases=["create_file"],
)
@sanitize_path_arg("filename")
async def write_to_file(filename: Path, contents: str, agent: Agent) -> str:
async def write_to_file(filename: str | Path, contents: str, agent: Agent) -> str:
"""Write contents to a file
Args:
@@ -222,14 +215,14 @@ async def write_to_file(filename: Path, contents: str, agent: Agent) -> str:
str: A message indicating success or failure
"""
checksum = text_checksum(contents)
if is_duplicate_operation("write", filename, agent, checksum):
raise DuplicateOperationError(f"File {filename.name} has already been updated.")
if is_duplicate_operation("write", Path(filename), agent, checksum):
raise DuplicateOperationError(f"File {filename} has already been updated.")
directory = os.path.dirname(filename)
os.makedirs(directory, exist_ok=True)
if directory := os.path.dirname(filename):
agent.workspace.get_path(directory).mkdir(exist_ok=True)
await agent.workspace.write_file(filename, contents)
log_operation("write", filename, agent, checksum)
return f"File {filename.name} has been written successfully."
return f"File {filename} has been written successfully."
def append_to_file(
@@ -264,8 +257,7 @@ def append_to_file(
)
},
)
@sanitize_path_arg("folder")
def list_folder(folder: Path, agent: Agent) -> list[str]:
def list_folder(folder: str | Path, agent: Agent) -> list[str]:
"""Lists files in a folder recursively
Args:
@@ -274,15 +266,4 @@ def list_folder(folder: Path, agent: Agent) -> list[str]:
Returns:
list[str]: A list of files found in the folder
"""
found_files = []
for root, _, files in os.walk(folder):
for file in files:
if file.startswith("."):
continue
relative_path = os.path.relpath(
os.path.join(root, file), agent.workspace.root
)
found_files.append(relative_path)
return found_files
return [str(p) for p in agent.workspace.list(folder)]

View File

@@ -16,7 +16,7 @@ def get_workspace(
) -> FileWorkspace:
assert bool(root_path) != bool(id), "Specify root_path or id to get workspace"
if root_path is None:
root_path = Path(f"workspaces/{id}")
root_path = Path(f"/workspaces/{id}")
match backend:
case FileWorkspaceBackendName.LOCAL:

View File

@@ -72,8 +72,8 @@ class FileWorkspace(ABC):
"""Write to a file in the workspace."""
@abstractmethod
def list_files(self, path: str | Path = ".") -> list[Path]:
"""List all files in a directory in the workspace."""
def list(self, path: str | Path = ".") -> list[Path]:
"""List all files (recursively) in a directory in the workspace."""
@abstractmethod
def delete_file(self, path: str | Path) -> None:

View File

@@ -29,6 +29,7 @@ class GCSFileWorkspace(FileWorkspace):
def __init__(self, config: GCSFileWorkspaceConfiguration):
self._bucket_name = config.bucket
self._root = config.root
assert self._root.is_absolute()
self._gcs = storage.Client()
super().__init__()
@@ -47,7 +48,7 @@ class GCSFileWorkspace(FileWorkspace):
self._bucket = self._gcs.get_bucket(self._bucket_name)
def get_path(self, relative_path: str | Path) -> Path:
return super().get_path(relative_path).relative_to(Path("/"))
return super().get_path(relative_path).relative_to("/")
def open_file(self, path: str | Path, mode: str = "r"):
"""Open a file in the workspace."""
@@ -78,11 +79,13 @@ class GCSFileWorkspace(FileWorkspace):
if inspect.isawaitable(res):
await res
def list_files(self, path: str | Path = ".") -> list[Path]:
"""List all files in a directory in the workspace."""
def list(self, path: str | Path = ".") -> list[Path]:
"""List all files (recursively) in a directory in the workspace."""
path = self.get_path(path)
blobs = self._bucket.list_blobs(prefix=str(path))
return [Path(blob.name) for blob in blobs if not blob.name.endswith("/")]
blobs = self._bucket.list_blobs(
prefix=f"{path}/" if path != Path(".") else None
)
return [Path(blob.name).relative_to(path) for blob in blobs]
def delete_file(self, path: str | Path) -> None:
"""Delete a file in the workspace."""

View File

@@ -56,10 +56,10 @@ class LocalFileWorkspace(FileWorkspace):
if inspect.isawaitable(res):
await res
def list_files(self, path: str | Path = "."):
"""List all files in a directory in the workspace."""
full_path = self.get_path(path)
return [str(file) for file in full_path.glob("*") if file.is_file()]
def list(self, path: str | Path = ".") -> list[Path]:
"""List all files (recursively) in a directory in the workspace."""
path = self.get_path(path)
return [file.relative_to(path) for file in path.rglob("*") if file.is_file()]
def delete_file(self, path: str | Path):
"""Delete a file in the workspace."""

View File

@@ -40,6 +40,7 @@ class S3FileWorkspace(FileWorkspace):
def __init__(self, config: S3FileWorkspaceConfiguration):
self._bucket_name = config.bucket
self._root = config.root
assert self._root.is_absolute()
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
self._s3 = boto3.resource(
@@ -71,7 +72,7 @@ class S3FileWorkspace(FileWorkspace):
self._bucket = self._s3.create_bucket(Bucket=self._bucket_name)
def get_path(self, relative_path: str | Path) -> Path:
return super().get_path(relative_path).relative_to(Path("/"))
return super().get_path(relative_path).relative_to("/")
def open_file(self, path: str | Path, mode: str = "r"):
"""Open a file in the workspace."""
@@ -99,20 +100,14 @@ class S3FileWorkspace(FileWorkspace):
if inspect.isawaitable(res):
await res
def list_files(self, path: str | Path = ".") -> list[Path]:
"""List all files in a directory in the workspace."""
def list(self, path: str | Path = ".") -> list[Path]:
"""List all files (recursively) in a directory in the workspace."""
path = self.get_path(path)
if path == Path("."):
return [
Path(obj.key)
for obj in self._bucket.objects.all()
if not obj.key.endswith("/")
]
return [Path(obj.key) for obj in self._bucket.objects.all()]
else:
return [
Path(obj.key)
for obj in self._bucket.objects.filter(Prefix=str(path))
if not obj.key.endswith("/")
Path(obj.key) for obj in self._bucket.objects.filter(Prefix=f"{path}/")
]
def delete_file(self, path: str | Path) -> None:

View File

@@ -60,8 +60,8 @@ def gcs_workspace(gcs_workspace_uninitialized: GCSFileWorkspace) -> GCSFileWorks
TEST_FILES: list[tuple[str | Path, str]] = [
("existing_test_file_1", "test content 1"),
("existing_test_file_2.txt", "test content 2"),
(Path("existing_test_file_3"), "test content 3"),
("/existing_test_file_2.txt", "test content 2"),
(Path("/existing_test_file_3"), "test content 3"),
(Path("existing/test/file/4"), "test content 4"),
]
@@ -69,7 +69,9 @@ TEST_FILES: list[tuple[str | Path, str]] = [
@pytest_asyncio.fixture
async def gcs_workspace_with_files(gcs_workspace: GCSFileWorkspace) -> GCSFileWorkspace:
for file_name, file_content in TEST_FILES:
gcs_workspace._bucket.blob(str(file_name)).upload_from_string(file_content)
gcs_workspace._bucket.blob(
str(gcs_workspace.get_path(file_name))
).upload_from_string(file_content)
yield gcs_workspace # type: ignore
@@ -84,8 +86,24 @@ async def test_read_file(gcs_workspace_with_files: GCSFileWorkspace):
def test_list_files(gcs_workspace_with_files: GCSFileWorkspace):
files = gcs_workspace_with_files.list_files()
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
# List at root level
assert (files := gcs_workspace_with_files.list()) == gcs_workspace_with_files.list()
assert len(files) > 0
assert set(files) == set(
p.relative_to("/") if (p := Path(file_name)).is_absolute() else p
for file_name, _ in TEST_FILES
)
# List at nested path
assert (
nested_files := gcs_workspace_with_files.list("existing")
) == gcs_workspace_with_files.list("existing")
assert len(nested_files) > 0
assert set(nested_files) == set(
p
for file_name, _ in TEST_FILES
if (p := Path(file_name)).is_relative_to("existing")
)
@pytest.mark.asyncio

View File

@@ -58,8 +58,8 @@ def s3_workspace(s3_workspace_uninitialized: S3FileWorkspace) -> S3FileWorkspace
TEST_FILES: list[tuple[str | Path, str]] = [
("existing_test_file_1", "test content 1"),
("existing_test_file_2.txt", "test content 2"),
(Path("existing_test_file_3"), "test content 3"),
("/existing_test_file_2.txt", "test content 2"),
(Path("/existing_test_file_3"), "test content 3"),
(Path("existing/test/file/4"), "test content 4"),
]
@@ -67,7 +67,9 @@ TEST_FILES: list[tuple[str | Path, str]] = [
@pytest_asyncio.fixture
async def s3_workspace_with_files(s3_workspace: S3FileWorkspace) -> S3FileWorkspace:
for file_name, file_content in TEST_FILES:
s3_workspace._bucket.Object(str(file_name)).put(Body=file_content)
s3_workspace._bucket.Object(str(s3_workspace.get_path(file_name))).put(
Body=file_content
)
yield s3_workspace # type: ignore
@@ -82,8 +84,24 @@ async def test_read_file(s3_workspace_with_files: S3FileWorkspace):
def test_list_files(s3_workspace_with_files: S3FileWorkspace):
files = s3_workspace_with_files.list_files()
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
# List at root level
assert (files := s3_workspace_with_files.list()) == s3_workspace_with_files.list()
assert len(files) > 0
assert set(files) == set(
p.relative_to("/") if (p := Path(file_name)).is_absolute() else p
for file_name, _ in TEST_FILES
)
# List at nested path
assert (
nested_files := s3_workspace_with_files.list("existing")
) == s3_workspace_with_files.list("existing")
assert len(nested_files) > 0
assert set(nested_files) == set(
p
for file_name, _ in TEST_FILES
if (p := Path(file_name)).is_relative_to("existing")
)
@pytest.mark.asyncio