mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-19 06:54:22 +01:00
feat(agent): Fully abstracted file storage access with FileStorage (#6931)
* Rename `FileWorkspace` to `FileStorage` - `autogpt.file_workspace` -> `autogpt.file_storage` - `LocalFileWorkspace` -> `LocalFileStorage` - `S3FileWorkspace` -> `S3FileStorage` - `GCSFileWorkspace` -> `GCSFileStorage` * Rename `WORKSPACE_BACKEND` to `FILE_STORAGE_BACKEND` * Rename `WORKSPACE_STORAGE_BUCKET` to `STORAGE_BUCKET` * Rewrite `AgentManager` to use `FileStorage` rather than direct local file access * Rename `AgentManager.retrieve_state(..)` method to `load_agent_state` * Add docstrings to `AgentManager` * Create `AgentFileManagerMixin` to replace `AgentFileManager`, `FileWorkspaceMixin`, `BaseAgent.attach_fs(..)` * Replace `BaseAgentSettings.save_to_json_file(..)` method by `AgentFileManagerMixin.save_state()` * Replace `BaseAgent.set_id(..)` method by `AgentFileManagerMixin.change_agent_id(..)` * Remove `BaseAgentSettings.load_from_json_file(..)` * Remove `AgentSettings.agent_data_dir` * Update `AgentProtocolServer` to work with the new `FileStorage` system and `AgentFileManagerMixin` * Make `agent_id` and `file_storage` parameters for creating an Agent: - `create_agent`, `configure_agent_with_state`, `_configure_agent`, `create_agent_state` in `autogpt.agent_factory.configurators` - `generate_agent_for_task` in `autogpt.agent_factory.generators` - `Agent.__init__(..)` - `BaseAgent.__init__(..)` - Initialize and pass in `file_storage` in `autogpt.app.main.run_auto_gpt(..)` and `autogpt.app.main.run_auto_gpt_server(..)` * Add `clone_with_subroot` to `FileStorage` * Add `exists`, `make_dir`, `delete_dir`, `rename`, `list_files`, `list_folders` methods to `FileStorage` * Update `autogpt.commands.file_operations` to use `FileStorage` and `AgentFileManagerMixin` features * Update tests for `FileStorage` implementations and usages * Rename `workspace` fixture to `storage` * Update conftest.py
This commit is contained in:
committed by
GitHub
parent
6c18627b0f
commit
37904a0f80
@@ -20,12 +20,12 @@ OPENAI_API_KEY=your-openai-api-key
|
||||
## DISABLED_COMMAND_CATEGORIES - The list of categories of commands that are disabled (Default: None)
|
||||
# DISABLED_COMMAND_CATEGORIES=
|
||||
|
||||
## WORKSPACE_BACKEND - Choose a storage backend for workspace contents
|
||||
## FILE_STORAGE_BACKEND - Choose a storage backend for contents
|
||||
## Options: local, gcs, s3
|
||||
# WORKSPACE_BACKEND=local
|
||||
# FILE_STORAGE_BACKEND=local
|
||||
|
||||
## WORKSPACE_STORAGE_BUCKET - GCS/S3 Bucket to store workspace contents in
|
||||
# WORKSPACE_STORAGE_BUCKET=autogpt
|
||||
## STORAGE_BUCKET - GCS/S3 Bucket to store contents in
|
||||
# STORAGE_BUCKET=autogpt
|
||||
|
||||
## GCS Credentials
|
||||
# see https://cloud.google.com/storage/docs/authentication#libauth
|
||||
|
||||
@@ -3,10 +3,12 @@ import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from autogpt.agent_manager.agent_manager import AgentManager
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.app.main import _configure_openai_provider, run_interaction_loop
|
||||
from autogpt.commands import COMMAND_CATEGORIES
|
||||
from autogpt.config import AIProfile, ConfigBuilder
|
||||
from autogpt.file_storage import FileStorageBackendName, get_storage
|
||||
from autogpt.logs.config import configure_logging
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
@@ -42,6 +44,7 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
agent_prompt_config.use_functions_api = config.openai_functions
|
||||
agent_settings = AgentSettings(
|
||||
name=Agent.default_settings.name,
|
||||
agent_id=AgentManager.generate_id("AutoGPT-benchmark"),
|
||||
description=Agent.default_settings.description,
|
||||
ai_profile=ai_profile,
|
||||
config=AgentConfiguration(
|
||||
@@ -55,13 +58,20 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
history=Agent.default_settings.history.copy(deep=True),
|
||||
)
|
||||
|
||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||
restrict_to_root = not local or config.restrict_to_workspace
|
||||
file_storage = get_storage(
|
||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
||||
)
|
||||
file_storage.initialize()
|
||||
|
||||
agent = Agent(
|
||||
settings=agent_settings,
|
||||
llm_provider=_configure_openai_provider(config),
|
||||
command_registry=command_registry,
|
||||
file_storage=file_storage,
|
||||
legacy_config=config,
|
||||
)
|
||||
agent.attach_fs(config.app_data_dir / "agents" / "AutoGPT-benchmark") # HACK
|
||||
return agent
|
||||
|
||||
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from autogpt.agent_manager import AgentManager
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.commands import COMMAND_CATEGORIES
|
||||
from autogpt.config import AIDirectives, AIProfile, Config
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.logs.config import configure_chat_plugins
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
from autogpt.plugins import scan_plugins
|
||||
|
||||
|
||||
def create_agent(
|
||||
agent_id: str,
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
app_config: Config,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
directives: Optional[AIDirectives] = None,
|
||||
) -> Agent:
|
||||
@@ -23,26 +25,28 @@ def create_agent(
|
||||
directives = AIDirectives.from_file(app_config.prompt_settings_file)
|
||||
|
||||
agent = _configure_agent(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=directives,
|
||||
app_config=app_config,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
agent.state.agent_id = AgentManager.generate_id(agent.ai_profile.ai_name)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def configure_agent_with_state(
|
||||
state: AgentSettings,
|
||||
app_config: Config,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
) -> Agent:
|
||||
return _configure_agent(
|
||||
state=state,
|
||||
app_config=app_config,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
@@ -50,14 +54,17 @@ def configure_agent_with_state(
|
||||
def _configure_agent(
|
||||
app_config: Config,
|
||||
llm_provider: ChatModelProvider,
|
||||
file_storage: FileStorage,
|
||||
agent_id: str = "",
|
||||
task: str = "",
|
||||
ai_profile: Optional[AIProfile] = None,
|
||||
directives: Optional[AIDirectives] = None,
|
||||
state: Optional[AgentSettings] = None,
|
||||
) -> Agent:
|
||||
if not (state or task and ai_profile and directives):
|
||||
if not (state or agent_id and task and ai_profile and directives):
|
||||
raise TypeError(
|
||||
"Either (state) or (task, ai_profile, directives) must be specified"
|
||||
"Either (state) or (agent_id, task, ai_profile, directives)"
|
||||
" must be specified"
|
||||
)
|
||||
|
||||
app_config.plugins = scan_plugins(app_config)
|
||||
@@ -70,6 +77,7 @@ def _configure_agent(
|
||||
)
|
||||
|
||||
agent_state = state or create_agent_state(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=directives,
|
||||
@@ -82,11 +90,13 @@ def _configure_agent(
|
||||
settings=agent_state,
|
||||
llm_provider=llm_provider,
|
||||
command_registry=command_registry,
|
||||
file_storage=file_storage,
|
||||
legacy_config=app_config,
|
||||
)
|
||||
|
||||
|
||||
def create_agent_state(
|
||||
agent_id: str,
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
directives: AIDirectives,
|
||||
@@ -96,6 +106,7 @@ def create_agent_state(
|
||||
agent_prompt_config.use_functions_api = app_config.openai_functions
|
||||
|
||||
return AgentSettings(
|
||||
agent_id=agent_id,
|
||||
name=Agent.default_settings.name,
|
||||
description=Agent.default_settings.description,
|
||||
task=task,
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt.config.ai_directives import AIDirectives
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
|
||||
from .configurators import _configure_agent
|
||||
from .profile_generator import generate_agent_profile_for_task
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.resource.model_providers.schema import ChatModelProvider
|
||||
|
||||
from autogpt.config.ai_directives import AIDirectives
|
||||
|
||||
from .configurators import _configure_agent
|
||||
from .profile_generator import generate_agent_profile_for_task
|
||||
|
||||
|
||||
async def generate_agent_for_task(
|
||||
agent_id: str,
|
||||
task: str,
|
||||
app_config: "Config",
|
||||
llm_provider: "ChatModelProvider",
|
||||
) -> "Agent":
|
||||
app_config: Config,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
) -> Agent:
|
||||
base_directives = AIDirectives.from_file(app_config.prompt_settings_file)
|
||||
ai_profile, task_directives = await generate_agent_profile_for_task(
|
||||
task=task,
|
||||
@@ -23,9 +28,11 @@ async def generate_agent_for_task(
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
return _configure_agent(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=base_directives + task_directives,
|
||||
app_config=app_config,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
@@ -2,47 +2,44 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import AgentSettings
|
||||
|
||||
from autogpt.agents.utils.agent_file_manager import AgentFileManager
|
||||
from autogpt.agents.agent import AgentSettings
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
|
||||
|
||||
class AgentManager:
|
||||
def __init__(self, app_data_dir: Path):
|
||||
self.agents_dir = app_data_dir / "agents"
|
||||
if not self.agents_dir.exists():
|
||||
self.agents_dir.mkdir()
|
||||
def __init__(self, file_storage: FileStorage):
|
||||
self.file_manager = file_storage.clone_with_subroot("agents")
|
||||
|
||||
@staticmethod
|
||||
def generate_id(agent_name: str) -> str:
|
||||
"""Generate a unique ID for an agent given agent name."""
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
return f"{agent_name}-{unique_id}"
|
||||
|
||||
def list_agents(self) -> list[str]:
|
||||
return [
|
||||
dir.name
|
||||
for dir in self.agents_dir.iterdir()
|
||||
if dir.is_dir() and AgentFileManager(dir).state_file_path.exists()
|
||||
]
|
||||
"""Return all agent directories within storage."""
|
||||
agent_dirs: list[str] = []
|
||||
for dir in self.file_manager.list_folders():
|
||||
if self.file_manager.exists(dir / "state.json"):
|
||||
agent_dirs.append(dir.name)
|
||||
return agent_dirs
|
||||
|
||||
def get_agent_dir(self, agent_id: str, must_exist: bool = False) -> Path:
|
||||
def get_agent_dir(self, agent_id: str) -> Path:
|
||||
"""Return the directory of the agent with the given ID."""
|
||||
assert len(agent_id) > 0
|
||||
agent_dir = self.agents_dir / agent_id
|
||||
if must_exist and not agent_dir.exists():
|
||||
agent_dir: Path | None = None
|
||||
if self.file_manager.exists(agent_id):
|
||||
agent_dir = self.file_manager.root / agent_id
|
||||
else:
|
||||
raise FileNotFoundError(f"No agent with ID '{agent_id}'")
|
||||
return agent_dir
|
||||
|
||||
def retrieve_state(self, agent_id: str) -> AgentSettings:
|
||||
from autogpt.agents.agent import AgentSettings
|
||||
|
||||
agent_dir = self.get_agent_dir(agent_id, True)
|
||||
state_file = AgentFileManager(agent_dir).state_file_path
|
||||
if not state_file.exists():
|
||||
def load_agent_state(self, agent_id: str) -> AgentSettings:
|
||||
"""Load the state of the agent with the given ID."""
|
||||
state_file_path = Path(agent_id) / "state.json"
|
||||
if not self.file_manager.exists(state_file_path):
|
||||
raise FileNotFoundError(f"Agent with ID '{agent_id}' has no state.json")
|
||||
|
||||
state = AgentSettings.load_from_json_file(state_file)
|
||||
state.agent_data_dir = agent_dir
|
||||
return state
|
||||
state = self.file_manager.read_file(state_file_path)
|
||||
return AgentSettings.parse_raw(state)
|
||||
|
||||
@@ -6,10 +6,6 @@ import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
import sentry_sdk
|
||||
from pydantic import Field
|
||||
|
||||
@@ -20,6 +16,7 @@ from autogpt.core.resource.model_providers import (
|
||||
ChatMessage,
|
||||
ChatModelProvider,
|
||||
)
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
from autogpt.logs.log_cycle import (
|
||||
CURRENT_CONTEXT_FILE_NAME,
|
||||
@@ -39,8 +36,8 @@ from autogpt.models.command import CommandOutput
|
||||
from autogpt.models.context_item import ContextItem
|
||||
|
||||
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
from .features.agent_file_manager import AgentFileManagerMixin
|
||||
from .features.context import ContextMixin
|
||||
from .features.file_workspace import FileWorkspaceMixin
|
||||
from .features.watchdog import WatchdogMixin
|
||||
from .prompt_strategies.one_shot import (
|
||||
OneShotAgentPromptConfiguration,
|
||||
@@ -54,6 +51,10 @@ from .utils.exceptions import (
|
||||
UnknownCommandError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +73,7 @@ class AgentSettings(BaseAgentSettings):
|
||||
|
||||
class Agent(
|
||||
ContextMixin,
|
||||
FileWorkspaceMixin,
|
||||
AgentFileManagerMixin,
|
||||
WatchdogMixin,
|
||||
BaseAgent,
|
||||
Configurable[AgentSettings],
|
||||
@@ -91,6 +92,7 @@ class Agent(
|
||||
settings: AgentSettings,
|
||||
llm_provider: ChatModelProvider,
|
||||
command_registry: CommandRegistry,
|
||||
file_storage: FileStorage,
|
||||
legacy_config: Config,
|
||||
):
|
||||
prompt_strategy = OneShotAgentPromptStrategy(
|
||||
@@ -102,6 +104,7 @@ class Agent(
|
||||
llm_provider=llm_provider,
|
||||
prompt_strategy=prompt_strategy,
|
||||
command_registry=command_registry,
|
||||
file_storage=file_storage,
|
||||
legacy_config=legacy_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
@@ -39,12 +38,11 @@ from autogpt.core.resource.model_providers.openai import (
|
||||
OpenAIModelName,
|
||||
)
|
||||
from autogpt.core.runner.client_lib.logging.helpers import dump_prompt
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.llm.providers.openai import get_openai_command_specs
|
||||
from autogpt.models.action_history import ActionResult, EpisodicActionHistory
|
||||
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT
|
||||
|
||||
from .utils.agent_file_manager import AgentFileManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CommandName = str
|
||||
@@ -126,7 +124,6 @@ class BaseAgentConfiguration(SystemConfiguration):
|
||||
|
||||
class BaseAgentSettings(SystemSettings):
|
||||
agent_id: str = ""
|
||||
agent_data_dir: Optional[Path] = None
|
||||
|
||||
ai_profile: AIProfile = Field(default_factory=lambda: AIProfile(ai_name="AutoGPT"))
|
||||
"""The AI profile or "personality" of the agent."""
|
||||
@@ -147,14 +144,6 @@ class BaseAgentSettings(SystemSettings):
|
||||
history: EpisodicActionHistory = Field(default_factory=EpisodicActionHistory)
|
||||
"""(STATE) The action history of the agent."""
|
||||
|
||||
def save_to_json_file(self, file_path: Path) -> None:
|
||||
with file_path.open("w") as f:
|
||||
f.write(self.json())
|
||||
|
||||
@classmethod
|
||||
def load_from_json_file(cls, file_path: Path):
|
||||
return cls.parse_file(file_path)
|
||||
|
||||
|
||||
class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
||||
"""Base class for all AutoGPT agent classes."""
|
||||
@@ -172,6 +161,7 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
||||
llm_provider: ChatModelProvider,
|
||||
prompt_strategy: PromptStrategy,
|
||||
command_registry: CommandRegistry,
|
||||
file_storage: FileStorage,
|
||||
legacy_config: Config,
|
||||
):
|
||||
self.state = settings
|
||||
@@ -183,12 +173,6 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
||||
self.legacy_config = legacy_config
|
||||
"""LEGACY: Monolithic application configuration."""
|
||||
|
||||
self.file_manager: AgentFileManager = (
|
||||
AgentFileManager(settings.agent_data_dir)
|
||||
if settings.agent_data_dir
|
||||
else None
|
||||
) # type: ignore
|
||||
|
||||
self.llm_provider = llm_provider
|
||||
|
||||
self.prompt_strategy = prompt_strategy
|
||||
@@ -203,21 +187,6 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
||||
|
||||
logger.debug(f"Created {__class__} '{self.ai_profile.ai_name}'")
|
||||
|
||||
def set_id(self, new_id: str, new_agent_dir: Optional[Path] = None):
|
||||
self.state.agent_id = new_id
|
||||
if self.state.agent_data_dir:
|
||||
if not new_agent_dir:
|
||||
raise ValueError(
|
||||
"new_agent_dir must be specified if one is currently configured"
|
||||
)
|
||||
self.attach_fs(new_agent_dir)
|
||||
|
||||
def attach_fs(self, agent_dir: Path) -> AgentFileManager:
|
||||
self.file_manager = AgentFileManager(agent_dir)
|
||||
self.file_manager.initialize()
|
||||
self.state.agent_data_dir = agent_dir
|
||||
return self.file_manager
|
||||
|
||||
@property
|
||||
def llm(self) -> ChatModelInfo:
|
||||
"""The LLM that the agent uses to think."""
|
||||
@@ -236,10 +205,6 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
||||
Returns:
|
||||
The command name and arguments, if any, and the agent's thoughts.
|
||||
"""
|
||||
assert self.file_manager, (
|
||||
f"Agent has no FileManager: call {__class__.__name__}.attach_fs()"
|
||||
" before trying to run the agent."
|
||||
)
|
||||
|
||||
# Scratchpad as surrogate PromptGenerator for plugin hooks
|
||||
self._prompt_scratchpad = PromptScratchpad()
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
|
||||
from ..base import BaseAgent, BaseAgentSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentFileManagerMixin:
|
||||
"""Mixin that adds file manager (e.g. Agent state)
|
||||
and workspace manager (e.g. Agent output files) support."""
|
||||
|
||||
files: FileStorage = None
|
||||
"""Agent-related files, e.g. state, logs.
|
||||
Use `workspace` to access the agent's workspace files."""
|
||||
|
||||
workspace: FileStorage = None
|
||||
"""Workspace that the agent has access to, e.g. for reading/writing files.
|
||||
Use `files` to access agent-related files, e.g. state, logs."""
|
||||
|
||||
STATE_FILE = "state.json"
|
||||
"""The name of the file where the agent's state is stored."""
|
||||
|
||||
LOGS_FILE = "file_logger.log"
|
||||
"""The name of the file where the agent's logs are stored."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Initialize other bases first, because we need the config from BaseAgent
|
||||
super(AgentFileManagerMixin, self).__init__(**kwargs)
|
||||
|
||||
if not isinstance(self, BaseAgent):
|
||||
raise NotImplementedError(
|
||||
f"{__class__.__name__} can only be applied to BaseAgent derivatives"
|
||||
)
|
||||
|
||||
if "file_storage" not in kwargs:
|
||||
raise ValueError(
|
||||
"AgentFileManagerMixin requires a file_storage in the constructor."
|
||||
)
|
||||
|
||||
state: BaseAgentSettings = getattr(self, "state")
|
||||
if not state.agent_id:
|
||||
raise ValueError("Agent must have an ID.")
|
||||
|
||||
file_storage: FileStorage = kwargs["file_storage"]
|
||||
self.files = file_storage.clone_with_subroot(f"agents/{state.agent_id}/")
|
||||
self.workspace = file_storage.clone_with_subroot(
|
||||
f"agents/{state.agent_id}/workspace"
|
||||
)
|
||||
self._file_storage = file_storage
|
||||
# Read and cache logs
|
||||
self._file_logs_cache = []
|
||||
if self.files.exists(self.LOGS_FILE):
|
||||
self._file_logs_cache = self.files.read_file(self.LOGS_FILE).split("\n")
|
||||
|
||||
async def log_file_operation(self, content: str) -> None:
|
||||
"""Log a file operation to the agent's log file."""
|
||||
logger.debug(f"Logging operation: {content}")
|
||||
self._file_logs_cache.append(content)
|
||||
await self.files.write_file(
|
||||
self.LOGS_FILE, "\n".join(self._file_logs_cache) + "\n"
|
||||
)
|
||||
|
||||
def get_file_operation_lines(self) -> list[str]:
|
||||
"""Get the agent's file operation logs as list of strings."""
|
||||
return self._file_logs_cache
|
||||
|
||||
async def save_state(self) -> None:
|
||||
"""Save the agent's state to the state file."""
|
||||
state: BaseAgentSettings = getattr(self, "state")
|
||||
await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
|
||||
|
||||
def change_agent_id(self, new_id: str):
|
||||
"""Change the agent's ID and update the file storage accordingly."""
|
||||
state: BaseAgentSettings = getattr(self, "state")
|
||||
# Rename the agent's files and workspace
|
||||
self._file_storage.rename(f"agents/{state.agent_id}", f"agents/{new_id}")
|
||||
# Update the file storage objects
|
||||
self.files = self._file_storage.clone_with_subroot(f"agents/{new_id}/")
|
||||
self.workspace = self._file_storage.clone_with_subroot(
|
||||
f"agents/{new_id}/workspace"
|
||||
)
|
||||
state.agent_id = new_id
|
||||
@@ -1,65 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from ..base import BaseAgent, Config
|
||||
|
||||
from autogpt.file_workspace import (
|
||||
FileWorkspace,
|
||||
FileWorkspaceBackendName,
|
||||
get_workspace,
|
||||
)
|
||||
|
||||
from ..base import AgentFileManager, BaseAgentSettings
|
||||
|
||||
|
||||
class FileWorkspaceMixin:
|
||||
"""Mixin that adds workspace support to a class"""
|
||||
|
||||
workspace: FileWorkspace = None
|
||||
"""Workspace that the agent has access to, e.g. for reading/writing files."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Initialize other bases first, because we need the config from BaseAgent
|
||||
super(FileWorkspaceMixin, self).__init__(**kwargs)
|
||||
|
||||
file_manager: AgentFileManager = getattr(self, "file_manager")
|
||||
if not file_manager:
|
||||
return
|
||||
|
||||
self._setup_workspace()
|
||||
|
||||
def attach_fs(self, agent_dir: Path):
|
||||
res = super(FileWorkspaceMixin, self).attach_fs(agent_dir)
|
||||
|
||||
self._setup_workspace()
|
||||
|
||||
return res
|
||||
|
||||
def _setup_workspace(self) -> None:
|
||||
settings: BaseAgentSettings = getattr(self, "state")
|
||||
assert settings.agent_id, "Cannot attach workspace to anonymous agent"
|
||||
app_config: Config = getattr(self, "legacy_config")
|
||||
file_manager: AgentFileManager = getattr(self, "file_manager")
|
||||
|
||||
ws_backend = app_config.workspace_backend
|
||||
local = ws_backend == FileWorkspaceBackendName.LOCAL
|
||||
workspace = get_workspace(
|
||||
backend=ws_backend,
|
||||
id=settings.agent_id if not local else "",
|
||||
root_path=file_manager.root / "workspace" if local else None,
|
||||
)
|
||||
if local and settings.config.allow_fs_access:
|
||||
workspace._restrict_to_root = False # type: ignore
|
||||
workspace.initialize()
|
||||
self.workspace = workspace
|
||||
|
||||
|
||||
def get_agent_workspace(agent: BaseAgent) -> FileWorkspace | None:
|
||||
if isinstance(agent, FileWorkspaceMixin):
|
||||
return agent.workspace
|
||||
|
||||
return None
|
||||
@@ -1,37 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentFileManager:
|
||||
"""A class that represents a workspace for an AutoGPT agent."""
|
||||
|
||||
def __init__(self, agent_data_dir: Path):
|
||||
self._root = agent_data_dir.resolve()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the workspace."""
|
||||
return self._root
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
self.init_file_ops_log(self.file_ops_log_path)
|
||||
|
||||
@property
|
||||
def state_file_path(self) -> Path:
|
||||
return self.root / "state.json"
|
||||
|
||||
@property
|
||||
def file_ops_log_path(self) -> Path:
|
||||
return self.root / "file_logger.log"
|
||||
|
||||
@staticmethod
|
||||
def init_file_ops_log(file_logger_path: Path) -> Path:
|
||||
if not file_logger_path.exists():
|
||||
with file_logger_path.open(mode="w", encoding="utf-8") as f:
|
||||
f.write("")
|
||||
return file_logger_path
|
||||
@@ -36,11 +36,7 @@ from autogpt.config import Config
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIProvider
|
||||
from autogpt.core.resource.model_providers.schema import ModelProviderBudget
|
||||
from autogpt.file_workspace import (
|
||||
FileWorkspace,
|
||||
FileWorkspaceBackendName,
|
||||
get_workspace,
|
||||
)
|
||||
from autogpt.file_storage import FileStorage
|
||||
from autogpt.logs.utils import fmt_kwargs
|
||||
from autogpt.models.action_history import ActionErrorResult, ActionSuccessResult
|
||||
|
||||
@@ -54,12 +50,14 @@ class AgentProtocolServer:
|
||||
self,
|
||||
app_config: Config,
|
||||
database: AgentDB,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
):
|
||||
self.app_config = app_config
|
||||
self.db = database
|
||||
self.file_storage = file_storage
|
||||
self.llm_provider = llm_provider
|
||||
self.agent_manager = AgentManager(app_data_dir=app_config.app_data_dir)
|
||||
self.agent_manager = AgentManager(file_storage)
|
||||
self._task_budgets = {}
|
||||
|
||||
async def start(self, port: int = 8000, router: APIRouter = base_router):
|
||||
@@ -134,16 +132,13 @@ class AgentProtocolServer:
|
||||
)
|
||||
logger.debug(f"Creating agent for task: '{task.input}'")
|
||||
task_agent = await generate_agent_for_task(
|
||||
agent_id=task_agent_id(task.task_id),
|
||||
task=task.input,
|
||||
app_config=self.app_config,
|
||||
file_storage=self.file_storage,
|
||||
llm_provider=self._get_task_llm_provider(task),
|
||||
)
|
||||
|
||||
# Assign an ID and a folder to the Agent and persist it
|
||||
agent_id = task_agent.state.agent_id = task_agent_id(task.task_id)
|
||||
logger.debug(f"New agent ID: {agent_id}")
|
||||
task_agent.attach_fs(self.app_config.app_data_dir / "agents" / agent_id)
|
||||
task_agent.state.save_to_json_file(task_agent.file_manager.state_file_path)
|
||||
await task_agent.save_state()
|
||||
|
||||
return task
|
||||
|
||||
@@ -182,8 +177,9 @@ class AgentProtocolServer:
|
||||
# Restore Agent instance
|
||||
task = await self.get_task(task_id)
|
||||
agent = configure_agent_with_state(
|
||||
state=self.agent_manager.retrieve_state(task_agent_id(task_id)),
|
||||
state=self.agent_manager.load_agent_state(task_agent_id(task_id)),
|
||||
app_config=self.app_config,
|
||||
file_storage=self.file_storage,
|
||||
llm_provider=self._get_task_llm_provider(task),
|
||||
)
|
||||
|
||||
@@ -346,7 +342,7 @@ class AgentProtocolServer:
|
||||
additional_output=additional_output,
|
||||
)
|
||||
|
||||
agent.state.save_to_json_file(agent.file_manager.state_file_path)
|
||||
await agent.save_state()
|
||||
return step
|
||||
|
||||
async def _on_agent_write_file(
|
||||
@@ -405,7 +401,7 @@ class AgentProtocolServer:
|
||||
else:
|
||||
file_path = os.path.join(relative_path, file_name)
|
||||
|
||||
workspace = self._get_task_agent_file_workspace(task_id, self.agent_manager)
|
||||
workspace = self._get_task_agent_file_workspace(task_id)
|
||||
await workspace.write_file(file_path, data)
|
||||
|
||||
artifact = await self.db.create_artifact(
|
||||
@@ -421,12 +417,12 @@ class AgentProtocolServer:
|
||||
Download a task artifact by ID.
|
||||
"""
|
||||
try:
|
||||
workspace = self._get_task_agent_file_workspace(task_id)
|
||||
artifact = await self.db.get_artifact(artifact_id)
|
||||
if artifact.file_name not in artifact.relative_path:
|
||||
file_path = os.path.join(artifact.relative_path, artifact.file_name)
|
||||
else:
|
||||
file_path = artifact.relative_path
|
||||
workspace = self._get_task_agent_file_workspace(task_id, self.agent_manager)
|
||||
retrieved_artifact = workspace.read_file(file_path, binary=True)
|
||||
except NotFoundError:
|
||||
raise
|
||||
@@ -441,28 +437,9 @@ class AgentProtocolServer:
|
||||
},
|
||||
)
|
||||
|
||||
def _get_task_agent_file_workspace(
|
||||
self,
|
||||
task_id: str | int,
|
||||
agent_manager: AgentManager,
|
||||
) -> FileWorkspace:
|
||||
use_local_ws = (
|
||||
self.app_config.workspace_backend == FileWorkspaceBackendName.LOCAL
|
||||
)
|
||||
def _get_task_agent_file_workspace(self, task_id: str | int) -> FileStorage:
|
||||
agent_id = task_agent_id(task_id)
|
||||
workspace = get_workspace(
|
||||
backend=self.app_config.workspace_backend,
|
||||
id=agent_id if not use_local_ws else "",
|
||||
root_path=agent_manager.get_agent_dir(
|
||||
agent_id=agent_id,
|
||||
must_exist=True,
|
||||
)
|
||||
/ "workspace"
|
||||
if use_local_ws
|
||||
else None,
|
||||
)
|
||||
workspace.initialize()
|
||||
return workspace
|
||||
return self.file_storage.clone_with_subroot(f"agents/{agent_id}/workspace")
|
||||
|
||||
def _get_task_llm_provider(
|
||||
self, task: Task, step_id: str = ""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
The application entry point. Can be invoked by a CLI or any other front end application.
|
||||
"""
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
@@ -15,6 +16,8 @@ from typing import TYPE_CHECKING, Optional
|
||||
from colorama import Fore, Style
|
||||
from forge.sdk.db import AgentDB
|
||||
|
||||
from autogpt.file_storage import FileStorageBackendName, get_storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
|
||||
@@ -76,7 +79,15 @@ async def run_auto_gpt(
|
||||
best_practices: Optional[list[str]] = None,
|
||||
override_directives: bool = False,
|
||||
):
|
||||
# Set up configuration
|
||||
config = ConfigBuilder.build_config_from_env()
|
||||
# Storage
|
||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||
restrict_to_root = not local or config.restrict_to_workspace
|
||||
file_storage = get_storage(
|
||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
||||
)
|
||||
file_storage.initialize()
|
||||
|
||||
# TODO: fill in llm values here
|
||||
assert_config_has_openai_api_key(config)
|
||||
@@ -148,7 +159,7 @@ async def run_auto_gpt(
|
||||
configure_chat_plugins(config)
|
||||
|
||||
# Let user choose an existing agent to run
|
||||
agent_manager = AgentManager(config.app_data_dir)
|
||||
agent_manager = AgentManager(file_storage)
|
||||
existing_agents = agent_manager.list_agents()
|
||||
load_existing_agent = ""
|
||||
if existing_agents:
|
||||
@@ -179,7 +190,7 @@ async def run_auto_gpt(
|
||||
# Resume an Existing Agent #
|
||||
############################
|
||||
if load_existing_agent:
|
||||
agent_state = agent_manager.retrieve_state(load_existing_agent)
|
||||
agent_state = agent_manager.load_agent_state(load_existing_agent)
|
||||
while True:
|
||||
answer = await clean_input(config, "Resume? [Y/n]")
|
||||
if answer.lower() == "y":
|
||||
@@ -194,6 +205,7 @@ async def run_auto_gpt(
|
||||
agent = configure_agent_with_state(
|
||||
state=agent_state,
|
||||
app_config=config,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
apply_overrides_to_ai_settings(
|
||||
@@ -274,13 +286,14 @@ async def run_auto_gpt(
|
||||
logger.info("AI config overrides specified through CLI; skipping revision")
|
||||
|
||||
agent = create_agent(
|
||||
agent_id=agent_manager.generate_id(ai_profile.ai_name),
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=ai_directives,
|
||||
app_config=config,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
agent.attach_fs(agent_manager.get_agent_dir(agent.state.agent_id))
|
||||
|
||||
if not agent.config.allow_fs_access:
|
||||
logger.info(
|
||||
@@ -309,14 +322,10 @@ async def run_auto_gpt(
|
||||
or agent_id
|
||||
)
|
||||
if save_as_id and save_as_id != agent_id:
|
||||
agent.set_id(
|
||||
new_id=save_as_id,
|
||||
new_agent_dir=agent_manager.get_agent_dir(save_as_id),
|
||||
)
|
||||
# TODO: clone workspace if user wants that
|
||||
# TODO: ... OR allow many-to-one relations of agents and workspaces
|
||||
agent.change_agent_id(save_as_id)
|
||||
# TODO: allow many-to-one relations of agents and workspaces
|
||||
|
||||
agent.state.save_to_json_file(agent.file_manager.state_file_path)
|
||||
await agent.save_state()
|
||||
|
||||
|
||||
@coroutine
|
||||
@@ -336,6 +345,14 @@ async def run_auto_gpt_server(
|
||||
|
||||
config = ConfigBuilder.build_config_from_env()
|
||||
|
||||
# Storage
|
||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||
restrict_to_root = not local or config.restrict_to_workspace
|
||||
file_storage = get_storage(
|
||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
||||
)
|
||||
file_storage.initialize()
|
||||
|
||||
# TODO: fill in llm values here
|
||||
assert_config_has_openai_api_key(config)
|
||||
|
||||
@@ -372,7 +389,10 @@ async def run_auto_gpt_server(
|
||||
)
|
||||
port: int = int(os.getenv("AP_SERVER_PORT", default=8000))
|
||||
server = AgentProtocolServer(
|
||||
app_config=config, database=database, llm_provider=llm_provider
|
||||
app_config=config,
|
||||
database=database,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
await server.start(port=port)
|
||||
|
||||
|
||||
@@ -35,17 +35,12 @@ def text_checksum(text: str) -> str:
|
||||
|
||||
|
||||
def operations_from_log(
|
||||
log_path: str | Path,
|
||||
logs: list[str],
|
||||
) -> Iterator[
|
||||
tuple[Literal["write", "append"], str, str] | tuple[Literal["delete"], str, None]
|
||||
]:
|
||||
"""Parse the file operations log and return a tuple containing the log entries"""
|
||||
try:
|
||||
log = open(log_path, "r", encoding="utf-8")
|
||||
except FileNotFoundError:
|
||||
return
|
||||
|
||||
for line in log:
|
||||
"""Parse logs and return a tuple containing the log entries"""
|
||||
for line in logs:
|
||||
line = line.replace("File Operation Logger", "").strip()
|
||||
if not line:
|
||||
continue
|
||||
@@ -57,14 +52,12 @@ def operations_from_log(
|
||||
elif operation == "delete":
|
||||
yield (operation, tail.strip(), None)
|
||||
|
||||
log.close()
|
||||
|
||||
def file_operations_state(logs: list[str]) -> dict[str, str]:
|
||||
"""Iterates over the operations and returns the expected state.
|
||||
|
||||
def file_operations_state(log_path: str | Path) -> dict[str, str]:
|
||||
"""Iterates over the operations log and returns the expected state.
|
||||
|
||||
Parses a log file at file_manager.file_ops_log_path to construct a dictionary
|
||||
that maps each file path written or appended to its checksum. Deleted files are
|
||||
Constructs a dictionary that maps each file path written
|
||||
or appended to its checksum. Deleted files are
|
||||
removed from the dictionary.
|
||||
|
||||
Returns:
|
||||
@@ -75,7 +68,7 @@ def file_operations_state(log_path: str | Path) -> dict[str, str]:
|
||||
ValueError: If the log file content is not in the expected format.
|
||||
"""
|
||||
state = {}
|
||||
for operation, path, checksum in operations_from_log(log_path):
|
||||
for operation, path, checksum in operations_from_log(logs):
|
||||
if operation in ("write", "append"):
|
||||
state[path] = checksum
|
||||
elif operation == "delete":
|
||||
@@ -98,7 +91,7 @@ def is_duplicate_operation(
|
||||
Returns:
|
||||
True if the operation has already been performed on the file
|
||||
"""
|
||||
state = file_operations_state(agent.file_manager.file_ops_log_path)
|
||||
state = file_operations_state(agent.get_file_operation_lines())
|
||||
if operation == "delete" and str(file_path) not in state:
|
||||
return True
|
||||
if operation == "write" and state.get(str(file_path)) == checksum:
|
||||
@@ -107,7 +100,7 @@ def is_duplicate_operation(
|
||||
|
||||
|
||||
@sanitize_path_arg("file_path", make_relative=True)
|
||||
def log_operation(
|
||||
async def log_operation(
|
||||
operation: Operation,
|
||||
file_path: str | Path,
|
||||
agent: Agent,
|
||||
@@ -124,9 +117,7 @@ def log_operation(
|
||||
if checksum is not None:
|
||||
log_entry += f" #{checksum}"
|
||||
logger.debug(f"Logging file operation: {log_entry}")
|
||||
append_to_file(
|
||||
agent.file_manager.file_ops_log_path, f"{log_entry}\n", agent, should_log=False
|
||||
)
|
||||
await agent.log_file_operation(f"{log_entry}")
|
||||
|
||||
|
||||
@command(
|
||||
@@ -218,33 +209,12 @@ async def write_to_file(filename: str | Path, contents: str, agent: Agent) -> st
|
||||
raise DuplicateOperationError(f"File {filename} has already been updated.")
|
||||
|
||||
if directory := os.path.dirname(filename):
|
||||
agent.workspace.get_path(directory).mkdir(exist_ok=True)
|
||||
agent.workspace.make_dir(directory)
|
||||
await agent.workspace.write_file(filename, contents)
|
||||
log_operation("write", filename, agent, checksum)
|
||||
await log_operation("write", filename, agent, checksum)
|
||||
return f"File {filename} has been written successfully."
|
||||
|
||||
|
||||
def append_to_file(
|
||||
filename: Path, text: str, agent: Agent, should_log: bool = True
|
||||
) -> None:
|
||||
"""Append text to a file
|
||||
|
||||
Args:
|
||||
filename (Path): The name of the file to append to
|
||||
text (str): The text to append to the file
|
||||
should_log (bool): Should log output
|
||||
"""
|
||||
directory = os.path.dirname(filename)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
with open(filename, "a") as f:
|
||||
f.write(text)
|
||||
|
||||
if should_log:
|
||||
with open(filename, "r") as f:
|
||||
checksum = text_checksum(f.read())
|
||||
log_operation("append", filename, agent, checksum=checksum)
|
||||
|
||||
|
||||
@command(
|
||||
"list_folder",
|
||||
"List the items in a folder",
|
||||
@@ -265,4 +235,4 @@ def list_folder(folder: str | Path, agent: Agent) -> list[str]:
|
||||
Returns:
|
||||
list[str]: A list of files found in the folder
|
||||
"""
|
||||
return [str(p) for p in agent.workspace.list(folder)]
|
||||
return [str(p) for p in agent.workspace.list_files(folder)]
|
||||
|
||||
@@ -20,7 +20,7 @@ from autogpt.core.resource.model_providers.openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OpenAICredentials,
|
||||
)
|
||||
from autogpt.file_workspace import FileWorkspaceBackendName
|
||||
from autogpt.file_storage import FileStorageBackendName
|
||||
from autogpt.logs.config import LoggingConfig
|
||||
from autogpt.plugins.plugins_config import PluginsConfig
|
||||
from autogpt.speech import TTSConfig
|
||||
@@ -57,11 +57,11 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
tts_config: TTSConfig = TTSConfig()
|
||||
logging: LoggingConfig = LoggingConfig()
|
||||
|
||||
# Workspace
|
||||
workspace_backend: FileWorkspaceBackendName = UserConfigurable(
|
||||
default=FileWorkspaceBackendName.LOCAL,
|
||||
from_env=lambda: FileWorkspaceBackendName(v)
|
||||
if (v := os.getenv("WORKSPACE_BACKEND"))
|
||||
# File storage
|
||||
file_storage_backend: FileStorageBackendName = UserConfigurable(
|
||||
default=FileStorageBackendName.LOCAL,
|
||||
from_env=lambda: FileStorageBackendName(v)
|
||||
if (v := os.getenv("FILE_STORAGE_BACKEND"))
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -89,9 +89,11 @@ def bootstrap_agent(task, continuous_mode) -> Agent:
|
||||
ai_role="a multi-purpose AI assistant.",
|
||||
ai_goals=[task],
|
||||
)
|
||||
# FIXME this won't work - ai_profile and triggering_prompt is not a valid argument,
|
||||
# lacks file_storage, settings and llm_provider
|
||||
return Agent(
|
||||
command_registry=command_registry,
|
||||
ai_profile=ai_profile,
|
||||
config=config,
|
||||
legacy_config=config,
|
||||
triggering_prompt=DEFAULT_TRIGGERING_PROMPT,
|
||||
)
|
||||
|
||||
44
autogpts/autogpt/autogpt/file_storage/__init__.py
Normal file
44
autogpts/autogpt/autogpt/file_storage/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import enum
|
||||
from pathlib import Path
|
||||
|
||||
from .base import FileStorage
|
||||
|
||||
|
||||
class FileStorageBackendName(str, enum.Enum):
|
||||
LOCAL = "local"
|
||||
GCS = "gcs"
|
||||
S3 = "s3"
|
||||
|
||||
|
||||
def get_storage(
|
||||
backend: FileStorageBackendName,
|
||||
root_path: Path = ".",
|
||||
restrict_to_root: bool = True,
|
||||
) -> FileStorage:
|
||||
match backend:
|
||||
case FileStorageBackendName.LOCAL:
|
||||
from .local import FileStorageConfiguration, LocalFileStorage
|
||||
|
||||
config = FileStorageConfiguration.from_env()
|
||||
config.root = root_path
|
||||
config.restrict_to_root = restrict_to_root
|
||||
return LocalFileStorage(config)
|
||||
case FileStorageBackendName.S3:
|
||||
from .s3 import S3FileStorage, S3FileStorageConfiguration
|
||||
|
||||
config = S3FileStorageConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return S3FileStorage(config)
|
||||
case FileStorageBackendName.GCS:
|
||||
from .gcs import GCSFileStorage, GCSFileStorageConfiguration
|
||||
|
||||
config = GCSFileStorageConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return GCSFileStorage(config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FileStorage",
|
||||
"FileStorageBackendName",
|
||||
"get_storage",
|
||||
]
|
||||
200
autogpts/autogpt/autogpt/file_storage/base.py
Normal file
200
autogpts/autogpt/autogpt/file_storage/base.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
The FileStorage class provides an interface for interacting with a file storage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from io import IOBase, TextIOBase
|
||||
from pathlib import Path
|
||||
from typing import IO, Any, BinaryIO, Callable, Literal, TextIO, overload
|
||||
|
||||
from autogpt.core.configuration.schema import SystemConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStorageConfiguration(SystemConfiguration):
|
||||
restrict_to_root: bool = True
|
||||
root: Path = Path("/")
|
||||
|
||||
|
||||
class FileStorage(ABC):
|
||||
"""A class that represents a file storage."""
|
||||
|
||||
on_write_file: Callable[[Path], Any] | None = None
|
||||
"""
|
||||
Event hook, executed after writing a file.
|
||||
|
||||
Params:
|
||||
Path: The path of the file that was written, relative to the storage root.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def root(self) -> Path:
|
||||
"""The root path of the file storage."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict file access to within the storage's root path."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self) -> None:
|
||||
"""
|
||||
Calling `initialize()` should bring the storage to a ready-to-use state.
|
||||
For example, it can create the resource in which files will be stored, if it
|
||||
doesn't exist yet. E.g. a folder on disk, or an S3 Bucket.
|
||||
"""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["w", "r"] = "r",
|
||||
binary: Literal[False] = False,
|
||||
) -> TextIO | TextIOBase:
|
||||
"""Returns a readable text file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["w", "r"] = "r",
|
||||
binary: Literal[True] = True,
|
||||
) -> BinaryIO | IOBase:
|
||||
"""Returns a readable binary file-like object representing the file."""
|
||||
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IO | IOBase:
|
||||
"""Returns a readable file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
|
||||
"""Read a file in the storage as text."""
|
||||
...
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: Literal[True] = True) -> bytes:
|
||||
"""Read a file in the storage as binary."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List all folders in a directory in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
|
||||
@abstractmethod
|
||||
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
|
||||
"""Create a new FileStorage with a subroot of the current storage."""
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
"""Get the full path for an item in the storage.
|
||||
|
||||
Parameters:
|
||||
relative_path: The relative path to resolve in the storage.
|
||||
|
||||
Returns:
|
||||
Path: The resolved path relative to the storage.
|
||||
"""
|
||||
return self._sanitize_path(relative_path)
|
||||
|
||||
def _sanitize_path(
|
||||
self,
|
||||
path: str | Path,
|
||||
) -> Path:
|
||||
"""Resolve the relative path within the given root if possible.
|
||||
|
||||
Parameters:
|
||||
relative_path: The relative path to resolve.
|
||||
|
||||
Returns:
|
||||
Path: The resolved path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path is absolute and a root is provided.
|
||||
ValueError: If the path is outside the root and the root is restricted.
|
||||
"""
|
||||
|
||||
# Posix systems disallow null bytes in paths. Windows is agnostic about it.
|
||||
# Do an explicit check here for all sorts of null byte representations.
|
||||
if "\0" in str(path):
|
||||
raise ValueError("Embedded null byte")
|
||||
|
||||
logger.debug(f"Resolving path '{path}' in storage '{self.root}'")
|
||||
|
||||
relative_path = Path(path)
|
||||
|
||||
# Allow absolute paths if they are contained in the storage.
|
||||
if (
|
||||
relative_path.is_absolute()
|
||||
and self.restrict_to_root
|
||||
and not relative_path.is_relative_to(self.root)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Attempted to access absolute path '{relative_path}' "
|
||||
f"in storage '{self.root}'"
|
||||
)
|
||||
|
||||
full_path = self.root / relative_path
|
||||
if self.is_local:
|
||||
full_path = full_path.resolve()
|
||||
else:
|
||||
full_path = Path(os.path.normpath(full_path))
|
||||
|
||||
logger.debug(f"Joined paths as '{full_path}'")
|
||||
|
||||
if self.restrict_to_root and not full_path.is_relative_to(self.root):
|
||||
raise ValueError(
|
||||
f"Attempted to access path '{full_path}' "
|
||||
f"outside of storage '{self.root}'."
|
||||
)
|
||||
|
||||
return full_path
|
||||
196
autogpts/autogpt/autogpt/file_storage/gcs.py
Normal file
196
autogpts/autogpt/autogpt/file_storage/gcs.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
The GCSWorkspace class provides an interface for interacting with a file workspace, and
|
||||
stores the files in a Google Cloud Storage bucket.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from io import IOBase
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
|
||||
from autogpt.core.configuration.schema import UserConfigurable
|
||||
|
||||
from .base import FileStorage, FileStorageConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GCSFileStorageConfiguration(FileStorageConfiguration):
|
||||
bucket: str = UserConfigurable("autogpt", from_env="STORAGE_BUCKET")
|
||||
|
||||
|
||||
class GCSFileStorage(FileStorage):
|
||||
"""A class that represents a Google Cloud Storage."""
|
||||
|
||||
_bucket: storage.Bucket
|
||||
|
||||
def __init__(self, config: GCSFileStorageConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
assert self._root.is_absolute()
|
||||
|
||||
self._gcs = storage.Client()
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file storage."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
return False
|
||||
|
||||
def initialize(self) -> None:
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._bucket = self._gcs.get_bucket(self._bucket_name)
|
||||
except NotFound:
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._gcs.create_bucket(self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
# We set GCS root with "/" at the beginning
|
||||
# but relative_to("/") will remove it
|
||||
# because we don't actually want it in the storage filenames
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def _get_blob(self, path: str | Path) -> storage.Blob:
|
||||
path = self.get_path(path)
|
||||
return self._bucket.blob(str(path))
|
||||
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IOBase:
|
||||
"""Open a file in the storage."""
|
||||
blob = self._get_blob(path)
|
||||
blob.reload() # pin revision number to prevent version mixing while reading
|
||||
return blob.open(f"{mode}b" if binary else mode)
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
return self.open_file(path, "r", binary).read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
blob = self._get_blob(path)
|
||||
|
||||
blob.upload_from_string(
|
||||
data=content,
|
||||
content_type=(
|
||||
"text/plain"
|
||||
if type(content) is str
|
||||
# TODO: get MIME type from file extension or binary content
|
||||
else "application/octet-stream"
|
||||
),
|
||||
)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
path = self.get_path(path)
|
||||
return [
|
||||
Path(blob.name).relative_to(path)
|
||||
for blob in self._bucket.list_blobs(
|
||||
prefix=f"{path}/" if path != Path(".") else None
|
||||
)
|
||||
]
|
||||
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List 'directories' directly in a given path or recursively in the storage."""
|
||||
path = self.get_path(path)
|
||||
folder_names = set()
|
||||
|
||||
# List objects with the specified prefix and delimiter
|
||||
for blob in self._bucket.list_blobs(prefix=path):
|
||||
# Remove path prefix and the object name (last part)
|
||||
folder = Path(blob.name).relative_to(path).parent
|
||||
if not folder or folder == Path("."):
|
||||
continue
|
||||
# For non-recursive, only add the first level of folders
|
||||
if not recursive:
|
||||
folder_names.add(folder.parts[0])
|
||||
else:
|
||||
# For recursive, need to add all nested folders
|
||||
for i in range(len(folder.parts)):
|
||||
folder_names.add("/".join(folder.parts[: i + 1]))
|
||||
|
||||
return [Path(f) for f in folder_names]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
path = self.get_path(path)
|
||||
blob = self._bucket.blob(str(path))
|
||||
blob.delete()
|
||||
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
# Since GCS does not have directories, we don't need to do anything
|
||||
pass
|
||||
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in GCS storage."""
|
||||
path = self.get_path(path)
|
||||
# Check for exact blob match (file)
|
||||
blob = self._bucket.blob(str(path))
|
||||
if blob.exists():
|
||||
return True
|
||||
# Check for any blobs with prefix (folder)
|
||||
prefix = f"{str(path).rstrip('/')}/"
|
||||
blobs = self._bucket.list_blobs(prefix=prefix, max_results=1)
|
||||
return next(blobs, None) is not None
|
||||
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
# GCS does not have directories, so we don't need to do anything
|
||||
pass
|
||||
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
old_path = self.get_path(old_path)
|
||||
new_path = self.get_path(new_path)
|
||||
blob = self._bucket.blob(str(old_path))
|
||||
# If the blob with exact name exists, rename it
|
||||
if blob.exists():
|
||||
self._bucket.rename_blob(blob, new_name=str(new_path))
|
||||
return
|
||||
# Otherwise, rename all blobs with the prefix (folder)
|
||||
for blob in self._bucket.list_blobs(prefix=f"{old_path}/"):
|
||||
new_name = str(blob.name).replace(str(old_path), str(new_path), 1)
|
||||
self._bucket.rename_blob(blob, new_name=new_name)
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> GCSFileStorage:
|
||||
"""Create a new GCSFileStorage with a subroot of the current storage."""
|
||||
file_storage = GCSFileStorage(
|
||||
GCSFileStorageConfiguration(
|
||||
root=Path("/").joinpath(self.get_path(subroot)),
|
||||
bucket=self._bucket_name,
|
||||
)
|
||||
)
|
||||
file_storage._gcs = self._gcs
|
||||
file_storage._bucket = self._bucket
|
||||
return file_storage
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
125
autogpts/autogpt/autogpt/file_storage/local.py
Normal file
125
autogpts/autogpt/autogpt/file_storage/local.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
The LocalFileStorage class implements a FileStorage that works with local files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import IO, Literal
|
||||
|
||||
from .base import FileStorage, FileStorageConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalFileStorage(FileStorage):
|
||||
"""A class that represents a file storage."""
|
||||
|
||||
def __init__(self, config: FileStorageConfiguration):
|
||||
self._root = config.root.resolve()
|
||||
self._restrict_to_root = config.restrict_to_root
|
||||
self.make_dir(self.root)
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file storage."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return self._restrict_to_root
|
||||
|
||||
@property
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
return True
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IO:
|
||||
"""Open a file in the storage."""
|
||||
return self._open_file(path, f"{mode}b" if binary else mode)
|
||||
|
||||
def _open_file(self, path: str | Path, mode: str) -> IO:
|
||||
full_path = self.get_path(path)
|
||||
return open(full_path, mode) # type: ignore
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
with self._open_file(path, "rb" if binary else "r") as file:
|
||||
return file.read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
with self._open_file(path, "wb" if type(content) is bytes else "w") as file:
|
||||
file.write(content)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
path = self.get_path(path)
|
||||
return [file.relative_to(path) for file in path.rglob("*") if file.is_file()]
|
||||
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List directories directly in a given path or recursively."""
|
||||
path = self.get_path(path)
|
||||
if recursive:
|
||||
return [
|
||||
folder.relative_to(path)
|
||||
for folder in path.rglob("*")
|
||||
if folder.is_dir()
|
||||
]
|
||||
else:
|
||||
return [
|
||||
folder.relative_to(path) for folder in path.iterdir() if folder.is_dir()
|
||||
]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
full_path = self.get_path(path)
|
||||
full_path.unlink()
|
||||
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
full_path = self.get_path(path)
|
||||
full_path.rmdir()
|
||||
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in the storage."""
|
||||
return self.get_path(path).exists()
|
||||
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
full_path = self.get_path(path)
|
||||
full_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
old_path = self.get_path(old_path)
|
||||
new_path = self.get_path(new_path)
|
||||
old_path.rename(new_path)
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
|
||||
"""Create a new LocalFileStorage with a subroot of the current storage."""
|
||||
return LocalFileStorage(
|
||||
FileStorageConfiguration(
|
||||
root=self.get_path(subroot),
|
||||
restrict_to_root=self.restrict_to_root,
|
||||
)
|
||||
)
|
||||
237
autogpts/autogpt/autogpt/file_storage/s3.py
Normal file
237
autogpts/autogpt/autogpt/file_storage/s3.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
The S3Workspace class provides an interface for interacting with a file workspace, and
|
||||
stores the files in an S3 bucket.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from io import IOBase, TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import boto3
|
||||
import botocore.exceptions
|
||||
from pydantic import SecretStr
|
||||
|
||||
from autogpt.core.configuration.schema import UserConfigurable
|
||||
|
||||
from .base import FileStorage, FileStorageConfiguration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import mypy_boto3_s3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S3FileStorageConfiguration(FileStorageConfiguration):
|
||||
bucket: str = UserConfigurable("autogpt", from_env="STORAGE_BUCKET")
|
||||
s3_endpoint_url: Optional[SecretStr] = UserConfigurable(
|
||||
from_env=lambda: SecretStr(v) if (v := os.getenv("S3_ENDPOINT_URL")) else None
|
||||
)
|
||||
|
||||
|
||||
class S3FileStorage(FileStorage):
|
||||
"""A class that represents an S3 storage."""
|
||||
|
||||
_bucket: mypy_boto3_s3.service_resource.Bucket
|
||||
|
||||
def __init__(self, config: S3FileStorageConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
assert self._root.is_absolute()
|
||||
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
|
||||
self._s3 = boto3.resource(
|
||||
"s3",
|
||||
endpoint_url=(
|
||||
config.s3_endpoint_url.get_secret_value()
|
||||
if config.s3_endpoint_url
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file storage."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self):
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
return False
|
||||
|
||||
def initialize(self) -> None:
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._s3.meta.client.head_bucket(Bucket=self._bucket_name)
|
||||
self._bucket = self._s3.Bucket(self._bucket_name)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if "(404)" not in str(e):
|
||||
raise
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._s3.create_bucket(Bucket=self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
# We set S3 root with "/" at the beginning
|
||||
# but relative_to("/") will remove it
|
||||
# because we don't actually want it in the storage filenames
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def _get_obj(self, path: str | Path) -> mypy_boto3_s3.service_resource.Object:
|
||||
"""Get an S3 object."""
|
||||
path = self.get_path(path)
|
||||
obj = self._bucket.Object(str(path))
|
||||
with contextlib.suppress(botocore.exceptions.ClientError):
|
||||
obj.load()
|
||||
return obj
|
||||
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IOBase:
|
||||
"""Open a file in the storage."""
|
||||
obj = self._get_obj(path)
|
||||
return obj.get()["Body"] if binary else TextIOWrapper(obj.get()["Body"])
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
return self.open_file(path, binary=binary).read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
obj = self._get_obj(path)
|
||||
obj.put(Body=content)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
path = self.get_path(path)
|
||||
if path == Path("."): # root level of bucket
|
||||
return [Path(obj.key) for obj in self._bucket.objects.all()]
|
||||
else:
|
||||
return [
|
||||
Path(obj.key).relative_to(path)
|
||||
for obj in self._bucket.objects.filter(Prefix=f"{path}/")
|
||||
]
|
||||
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List 'directories' directly in a given path or recursively in the storage."""
|
||||
path = self.get_path(path)
|
||||
folder_names = set()
|
||||
|
||||
# List objects with the specified prefix and delimiter
|
||||
for obj_summary in self._bucket.objects.filter(Prefix=str(path)):
|
||||
# Remove path prefix and the object name (last part)
|
||||
folder = Path(obj_summary.key).relative_to(path).parent
|
||||
if not folder or folder == Path("."):
|
||||
continue
|
||||
# For non-recursive, only add the first level of folders
|
||||
if not recursive:
|
||||
folder_names.add(folder.parts[0])
|
||||
else:
|
||||
# For recursive, need to add all nested folders
|
||||
for i in range(len(folder.parts)):
|
||||
folder_names.add("/".join(folder.parts[: i + 1]))
|
||||
|
||||
return [Path(f) for f in folder_names]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
path = self.get_path(path)
|
||||
obj = self._s3.Object(self._bucket_name, str(path))
|
||||
obj.delete()
|
||||
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
# S3 does not have directories, so we don't need to do anything
|
||||
pass
|
||||
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in S3 storage."""
|
||||
path = self.get_path(path)
|
||||
try:
|
||||
# Check for exact object match (file)
|
||||
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=str(path))
|
||||
return True
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
|
||||
# If the object does not exist,
|
||||
# check for objects with the prefix (folder)
|
||||
prefix = f"{str(path).rstrip('/')}/"
|
||||
objs = list(self._bucket.objects.filter(Prefix=prefix, MaxKeys=1))
|
||||
return len(objs) > 0 # True if any objects exist with the prefix
|
||||
else:
|
||||
raise # Re-raise for any other client errors
|
||||
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
# S3 does not have directories, so we don't need to do anything
|
||||
pass
|
||||
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
old_path = str(self.get_path(old_path))
|
||||
new_path = str(self.get_path(new_path))
|
||||
|
||||
try:
|
||||
# If file exists, rename it
|
||||
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=old_path)
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": old_path},
|
||||
Bucket=self._bucket_name,
|
||||
Key=new_path,
|
||||
)
|
||||
self._s3.meta.client.delete_object(Bucket=self._bucket_name, Key=old_path)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
|
||||
# If the object does not exist,
|
||||
# it may be a folder
|
||||
prefix = f"{old_path.rstrip('/')}/"
|
||||
objs = list(self._bucket.objects.filter(Prefix=prefix))
|
||||
for obj in objs:
|
||||
new_key = new_path + obj.key[len(old_path) :]
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": obj.key},
|
||||
Bucket=self._bucket_name,
|
||||
Key=new_key,
|
||||
)
|
||||
self._s3.meta.client.delete_object(
|
||||
Bucket=self._bucket_name, Key=obj.key
|
||||
)
|
||||
else:
|
||||
raise # Re-raise for any other client errors
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> S3FileStorage:
|
||||
"""Create a new S3FileStorage with a subroot of the current storage."""
|
||||
file_storage = S3FileStorage(
|
||||
S3FileStorageConfiguration(
|
||||
bucket=self._bucket_name,
|
||||
root=Path("/").joinpath(self.get_path(subroot)),
|
||||
s3_endpoint_url=self._s3.meta.client.meta.endpoint_url,
|
||||
)
|
||||
)
|
||||
file_storage._s3 = self._s3
|
||||
file_storage._bucket = self._bucket
|
||||
return file_storage
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
@@ -1,46 +0,0 @@
|
||||
import enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .base import FileWorkspace
|
||||
|
||||
|
||||
class FileWorkspaceBackendName(str, enum.Enum):
|
||||
LOCAL = "local"
|
||||
GCS = "gcs"
|
||||
S3 = "s3"
|
||||
|
||||
|
||||
def get_workspace(
|
||||
backend: FileWorkspaceBackendName, *, id: str = "", root_path: Optional[Path] = None
|
||||
) -> FileWorkspace:
|
||||
assert bool(root_path) != bool(id), "Specify root_path or id to get workspace"
|
||||
if root_path is None:
|
||||
root_path = Path(f"/workspaces/{id}")
|
||||
|
||||
match backend:
|
||||
case FileWorkspaceBackendName.LOCAL:
|
||||
from .local import FileWorkspaceConfiguration, LocalFileWorkspace
|
||||
|
||||
config = FileWorkspaceConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return LocalFileWorkspace(config)
|
||||
case FileWorkspaceBackendName.S3:
|
||||
from .s3 import S3FileWorkspace, S3FileWorkspaceConfiguration
|
||||
|
||||
config = S3FileWorkspaceConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return S3FileWorkspace(config)
|
||||
case FileWorkspaceBackendName.GCS:
|
||||
from .gcs import GCSFileWorkspace, GCSFileWorkspaceConfiguration
|
||||
|
||||
config = GCSFileWorkspaceConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return GCSFileWorkspace(config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FileWorkspace",
|
||||
"FileWorkspaceBackendName",
|
||||
"get_workspace",
|
||||
]
|
||||
@@ -1,164 +0,0 @@
|
||||
"""
|
||||
The FileWorkspace class provides an interface for interacting with a file workspace.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from io import IOBase, TextIOBase
|
||||
from pathlib import Path
|
||||
from typing import IO, Any, BinaryIO, Callable, Literal, Optional, TextIO, overload
|
||||
|
||||
from autogpt.core.configuration.schema import SystemConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileWorkspaceConfiguration(SystemConfiguration):
|
||||
restrict_to_root: bool = True
|
||||
root: Path = Path("/")
|
||||
|
||||
|
||||
class FileWorkspace(ABC):
|
||||
"""A class that represents a file workspace."""
|
||||
|
||||
on_write_file: Callable[[Path], Any] | None = None
|
||||
"""
|
||||
Event hook, executed after writing a file.
|
||||
|
||||
Params:
|
||||
Path: The path of the file that was written, relative to the workspace root.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def root(self) -> Path:
|
||||
"""The root path of the file workspace."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict file access to within the workspace's root path."""
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self) -> None:
|
||||
"""
|
||||
Calling `initialize()` should bring the workspace to a ready-to-use state.
|
||||
For example, it can create the resource in which files will be stored, if it
|
||||
doesn't exist yet. E.g. a folder on disk, or an S3 Bucket.
|
||||
"""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self, path: str | Path, binary: Literal[False] = False
|
||||
) -> TextIO | TextIOBase:
|
||||
"""Returns a readable text file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self, path: str | Path, binary: Literal[True] = True
|
||||
) -> BinaryIO | IOBase:
|
||||
"""Returns a readable binary file-like object representing the file."""
|
||||
|
||||
@abstractmethod
|
||||
def open_file(self, path: str | Path, binary: bool = False) -> IO | IOBase:
|
||||
"""Returns a readable file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
|
||||
"""Read a file in the workspace as text."""
|
||||
...
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: Literal[True] = True) -> bytes:
|
||||
"""Read a file in the workspace as binary."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the workspace."""
|
||||
|
||||
@abstractmethod
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the workspace."""
|
||||
|
||||
@abstractmethod
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the workspace."""
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
"""Get the full path for an item in the workspace.
|
||||
|
||||
Parameters:
|
||||
relative_path: The relative path to resolve in the workspace.
|
||||
|
||||
Returns:
|
||||
Path: The resolved path relative to the workspace.
|
||||
"""
|
||||
return self._sanitize_path(relative_path, self.root)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_path(
|
||||
relative_path: str | Path,
|
||||
root: Optional[str | Path] = None,
|
||||
restrict_to_root: bool = True,
|
||||
) -> Path:
|
||||
"""Resolve the relative path within the given root if possible.
|
||||
|
||||
Parameters:
|
||||
relative_path: The relative path to resolve.
|
||||
root: The root path to resolve the relative path within.
|
||||
restrict_to_root: Whether to restrict the path to the root.
|
||||
|
||||
Returns:
|
||||
Path: The resolved path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path is absolute and a root is provided.
|
||||
ValueError: If the path is outside the root and the root is restricted.
|
||||
"""
|
||||
|
||||
# Posix systems disallow null bytes in paths. Windows is agnostic about it.
|
||||
# Do an explicit check here for all sorts of null byte representations.
|
||||
|
||||
if "\0" in str(relative_path) or "\0" in str(root):
|
||||
raise ValueError("embedded null byte")
|
||||
|
||||
if root is None:
|
||||
return Path(relative_path).resolve()
|
||||
|
||||
logger.debug(f"Resolving path '{relative_path}' in workspace '{root}'")
|
||||
|
||||
root, relative_path = Path(root).resolve(), Path(relative_path)
|
||||
|
||||
logger.debug(f"Resolved root as '{root}'")
|
||||
|
||||
# Allow absolute paths if they are contained in the workspace.
|
||||
if (
|
||||
relative_path.is_absolute()
|
||||
and restrict_to_root
|
||||
and not relative_path.is_relative_to(root)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Attempted to access absolute path '{relative_path}' "
|
||||
f"in workspace '{root}'."
|
||||
)
|
||||
|
||||
full_path = root.joinpath(relative_path).resolve()
|
||||
|
||||
logger.debug(f"Joined paths as '{full_path}'")
|
||||
|
||||
if restrict_to_root and not full_path.is_relative_to(root):
|
||||
raise ValueError(
|
||||
f"Attempted to access path '{full_path}' outside of workspace '{root}'."
|
||||
)
|
||||
|
||||
return full_path
|
||||
@@ -1,113 +0,0 @@
|
||||
"""
|
||||
The GCSWorkspace class provides an interface for interacting with a file workspace, and
|
||||
stores the files in a Google Cloud Storage bucket.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from io import IOBase
|
||||
from pathlib import Path
|
||||
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
|
||||
from autogpt.core.configuration.schema import UserConfigurable
|
||||
|
||||
from .base import FileWorkspace, FileWorkspaceConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GCSFileWorkspaceConfiguration(FileWorkspaceConfiguration):
|
||||
bucket: str = UserConfigurable("autogpt", from_env="WORKSPACE_STORAGE_BUCKET")
|
||||
|
||||
|
||||
class GCSFileWorkspace(FileWorkspace):
|
||||
"""A class that represents a Google Cloud Storage workspace."""
|
||||
|
||||
_bucket: storage.Bucket
|
||||
|
||||
def __init__(self, config: GCSFileWorkspaceConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
assert self._root.is_absolute()
|
||||
|
||||
self._gcs = storage.Client()
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file workspace."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return True
|
||||
|
||||
def initialize(self) -> None:
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._bucket = self._gcs.get_bucket(self._bucket_name)
|
||||
except NotFound:
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._gcs.create_bucket(self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def _get_blob(self, path: str | Path) -> storage.Blob:
|
||||
path = self.get_path(path)
|
||||
return self._bucket.blob(str(path))
|
||||
|
||||
def open_file(self, path: str | Path, binary: bool = False) -> IOBase:
|
||||
"""Open a file in the workspace."""
|
||||
blob = self._get_blob(path)
|
||||
blob.reload() # pin revision number to prevent version mixing while reading
|
||||
return blob.open("rb" if binary else "r")
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the workspace."""
|
||||
return self.open_file(path, binary).read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the workspace."""
|
||||
blob = self._get_blob(path)
|
||||
|
||||
blob.upload_from_string(
|
||||
data=content,
|
||||
content_type=(
|
||||
"text/plain"
|
||||
if type(content) is str
|
||||
# TODO: get MIME type from file extension or binary content
|
||||
else "application/octet-stream"
|
||||
),
|
||||
)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
path = self.get_path(path)
|
||||
return [
|
||||
Path(blob.name).relative_to(path)
|
||||
for blob in self._bucket.list_blobs(
|
||||
prefix=f"{path}/" if path != Path(".") else None
|
||||
)
|
||||
]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the workspace."""
|
||||
path = self.get_path(path)
|
||||
blob = self._bucket.blob(str(path))
|
||||
blob.delete()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
@@ -1,71 +0,0 @@
|
||||
"""
|
||||
The LocalFileWorkspace class implements a FileWorkspace that works with local files.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import IO
|
||||
|
||||
from .base import FileWorkspace, FileWorkspaceConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalFileWorkspace(FileWorkspace):
|
||||
"""A class that represents a file workspace."""
|
||||
|
||||
def __init__(self, config: FileWorkspaceConfiguration):
|
||||
self._root = self._sanitize_path(config.root)
|
||||
self._restrict_to_root = config.restrict_to_root
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file workspace."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return self._restrict_to_root
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def open_file(self, path: str | Path, binary: bool = False) -> IO:
|
||||
"""Open a file in the workspace."""
|
||||
return self._open_file(path, "rb" if binary else "r")
|
||||
|
||||
def _open_file(self, path: str | Path, mode: str = "r") -> IO:
|
||||
full_path = self.get_path(path)
|
||||
return open(full_path, mode) # type: ignore
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the workspace."""
|
||||
with self._open_file(path, "rb" if binary else "r") as file:
|
||||
return file.read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the workspace."""
|
||||
with self._open_file(path, "wb" if type(content) is bytes else "w") as file:
|
||||
file.write(content)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
path = self.get_path(path)
|
||||
return [file.relative_to(path) for file in path.rglob("*") if file.is_file()]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the workspace."""
|
||||
full_path = self.get_path(path)
|
||||
full_path.unlink()
|
||||
@@ -1,128 +0,0 @@
|
||||
"""
|
||||
The S3Workspace class provides an interface for interacting with a file workspace, and
|
||||
stores the files in an S3 bucket.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from io import IOBase, TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import boto3
|
||||
import botocore.exceptions
|
||||
from pydantic import SecretStr
|
||||
|
||||
from autogpt.core.configuration.schema import UserConfigurable
|
||||
|
||||
from .base import FileWorkspace, FileWorkspaceConfiguration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import mypy_boto3_s3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S3FileWorkspaceConfiguration(FileWorkspaceConfiguration):
|
||||
bucket: str = UserConfigurable("autogpt", from_env="WORKSPACE_STORAGE_BUCKET")
|
||||
s3_endpoint_url: Optional[SecretStr] = UserConfigurable(
|
||||
from_env=lambda: SecretStr(v) if (v := os.getenv("S3_ENDPOINT_URL")) else None
|
||||
)
|
||||
|
||||
|
||||
class S3FileWorkspace(FileWorkspace):
|
||||
"""A class that represents an S3 workspace."""
|
||||
|
||||
_bucket: mypy_boto3_s3.service_resource.Bucket
|
||||
|
||||
def __init__(self, config: S3FileWorkspaceConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
assert self._root.is_absolute()
|
||||
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
|
||||
self._s3 = boto3.resource(
|
||||
"s3",
|
||||
endpoint_url=config.s3_endpoint_url.get_secret_value()
|
||||
if config.s3_endpoint_url
|
||||
else None,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file workspace."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self):
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return True
|
||||
|
||||
def initialize(self) -> None:
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._s3.meta.client.head_bucket(Bucket=self._bucket_name)
|
||||
self._bucket = self._s3.Bucket(self._bucket_name)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if "(404)" not in str(e):
|
||||
raise
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._s3.create_bucket(Bucket=self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def _get_obj(self, path: str | Path) -> mypy_boto3_s3.service_resource.Object:
|
||||
"""Get an S3 object."""
|
||||
path = self.get_path(path)
|
||||
obj = self._bucket.Object(str(path))
|
||||
with contextlib.suppress(botocore.exceptions.ClientError):
|
||||
obj.load()
|
||||
return obj
|
||||
|
||||
def open_file(self, path: str | Path, binary: bool = False) -> IOBase:
|
||||
"""Open a file in the workspace."""
|
||||
obj = self._get_obj(path)
|
||||
return obj.get()["Body"] if binary else TextIOWrapper(obj.get()["Body"])
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the workspace."""
|
||||
return self.open_file(path, binary).read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the workspace."""
|
||||
obj = self._get_obj(path)
|
||||
obj.put(Body=content)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
path = self.get_path(path)
|
||||
if path == Path("."): # root level of bucket
|
||||
return [Path(obj.key) for obj in self._bucket.objects.all()]
|
||||
else:
|
||||
return [
|
||||
Path(obj.key).relative_to(path)
|
||||
for obj in self._bucket.objects.filter(Prefix=f"{path}/")
|
||||
]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the workspace."""
|
||||
path = self.get_path(path)
|
||||
obj = self._s3.Object(self._bucket_name, str(path))
|
||||
obj.delete()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
@@ -11,10 +13,10 @@ from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.app.main import _configure_openai_provider
|
||||
from autogpt.config import AIProfile, Config, ConfigBuilder
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider, OpenAIProvider
|
||||
from autogpt.file_workspace.local import (
|
||||
FileWorkspace,
|
||||
FileWorkspaceConfiguration,
|
||||
LocalFileWorkspace,
|
||||
from autogpt.file_storage.local import (
|
||||
FileStorage,
|
||||
FileStorageConfiguration,
|
||||
LocalFileStorage,
|
||||
)
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
from autogpt.logs.config import configure_logging
|
||||
@@ -40,20 +42,12 @@ def app_data_dir(tmp_project_root: Path) -> Path:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent_data_dir(app_data_dir: Path) -> Path:
|
||||
return app_data_dir / "agents/AutoGPT"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def workspace_root(agent_data_dir: Path) -> Path:
|
||||
return agent_data_dir / "workspace"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def workspace(workspace_root: Path) -> FileWorkspace:
|
||||
workspace = LocalFileWorkspace(FileWorkspaceConfiguration(root=workspace_root))
|
||||
workspace.initialize()
|
||||
return workspace
|
||||
def storage(app_data_dir: Path) -> FileStorage:
|
||||
storage = LocalFileStorage(
|
||||
FileStorageConfiguration(root=app_data_dir, restrict_to_root=False)
|
||||
)
|
||||
storage.initialize()
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -120,7 +114,7 @@ def llm_provider(config: Config) -> OpenAIProvider:
|
||||
|
||||
@pytest.fixture
|
||||
def agent(
|
||||
agent_data_dir: Path, config: Config, llm_provider: ChatModelProvider
|
||||
config: Config, llm_provider: ChatModelProvider, storage: FileStorage
|
||||
) -> Agent:
|
||||
ai_profile = AIProfile(
|
||||
ai_name="Base",
|
||||
@@ -153,7 +147,7 @@ def agent(
|
||||
settings=agent_settings,
|
||||
llm_provider=llm_provider,
|
||||
command_registry=command_registry,
|
||||
file_storage=storage,
|
||||
legacy_config=config,
|
||||
)
|
||||
agent.attach_fs(agent_data_dir)
|
||||
return agent
|
||||
|
||||
@@ -4,14 +4,12 @@ import orjson
|
||||
import pytest
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.file_workspace import FileWorkspace
|
||||
from autogpt.file_storage import FileStorage
|
||||
from autogpt.memory.vector import JSONFileMemory, MemoryItem
|
||||
|
||||
|
||||
def test_json_memory_init_without_backing_file(
|
||||
config: Config, workspace: FileWorkspace
|
||||
):
|
||||
index_file = workspace.root / f"{config.memory_index}.json"
|
||||
def test_json_memory_init_without_backing_file(config: Config, storage: FileStorage):
|
||||
index_file = storage.root / f"{config.memory_index}.json"
|
||||
|
||||
assert not index_file.exists()
|
||||
JSONFileMemory(config)
|
||||
@@ -19,10 +17,8 @@ def test_json_memory_init_without_backing_file(
|
||||
assert index_file.read_text() == "[]"
|
||||
|
||||
|
||||
def test_json_memory_init_with_backing_empty_file(
|
||||
config: Config, workspace: FileWorkspace
|
||||
):
|
||||
index_file = workspace.root / f"{config.memory_index}.json"
|
||||
def test_json_memory_init_with_backing_empty_file(config: Config, storage: FileStorage):
|
||||
index_file = storage.root / f"{config.memory_index}.json"
|
||||
index_file.touch()
|
||||
|
||||
assert index_file.exists()
|
||||
@@ -32,9 +28,9 @@ def test_json_memory_init_with_backing_empty_file(
|
||||
|
||||
|
||||
def test_json_memory_init_with_backing_invalid_file(
|
||||
config: Config, workspace: FileWorkspace
|
||||
config: Config, storage: FileStorage
|
||||
):
|
||||
index_file = workspace.root / f"{config.memory_index}.json"
|
||||
index_file = storage.root / f"{config.memory_index}.json"
|
||||
index_file.touch()
|
||||
|
||||
raw_data = {"texts": ["test"]}
|
||||
|
||||
@@ -18,11 +18,11 @@ def image_size(request):
|
||||
|
||||
@pytest.mark.requires_openai_api_key
|
||||
@pytest.mark.vcr
|
||||
def test_dalle(agent: Agent, workspace, image_size, cached_openai_client):
|
||||
def test_dalle(agent: Agent, storage, image_size, cached_openai_client):
|
||||
"""Test DALL-E image generation."""
|
||||
generate_and_validate(
|
||||
agent,
|
||||
workspace,
|
||||
storage,
|
||||
image_provider="dalle",
|
||||
image_size=image_size,
|
||||
)
|
||||
@@ -37,11 +37,11 @@ def test_dalle(agent: Agent, workspace, image_size, cached_openai_client):
|
||||
"image_model",
|
||||
["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"],
|
||||
)
|
||||
def test_huggingface(agent: Agent, workspace, image_size, image_model):
|
||||
def test_huggingface(agent: Agent, storage, image_size, image_model):
|
||||
"""Test HuggingFace image generation."""
|
||||
generate_and_validate(
|
||||
agent,
|
||||
workspace,
|
||||
storage,
|
||||
image_provider="huggingface",
|
||||
image_size=image_size,
|
||||
hugging_face_image_model=image_model,
|
||||
@@ -49,18 +49,18 @@ def test_huggingface(agent: Agent, workspace, image_size, image_model):
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="SD WebUI call does not work.")
|
||||
def test_sd_webui(agent: Agent, workspace, image_size):
|
||||
def test_sd_webui(agent: Agent, storage, image_size):
|
||||
"""Test SD WebUI image generation."""
|
||||
generate_and_validate(
|
||||
agent,
|
||||
workspace,
|
||||
storage,
|
||||
image_provider="sd_webui",
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="SD WebUI call does not work.")
|
||||
def test_sd_webui_negative_prompt(agent: Agent, workspace, image_size):
|
||||
def test_sd_webui_negative_prompt(agent: Agent, storage, image_size):
|
||||
gen_image = functools.partial(
|
||||
generate_image_with_sd_webui,
|
||||
prompt="astronaut riding a horse",
|
||||
@@ -91,7 +91,7 @@ def lst(txt):
|
||||
|
||||
def generate_and_validate(
|
||||
agent: Agent,
|
||||
workspace,
|
||||
storage,
|
||||
image_size,
|
||||
image_provider,
|
||||
hugging_face_image_model=None,
|
||||
@@ -125,7 +125,7 @@ def generate_and_validate(
|
||||
)
|
||||
@pytest.mark.parametrize("delay", [10, 0])
|
||||
def test_huggingface_fail_request_with_delay(
|
||||
agent: Agent, workspace, image_size, image_model, return_text, delay
|
||||
agent: Agent, storage, image_size, image_model, return_text, delay
|
||||
):
|
||||
return_text = return_text.replace("[model]", image_model).replace(
|
||||
"[delay]", str(delay)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from autogpt.config.ai_profile import AIProfile
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
|
||||
"""
|
||||
Test cases for the AIProfile class, which handles loads the AI configuration
|
||||
@@ -45,10 +46,10 @@ api_budget: 0.0
|
||||
assert ai_settings_file.read_text() == yaml_content2
|
||||
|
||||
|
||||
def test_ai_profile_file_not_exists(workspace):
|
||||
def test_ai_profile_file_not_exists(storage: FileStorage):
|
||||
"""Test if file does not exist."""
|
||||
|
||||
ai_settings_file = workspace.get_path("ai_settings.yaml")
|
||||
ai_settings_file = storage.get_path("ai_settings.yaml")
|
||||
|
||||
ai_profile = AIProfile.load(str(ai_settings_file))
|
||||
assert ai_profile.ai_name == ""
|
||||
@@ -57,10 +58,10 @@ def test_ai_profile_file_not_exists(workspace):
|
||||
assert ai_profile.api_budget == 0.0
|
||||
|
||||
|
||||
def test_ai_profile_file_is_empty(workspace):
|
||||
def test_ai_profile_file_is_empty(storage: FileStorage):
|
||||
"""Test if file does not exist."""
|
||||
|
||||
ai_settings_file = workspace.get_path("ai_settings.yaml")
|
||||
ai_settings_file = storage.get_path("ai_settings.yaml")
|
||||
ai_settings_file.write_text("")
|
||||
|
||||
ai_profile = AIProfile.load(str(ai_settings_file))
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -11,7 +9,7 @@ import autogpt.commands.file_operations as file_ops
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import DuplicateOperationError
|
||||
from autogpt.config import Config
|
||||
from autogpt.file_workspace import FileWorkspace
|
||||
from autogpt.file_storage import FileStorage
|
||||
from autogpt.memory.vector.memory_item import MemoryItem
|
||||
from autogpt.memory.vector.utils import Embedding
|
||||
|
||||
@@ -46,40 +44,22 @@ def test_file_name():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_file_path(test_file_name: Path, workspace: FileWorkspace):
|
||||
return workspace.get_path(test_file_name)
|
||||
def test_file_path(test_file_name: Path, storage: FileStorage):
|
||||
return storage.get_path(test_file_name)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_file(test_file_path: Path):
|
||||
file = open(test_file_path, "w")
|
||||
yield file
|
||||
if not file.closed:
|
||||
file.close()
|
||||
def test_directory(storage: FileStorage):
|
||||
return storage.get_path("test_directory")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_file_with_content_path(test_file: TextIOWrapper, file_content, agent: Agent):
|
||||
test_file.write(file_content)
|
||||
test_file.close()
|
||||
file_ops.log_operation(
|
||||
"write", Path(test_file.name), agent, file_ops.text_checksum(file_content)
|
||||
)
|
||||
return Path(test_file.name)
|
||||
def test_nested_file(storage: FileStorage):
|
||||
return storage.get_path("nested/test_file.txt")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_directory(workspace: FileWorkspace):
|
||||
return workspace.get_path("test_directory")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_nested_file(workspace: FileWorkspace):
|
||||
return workspace.get_path("nested/test_file.txt")
|
||||
|
||||
|
||||
def test_file_operations_log(test_file: TextIOWrapper):
|
||||
log_file_content = (
|
||||
def test_file_operations_log():
|
||||
all_logs = (
|
||||
"File Operation Logger\n"
|
||||
"write: path/to/file1.txt #checksum1\n"
|
||||
"write: path/to/file2.txt #checksum2\n"
|
||||
@@ -87,8 +67,7 @@ def test_file_operations_log(test_file: TextIOWrapper):
|
||||
"append: path/to/file2.txt #checksum4\n"
|
||||
"delete: path/to/file3.txt\n"
|
||||
)
|
||||
test_file.write(log_file_content)
|
||||
test_file.close()
|
||||
logs = all_logs.split("\n")
|
||||
|
||||
expected = [
|
||||
("write", "path/to/file1.txt", "checksum1"),
|
||||
@@ -97,28 +76,7 @@ def test_file_operations_log(test_file: TextIOWrapper):
|
||||
("append", "path/to/file2.txt", "checksum4"),
|
||||
("delete", "path/to/file3.txt", None),
|
||||
]
|
||||
assert list(file_ops.operations_from_log(test_file.name)) == expected
|
||||
|
||||
|
||||
def test_file_operations_state(test_file: TextIOWrapper):
|
||||
# Prepare a fake log file
|
||||
log_file_content = (
|
||||
"File Operation Logger\n"
|
||||
"write: path/to/file1.txt #checksum1\n"
|
||||
"write: path/to/file2.txt #checksum2\n"
|
||||
"write: path/to/file3.txt #checksum3\n"
|
||||
"append: path/to/file2.txt #checksum4\n"
|
||||
"delete: path/to/file3.txt\n"
|
||||
)
|
||||
test_file.write(log_file_content)
|
||||
test_file.close()
|
||||
|
||||
# Call the function and check the returned dictionary
|
||||
expected_state = {
|
||||
"path/to/file1.txt": "checksum1",
|
||||
"path/to/file2.txt": "checksum4",
|
||||
}
|
||||
assert file_ops.file_operations_state(test_file.name) == expected_state
|
||||
assert list(file_ops.operations_from_log(logs)) == expected
|
||||
|
||||
|
||||
def test_is_duplicate_operation(agent: Agent, mocker: MockerFixture):
|
||||
@@ -167,11 +125,11 @@ def test_is_duplicate_operation(agent: Agent, mocker: MockerFixture):
|
||||
|
||||
|
||||
# Test logging a file operation
|
||||
def test_log_operation(agent: Agent):
|
||||
file_ops.log_operation("log_test", Path("path/to/test"), agent=agent)
|
||||
with open(agent.file_manager.file_ops_log_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert "log_test: path/to/test\n" in content
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_operation(agent: Agent):
|
||||
await file_ops.log_operation("log_test", Path("path/to/test"), agent=agent)
|
||||
log_entry = agent.get_file_operation_lines()[-1]
|
||||
assert "log_test: path/to/test" in log_entry
|
||||
|
||||
|
||||
def test_text_checksum(file_content: str):
|
||||
@@ -181,22 +139,27 @@ def test_text_checksum(file_content: str):
|
||||
assert checksum != different_checksum
|
||||
|
||||
|
||||
def test_log_operation_with_checksum(agent: Agent):
|
||||
file_ops.log_operation(
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_operation_with_checksum(agent: Agent):
|
||||
await file_ops.log_operation(
|
||||
"log_test", Path("path/to/test"), agent=agent, checksum="ABCDEF"
|
||||
)
|
||||
with open(agent.file_manager.file_ops_log_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert "log_test: path/to/test #ABCDEF\n" in content
|
||||
log_entry = agent.get_file_operation_lines()[-1]
|
||||
assert "log_test: path/to/test #ABCDEF" in log_entry
|
||||
|
||||
|
||||
def test_read_file(
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file(
|
||||
mock_MemoryItem_from_text,
|
||||
test_file_with_content_path: Path,
|
||||
test_file_path: Path,
|
||||
file_content,
|
||||
agent: Agent,
|
||||
):
|
||||
content = file_ops.read_file(test_file_with_content_path, agent=agent)
|
||||
await agent.workspace.write_file(test_file_path.name, file_content)
|
||||
await file_ops.log_operation(
|
||||
"write", Path(test_file_path.name), agent, file_ops.text_checksum(file_content)
|
||||
)
|
||||
content = file_ops.read_file(test_file_path.name, agent=agent)
|
||||
assert content.replace("\r", "") == file_content
|
||||
|
||||
|
||||
@@ -229,15 +192,14 @@ async def test_write_file_logs_checksum(test_file_name: Path, agent: Agent):
|
||||
new_content = "This is new content.\n"
|
||||
new_checksum = file_ops.text_checksum(new_content)
|
||||
await file_ops.write_to_file(test_file_name, new_content, agent=agent)
|
||||
with open(agent.file_manager.file_ops_log_path, "r", encoding="utf-8") as f:
|
||||
log_entry = f.read()
|
||||
assert log_entry == f"write: {test_file_name} #{new_checksum}\n"
|
||||
log_entry = agent.get_file_operation_lines()[-1]
|
||||
assert log_entry == f"write: {test_file_name} #{new_checksum}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_fails_if_content_exists(test_file_name: Path, agent: Agent):
|
||||
new_content = "This is new content.\n"
|
||||
file_ops.log_operation(
|
||||
await file_ops.log_operation(
|
||||
"write",
|
||||
test_file_name,
|
||||
agent=agent,
|
||||
@@ -249,81 +211,42 @@ async def test_write_file_fails_if_content_exists(test_file_name: Path, agent: A
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_succeeds_if_content_different(
|
||||
test_file_with_content_path: Path, agent: Agent
|
||||
test_file_path: Path, file_content: str, agent: Agent
|
||||
):
|
||||
await agent.workspace.write_file(test_file_path.name, file_content)
|
||||
await file_ops.log_operation(
|
||||
"write", Path(test_file_path.name), agent, file_ops.text_checksum(file_content)
|
||||
)
|
||||
new_content = "This is different content.\n"
|
||||
await file_ops.write_to_file(test_file_with_content_path, new_content, agent=agent)
|
||||
await file_ops.write_to_file(test_file_path.name, new_content, agent=agent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_to_file(test_nested_file: Path, agent: Agent):
|
||||
append_text = "This is appended text.\n"
|
||||
await file_ops.write_to_file(test_nested_file, append_text, agent=agent)
|
||||
async def test_list_files(agent: Agent):
|
||||
# Create files A and B
|
||||
file_a_name = "file_a.txt"
|
||||
file_b_name = "file_b.txt"
|
||||
test_directory = Path("test_directory")
|
||||
|
||||
file_ops.append_to_file(test_nested_file, append_text, agent=agent)
|
||||
|
||||
with open(test_nested_file, "r") as f:
|
||||
content_after = f.read()
|
||||
|
||||
assert content_after == append_text + append_text
|
||||
|
||||
|
||||
def test_append_to_file_uses_checksum_from_appended_file(
|
||||
test_file_name: Path, agent: Agent
|
||||
):
|
||||
append_text = "This is appended text.\n"
|
||||
file_ops.append_to_file(
|
||||
agent.workspace.get_path(test_file_name),
|
||||
append_text,
|
||||
agent=agent,
|
||||
)
|
||||
file_ops.append_to_file(
|
||||
agent.workspace.get_path(test_file_name),
|
||||
append_text,
|
||||
agent=agent,
|
||||
)
|
||||
with open(agent.file_manager.file_ops_log_path, "r", encoding="utf-8") as f:
|
||||
log_contents = f.read()
|
||||
|
||||
digest = hashlib.md5()
|
||||
digest.update(append_text.encode("utf-8"))
|
||||
checksum1 = digest.hexdigest()
|
||||
digest.update(append_text.encode("utf-8"))
|
||||
checksum2 = digest.hexdigest()
|
||||
assert log_contents == (
|
||||
f"append: {test_file_name} #{checksum1}\n"
|
||||
f"append: {test_file_name} #{checksum2}\n"
|
||||
)
|
||||
|
||||
|
||||
def test_list_files(workspace: FileWorkspace, test_directory: Path, agent: Agent):
|
||||
# Case 1: Create files A and B, search for A, and ensure we don't return A and B
|
||||
file_a = workspace.get_path("file_a.txt")
|
||||
file_b = workspace.get_path("file_b.txt")
|
||||
|
||||
with open(file_a, "w") as f:
|
||||
f.write("This is file A.")
|
||||
|
||||
with open(file_b, "w") as f:
|
||||
f.write("This is file B.")
|
||||
await agent.workspace.write_file(file_a_name, "This is file A.")
|
||||
await agent.workspace.write_file(file_b_name, "This is file B.")
|
||||
|
||||
# Create a subdirectory and place a copy of file_a in it
|
||||
if not os.path.exists(test_directory):
|
||||
os.makedirs(test_directory)
|
||||
agent.workspace.make_dir(test_directory)
|
||||
await agent.workspace.write_file(
|
||||
test_directory / file_a_name, "This is file A in the subdirectory."
|
||||
)
|
||||
|
||||
with open(os.path.join(test_directory, file_a.name), "w") as f:
|
||||
f.write("This is file A in the subdirectory.")
|
||||
|
||||
files = file_ops.list_folder(str(workspace.root), agent=agent)
|
||||
assert file_a.name in files
|
||||
assert file_b.name in files
|
||||
assert os.path.join(Path(test_directory).name, file_a.name) in files
|
||||
files = file_ops.list_folder(".", agent=agent)
|
||||
assert file_a_name in files
|
||||
assert file_b_name in files
|
||||
assert os.path.join(test_directory, file_a_name) in files
|
||||
|
||||
# Clean up
|
||||
os.remove(file_a)
|
||||
os.remove(file_b)
|
||||
os.remove(os.path.join(test_directory, file_a.name))
|
||||
os.rmdir(test_directory)
|
||||
agent.workspace.delete_file(file_a_name)
|
||||
agent.workspace.delete_file(file_b_name)
|
||||
agent.workspace.delete_file(test_directory / file_a_name)
|
||||
agent.workspace.delete_dir(test_directory)
|
||||
|
||||
# Case 2: Search for a file that does not exist and make sure we don't throw
|
||||
non_existent_file = "non_existent_file.txt"
|
||||
|
||||
179
autogpts/autogpt/tests/unit/test_gcs_file_storage.py
Normal file
179
autogpts/autogpt/tests/unit/test_gcs_file_storage.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from google.auth.exceptions import GoogleAuthError
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
|
||||
from autogpt.file_storage.gcs import GCSFileStorage, GCSFileStorageConfiguration
|
||||
|
||||
try:
|
||||
storage.Client()
|
||||
except GoogleAuthError:
|
||||
pytest.skip("Google Cloud Authentication not configured", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_bucket_name() -> str:
|
||||
return f"test-bucket-{str(uuid.uuid4())[:8]}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_root() -> Path:
|
||||
return Path("/workspaces/AutoGPT-some-unique-task-id")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_storage_uninitialized(gcs_bucket_name: str, gcs_root: Path) -> GCSFileStorage:
|
||||
os.environ["STORAGE_BUCKET"] = gcs_bucket_name
|
||||
storage_config = GCSFileStorageConfiguration.from_env()
|
||||
storage_config.root = gcs_root
|
||||
storage = GCSFileStorage(storage_config)
|
||||
yield storage # type: ignore
|
||||
del os.environ["STORAGE_BUCKET"]
|
||||
|
||||
|
||||
def test_initialize(gcs_bucket_name: str, gcs_storage_uninitialized: GCSFileStorage):
|
||||
gcs = gcs_storage_uninitialized._gcs
|
||||
|
||||
# test that the bucket doesn't exist yet
|
||||
with pytest.raises(NotFound):
|
||||
gcs.get_bucket(gcs_bucket_name)
|
||||
|
||||
gcs_storage_uninitialized.initialize()
|
||||
|
||||
# test that the bucket has been created
|
||||
bucket = gcs.get_bucket(gcs_bucket_name)
|
||||
|
||||
# clean up
|
||||
bucket.delete(force=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_storage(gcs_storage_uninitialized: GCSFileStorage) -> GCSFileStorage:
|
||||
(gcs_storage := gcs_storage_uninitialized).initialize()
|
||||
yield gcs_storage # type: ignore
|
||||
|
||||
# Empty & delete the test bucket
|
||||
gcs_storage._bucket.delete(force=True)
|
||||
|
||||
|
||||
def test_workspace_bucket_name(
|
||||
gcs_storage: GCSFileStorage,
|
||||
gcs_bucket_name: str,
|
||||
):
|
||||
assert gcs_storage._bucket.name == gcs_bucket_name
|
||||
|
||||
|
||||
NESTED_DIR = "existing/test/dir"
|
||||
TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
("existing_test_file_1", "test content 1"),
|
||||
("existing_test_file_2.txt", "test content 2"),
|
||||
(Path("existing_test_file_3"), "test content 3"),
|
||||
(Path(f"{NESTED_DIR}/test_file_4"), "test content 4"),
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def gcs_storage_with_files(gcs_storage: GCSFileStorage) -> GCSFileStorage:
|
||||
for file_name, file_content in TEST_FILES:
|
||||
gcs_storage._bucket.blob(
|
||||
str(gcs_storage.get_path(file_name))
|
||||
).upload_from_string(file_content)
|
||||
yield gcs_storage # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file(gcs_storage_with_files: GCSFileStorage):
|
||||
for file_name, file_content in TEST_FILES:
|
||||
content = gcs_storage_with_files.read_file(file_name)
|
||||
assert content == file_content
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
gcs_storage_with_files.read_file("non_existent_file")
|
||||
|
||||
|
||||
def test_list_files(gcs_storage_with_files: GCSFileStorage):
|
||||
# List at root level
|
||||
assert (
|
||||
files := gcs_storage_with_files.list_files()
|
||||
) == gcs_storage_with_files.list_files()
|
||||
assert len(files) > 0
|
||||
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
|
||||
|
||||
# List at nested path
|
||||
assert (
|
||||
nested_files := gcs_storage_with_files.list_files(NESTED_DIR)
|
||||
) == gcs_storage_with_files.list_files(NESTED_DIR)
|
||||
assert len(nested_files) > 0
|
||||
assert set(nested_files) == set(
|
||||
p.relative_to(NESTED_DIR)
|
||||
for file_name, _ in TEST_FILES
|
||||
if (p := Path(file_name)).is_relative_to(NESTED_DIR)
|
||||
)
|
||||
|
||||
|
||||
def test_list_folders(gcs_storage_with_files: GCSFileStorage):
|
||||
# List recursive
|
||||
folders = gcs_storage_with_files.list_folders(recursive=True)
|
||||
assert len(folders) > 0
|
||||
assert set(folders) == {
|
||||
Path("existing"),
|
||||
Path("existing/test"),
|
||||
Path("existing/test/dir"),
|
||||
}
|
||||
# List non-recursive
|
||||
folders = gcs_storage_with_files.list_folders(recursive=False)
|
||||
assert len(folders) > 0
|
||||
assert set(folders) == {Path("existing")}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_read_file(gcs_storage: GCSFileStorage):
|
||||
await gcs_storage.write_file("test_file", "test_content")
|
||||
assert gcs_storage.read_file("test_file") == "test_content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overwrite_file(gcs_storage_with_files: GCSFileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
await gcs_storage_with_files.write_file(file_name, "new content")
|
||||
assert gcs_storage_with_files.read_file(file_name) == "new content"
|
||||
|
||||
|
||||
def test_delete_file(gcs_storage_with_files: GCSFileStorage):
|
||||
for file_to_delete, _ in TEST_FILES:
|
||||
gcs_storage_with_files.delete_file(file_to_delete)
|
||||
assert not gcs_storage_with_files.exists(file_to_delete)
|
||||
|
||||
|
||||
def test_exists(gcs_storage_with_files: GCSFileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
assert gcs_storage_with_files.exists(file_name)
|
||||
|
||||
assert not gcs_storage_with_files.exists("non_existent_file")
|
||||
|
||||
|
||||
def test_rename_file(gcs_storage_with_files: GCSFileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
new_name = str(file_name) + "_renamed"
|
||||
gcs_storage_with_files.rename(file_name, new_name)
|
||||
assert gcs_storage_with_files.exists(new_name)
|
||||
assert not gcs_storage_with_files.exists(file_name)
|
||||
|
||||
|
||||
def test_rename_dir(gcs_storage_with_files: GCSFileStorage):
|
||||
gcs_storage_with_files.rename(NESTED_DIR, "existing/test/dir_renamed")
|
||||
assert gcs_storage_with_files.exists("existing/test/dir_renamed")
|
||||
assert not gcs_storage_with_files.exists(NESTED_DIR)
|
||||
|
||||
|
||||
def test_clone(gcs_storage_with_files: GCSFileStorage, gcs_root: Path):
|
||||
cloned = gcs_storage_with_files.clone_with_subroot("existing/test")
|
||||
assert cloned.root == gcs_root / Path("existing/test")
|
||||
assert cloned._bucket.name == gcs_storage_with_files._bucket.name
|
||||
assert cloned.exists("dir")
|
||||
assert cloned.exists("dir/test_file_4")
|
||||
@@ -1,131 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from google.auth.exceptions import GoogleAuthError
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
|
||||
from autogpt.file_workspace.gcs import GCSFileWorkspace, GCSFileWorkspaceConfiguration
|
||||
|
||||
try:
|
||||
storage.Client()
|
||||
except GoogleAuthError:
|
||||
pytest.skip("Google Cloud Authentication not configured", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_bucket_name() -> str:
|
||||
return f"test-bucket-{str(uuid.uuid4())[:8]}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_workspace_uninitialized(gcs_bucket_name: str) -> GCSFileWorkspace:
|
||||
os.environ["WORKSPACE_STORAGE_BUCKET"] = gcs_bucket_name
|
||||
ws_config = GCSFileWorkspaceConfiguration.from_env()
|
||||
ws_config.root = Path("/workspaces/AutoGPT-some-unique-task-id")
|
||||
workspace = GCSFileWorkspace(ws_config)
|
||||
yield workspace # type: ignore
|
||||
del os.environ["WORKSPACE_STORAGE_BUCKET"]
|
||||
|
||||
|
||||
def test_initialize(
|
||||
gcs_bucket_name: str, gcs_workspace_uninitialized: GCSFileWorkspace
|
||||
):
|
||||
gcs = gcs_workspace_uninitialized._gcs
|
||||
|
||||
# test that the bucket doesn't exist yet
|
||||
with pytest.raises(NotFound):
|
||||
gcs.get_bucket(gcs_bucket_name)
|
||||
|
||||
gcs_workspace_uninitialized.initialize()
|
||||
|
||||
# test that the bucket has been created
|
||||
bucket = gcs.get_bucket(gcs_bucket_name)
|
||||
|
||||
# clean up
|
||||
bucket.delete(force=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_workspace(gcs_workspace_uninitialized: GCSFileWorkspace) -> GCSFileWorkspace:
|
||||
(gcs_workspace := gcs_workspace_uninitialized).initialize()
|
||||
yield gcs_workspace # type: ignore
|
||||
|
||||
# Empty & delete the test bucket
|
||||
gcs_workspace._bucket.delete(force=True)
|
||||
|
||||
|
||||
def test_workspace_bucket_name(
|
||||
gcs_workspace: GCSFileWorkspace,
|
||||
gcs_bucket_name: str,
|
||||
):
|
||||
assert gcs_workspace._bucket.name == gcs_bucket_name
|
||||
|
||||
|
||||
NESTED_DIR = "existing/test/dir"
|
||||
TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
("existing_test_file_1", "test content 1"),
|
||||
("existing_test_file_2.txt", "test content 2"),
|
||||
(Path("existing_test_file_3"), "test content 3"),
|
||||
(Path(f"{NESTED_DIR}/test/file/4"), "test content 4"),
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def gcs_workspace_with_files(gcs_workspace: GCSFileWorkspace) -> GCSFileWorkspace:
|
||||
for file_name, file_content in TEST_FILES:
|
||||
gcs_workspace._bucket.blob(
|
||||
str(gcs_workspace.get_path(file_name))
|
||||
).upload_from_string(file_content)
|
||||
yield gcs_workspace # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file(gcs_workspace_with_files: GCSFileWorkspace):
|
||||
for file_name, file_content in TEST_FILES:
|
||||
content = gcs_workspace_with_files.read_file(file_name)
|
||||
assert content == file_content
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
gcs_workspace_with_files.read_file("non_existent_file")
|
||||
|
||||
|
||||
def test_list_files(gcs_workspace_with_files: GCSFileWorkspace):
|
||||
# List at root level
|
||||
assert (files := gcs_workspace_with_files.list()) == gcs_workspace_with_files.list()
|
||||
assert len(files) > 0
|
||||
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
|
||||
|
||||
# List at nested path
|
||||
assert (
|
||||
nested_files := gcs_workspace_with_files.list(NESTED_DIR)
|
||||
) == gcs_workspace_with_files.list(NESTED_DIR)
|
||||
assert len(nested_files) > 0
|
||||
assert set(nested_files) == set(
|
||||
p.relative_to(NESTED_DIR)
|
||||
for file_name, _ in TEST_FILES
|
||||
if (p := Path(file_name)).is_relative_to(NESTED_DIR)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_read_file(gcs_workspace: GCSFileWorkspace):
|
||||
await gcs_workspace.write_file("test_file", "test_content")
|
||||
assert gcs_workspace.read_file("test_file") == "test_content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overwrite_file(gcs_workspace_with_files: GCSFileWorkspace):
|
||||
for file_name, _ in TEST_FILES:
|
||||
await gcs_workspace_with_files.write_file(file_name, "new content")
|
||||
assert gcs_workspace_with_files.read_file(file_name) == "new content"
|
||||
|
||||
|
||||
def test_delete_file(gcs_workspace_with_files: GCSFileWorkspace):
|
||||
for file_to_delete, _ in TEST_FILES:
|
||||
gcs_workspace_with_files.delete_file(file_to_delete)
|
||||
with pytest.raises(NotFound):
|
||||
gcs_workspace_with_files.read_file(file_to_delete)
|
||||
@@ -5,6 +5,7 @@ from git.repo.base import Repo
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import CommandExecutionError
|
||||
from autogpt.commands.git_operations import clone_repository
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -12,13 +13,13 @@ def mock_clone_from(mocker):
|
||||
return mocker.patch.object(Repo, "clone_from")
|
||||
|
||||
|
||||
def test_clone_auto_gpt_repository(workspace, mock_clone_from, agent: Agent):
|
||||
def test_clone_auto_gpt_repository(storage: FileStorage, mock_clone_from, agent: Agent):
|
||||
mock_clone_from.return_value = None
|
||||
|
||||
repo = "github.com/Significant-Gravitas/Auto-GPT.git"
|
||||
scheme = "https://"
|
||||
url = scheme + repo
|
||||
clone_path = workspace.get_path("auto-gpt-repo")
|
||||
clone_path = storage.get_path("auto-gpt-repo")
|
||||
|
||||
expected_output = f"Cloned {url} to {clone_path}"
|
||||
|
||||
@@ -31,9 +32,9 @@ def test_clone_auto_gpt_repository(workspace, mock_clone_from, agent: Agent):
|
||||
)
|
||||
|
||||
|
||||
def test_clone_repository_error(workspace, mock_clone_from, agent: Agent):
|
||||
def test_clone_repository_error(storage: FileStorage, mock_clone_from, agent: Agent):
|
||||
url = "https://github.com/this-repository/does-not-exist.git"
|
||||
clone_path = workspace.get_path("does-not-exist")
|
||||
clone_path = storage.get_path("does-not-exist")
|
||||
|
||||
mock_clone_from.side_effect = GitCommandError(
|
||||
"clone", "fatal: repository not found", ""
|
||||
|
||||
190
autogpts/autogpt/tests/unit/test_local_file_storage.py
Normal file
190
autogpts/autogpt/tests/unit/test_local_file_storage.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt.file_storage.local import FileStorageConfiguration, LocalFileStorage
|
||||
|
||||
_ACCESSIBLE_PATHS = [
|
||||
Path("."),
|
||||
Path("test_file.txt"),
|
||||
Path("test_folder"),
|
||||
Path("test_folder/test_file.txt"),
|
||||
Path("test_folder/.."),
|
||||
Path("test_folder/../test_file.txt"),
|
||||
Path("test_folder/../test_folder"),
|
||||
Path("test_folder/../test_folder/test_file.txt"),
|
||||
]
|
||||
|
||||
_INACCESSIBLE_PATHS = (
|
||||
[
|
||||
# Takes us out of the workspace
|
||||
Path(".."),
|
||||
Path("../test_file.txt"),
|
||||
Path("../not_auto_gpt_workspace"),
|
||||
Path("../not_auto_gpt_workspace/test_file.txt"),
|
||||
Path("test_folder/../.."),
|
||||
Path("test_folder/../../test_file.txt"),
|
||||
Path("test_folder/../../not_auto_gpt_workspace"),
|
||||
Path("test_folder/../../not_auto_gpt_workspace/test_file.txt"),
|
||||
]
|
||||
+ [
|
||||
# Contains null byte
|
||||
Path("\0"),
|
||||
Path("\0test_file.txt"),
|
||||
Path("test_folder/\0"),
|
||||
Path("test_folder/\0test_file.txt"),
|
||||
]
|
||||
+ [
|
||||
# Absolute paths
|
||||
Path("/"),
|
||||
Path("/test_file.txt"),
|
||||
Path("/home"),
|
||||
]
|
||||
)
|
||||
|
||||
_TEST_FILES = [
|
||||
Path("test_file.txt"),
|
||||
Path("dir/test_file.txt"),
|
||||
Path("dir/test_file2.txt"),
|
||||
Path("dir/sub_dir/test_file.txt"),
|
||||
]
|
||||
|
||||
_TEST_DIRS = [
|
||||
Path("dir"),
|
||||
Path("dir/sub_dir"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def storage_root(tmp_path):
|
||||
return tmp_path / "data"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def storage(storage_root):
|
||||
return LocalFileStorage(
|
||||
FileStorageConfiguration(root=storage_root, restrict_to_root=True)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def content():
|
||||
return "test content"
|
||||
|
||||
|
||||
@pytest.fixture(params=_ACCESSIBLE_PATHS)
|
||||
def accessible_path(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=_INACCESSIBLE_PATHS)
|
||||
def inaccessible_path(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=_TEST_FILES)
|
||||
def file_path(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_file(file_path: Path, content: str, storage: LocalFileStorage):
|
||||
if file_path.parent:
|
||||
storage.make_dir(file_path.parent)
|
||||
await storage.write_file(file_path, content)
|
||||
file = storage.open_file(file_path)
|
||||
assert file.read() == content
|
||||
file.close()
|
||||
storage.delete_file(file_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_read_file(content: str, storage: LocalFileStorage):
|
||||
await storage.write_file("test_file.txt", content)
|
||||
assert storage.read_file("test_file.txt") == content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files(content: str, storage: LocalFileStorage):
|
||||
storage.make_dir("dir")
|
||||
storage.make_dir("dir/sub_dir")
|
||||
await storage.write_file("test_file.txt", content)
|
||||
await storage.write_file("dir/test_file.txt", content)
|
||||
await storage.write_file("dir/test_file2.txt", content)
|
||||
await storage.write_file("dir/sub_dir/test_file.txt", content)
|
||||
files = storage.list_files()
|
||||
assert Path("test_file.txt") in files
|
||||
assert Path("dir/test_file.txt") in files
|
||||
assert Path("dir/test_file2.txt") in files
|
||||
assert Path("dir/sub_dir/test_file.txt") in files
|
||||
storage.delete_file("test_file.txt")
|
||||
storage.delete_file("dir/test_file.txt")
|
||||
storage.delete_file("dir/test_file2.txt")
|
||||
storage.delete_file("dir/sub_dir/test_file.txt")
|
||||
storage.delete_dir("dir/sub_dir")
|
||||
storage.delete_dir("dir")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders(content: str, storage: LocalFileStorage):
|
||||
storage.make_dir("dir")
|
||||
storage.make_dir("dir/sub_dir")
|
||||
await storage.write_file("dir/test_file.txt", content)
|
||||
await storage.write_file("dir/sub_dir/test_file.txt", content)
|
||||
folders = storage.list_folders(recursive=False)
|
||||
folders_recursive = storage.list_folders(recursive=True)
|
||||
assert Path("dir") in folders
|
||||
assert Path("dir/sub_dir") not in folders
|
||||
assert Path("dir") in folders_recursive
|
||||
assert Path("dir/sub_dir") in folders_recursive
|
||||
storage.delete_file("dir/test_file.txt")
|
||||
storage.delete_file("dir/sub_dir/test_file.txt")
|
||||
storage.delete_dir("dir/sub_dir")
|
||||
storage.delete_dir("dir")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_delete_file(
|
||||
file_path: Path, content: str, storage: LocalFileStorage
|
||||
):
|
||||
if file_path.parent:
|
||||
storage.make_dir(file_path.parent)
|
||||
await storage.write_file(file_path, content)
|
||||
assert storage.exists(file_path)
|
||||
storage.delete_file(file_path)
|
||||
assert not storage.exists(file_path)
|
||||
|
||||
|
||||
@pytest.fixture(params=_TEST_DIRS)
|
||||
def test_make_delete_dir(request, storage: LocalFileStorage):
|
||||
storage.make_dir(request)
|
||||
assert storage.exists(request)
|
||||
storage.delete_dir(request)
|
||||
assert not storage.exists(request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rename(file_path: Path, content: str, storage: LocalFileStorage):
|
||||
if file_path.parent:
|
||||
storage.make_dir(file_path.parent)
|
||||
await storage.write_file(file_path, content)
|
||||
assert storage.exists(file_path)
|
||||
storage.rename(file_path, Path(str(file_path) + "_renamed"))
|
||||
assert not storage.exists(file_path)
|
||||
assert storage.exists(Path(str(file_path) + "_renamed"))
|
||||
|
||||
|
||||
def test_clone_with_subroot(storage: LocalFileStorage):
|
||||
subroot = storage.clone_with_subroot("dir")
|
||||
assert subroot.root == storage.root / "dir"
|
||||
|
||||
|
||||
def test_get_path_accessible(accessible_path: Path, storage: LocalFileStorage):
|
||||
full_path = storage.get_path(accessible_path)
|
||||
assert full_path.is_absolute()
|
||||
assert full_path.is_relative_to(storage.root)
|
||||
|
||||
|
||||
def test_get_path_inaccessible(inaccessible_path: Path, storage: LocalFileStorage):
|
||||
with pytest.raises(ValueError):
|
||||
storage.get_path(inaccessible_path)
|
||||
@@ -1,92 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt.file_workspace.local import FileWorkspaceConfiguration, LocalFileWorkspace
|
||||
|
||||
_WORKSPACE_ROOT = Path("home/users/monty/auto_gpt_workspace")
|
||||
|
||||
_ACCESSIBLE_PATHS = [
|
||||
Path("."),
|
||||
Path("test_file.txt"),
|
||||
Path("test_folder"),
|
||||
Path("test_folder/test_file.txt"),
|
||||
Path("test_folder/.."),
|
||||
Path("test_folder/../test_file.txt"),
|
||||
Path("test_folder/../test_folder"),
|
||||
Path("test_folder/../test_folder/test_file.txt"),
|
||||
]
|
||||
|
||||
_INACCESSIBLE_PATHS = (
|
||||
[
|
||||
# Takes us out of the workspace
|
||||
Path(".."),
|
||||
Path("../test_file.txt"),
|
||||
Path("../not_auto_gpt_workspace"),
|
||||
Path("../not_auto_gpt_workspace/test_file.txt"),
|
||||
Path("test_folder/../.."),
|
||||
Path("test_folder/../../test_file.txt"),
|
||||
Path("test_folder/../../not_auto_gpt_workspace"),
|
||||
Path("test_folder/../../not_auto_gpt_workspace/test_file.txt"),
|
||||
]
|
||||
+ [
|
||||
# Contains null byte
|
||||
Path("\0"),
|
||||
Path("\0test_file.txt"),
|
||||
Path("test_folder/\0"),
|
||||
Path("test_folder/\0test_file.txt"),
|
||||
]
|
||||
+ [
|
||||
# Absolute paths
|
||||
Path("/"),
|
||||
Path("/test_file.txt"),
|
||||
Path("/home"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def workspace_root(tmp_path):
|
||||
return tmp_path / _WORKSPACE_ROOT
|
||||
|
||||
|
||||
@pytest.fixture(params=_ACCESSIBLE_PATHS)
|
||||
def accessible_path(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=_INACCESSIBLE_PATHS)
|
||||
def inaccessible_path(request):
|
||||
return request.param
|
||||
|
||||
|
||||
def test_sanitize_path_accessible(accessible_path, workspace_root):
|
||||
full_path = LocalFileWorkspace._sanitize_path(
|
||||
accessible_path,
|
||||
root=workspace_root,
|
||||
restrict_to_root=True,
|
||||
)
|
||||
assert full_path.is_absolute()
|
||||
assert full_path.is_relative_to(workspace_root)
|
||||
|
||||
|
||||
def test_sanitize_path_inaccessible(inaccessible_path, workspace_root):
|
||||
with pytest.raises(ValueError):
|
||||
LocalFileWorkspace._sanitize_path(
|
||||
inaccessible_path,
|
||||
root=workspace_root,
|
||||
restrict_to_root=True,
|
||||
)
|
||||
|
||||
|
||||
def test_get_path_accessible(accessible_path, workspace_root):
|
||||
workspace = LocalFileWorkspace(FileWorkspaceConfiguration(root=workspace_root))
|
||||
full_path = workspace.get_path(accessible_path)
|
||||
assert full_path.is_absolute()
|
||||
assert full_path.is_relative_to(workspace_root)
|
||||
|
||||
|
||||
def test_get_path_inaccessible(inaccessible_path, workspace_root):
|
||||
workspace = LocalFileWorkspace(FileWorkspaceConfiguration(root=workspace_root))
|
||||
with pytest.raises(ValueError):
|
||||
workspace.get_path(inaccessible_path)
|
||||
174
autogpts/autogpt/tests/unit/test_s3_file_storage.py
Normal file
174
autogpts/autogpt/tests/unit/test_s3_file_storage.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from autogpt.file_storage.s3 import S3FileStorage, S3FileStorageConfiguration
|
||||
|
||||
if not os.getenv("S3_ENDPOINT_URL") and not os.getenv("AWS_ACCESS_KEY_ID"):
|
||||
pytest.skip("S3 environment variables are not set", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_bucket_name() -> str:
|
||||
return f"test-bucket-{str(uuid.uuid4())[:8]}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_root() -> Path:
|
||||
return Path("/workspaces/AutoGPT-some-unique-task-id")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_storage_uninitialized(s3_bucket_name: str, s3_root: str) -> S3FileStorage:
|
||||
os.environ["STORAGE_BUCKET"] = s3_bucket_name
|
||||
storage_config = S3FileStorageConfiguration.from_env()
|
||||
storage_config.root = s3_root
|
||||
storage = S3FileStorage(storage_config)
|
||||
yield storage # type: ignore
|
||||
del os.environ["STORAGE_BUCKET"]
|
||||
|
||||
|
||||
def test_initialize(s3_bucket_name: str, s3_storage_uninitialized: S3FileStorage):
|
||||
s3 = s3_storage_uninitialized._s3
|
||||
|
||||
# test that the bucket doesn't exist yet
|
||||
with pytest.raises(ClientError):
|
||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name)
|
||||
|
||||
s3_storage_uninitialized.initialize()
|
||||
|
||||
# test that the bucket has been created
|
||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name)
|
||||
|
||||
|
||||
def test_workspace_bucket_name(
|
||||
s3_storage: S3FileStorage,
|
||||
s3_bucket_name: str,
|
||||
):
|
||||
assert s3_storage._bucket.name == s3_bucket_name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_storage(s3_storage_uninitialized: S3FileStorage) -> S3FileStorage:
|
||||
(s3_storage := s3_storage_uninitialized).initialize()
|
||||
yield s3_storage # type: ignore
|
||||
|
||||
# Empty & delete the test bucket
|
||||
s3_storage._bucket.objects.all().delete()
|
||||
s3_storage._bucket.delete()
|
||||
|
||||
|
||||
NESTED_DIR = "existing/test/dir"
|
||||
TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
("existing_test_file_1", "test content 1"),
|
||||
("existing_test_file_2.txt", "test content 2"),
|
||||
(Path("existing_test_file_3"), "test content 3"),
|
||||
(Path(f"{NESTED_DIR}/test_file_4"), "test content 4"),
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def s3_storage_with_files(s3_storage: S3FileStorage) -> S3FileStorage:
|
||||
for file_name, file_content in TEST_FILES:
|
||||
s3_storage._bucket.Object(str(s3_storage.get_path(file_name))).put(
|
||||
Body=file_content
|
||||
)
|
||||
yield s3_storage # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file(s3_storage_with_files: S3FileStorage):
|
||||
for file_name, file_content in TEST_FILES:
|
||||
content = s3_storage_with_files.read_file(file_name)
|
||||
assert content == file_content
|
||||
|
||||
with pytest.raises(ClientError):
|
||||
s3_storage_with_files.read_file("non_existent_file")
|
||||
|
||||
|
||||
def test_list_files(s3_storage_with_files: S3FileStorage):
|
||||
# List at root level
|
||||
assert (
|
||||
files := s3_storage_with_files.list_files()
|
||||
) == s3_storage_with_files.list_files()
|
||||
assert len(files) > 0
|
||||
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
|
||||
|
||||
# List at nested path
|
||||
assert (
|
||||
nested_files := s3_storage_with_files.list_files(NESTED_DIR)
|
||||
) == s3_storage_with_files.list_files(NESTED_DIR)
|
||||
assert len(nested_files) > 0
|
||||
assert set(nested_files) == set(
|
||||
p.relative_to(NESTED_DIR)
|
||||
for file_name, _ in TEST_FILES
|
||||
if (p := Path(file_name)).is_relative_to(NESTED_DIR)
|
||||
)
|
||||
|
||||
|
||||
def test_list_folders(s3_storage_with_files: S3FileStorage):
|
||||
# List recursive
|
||||
folders = s3_storage_with_files.list_folders(recursive=True)
|
||||
assert len(folders) > 0
|
||||
assert set(folders) == {
|
||||
Path("existing"),
|
||||
Path("existing/test"),
|
||||
Path("existing/test/dir"),
|
||||
}
|
||||
# List non-recursive
|
||||
folders = s3_storage_with_files.list_folders(recursive=False)
|
||||
assert len(folders) > 0
|
||||
assert set(folders) == {Path("existing")}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_read_file(s3_storage: S3FileStorage):
|
||||
await s3_storage.write_file("test_file", "test_content")
|
||||
assert s3_storage.read_file("test_file") == "test_content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overwrite_file(s3_storage_with_files: S3FileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
await s3_storage_with_files.write_file(file_name, "new content")
|
||||
assert s3_storage_with_files.read_file(file_name) == "new content"
|
||||
|
||||
|
||||
def test_delete_file(s3_storage_with_files: S3FileStorage):
|
||||
for file_to_delete, _ in TEST_FILES:
|
||||
s3_storage_with_files.delete_file(file_to_delete)
|
||||
with pytest.raises(ClientError):
|
||||
s3_storage_with_files.read_file(file_to_delete)
|
||||
|
||||
|
||||
def test_exists(s3_storage_with_files: S3FileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
assert s3_storage_with_files.exists(file_name)
|
||||
|
||||
assert not s3_storage_with_files.exists("non_existent_file")
|
||||
|
||||
|
||||
def test_rename_file(s3_storage_with_files: S3FileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
new_name = str(file_name) + "_renamed"
|
||||
s3_storage_with_files.rename(file_name, new_name)
|
||||
assert s3_storage_with_files.exists(new_name)
|
||||
assert not s3_storage_with_files.exists(file_name)
|
||||
|
||||
|
||||
def test_rename_dir(s3_storage_with_files: S3FileStorage):
|
||||
s3_storage_with_files.rename(NESTED_DIR, "existing/test/dir_renamed")
|
||||
assert s3_storage_with_files.exists("existing/test/dir_renamed")
|
||||
assert not s3_storage_with_files.exists(NESTED_DIR)
|
||||
|
||||
|
||||
def test_clone(s3_storage_with_files: S3FileStorage, s3_root: Path):
|
||||
cloned = s3_storage_with_files.clone_with_subroot("existing/test")
|
||||
assert cloned.root == s3_root / Path("existing/test")
|
||||
assert cloned._bucket.name == s3_storage_with_files._bucket.name
|
||||
assert cloned.exists("dir")
|
||||
assert cloned.exists("dir/test_file_4")
|
||||
@@ -1,123 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from autogpt.file_workspace.s3 import S3FileWorkspace, S3FileWorkspaceConfiguration
|
||||
|
||||
if not os.getenv("S3_ENDPOINT_URL") and not os.getenv("AWS_ACCESS_KEY_ID"):
|
||||
pytest.skip("S3 environment variables are not set", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_bucket_name() -> str:
|
||||
return f"test-bucket-{str(uuid.uuid4())[:8]}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_workspace_uninitialized(s3_bucket_name: str) -> S3FileWorkspace:
|
||||
os.environ["WORKSPACE_STORAGE_BUCKET"] = s3_bucket_name
|
||||
ws_config = S3FileWorkspaceConfiguration.from_env()
|
||||
ws_config.root = Path("/workspaces/AutoGPT-some-unique-task-id")
|
||||
workspace = S3FileWorkspace(ws_config)
|
||||
yield workspace # type: ignore
|
||||
del os.environ["WORKSPACE_STORAGE_BUCKET"]
|
||||
|
||||
|
||||
def test_initialize(s3_bucket_name: str, s3_workspace_uninitialized: S3FileWorkspace):
|
||||
s3 = s3_workspace_uninitialized._s3
|
||||
|
||||
# test that the bucket doesn't exist yet
|
||||
with pytest.raises(ClientError):
|
||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name)
|
||||
|
||||
s3_workspace_uninitialized.initialize()
|
||||
|
||||
# test that the bucket has been created
|
||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name)
|
||||
|
||||
|
||||
def test_workspace_bucket_name(
|
||||
s3_workspace: S3FileWorkspace,
|
||||
s3_bucket_name: str,
|
||||
):
|
||||
assert s3_workspace._bucket.name == s3_bucket_name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_workspace(s3_workspace_uninitialized: S3FileWorkspace) -> S3FileWorkspace:
|
||||
(s3_workspace := s3_workspace_uninitialized).initialize()
|
||||
yield s3_workspace # type: ignore
|
||||
|
||||
# Empty & delete the test bucket
|
||||
s3_workspace._bucket.objects.all().delete()
|
||||
s3_workspace._bucket.delete()
|
||||
|
||||
|
||||
NESTED_DIR = "existing/test/dir"
|
||||
TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
("existing_test_file_1", "test content 1"),
|
||||
("existing_test_file_2.txt", "test content 2"),
|
||||
(Path("existing_test_file_3"), "test content 3"),
|
||||
(Path(f"{NESTED_DIR}/test/file/4"), "test content 4"),
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def s3_workspace_with_files(s3_workspace: S3FileWorkspace) -> S3FileWorkspace:
|
||||
for file_name, file_content in TEST_FILES:
|
||||
s3_workspace._bucket.Object(str(s3_workspace.get_path(file_name))).put(
|
||||
Body=file_content
|
||||
)
|
||||
yield s3_workspace # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file(s3_workspace_with_files: S3FileWorkspace):
|
||||
for file_name, file_content in TEST_FILES:
|
||||
content = s3_workspace_with_files.read_file(file_name)
|
||||
assert content == file_content
|
||||
|
||||
with pytest.raises(ClientError):
|
||||
s3_workspace_with_files.read_file("non_existent_file")
|
||||
|
||||
|
||||
def test_list_files(s3_workspace_with_files: S3FileWorkspace):
|
||||
# List at root level
|
||||
assert (files := s3_workspace_with_files.list()) == s3_workspace_with_files.list()
|
||||
assert len(files) > 0
|
||||
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
|
||||
|
||||
# List at nested path
|
||||
assert (
|
||||
nested_files := s3_workspace_with_files.list(NESTED_DIR)
|
||||
) == s3_workspace_with_files.list(NESTED_DIR)
|
||||
assert len(nested_files) > 0
|
||||
assert set(nested_files) == set(
|
||||
p.relative_to(NESTED_DIR)
|
||||
for file_name, _ in TEST_FILES
|
||||
if (p := Path(file_name)).is_relative_to(NESTED_DIR)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_read_file(s3_workspace: S3FileWorkspace):
|
||||
await s3_workspace.write_file("test_file", "test_content")
|
||||
assert s3_workspace.read_file("test_file") == "test_content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overwrite_file(s3_workspace_with_files: S3FileWorkspace):
|
||||
for file_name, _ in TEST_FILES:
|
||||
await s3_workspace_with_files.write_file(file_name, "new content")
|
||||
assert s3_workspace_with_files.read_file(file_name) == "new content"
|
||||
|
||||
|
||||
def test_delete_file(s3_workspace_with_files: S3FileWorkspace):
|
||||
for file_to_delete, _ in TEST_FILES:
|
||||
s3_workspace_with_files.delete_file(file_to_delete)
|
||||
with pytest.raises(ClientError):
|
||||
s3_workspace_with_files.read_file(file_to_delete)
|
||||
@@ -8,7 +8,3 @@ def skip_in_ci(test_function):
|
||||
os.environ.get("CI") == "true",
|
||||
reason="This test doesn't work on GitHub Actions.",
|
||||
)(test_function)
|
||||
|
||||
|
||||
def get_workspace_file_path(workspace, file_name):
|
||||
return str(workspace.get_path(file_name))
|
||||
|
||||
Reference in New Issue
Block a user