fix(agent/file_workspace): Fix workspace initialization and listing behavior in GCS and S3 file workspaces

- Update GCSFileWorkspace.initialize() to handle cases where the bucket doesn't exist and create it if necessary
- Add logging to S3FileWorkspace.initialize() and GCSFileWorkspace.initialize()
- Update GCSFileWorkspace.list() and S3FileWorkspace.list() to correctly handle nested paths and return the relative paths of files
- Fix tests for GCSFileWorkspace and S3FileWorkspace to account for the changes in initialization and listing behavior
- Fix S3FileWorkspace.open_file() to correctly switch between binary and text mode
- Added tests to verify the fixes in workspace initialization and listing behavior
This commit is contained in:
Reinier van der Leer
2023-12-13 18:07:25 +01:00
parent d820239a7c
commit 967338193e
4 changed files with 67 additions and 47 deletions

View File

@@ -10,6 +10,7 @@ from io import IOBase
from pathlib import Path
from google.cloud import storage
from google.cloud.exceptions import NotFound
from autogpt.core.configuration.schema import UserConfigurable
@@ -46,7 +47,12 @@ class GCSFileWorkspace(FileWorkspace):
return True
def initialize(self) -> None:
self._bucket = self._gcs.get_bucket(self._bucket_name)
logger.debug(f"Initializing {repr(self)}...")
try:
self._bucket = self._gcs.get_bucket(self._bucket_name)
except NotFound:
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
self._bucket = self._gcs.create_bucket(self._bucket_name)
def get_path(self, relative_path: str | Path) -> Path:
return super().get_path(relative_path).relative_to("/")
@@ -85,13 +91,18 @@ class GCSFileWorkspace(FileWorkspace):
def list(self, path: str | Path = ".") -> list[Path]:
"""List all files (recursively) in a directory in the workspace."""
path = self.get_path(path)
blobs = self._bucket.list_blobs(
prefix=f"{path}/" if path != Path(".") else None
)
return [Path(blob.name).relative_to(path) for blob in blobs]
return [
Path(blob.name).relative_to(path)
for blob in self._bucket.list_blobs(
prefix=f"{path}/" if path != Path(".") else None
)
]
def delete_file(self, path: str | Path) -> None:
"""Delete a file in the workspace."""
path = self.get_path(path)
blob = self._bucket.blob(str(path))
blob.delete()
def __repr__(self) -> str:
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"

View File

