mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-02-23 07:04:24 +01:00
fix(agent/file_workspace): Fix workspace initialization and listing behavior in GCS and S3 file workspaces
- Update GCSFileWorkspace.initialize() to handle cases where the bucket doesn't exist and create it if necessary - Add logging to S3FileWorkspace.initialize() and GCSFileWorkspace.initialize() - Update GCSFileWorkspace.list() and S3FileWorkspace.list() to correctly handle nested paths and return the relative paths of files - Fix tests for GCSFileWorkspace and S3FileWorkspace to account for the changes in initialization and listing behavior - Fix S3FileWorkspace.open_file() to correctly switch between binary and text mode - Added tests to verify the fixes in workspace initialization and listing behavior
This commit is contained in:
@@ -10,6 +10,7 @@ from io import IOBase
|
||||
from pathlib import Path
|
||||
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
|
||||
from autogpt.core.configuration.schema import UserConfigurable
|
||||
|
||||
@@ -46,7 +47,12 @@ class GCSFileWorkspace(FileWorkspace):
|
||||
return True
|
||||
|
||||
def initialize(self) -> None:
|
||||
self._bucket = self._gcs.get_bucket(self._bucket_name)
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._bucket = self._gcs.get_bucket(self._bucket_name)
|
||||
except NotFound:
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._gcs.create_bucket(self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
@@ -85,13 +91,18 @@ class GCSFileWorkspace(FileWorkspace):
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
path = self.get_path(path)
|
||||
blobs = self._bucket.list_blobs(
|
||||
prefix=f"{path}/" if path != Path(".") else None
|
||||
)
|
||||
return [Path(blob.name).relative_to(path) for blob in blobs]
|
||||
return [
|
||||
Path(blob.name).relative_to(path)
|
||||
for blob in self._bucket.list_blobs(
|
||||
prefix=f"{path}/" if path != Path(".") else None
|
||||
)
|
||||
]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the workspace."""
|
||||
path = self.get_path(path)
|
||||
blob = self._bucket.blob(str(path))
|
||||
blob.delete()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
|
||||
@@ -64,12 +64,14 @@ class S3FileWorkspace(FileWorkspace):
|
||||
return True
|
||||
|
||||
def initialize(self) -> None:
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._s3.meta.client.head_bucket(Bucket=self._bucket_name)
|
||||
self._bucket = self._s3.Bucket(self._bucket_name)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if "(404)" not in str(e):
|
||||
raise
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._s3.create_bucket(Bucket=self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
@@ -86,12 +88,11 @@ class S3FileWorkspace(FileWorkspace):
|
||||
def open_file(self, path: str | Path, binary: bool = False) -> IOBase:
|
||||
"""Open a file in the workspace."""
|
||||
obj = self._get_obj(path)
|
||||
return obj.get()["Body"] if not binary else TextIOWrapper(obj.get()["Body"])
|
||||
return obj.get()["Body"] if binary else TextIOWrapper(obj.get()["Body"])
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the workspace."""
|
||||
file_content = self.open_file(path, binary).read()
|
||||
return file_content if binary else file_content.decode()
|
||||
return self.open_file(path, binary).read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the workspace."""
|
||||
@@ -109,11 +110,12 @@ class S3FileWorkspace(FileWorkspace):
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
path = self.get_path(path)
|
||||
if path == Path("."):
|
||||
if path == Path("."): # root level of bucket
|
||||
return [Path(obj.key) for obj in self._bucket.objects.all()]
|
||||
else:
|
||||
return [
|
||||
Path(obj.key) for obj in self._bucket.objects.filter(Prefix=f"{path}/")
|
||||
Path(obj.key).relative_to(path)
|
||||
for obj in self._bucket.objects.filter(Prefix=f"{path}/")
|
||||
]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
@@ -121,3 +123,6 @@ class S3FileWorkspace(FileWorkspace):
|
||||
path = self.get_path(path)
|
||||
obj = self._s3.Object(self._bucket_name, str(path))
|
||||
obj.delete()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
|
||||
@@ -4,12 +4,16 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from google.auth.exceptions import GoogleAuthError
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
|
||||
from autogpt.file_workspace.gcs import GCSFileWorkspace, GCSFileWorkspaceConfiguration
|
||||
|
||||
if not os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
|
||||
pytest.skip("GOOGLE_APPLICATION_CREDENTIALS are not set", allow_module_level=True)
|
||||
try:
|
||||
storage.Client()
|
||||
except GoogleAuthError:
|
||||
pytest.skip("Google Cloud Authentication not configured", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -21,6 +25,7 @@ def gcs_bucket_name() -> str:
|
||||
def gcs_workspace_uninitialized(gcs_bucket_name: str) -> GCSFileWorkspace:
|
||||
os.environ["WORKSPACE_STORAGE_BUCKET"] = gcs_bucket_name
|
||||
ws_config = GCSFileWorkspaceConfiguration.from_env()
|
||||
ws_config.root = Path("/workspaces/AutoGPT-some-unique-task-id")
|
||||
workspace = GCSFileWorkspace(ws_config)
|
||||
yield workspace # type: ignore
|
||||
del os.environ["WORKSPACE_STORAGE_BUCKET"]
|
||||
@@ -29,16 +34,28 @@ def gcs_workspace_uninitialized(gcs_bucket_name: str) -> GCSFileWorkspace:
|
||||
def test_initialize(
|
||||
gcs_bucket_name: str, gcs_workspace_uninitialized: GCSFileWorkspace
|
||||
):
|
||||
gcs = gcs_workspace_uninitialized._bucket
|
||||
gcs = gcs_workspace_uninitialized._gcs
|
||||
|
||||
# test that the bucket doesn't exist yet
|
||||
with pytest.raises(NotFound):
|
||||
gcs.get_blob(gcs_bucket_name)
|
||||
gcs.get_bucket(gcs_bucket_name)
|
||||
|
||||
gcs_workspace_uninitialized.initialize()
|
||||
|
||||
# test that the bucket has been created
|
||||
gcs.get_blob(gcs_bucket_name)
|
||||
bucket = gcs.get_bucket(gcs_bucket_name)
|
||||
|
||||
# clean up
|
||||
bucket.delete(force=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gcs_workspace(gcs_workspace_uninitialized: GCSFileWorkspace) -> GCSFileWorkspace:
|
||||
(gcs_workspace := gcs_workspace_uninitialized).initialize()
|
||||
yield gcs_workspace # type: ignore
|
||||
|
||||
# Empty & delete the test bucket
|
||||
gcs_workspace._bucket.delete(force=True)
|
||||
|
||||
|
||||
def test_workspace_bucket_name(
|
||||
@@ -48,21 +65,12 @@ def test_workspace_bucket_name(
|
||||
assert gcs_workspace._bucket.name == gcs_bucket_name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gcs_workspace(gcs_workspace_uninitialized: GCSFileWorkspace) -> GCSFileWorkspace:
|
||||
(gcs_workspace := gcs_workspace_uninitialized).initialize()
|
||||
yield gcs_workspace # type: ignore
|
||||
|
||||
# Empty & delete the test bucket
|
||||
gcs_workspace._bucket.delete_blobs(gcs_workspace._bucket.list_blobs())
|
||||
gcs_workspace._bucket.delete()
|
||||
|
||||
|
||||
NESTED_DIR = "existing/test/dir"
|
||||
TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
("existing_test_file_1", "test content 1"),
|
||||
("/existing_test_file_2.txt", "test content 2"),
|
||||
(Path("/existing_test_file_3"), "test content 3"),
|
||||
(Path("existing/test/file/4"), "test content 4"),
|
||||
("existing_test_file_2.txt", "test content 2"),
|
||||
(Path("existing_test_file_3"), "test content 3"),
|
||||
(Path(f"{NESTED_DIR}/test/file/4"), "test content 4"),
|
||||
]
|
||||
|
||||
|
||||
@@ -89,20 +97,17 @@ def test_list_files(gcs_workspace_with_files: GCSFileWorkspace):
|
||||
# List at root level
|
||||
assert (files := gcs_workspace_with_files.list()) == gcs_workspace_with_files.list()
|
||||
assert len(files) > 0
|
||||
assert set(files) == set(
|
||||
p.relative_to("/") if (p := Path(file_name)).is_absolute() else p
|
||||
for file_name, _ in TEST_FILES
|
||||
)
|
||||
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
|
||||
|
||||
# List at nested path
|
||||
assert (
|
||||
nested_files := gcs_workspace_with_files.list("existing")
|
||||
) == gcs_workspace_with_files.list("existing")
|
||||
nested_files := gcs_workspace_with_files.list(NESTED_DIR)
|
||||
) == gcs_workspace_with_files.list(NESTED_DIR)
|
||||
assert len(nested_files) > 0
|
||||
assert set(nested_files) == set(
|
||||
p
|
||||
p.relative_to(NESTED_DIR)
|
||||
for file_name, _ in TEST_FILES
|
||||
if (p := Path(file_name)).is_relative_to("existing")
|
||||
if (p := Path(file_name)).is_relative_to(NESTED_DIR)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ def s3_bucket_name() -> str:
|
||||
def s3_workspace_uninitialized(s3_bucket_name: str) -> S3FileWorkspace:
|
||||
os.environ["WORKSPACE_STORAGE_BUCKET"] = s3_bucket_name
|
||||
ws_config = S3FileWorkspaceConfiguration.from_env()
|
||||
ws_config.root = Path("/workspaces/AutoGPT-some-unique-task-id")
|
||||
workspace = S3FileWorkspace(ws_config)
|
||||
yield workspace # type: ignore
|
||||
del os.environ["WORKSPACE_STORAGE_BUCKET"]
|
||||
@@ -56,11 +57,12 @@ def s3_workspace(s3_workspace_uninitialized: S3FileWorkspace) -> S3FileWorkspace
|
||||
s3_workspace._bucket.delete()
|
||||
|
||||
|
||||
NESTED_DIR = "existing/test/dir"
|
||||
TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
("existing_test_file_1", "test content 1"),
|
||||
("/existing_test_file_2.txt", "test content 2"),
|
||||
(Path("/existing_test_file_3"), "test content 3"),
|
||||
(Path("existing/test/file/4"), "test content 4"),
|
||||
("existing_test_file_2.txt", "test content 2"),
|
||||
(Path("existing_test_file_3"), "test content 3"),
|
||||
(Path(f"{NESTED_DIR}/test/file/4"), "test content 4"),
|
||||
]
|
||||
|
||||
|
||||
@@ -87,20 +89,17 @@ def test_list_files(s3_workspace_with_files: S3FileWorkspace):
|
||||
# List at root level
|
||||
assert (files := s3_workspace_with_files.list()) == s3_workspace_with_files.list()
|
||||
assert len(files) > 0
|
||||
assert set(files) == set(
|
||||
p.relative_to("/") if (p := Path(file_name)).is_absolute() else p
|
||||
for file_name, _ in TEST_FILES
|
||||
)
|
||||
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
|
||||
|
||||
# List at nested path
|
||||
assert (
|
||||
nested_files := s3_workspace_with_files.list("existing")
|
||||
) == s3_workspace_with_files.list("existing")
|
||||
nested_files := s3_workspace_with_files.list(NESTED_DIR)
|
||||
) == s3_workspace_with_files.list(NESTED_DIR)
|
||||
assert len(nested_files) > 0
|
||||
assert set(nested_files) == set(
|
||||
p
|
||||
p.relative_to(NESTED_DIR)
|
||||
for file_name, _ in TEST_FILES
|
||||
if (p := Path(file_name)).is_relative_to("existing")
|
||||
if (p := Path(file_name)).is_relative_to(NESTED_DIR)
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user