mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-27 19:04:25 +01:00
Agent Skeleton (#12)
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -163,6 +163,5 @@ CURRENT_BULLETIN.md
|
||||
|
||||
# agbenchmark
|
||||
|
||||
agbenchmark/reports
|
||||
agbenchmark/workspace
|
||||
agbenchmark
|
||||
agent.db
|
||||
|
||||
22
Makefile
Normal file
22
Makefile
Normal file
@@ -0,0 +1,22 @@
|
||||
.PHONY: update-protocol
|
||||
|
||||
update-protocol:
|
||||
@if [ -d "../agent-protocol/sdk/python/agent_protocol" ]; then \
|
||||
cp -r ../agent-protocol/sdk/python/agent_protocol autogpt; \
|
||||
rm -Rf autogpt/agent_protocol/utils; \
|
||||
rm -Rf autogpt/agent_protocol/cli.py; \
|
||||
echo "Protocol updated successfully!"; \
|
||||
else \
|
||||
echo "Error: Source directory ../agent-protocol/sdk/python/agent_protocol does not exist."; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
change-protocol:
|
||||
@if [ -d "autogpt/agent_protocol" ]; then \
|
||||
cp -r autogpt/agent_protocol ../agent-protocol/sdk/python; \
|
||||
rm ../agent-protocol/sdk/python/agent_protocol/README.md; \
|
||||
echo "Protocol reversed successfully!"; \
|
||||
else \
|
||||
echo "Error: Target directory autogpt/agent_protocol does not exist."; \
|
||||
exit 1; \
|
||||
fi
|
||||
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
|
||||
from agent_protocol import Agent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import autogpt.agent
|
||||
import autogpt.db
|
||||
from autogpt.benchmark_integration import add_benchmark_routes
|
||||
from autogpt.workspace import LocalWorkspace
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Runs the agent server"""
|
||||
@@ -13,13 +13,11 @@ if __name__ == "__main__":
|
||||
router = add_benchmark_routes()
|
||||
|
||||
database_name = os.getenv("DATABASE_STRING")
|
||||
workspace = LocalWorkspace(os.getenv("AGENT_WORKSPACE"))
|
||||
print(database_name)
|
||||
port = os.getenv("PORT")
|
||||
workspace = os.getenv("AGENT_WORKSPACE")
|
||||
auto_gpt = autogpt.agent.AutoGPT()
|
||||
|
||||
database = autogpt.db.AgentDB(database_name)
|
||||
agent = Agent.setup_agent(auto_gpt.task_handler, auto_gpt.step_handler)
|
||||
agent.db = database
|
||||
agent.workspace = workspace
|
||||
agent = autogpt.agent.AutoGPT(db=database, workspace=workspace)
|
||||
|
||||
agent.start(port=port, router=router)
|
||||
|
||||
@@ -1,32 +1,64 @@
|
||||
import time
|
||||
|
||||
from agent_protocol import Agent, Step, Task
|
||||
import os
|
||||
|
||||
import autogpt.utils
|
||||
from autogpt.agent_protocol import Agent, Artifact, Step, Task, TaskDB
|
||||
|
||||
from .workspace import Workspace
|
||||
|
||||
|
||||
class AutoGPT:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
class AutoGPT(Agent):
|
||||
def __init__(self, db: TaskDB, workspace: Workspace) -> None:
|
||||
super().__init__(db)
|
||||
self.workspace = workspace
|
||||
|
||||
async def task_handler(self, task: Task) -> None:
|
||||
async def create_task(self, task: Task) -> None:
|
||||
print(f"task: {task.input}")
|
||||
await Agent.db.create_step(task.task_id, task.input, is_last=True)
|
||||
time.sleep(2)
|
||||
|
||||
# autogpt.utils.run(task.input) the task_handler only creates the task, it doesn't execute it
|
||||
# print(f"Created Task id: {task.task_id}")
|
||||
return task
|
||||
|
||||
async def step_handler(self, step: Step) -> Step:
|
||||
# print(f"step: {step}")
|
||||
agent_step = await Agent.db.get_step(step.task_id, step.step_id)
|
||||
updated_step: Step = await Agent.db.update_step(
|
||||
agent_step.task_id, agent_step.step_id, status="completed"
|
||||
)
|
||||
updated_step.output = agent_step.input
|
||||
if step.is_last:
|
||||
print(f"Task completed: {updated_step.task_id}")
|
||||
async def run_step(self, step: Step) -> Step:
|
||||
artifacts = autogpt.utils.run(step.input)
|
||||
for artifact in artifacts:
|
||||
art = await self.db.create_artifact(
|
||||
task_id=step.task_id,
|
||||
file_name=artifact["file_name"],
|
||||
uri=artifact["uri"],
|
||||
agent_created=True,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
assert isinstance(
|
||||
art, Artifact
|
||||
), f"Artifact not isntance of Artifact {type(art)}"
|
||||
step.artifacts.append(art)
|
||||
step.status = "completed"
|
||||
return step
|
||||
|
||||
async def retrieve_artifact(self, task_id: str, artifact: Artifact) -> bytes:
|
||||
"""
|
||||
Retrieve the artifact data from wherever it is stored and return it as bytes.
|
||||
"""
|
||||
if not artifact.uri.startswith("file://"):
|
||||
raise NotImplementedError("Loading from uri not implemented")
|
||||
file_path = artifact.uri.split("file://")[1]
|
||||
if not self.workspace.exists(file_path):
|
||||
raise FileNotFoundError(f"File {file_path} not found in workspace")
|
||||
return self.workspace.read(file_path)
|
||||
|
||||
async def save_artifact(
|
||||
self, task_id: str, artifact: Artifact, data: bytes
|
||||
) -> Artifact:
|
||||
"""
|
||||
Save the artifact data to the agent's workspace, loading from uri if bytes are not available.
|
||||
"""
|
||||
assert (
|
||||
data is not None and artifact.uri is not None
|
||||
), "Data or Artifact uri must be set"
|
||||
|
||||
if data is not None:
|
||||
file_path = os.path.join(task_id / artifact.file_name)
|
||||
self.write(file_path, data)
|
||||
artifact.uri = f"file://{file_path}"
|
||||
self.db.save_artifact(task_id, artifact)
|
||||
else:
|
||||
print(f"Step completed: {updated_step}")
|
||||
return updated_step
|
||||
raise NotImplementedError("Loading from uri not implemented")
|
||||
|
||||
return artifact
|
||||
|
||||
21
autogpt/agent_protocol/README.md
Normal file
21
autogpt/agent_protocol/README.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# Autogpt Protocol Directory
|
||||
|
||||
# DO NOT MODIFY ANY FILES IN THIS DIRECTORY
|
||||
|
||||
This directory contains protocol definitions crucial for our project. The current setup is a temporary measure to allow for speedy updating of the protocol.
|
||||
|
||||
## Background
|
||||
|
||||
In an ideal scenario, we'd directly use a submodule pointing to the original repository. However, given our specific needs and to expedite our development process, we've chosen a slightly different approach.
|
||||
|
||||
## Process
|
||||
|
||||
1. **Fork and Clone**: We started by forking the original repository `e2b-dev/agent-protocol` (not `Swiftyos/agent-protocol` as previously mentioned) to have our own version. This allows us to have more control over updates and possibly any specific changes that our project might need in the future.
|
||||
|
||||
2. **Manual Content Integration**: Instead of adding the entire forked repository as a submodule, we've manually copied over the contents of `sdk/python/agent_protocol` into this directory. This ensures we only have the parts we need, without any additional overhead.
|
||||
|
||||
3. **Updates**: Any necessary updates to the protocol can be made directly in our fork, and subsequently, the required changes can be reflected in this directory.
|
||||
|
||||
## Credits
|
||||
|
||||
All credit for the original protocol definitions goes to [e2b-dev/agent-protocol](https://github.com/e2b-dev/agent-protocol). We deeply appreciate their efforts in building the protocol, and this temporary measure is in no way intended to diminish the significance of their work. It's purely a practical approach for our specific requirements at this point in our development phase.
|
||||
16
autogpt/agent_protocol/__init__.py
Normal file
16
autogpt/agent_protocol/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .agent import Agent
|
||||
from .agent import base_router as router
|
||||
from .db import Step, Task, TaskDB
|
||||
from .models import Artifact, Status, StepRequestBody, TaskRequestBody
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"Artifact",
|
||||
"Status",
|
||||
"Step",
|
||||
"StepRequestBody",
|
||||
"Task",
|
||||
"TaskDB",
|
||||
"TaskRequestBody",
|
||||
"router",
|
||||
]
|
||||
233
autogpt/agent_protocol/agent.py
Normal file
233
autogpt/agent_protocol/agent.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse, JSONResponse, Response
|
||||
from hypercorn.asyncio import serve
|
||||
from hypercorn.config import Config
|
||||
|
||||
from .db import Step, Task, TaskDB
|
||||
from .middlewares import AgentMiddleware
|
||||
from .models import Artifact, Status, StepRequestBody, TaskRequestBody
|
||||
from .server import app
|
||||
|
||||
base_router = APIRouter()
|
||||
|
||||
|
||||
@base_router.get("/heartbeat")
|
||||
async def heartbeat() -> Response:
|
||||
"""
|
||||
Heartbeat endpoint to check if the server is running.
|
||||
"""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@base_router.post("/agent/tasks", response_model=Task, tags=["agent"])
|
||||
async def create_agent_task(
|
||||
request: Request, body: TaskRequestBody | None = None
|
||||
) -> Task:
|
||||
"""
|
||||
Creates a task for the agent.
|
||||
"""
|
||||
agent: Agent = request["agent"]
|
||||
|
||||
task = await agent.db.create_task(
|
||||
input=body.input if body else None,
|
||||
additional_input=body.additional_input if body else None,
|
||||
)
|
||||
print(task)
|
||||
await agent.create_task(task)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@base_router.get("/agent/tasks", response_model=List[str], tags=["agent"])
|
||||
async def list_agent_tasks_ids(request: Request) -> List[str]:
|
||||
"""
|
||||
List all tasks that have been created for the agent.
|
||||
"""
|
||||
agent: Agent = request["agent"]
|
||||
return [task.task_id for task in await agent.db.list_tasks()]
|
||||
|
||||
|
||||
@base_router.get("/agent/tasks/{task_id}", response_model=Task, tags=["agent"])
|
||||
async def get_agent_task(request: Request, task_id: str) -> Task:
|
||||
"""
|
||||
Get details about a specified agent task.
|
||||
"""
|
||||
agent: Agent = request["agent"]
|
||||
return await agent.db.get_task(task_id)
|
||||
|
||||
|
||||
@base_router.get(
|
||||
"/agent/tasks/{task_id}/steps",
|
||||
response_model=List[str],
|
||||
tags=["agent"],
|
||||
)
|
||||
async def list_agent_task_steps(request: Request, task_id: str) -> List[str]:
|
||||
"""
|
||||
List all steps for the specified task.
|
||||
"""
|
||||
agent: Agent = request["agent"]
|
||||
task = await agent.db.get_task(task_id)
|
||||
return [s.step_id for s in task.steps]
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/agent/tasks/{task_id}/steps",
|
||||
response_model=Step,
|
||||
tags=["agent"],
|
||||
)
|
||||
async def execute_agent_task_step(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
body: StepRequestBody | None = None,
|
||||
) -> Step:
|
||||
"""
|
||||
Execute a step in the specified agent task.
|
||||
"""
|
||||
agent: Agent = request["agent"]
|
||||
|
||||
if body.input != "y":
|
||||
step = await agent.db.create_step(
|
||||
task_id=task_id,
|
||||
input=body.input if body else None,
|
||||
additional_properties=body.additional_input if body else None,
|
||||
)
|
||||
step = await agent.run_step(step)
|
||||
step.output = "Task completed"
|
||||
step.is_last = True
|
||||
else:
|
||||
steps = await agent.db.list_steps(task_id)
|
||||
artifacts = await agent.db.list_artifacts(task_id)
|
||||
step = steps[-1]
|
||||
step.artifacts = artifacts
|
||||
step.output = "No more steps to run."
|
||||
step.is_last = True
|
||||
if isinstance(step.status, Status):
|
||||
step.status = step.status.value
|
||||
return JSONResponse(content=step.dict(), status_code=200)
|
||||
|
||||
|
||||
@base_router.get(
|
||||
"/agent/tasks/{task_id}/steps/{step_id}",
|
||||
response_model=Step,
|
||||
tags=["agent"],
|
||||
)
|
||||
async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> Step:
|
||||
"""
|
||||
Get details about a specified task step.
|
||||
"""
|
||||
agent: Agent = request["agent"]
|
||||
return await agent.db.get_step(task_id, step_id)
|
||||
|
||||
|
||||
@base_router.get(
|
||||
"/agent/tasks/{task_id}/artifacts",
|
||||
response_model=List[Artifact],
|
||||
tags=["agent"],
|
||||
)
|
||||
async def list_agent_task_artifacts(request: Request, task_id: str) -> List[Artifact]:
|
||||
"""
|
||||
List all artifacts for the specified task.
|
||||
"""
|
||||
agent: Agent = request["agent"]
|
||||
task = await agent.db.get_task(task_id)
|
||||
return task.artifacts
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/agent/tasks/{task_id}/artifacts",
|
||||
response_model=Artifact,
|
||||
tags=["agent"],
|
||||
)
|
||||
async def upload_agent_task_artifacts(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
file: UploadFile | None = None,
|
||||
uri: str | None = None,
|
||||
) -> Artifact:
|
||||
"""
|
||||
Upload an artifact for the specified task.
|
||||
"""
|
||||
agent: Agent = request["agent"]
|
||||
if not file and not uri:
|
||||
return Response(status_code=400, content="No file or uri provided")
|
||||
data = None
|
||||
if not uri:
|
||||
file_name = file.filename or str(uuid4())
|
||||
try:
|
||||
data = b""
|
||||
while contents := file.file.read(1024 * 1024):
|
||||
data += contents
|
||||
except Exception as e:
|
||||
return Response(status_code=500, content=str(e))
|
||||
|
||||
artifact = await agent.save_artifact(task_id, artifact, data)
|
||||
agent.db.add_artifact(task_id, artifact)
|
||||
|
||||
return artifact
|
||||
|
||||
|
||||
@base_router.get(
|
||||
"/agent/tasks/{task_id}/artifacts/{artifact_id}",
|
||||
tags=["agent"],
|
||||
)
|
||||
async def download_agent_task_artifacts(
|
||||
request: Request, task_id: str, artifact_id: str
|
||||
) -> FileResponse:
|
||||
"""
|
||||
Download the specified artifact.
|
||||
"""
|
||||
agent: Agent = request["agent"]
|
||||
artifact = await agent.db.get_artifact(task_id, artifact_id)
|
||||
retrieved_artifact: Artifact = await agent.retrieve_artifact(task_id, artifact)
|
||||
path = artifact.file_name
|
||||
with open(path, "wb") as f:
|
||||
f.write(retrieved_artifact)
|
||||
return FileResponse(
|
||||
# Note: mimetype is guessed in the FileResponse constructor
|
||||
path=path,
|
||||
filename=artifact.file_name,
|
||||
)
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, db: TaskDB):
|
||||
self.name = self.__class__.__name__
|
||||
self.db = db
|
||||
|
||||
def start(self, port: int = 8000, router: APIRouter = base_router):
|
||||
"""
|
||||
Start the agent server.
|
||||
"""
|
||||
config = Config()
|
||||
config.bind = [f"localhost:{port}"] # As an example configuration setting
|
||||
app.title = f"{self.name} - Agent Communication Protocol"
|
||||
app.include_router(router)
|
||||
app.add_middleware(AgentMiddleware, agent=self)
|
||||
asyncio.run(serve(app, config))
|
||||
|
||||
async def create_task(self, task: Task):
|
||||
"""
|
||||
Handles a new task
|
||||
"""
|
||||
return task
|
||||
|
||||
async def run_step(self, step: Step) -> Step:
|
||||
return step
|
||||
|
||||
async def retrieve_artifact(self, task_id: str, artifact: Artifact) -> bytes:
|
||||
"""
|
||||
Retrieve the artifact data from wherever it is stored and return it as bytes.
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
async def save_artifact(
|
||||
self, task_id: str, artifact: Artifact, data: bytes | None = None
|
||||
) -> Artifact:
|
||||
"""
|
||||
Save the artifact data to the agent's workspace, loading from uri if bytes are not available.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
173
autogpt/agent_protocol/db.py
Normal file
173
autogpt/agent_protocol/db.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .models import Artifact, Status
|
||||
from .models import Step as APIStep
|
||||
from .models import Task as APITask
|
||||
|
||||
|
||||
class Step(APIStep):
|
||||
additional_properties: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
class Task(APITask):
|
||||
steps: Optional[List[Step]] = []
|
||||
|
||||
|
||||
class NotFoundException(Exception):
|
||||
"""
|
||||
Exception raised when a resource is not found.
|
||||
"""
|
||||
|
||||
def __init__(self, item_name: str, item_id: str):
|
||||
self.item_name = item_name
|
||||
self.item_id = item_id
|
||||
super().__init__(f"{item_name} with {item_id} not found.")
|
||||
|
||||
|
||||
class TaskDB(ABC):
|
||||
async def create_task(
|
||||
self, input: Optional[str], additional_input: Any = None
|
||||
) -> Task:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
name: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
is_last: bool = False,
|
||||
additional_properties: Optional[Dict[str, str]] = None,
|
||||
) -> Step:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
artifact: Artifact,
|
||||
step_id: str | None = None,
|
||||
) -> Artifact:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_task(self, task_id: str) -> Task:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_step(self, task_id: str, step_id: str) -> Step:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_tasks(self) -> List[Task]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_steps(
|
||||
self, task_id: str, status: Optional[Status] = None
|
||||
) -> List[Step]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
status: str,
|
||||
additional_properties: Optional[Dict[str, str]] = None,
|
||||
) -> Step:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InMemoryTaskDB(TaskDB):
|
||||
_tasks: Dict[str, Task] = {}
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
input: Optional[str],
|
||||
additional_input: Any = None,
|
||||
) -> Task:
|
||||
task_id = str(uuid.uuid4())
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
input=input,
|
||||
steps=[],
|
||||
artifacts=[],
|
||||
additional_input=additional_input,
|
||||
)
|
||||
self._tasks[task_id] = task
|
||||
return task
|
||||
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
name: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
is_last=False,
|
||||
additional_properties: Optional[Dict[str, Any]] = None,
|
||||
) -> Step:
|
||||
step_id = str(uuid.uuid4())
|
||||
step = Step(
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
name=name,
|
||||
input=input,
|
||||
status=Status.created,
|
||||
is_last=is_last,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
task = await self.get_task(task_id)
|
||||
task.steps.append(step)
|
||||
return step
|
||||
|
||||
async def get_task(self, task_id: str) -> Task:
|
||||
task = self._tasks.get(task_id, None)
|
||||
if not task:
|
||||
raise NotFoundException("Task", task_id)
|
||||
return task
|
||||
|
||||
async def get_step(self, task_id: str, step_id: str) -> Step:
|
||||
task = await self.get_task(task_id)
|
||||
step = next(filter(lambda s: s.task_id == task_id, task.steps), None)
|
||||
if not step:
|
||||
raise NotFoundException("Step", step_id)
|
||||
return step
|
||||
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||
task = await self.get_task(task_id)
|
||||
artifact = next(
|
||||
filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None
|
||||
)
|
||||
if not artifact:
|
||||
raise NotFoundException("Artifact", artifact_id)
|
||||
return artifact
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
file_name: str,
|
||||
relative_path: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> Artifact:
|
||||
artifact_id = str(uuid.uuid4())
|
||||
artifact = Artifact(
|
||||
artifact_id=artifact_id, file_name=file_name, relative_path=relative_path
|
||||
)
|
||||
task = await self.get_task(task_id)
|
||||
task.artifacts.append(artifact)
|
||||
|
||||
if step_id:
|
||||
step = await self.get_step(task_id, step_id)
|
||||
step.artifacts.append(artifact)
|
||||
|
||||
return artifact
|
||||
|
||||
async def list_tasks(self) -> List[Task]:
|
||||
return [task for task in self._tasks.values()]
|
||||
|
||||
async def list_steps(
|
||||
self, task_id: str, status: Optional[Status] = None
|
||||
) -> List[Step]:
|
||||
task = await self.get_task(task_id)
|
||||
steps = task.steps
|
||||
if status:
|
||||
steps = list(filter(lambda s: s.status == status, steps))
|
||||
return steps
|
||||
46
autogpt/agent_protocol/middlewares.py
Normal file
46
autogpt/agent_protocol/middlewares.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
from .db import NotFoundException
|
||||
|
||||
|
||||
async def not_found_exception_handler(
|
||||
request: Request, exc: NotFoundException
|
||||
) -> PlainTextResponse:
|
||||
return PlainTextResponse(
|
||||
str(exc),
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
|
||||
class AgentMiddleware:
|
||||
"""
|
||||
Middleware that injects the agent instance into the request scope.
|
||||
"""
|
||||
|
||||
def __init__(self, app: FastAPI, agent: "Agent"):
|
||||
"""
|
||||
|
||||
Args:
|
||||
app: The FastAPI app - automatically injected by FastAPI.
|
||||
agent: The agent instance to inject into the request scope.
|
||||
|
||||
Examples:
|
||||
>>> from fastapi import FastAPI, Request
|
||||
>>> from agent_protocol.agent import Agent
|
||||
>>> from agent_protocol.middlewares import AgentMiddleware
|
||||
>>> app = FastAPI()
|
||||
>>> @app.get("/")
|
||||
>>> async def root(request: Request):
|
||||
>>> agent = request["agent"]
|
||||
>>> task = agent.db.create_task("Do something.")
|
||||
>>> return {"task_id": a.task_id}
|
||||
>>> agent = Agent()
|
||||
>>> app.add_middleware(AgentMiddleware, agent=agent)
|
||||
"""
|
||||
self.app = app
|
||||
self.agent = agent
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
scope["agent"] = self.agent
|
||||
await self.app(scope, receive, send)
|
||||
131
autogpt/agent_protocol/models.py
Normal file
131
autogpt/agent_protocol/models.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# generated by fastapi-codegen:
|
||||
# filename: ../../openapi.yml
|
||||
# timestamp: 2023-08-17T11:26:07+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TaskInput(BaseModel):
|
||||
__root__: Any = Field(
|
||||
...,
|
||||
description="Input parameters for the task. Any value is allowed.",
|
||||
example='{\n"debug": false,\n"mode": "benchmarks"\n}',
|
||||
)
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
artifact_id: str = Field(
|
||||
...,
|
||||
description="ID of the artifact.",
|
||||
example="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
)
|
||||
file_name: str = Field(
|
||||
..., description="Filename of the artifact.", example="main.py"
|
||||
)
|
||||
agent_created: Optional[bool] = Field(
|
||||
None,
|
||||
description="Whether the artifact has been created by the agent.",
|
||||
example=False,
|
||||
)
|
||||
uri: Optional[str] = Field(
|
||||
None,
|
||||
description="URI of the artifact.",
|
||||
example="file://home/bob/workspace/bucket/main.py",
|
||||
)
|
||||
|
||||
|
||||
class ArtifactUpload(BaseModel):
|
||||
file: bytes = Field(..., description="File to upload.")
|
||||
relative_path: Optional[str] = Field(
|
||||
None,
|
||||
description="Relative path of the artifact in the agent's workspace.",
|
||||
example="python/code",
|
||||
)
|
||||
|
||||
|
||||
class StepInput(BaseModel):
|
||||
__root__: Any = Field(
|
||||
...,
|
||||
description="Input parameters for the task step. Any value is allowed.",
|
||||
example='{\n"file_to_refactor": "models.py"\n}',
|
||||
)
|
||||
|
||||
|
||||
class StepOutput(BaseModel):
|
||||
__root__: Any = Field(
|
||||
...,
|
||||
description="Output that the task step has produced. Any value is allowed.",
|
||||
example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}',
|
||||
)
|
||||
|
||||
|
||||
class TaskRequestBody(BaseModel):
|
||||
input: Optional[str] = Field(
|
||||
None,
|
||||
description="Input prompt for the task.",
|
||||
example="Write the words you receive to the file 'output.txt'.",
|
||||
)
|
||||
additional_input: Optional[TaskInput] = None
|
||||
|
||||
|
||||
class Task(TaskRequestBody):
|
||||
task_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task.",
|
||||
example="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
)
|
||||
artifacts: Optional[List[Artifact]] = Field(
|
||||
[],
|
||||
description="A list of artifacts that the task has produced.",
|
||||
example=[
|
||||
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
|
||||
"ab7b4091-2560-4692-a4fe-d831ea3ca7d6",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StepRequestBody(BaseModel):
|
||||
input: Optional[str] = Field(
|
||||
None, description="Input prompt for the step.", example="Washington"
|
||||
)
|
||||
additional_input: Optional[StepInput] = None
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
completed = "completed"
|
||||
|
||||
|
||||
class Step(StepRequestBody):
|
||||
task_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task this step belongs to.",
|
||||
example="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
)
|
||||
step_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task step.",
|
||||
example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None, description="The name of the task step.", example="Write to file"
|
||||
)
|
||||
status: Status = Field(..., description="The status of the task step.")
|
||||
output: Optional[str] = Field(
|
||||
None,
|
||||
description="Output of the task step.",
|
||||
example="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')",
|
||||
)
|
||||
additional_output: Optional[StepOutput] = None
|
||||
artifacts: Optional[List[Artifact]] = Field(
|
||||
[], description="A list of artifacts that the step has produced."
|
||||
)
|
||||
is_last: Optional[bool] = Field(
|
||||
False, description="Whether this is the last step in the task."
|
||||
)
|
||||
18
autogpt/agent_protocol/server.py
Normal file
18
autogpt/agent_protocol/server.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# generated by fastapi-codegen:
|
||||
# filename: openapi.yml
|
||||
# timestamp: 2023-08-07T12:14:43+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .db import NotFoundException
|
||||
from .middlewares import not_found_exception_handler
|
||||
|
||||
app = FastAPI(
|
||||
title="Agent Communication Protocol",
|
||||
description="Specification of the API protocol for communication with an agent.",
|
||||
version="v0.3",
|
||||
)
|
||||
|
||||
app.add_exception_handler(NotFoundException, not_found_exception_handler)
|
||||
@@ -1,13 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
from agbenchmark.app import get_artifact, get_skill_tree
|
||||
from agent_protocol.agent import base_router
|
||||
from fastapi import APIRouter
|
||||
from fastapi import (
|
||||
HTTPException as FastAPIHTTPException, # Import HTTPException from FastAPI
|
||||
)
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from autogpt.agent_protocol.agent import base_router
|
||||
|
||||
|
||||
def add_benchmark_routes():
|
||||
new_router = APIRouter()
|
||||
|
||||
183
autogpt/db.py
183
autogpt/db.py
@@ -6,18 +6,11 @@ IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from agent_protocol import Artifact, Step, Task, TaskDB
|
||||
from agent_protocol.models import Status, TaskInput
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
LargeBinary,
|
||||
String,
|
||||
create_engine,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase, relationship, sessionmaker
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, create_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker
|
||||
|
||||
from autogpt.agent_protocol import Artifact, Step, Task, TaskDB
|
||||
from autogpt.agent_protocol.models import Status, TaskInput
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
@@ -45,6 +38,7 @@ class StepModel(Base):
|
||||
step_id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
task_id = Column(Integer, ForeignKey("tasks.task_id"))
|
||||
name = Column(String)
|
||||
input = Column(String)
|
||||
status = Column(String)
|
||||
is_last = Column(Boolean, default=False)
|
||||
additional_properties = Column(String)
|
||||
@@ -58,13 +52,61 @@ class ArtifactModel(Base):
|
||||
artifact_id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
task_id = Column(Integer, ForeignKey("tasks.task_id"))
|
||||
step_id = Column(Integer, ForeignKey("steps.step_id"))
|
||||
agent_created = Column(Boolean, default=False)
|
||||
file_name = Column(String)
|
||||
relative_path = Column(String)
|
||||
file_data = Column(LargeBinary)
|
||||
uri = Column(String)
|
||||
|
||||
task = relationship("TaskModel", back_populates="artifacts")
|
||||
|
||||
|
||||
def convert_to_task(task_obj: TaskModel) -> Task:
|
||||
steps_list = []
|
||||
for step in task_obj.steps:
|
||||
status = Status.completed if step.status == "completed" else Status.created
|
||||
steps_list.append(
|
||||
Step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
name=step.name,
|
||||
status=status,
|
||||
is_last=step.is_last == 1,
|
||||
additional_properties=step.additional_properties,
|
||||
)
|
||||
)
|
||||
return Task(
|
||||
task_id=task_obj.task_id,
|
||||
input=task_obj.input,
|
||||
additional_input=task_obj.additional_input,
|
||||
artifacts=[],
|
||||
steps=steps_list,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_step(step_model: StepModel) -> Step:
|
||||
print(step_model)
|
||||
step_artifacts = [
|
||||
Artifact(
|
||||
artifact_id=artifact.artifact_id,
|
||||
file_name=artifact.file_name,
|
||||
agent_created=artifact.agent_created,
|
||||
uri=artifact.uri,
|
||||
)
|
||||
for artifact in step_model.task.artifacts
|
||||
if artifact.step_id == step_model.step_id
|
||||
]
|
||||
status = Status.completed if step_model.status == "completed" else Status.created
|
||||
return Step(
|
||||
task_id=step_model.task_id,
|
||||
step_id=step_model.step_id,
|
||||
name=step_model.name,
|
||||
input=step_model.input,
|
||||
status=status,
|
||||
artifacts=step_artifacts,
|
||||
is_last=step_model.is_last == 1,
|
||||
additional_properties=step_model.additional_properties,
|
||||
)
|
||||
|
||||
|
||||
# sqlite:///{database_name}
|
||||
class AgentDB(TaskDB):
|
||||
def __init__(self, database_string) -> None:
|
||||
@@ -75,26 +117,23 @@ class AgentDB(TaskDB):
|
||||
print("Databases Created")
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
input: Optional[str],
|
||||
additional_input: Optional[TaskInput] = None,
|
||||
artifacts: List[Artifact] = None,
|
||||
steps: List[Step] = None,
|
||||
self, input: Optional[str], additional_input: Optional[TaskInput] = None
|
||||
) -> Task:
|
||||
session = self.Session()
|
||||
new_task = TaskModel(
|
||||
input=input,
|
||||
additional_input=additional_input.json() if additional_input else None,
|
||||
additional_input=additional_input.__root__ if additional_input else None,
|
||||
)
|
||||
session.add(new_task)
|
||||
session.commit()
|
||||
session.refresh(new_task)
|
||||
return await self.get_task(new_task.task_id)
|
||||
return convert_to_task(new_task)
|
||||
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
name: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
is_last: bool = False,
|
||||
additional_properties: Optional[Dict[str, str]] = None,
|
||||
) -> Step:
|
||||
@@ -102,6 +141,7 @@ class AgentDB(TaskDB):
|
||||
new_step = StepModel(
|
||||
task_id=task_id,
|
||||
name=name,
|
||||
input=input,
|
||||
status="created",
|
||||
is_last=is_last,
|
||||
additional_properties=additional_properties,
|
||||
@@ -109,23 +149,28 @@ class AgentDB(TaskDB):
|
||||
session.add(new_step)
|
||||
session.commit()
|
||||
session.refresh(new_step)
|
||||
return await self.get_step(task_id, new_step.step_id)
|
||||
return convert_to_step(new_step)
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
file_name: str,
|
||||
relative_path: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
file_data: bytes | None = None,
|
||||
uri: str,
|
||||
agent_created: bool = False,
|
||||
step_id: str | None = None,
|
||||
) -> Artifact:
|
||||
session = self.Session()
|
||||
|
||||
if existing_artifact := session.query(ArtifactModel).filter_by(uri=uri).first():
|
||||
session.close()
|
||||
return existing_artifact
|
||||
|
||||
new_artifact = ArtifactModel(
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
agent_created=agent_created,
|
||||
file_name=file_name,
|
||||
relative_path=relative_path,
|
||||
file_data=file_data,
|
||||
uri=uri,
|
||||
)
|
||||
session.add(new_artifact)
|
||||
session.commit()
|
||||
@@ -135,50 +180,27 @@ class AgentDB(TaskDB):
|
||||
async def get_task(self, task_id: int) -> Task:
|
||||
"""Get a task by its id"""
|
||||
session = self.Session()
|
||||
task_obj = session.query(TaskModel).filter_by(task_id=task_id).first()
|
||||
task_obj = (
|
||||
session.query(TaskModel)
|
||||
.options(joinedload(TaskModel.steps))
|
||||
.filter_by(task_id=task_id)
|
||||
.first()
|
||||
)
|
||||
if task_obj:
|
||||
task = Task(
|
||||
task_id=task_obj.task_id,
|
||||
input=task_obj.input,
|
||||
additional_input=task_obj.additional_input,
|
||||
steps=[],
|
||||
)
|
||||
steps_obj = session.query(StepModel).filter_by(task_id=task_id).all()
|
||||
if steps_obj:
|
||||
for step in steps_obj:
|
||||
status = (
|
||||
Status.created if step.status == "created" else Status.completed
|
||||
)
|
||||
task.steps.append(
|
||||
Step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
name=step.name,
|
||||
status=status,
|
||||
is_last=step.is_last == 1,
|
||||
additional_properties=step.additional_properties,
|
||||
)
|
||||
)
|
||||
return task
|
||||
return convert_to_task(task_obj)
|
||||
else:
|
||||
raise DataNotFoundError("Task not found")
|
||||
|
||||
async def get_step(self, task_id: int, step_id: int) -> Step:
|
||||
session = self.Session()
|
||||
if (
|
||||
step := session.query(StepModel)
|
||||
.filter_by(task_id=task_id, step_id=step_id)
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.options(joinedload(StepModel.task).joinedload(TaskModel.artifacts))
|
||||
.filter(StepModel.step_id == step_id)
|
||||
.first()
|
||||
):
|
||||
status = Status.completed if step.status == "completed" else Status.created
|
||||
return Step(
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
name=step.name,
|
||||
status=status,
|
||||
is_last=step.is_last == 1,
|
||||
additional_properties=step.additional_properties,
|
||||
)
|
||||
return convert_to_step(step)
|
||||
|
||||
else:
|
||||
raise DataNotFoundError("Step not found")
|
||||
|
||||
@@ -204,30 +226,20 @@ class AgentDB(TaskDB):
|
||||
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||
session = self.Session()
|
||||
if (
|
||||
artifact := session.query(ArtifactModel)
|
||||
if artifact_model := (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id, artifact_id=artifact_id)
|
||||
.first()
|
||||
):
|
||||
return Artifact(
|
||||
artifact_id=artifact.artifact_id,
|
||||
file_name=artifact.file_name,
|
||||
relative_path=artifact.relative_path,
|
||||
artifact_id=str(artifact_model.artifact_id), # Casting to string
|
||||
file_name=artifact_model.file_name,
|
||||
agent_created=artifact_model.agent_created,
|
||||
uri=artifact_model.uri,
|
||||
)
|
||||
else:
|
||||
raise DataNotFoundError("Artifact not found")
|
||||
|
||||
async def get_artifact_file(self, task_id: str, artifact_id: str) -> bytes:
|
||||
session = self.Session()
|
||||
if (
|
||||
artifact := session.query(ArtifactModel.file_data)
|
||||
.filter_by(task_id=task_id, artifact_id=artifact_id)
|
||||
.first()
|
||||
):
|
||||
return artifact.file_data
|
||||
else:
|
||||
raise DataNotFoundError("Artifact not found")
|
||||
|
||||
async def list_tasks(self) -> List[Task]:
|
||||
session = self.Session()
|
||||
tasks = session.query(TaskModel).all()
|
||||
@@ -252,3 +264,16 @@ class AgentDB(TaskDB):
|
||||
)
|
||||
for step in steps
|
||||
]
|
||||
|
||||
async def list_artifacts(self, task_id: str) -> List[Artifact]:
|
||||
session = self.Session()
|
||||
artifacts = session.query(ArtifactModel).filter_by(task_id=task_id).all()
|
||||
return [
|
||||
Artifact(
|
||||
artifact_id=str(artifact.artifact_id),
|
||||
file_name=artifact.file_name,
|
||||
agent_created=artifact.agent_created,
|
||||
uri=artifact.uri,
|
||||
)
|
||||
for artifact in artifacts
|
||||
]
|
||||
|
||||
@@ -52,6 +52,17 @@ async def test_create_and_get_task():
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_task_with_steps():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
await agent_db.create_task("task_input")
|
||||
task = await agent_db.get_task(1)
|
||||
|
||||
assert task.input == "task_input"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_not_found():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
@@ -102,41 +113,23 @@ async def test_get_artifact():
|
||||
# Given: A task and its corresponding artifact
|
||||
task = await db.create_task("test_input")
|
||||
step = await db.create_step(task.task_id, "step_name")
|
||||
|
||||
# Create an artifact
|
||||
artifact = await db.create_artifact(
|
||||
task.task_id, "sample_file.txt", "/path/to/sample_file.txt", step.step_id
|
||||
task_id=task.task_id,
|
||||
file_name="sample_file.txt",
|
||||
uri="file:///path/to/sample_file.txt",
|
||||
agent_created=True,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
# When: The artifact is fetched by its ID
|
||||
fetched_artifact = await db.get_artifact(task.task_id, artifact.artifact_id)
|
||||
fetched_artifact = await db.get_artifact(int(task.task_id), artifact.artifact_id)
|
||||
|
||||
# Then: The fetched artifact matches the original
|
||||
assert fetched_artifact.artifact_id == artifact.artifact_id
|
||||
assert fetched_artifact.file_name == "sample_file.txt"
|
||||
assert fetched_artifact.relative_path == "/path/to/sample_file.txt"
|
||||
os.remove(db_name.split("///")[1])
|
||||
assert fetched_artifact.uri == "file:///path/to/sample_file.txt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_artifact_file():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
sample_data = b"sample data"
|
||||
# Given: A task and its corresponding artifact
|
||||
task = await db.create_task("test_input")
|
||||
step = await db.create_step(task.task_id, "step_name")
|
||||
artifact = await db.create_artifact(
|
||||
task.task_id,
|
||||
"sample_file.txt",
|
||||
"/path/to/sample_file.txt",
|
||||
step.step_id,
|
||||
sample_data,
|
||||
)
|
||||
|
||||
# When: The artifact is fetched by its ID
|
||||
fetched_artifact = await db.get_artifact_file(task.task_id, artifact.artifact_id)
|
||||
|
||||
# Then: The fetched artifact matches the original
|
||||
assert fetched_artifact == sample_data
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@@ -174,6 +167,6 @@ async def test_list_steps():
|
||||
|
||||
# Then: The fetched steps list includes the created steps
|
||||
step_ids = [step.step_id for step in fetched_steps]
|
||||
assert step1.step_id in step_ids
|
||||
assert step2.step_id in step_ids
|
||||
assert str(step1.step_id) in step_ids
|
||||
assert str(step2.step_id) in step_ids
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
0
autogpt/llm.py
Normal file
0
autogpt/llm.py
Normal file
@@ -5,6 +5,7 @@ PLEASE IGNORE
|
||||
-------------------------------------------------------------
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import typing
|
||||
from pathlib import Path
|
||||
@@ -41,11 +42,28 @@ def chat_completion_request(
|
||||
exit()
|
||||
|
||||
|
||||
def run(task: str) -> None:
|
||||
def run(task: str):
|
||||
"""Runs the agent for benchmarking"""
|
||||
print("Running agent")
|
||||
steps = plan(task)
|
||||
execute_plan(steps)
|
||||
# check for artifacts in workspace
|
||||
items = glob.glob(os.path.join(workspace, "*"))
|
||||
if items:
|
||||
artifacts = []
|
||||
print(f"Found {len(items)} artifacts in workspace")
|
||||
for item in items:
|
||||
with open(item, "r") as f:
|
||||
item_contents = f.read()
|
||||
path_within_workspace = os.path.relpath(item, workspace)
|
||||
artifacts.append(
|
||||
{
|
||||
"file_name": os.path.basename(item),
|
||||
"uri": f"file://{path_within_workspace}",
|
||||
"contents": item_contents,
|
||||
}
|
||||
)
|
||||
return artifacts
|
||||
|
||||
|
||||
def execute_plan(plan: typing.List[str]) -> None:
|
||||
|
||||
127
autogpt/workspace.py
Normal file
127
autogpt/workspace.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import abc
|
||||
import os
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
from google.cloud import storage
|
||||
|
||||
|
||||
class Workspace(abc.ABC):
|
||||
@abc.abstractclassmethod
|
||||
def __init__(self, base_path: str) -> None:
|
||||
self.base_path = base_path
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def read(self, path: str) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def write(self, path: str, data: bytes) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def delete(
|
||||
self, path: str, directory: bool = False, recursive: bool = False
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def exists(self, path: str) -> bool:
|
||||
pass
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def list(self, path: str) -> typing.List[str]:
|
||||
pass
|
||||
|
||||
|
||||
class LocalWorkspace(Workspace):
|
||||
def __init__(self, base_path: str):
|
||||
self.base_path = Path(base_path).resolve()
|
||||
|
||||
def _resolve_path(self, path: str) -> Path:
|
||||
abs_path = (self.base_path / path).resolve()
|
||||
if not str(abs_path).startswith(str(self.base_path)):
|
||||
raise ValueError("Directory traversal is not allowed!")
|
||||
return abs_path
|
||||
|
||||
def read(self, path: str) -> bytes:
|
||||
path = self.base_path / path
|
||||
with open(self._resolve_path(path), "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def write(self, path: str, data: bytes) -> None:
|
||||
path = self.base_path / path
|
||||
with open(self._resolve_path(path), "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
def delete(
|
||||
self, path: str, directory: bool = False, recursive: bool = False
|
||||
) -> None:
|
||||
path = self.base_path / path
|
||||
resolved_path = self._resolve_path(path)
|
||||
if directory:
|
||||
if recursive:
|
||||
os.rmdir(resolved_path)
|
||||
else:
|
||||
os.removedirs(resolved_path)
|
||||
else:
|
||||
os.remove(resolved_path)
|
||||
|
||||
def exists(self, path: str) -> bool:
|
||||
path = self.base_path / path
|
||||
return self._resolve_path(path).exists()
|
||||
|
||||
def list(self, path: str) -> typing.List[str]:
|
||||
path = self.base_path / path
|
||||
base = self._resolve_path(path)
|
||||
return [str(p.relative_to(self.base_path)) for p in base.iterdir()]
|
||||
|
||||
|
||||
class GCSWorkspace(Workspace):
|
||||
def __init__(self, base_path: str, bucket_name: str):
|
||||
self.client = storage.Client()
|
||||
self.bucket_name = bucket_name
|
||||
self.base_path = base_path.strip("/") # Ensure no trailing or leading slash
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
resolved = os.path.join(self.base_path, path).strip("/")
|
||||
if not resolved.startswith(self.base_path):
|
||||
raise ValueError("Directory traversal is not allowed!")
|
||||
return resolved
|
||||
|
||||
def read(self, path: str) -> bytes:
|
||||
path = self.base_path / path
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(self._resolve_path(path))
|
||||
return blob.download_as_bytes()
|
||||
|
||||
def write(self, path: str, data: bytes) -> None:
|
||||
path = self.base_path / path
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.blob(self._resolve_path(path))
|
||||
blob.upload_from_string(data)
|
||||
|
||||
def delete(
|
||||
self, path: str, directory: bool = False, recursive: bool = False
|
||||
) -> None:
|
||||
path = self.base_path / path
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
if directory and recursive:
|
||||
# Note: GCS doesn't really have directories, so this will just delete all blobs with the given prefix
|
||||
blobs = bucket.list_blobs(prefix=self._resolve_path(path))
|
||||
bucket.delete_blobs(blobs)
|
||||
else:
|
||||
blob = bucket.blob(self._resolve_path(path))
|
||||
blob.delete()
|
||||
|
||||
def exists(self, path: str) -> bool:
|
||||
path = self.base_path / path
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.blob(self._resolve_path(path))
|
||||
return blob.exists()
|
||||
|
||||
def list(self, path: str) -> typing.List[str]:
|
||||
path = self.base_path / path
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blobs = bucket.list_blobs(prefix=self._resolve_path(path))
|
||||
return [blob.name for blob in blobs]
|
||||
46
autogpt/workspace_test.py
Normal file
46
autogpt/workspace_test.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
# Assuming the classes are defined in a file named workspace.py
|
||||
from .workspace import LocalWorkspace
|
||||
|
||||
# Constants
|
||||
TEST_BASE_PATH = "/tmp/test_workspace"
|
||||
TEST_FILE_CONTENT = b"Hello World"
|
||||
|
||||
|
||||
# Setup and Teardown for LocalWorkspace
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_local_workspace():
|
||||
os.makedirs(TEST_BASE_PATH, exist_ok=True)
|
||||
yield
|
||||
os.system(f"rm -rf {TEST_BASE_PATH}") # Cleanup after tests
|
||||
|
||||
|
||||
def test_local_read_write_delete_exists(setup_local_workspace):
|
||||
workspace = LocalWorkspace(TEST_BASE_PATH)
|
||||
|
||||
# Write
|
||||
workspace.write("test_file.txt", TEST_FILE_CONTENT)
|
||||
|
||||
# Exists
|
||||
assert workspace.exists("test_file.txt")
|
||||
|
||||
# Read
|
||||
assert workspace.read("test_file.txt") == TEST_FILE_CONTENT
|
||||
|
||||
# Delete
|
||||
workspace.delete("test_file.txt")
|
||||
assert not workspace.exists("test_file.txt")
|
||||
|
||||
|
||||
def test_local_list(setup_local_workspace):
|
||||
workspace = LocalWorkspace(TEST_BASE_PATH)
|
||||
workspace.write("test1.txt", TEST_FILE_CONTENT)
|
||||
workspace.write("test2.txt", TEST_FILE_CONTENT)
|
||||
|
||||
files = workspace.list(".")
|
||||
assert set(files) == set(["test1.txt", "test2.txt"])
|
||||
479
poetry.lock
generated
479
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -10,10 +10,10 @@ readme = "README.md"
|
||||
python = "^3.10"
|
||||
python-dotenv = "^1.0.0"
|
||||
openai = "^0.27.8"
|
||||
agent-protocol = "^0.2.2"
|
||||
helicone = "^1.0.6"
|
||||
tenacity = "^8.2.2"
|
||||
sqlalchemy = "^2.0.19"
|
||||
google-cloud-storage = "^2.10.0"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
@@ -27,6 +27,7 @@ types-requests = "^2.31.0.2"
|
||||
pytest = "^7.4.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
watchdog = "^3.0.0"
|
||||
mock = "^5.1.0"
|
||||
|
||||
|
||||
[tool.poetry.group.ui.dependencies]
|
||||
|
||||
8
run
8
run
@@ -1,8 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:$PWD
|
||||
|
||||
poetry install
|
||||
poetry shell
|
||||
|
||||
watchmedo auto-restart -p "*.py" -R python3 -- autogpt/__main__.py "$@"
|
||||
# poetry install
|
||||
# poetry shell
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:$PWD; watchmedo auto-restart -p "*.py" -R python3 -- autogpt/__main__.py "$@"
|
||||
|
||||
Reference in New Issue
Block a user