@@ -64,12 +64,14 @@ class S3FileWorkspace(FileWorkspace):
return True
def initialize(self) -> None:
logger.debug(f"Initializing {repr(self)}...")
try:
self._s3.meta.client.head_bucket(Bucket=self._bucket_name)
self._bucket = self._s3.Bucket(self._bucket_name)
except botocore.exceptions.ClientError as e:
if "(404)" not in str(e):
raise
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
self._bucket = self._s3.create_bucket(Bucket=self._bucket_name)
def get_path(self, relative_path: str | Path) -> Path:
@@ -86,12 +88,11 @@ class S3FileWorkspace(FileWorkspace):
def open_file(self, path: str | Path, binary: bool = False) -> IOBase:
"""Open a file in the workspace."""
obj = self._get_obj(path)
return obj.get()["Body"] if not binary else TextIOWrapper(obj.get()["Body"])
return obj.get()["Body"] if binary else TextIOWrapper(obj.get()["Body"])
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the workspace."""
file_content = self.open_file(path, binary).read()
return file_content if binary else file_content.decode()
return self.open_file(path, binary).read()
async def write_file(self, path: str | Path, content: str | bytes) -> None:
"""Write to a file in the workspace."""
@@ -109,11 +110,12 @@ class S3FileWorkspace(FileWorkspace):
def list(self, path: str | Path = ".") -> list[Path]:
"""List all files (recursively) in a directory in the workspace."""
path = self.get_path(path)
if path == Path("."):
if path == Path("."): # root level of bucket
return [Path(obj.key) for obj in self._bucket.objects.all()]
else:
return [
Path(obj.key) for obj in self._bucket.objects.filter(Prefix=f"{path}/")
Path(obj.key).relative_to(path)
for obj in self._bucket.objects.filter(Prefix=f"{path}/")
]
def delete_file(self, path: str | Path) -> None:
@@ -121,3 +123,6 @@ class S3FileWorkspace(FileWorkspace):
path = self.get_path(path)
obj = self._s3.Object(self._bucket_name, str(path))
obj.delete()
def __repr__(self) -> str:
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"

View File

@@ -4,12 +4,16 @@ from pathlib import Path
import pytest
import pytest_asyncio
from google.auth.exceptions import GoogleAuthError
from google.cloud import storage
from google.cloud.exceptions import NotFound
from autogpt.file_workspace.gcs import GCSFileWorkspace, GCSFileWorkspaceConfiguration
if not os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
pytest.skip("GOOGLE_APPLICATION_CREDENTIALS are not set", allow_module_level=True)
try:
storage.Client()
except GoogleAuthError:
pytest.skip("Google Cloud Authentication not configured", allow_module_level=True)
@pytest.fixture
@@ -21,6 +25,7 @@ def gcs_bucket_name() -> str:
def gcs_workspace_uninitialized(gcs_bucket_name: str) -> GCSFileWorkspace:
os.environ["WORKSPACE_STORAGE_BUCKET"] = gcs_bucket_name
ws_config = GCSFileWorkspaceConfiguration.from_env()
ws_config.root = Path("/workspaces/AutoGPT-some-unique-task-id")
workspace = GCSFileWorkspace(ws_config)
yield workspace # type: ignore
del os.environ["WORKSPACE_STORAGE_BUCKET"]
@@ -29,16 +34,28 @@ def gcs_workspace_uninitialized(gcs_bucket_name: str) -> GCSFileWorkspace:
def test_initialize(
gcs_bucket_name: str, gcs_workspace_uninitialized: GCSFileWorkspace
):
gcs = gcs_workspace_uninitialized._bucket
gcs = gcs_workspace_uninitialized._gcs
# test that the bucket doesn't exist yet
with pytest.raises(NotFound):
gcs.get_blob(gcs_bucket_name)
gcs.get_bucket(gcs_bucket_name)
gcs_workspace_uninitialized.initialize()
# test that the bucket has been created
gcs.get_blob(gcs_bucket_name)
bucket = gcs.get_bucket(gcs_bucket_name)
# clean up
bucket.delete(force=True)
@pytest.fixture
def gcs_workspace(gcs_workspace_uninitialized: GCSFileWorkspace) -> GCSFileWorkspace:
(gcs_workspace := gcs_workspace_uninitialized).initialize()
yield gcs_workspace # type: ignore
# Empty & delete the test bucket
gcs_workspace._bucket.delete(force=True)
def test_workspace_bucket_name(
@@ -48,21 +65,12 @@ def test_workspace_bucket_name(
assert gcs_workspace._bucket.name == gcs_bucket_name
@pytest.fixture
def gcs_workspace(gcs_workspace_uninitialized: GCSFileWorkspace) -> GCSFileWorkspace:
(gcs_workspace := gcs_workspace_uninitialized).initialize()
yield gcs_workspace # type: ignore
# Empty & delete the test bucket
gcs_workspace._bucket.delete_blobs(gcs_workspace._bucket.list_blobs())
gcs_workspace._bucket.delete()
NESTED_DIR = "existing/test/dir"
TEST_FILES: list[tuple[str | Path, str]] = [
("existing_test_file_1", "test content 1"),
("/existing_test_file_2.txt", "test content 2"),
(Path("/existing_test_file_3"), "test content 3"),
(Path("existing/test/file/4"), "test content 4"),
("existing_test_file_2.txt", "test content 2"),
(Path("existing_test_file_3"), "test content 3"),
(Path(f"{NESTED_DIR}/test/file/4"), "test content 4"),
]
@@ -89,20 +97,17 @@ def test_list_files(gcs_workspace_with_files: GCSFileWorkspace):
# List at root level
assert (files := gcs_workspace_with_files.list()) == gcs_workspace_with_files.list()
assert len(files) > 0
assert set(files) == set(
p.relative_to("/") if (p := Path(file_name)).is_absolute() else p
for file_name, _ in TEST_FILES
)
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
# List at nested path
assert (
nested_files := gcs_workspace_with_files.list("existing")
) == gcs_workspace_with_files.list("existing")
nested_files := gcs_workspace_with_files.list(NESTED_DIR)
) == gcs_workspace_with_files.list(NESTED_DIR)
assert len(nested_files) > 0
assert set(nested_files) == set(
p
p.relative_to(NESTED_DIR)
for file_name, _ in TEST_FILES
if (p := Path(file_name)).is_relative_to("existing")
if (p := Path(file_name)).is_relative_to(NESTED_DIR)
)

View File

@@ -21,6 +21,7 @@ def s3_bucket_name() -> str:
def s3_workspace_uninitialized(s3_bucket_name: str) -> S3FileWorkspace:
os.environ["WORKSPACE_STORAGE_BUCKET"] = s3_bucket_name
ws_config = S3FileWorkspaceConfiguration.from_env()
ws_config.root = Path("/workspaces/AutoGPT-some-unique-task-id")
workspace = S3FileWorkspace(ws_config)
yield workspace # type: ignore
del os.environ["WORKSPACE_STORAGE_BUCKET"]
@@ -56,11 +57,12 @@ def s3_workspace(s3_workspace_uninitialized: S3FileWorkspace) -> S3FileWorkspace
s3_workspace._bucket.delete()
NESTED_DIR = "existing/test/dir"
TEST_FILES: list[tuple[str | Path, str]] = [
("existing_test_file_1", "test content 1"),
("/existing_test_file_2.txt", "test content 2"),
(Path("/existing_test_file_3"), "test content 3"),
(Path("existing/test/file/4"), "test content 4"),
("existing_test_file_2.txt", "test content 2"),
(Path("existing_test_file_3"), "test content 3"),
(Path(f"{NESTED_DIR}/test/file/4"), "test content 4"),
]
@@ -87,20 +89,17 @@ def test_list_files(s3_workspace_with_files: S3FileWorkspace):
# List at root level
assert (files := s3_workspace_with_files.list()) == s3_workspace_with_files.list()
assert len(files) > 0
assert set(files) == set(
p.relative_to("/") if (p := Path(file_name)).is_absolute() else p
for file_name, _ in TEST_FILES
)
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
# List at nested path
assert (
nested_files := s3_workspace_with_files.list("existing")
) == s3_workspace_with_files.list("existing")
nested_files := s3_workspace_with_files.list(NESTED_DIR)
) == s3_workspace_with_files.list(NESTED_DIR)
assert len(nested_files) > 0
assert set(nested_files) == set(
p
p.relative_to(NESTED_DIR)
for file_name, _ in TEST_FILES
if (p := Path(file_name)).is_relative_to("existing")
if (p := Path(file_name)).is_relative_to(NESTED_DIR)
)