Updating to version v0.4 of the Protocol (#21)

This commit is contained in:
Swifty
2023-08-25 16:03:16 +02:00
committed by GitHub
parent 16cbf5dc2a
commit 62a8c7ae04
8 changed files with 826 additions and 301 deletions

View File

@@ -27,7 +27,7 @@ if __name__ == "__main__":
workspace = LocalWorkspace(os.getenv("AGENT_WORKSPACE"))
port = os.getenv("PORT")
database = autogpt.db.AgentDB(database_name, debug_enabled=True)
database = autogpt.db.AgentDB(database_name, debug_enabled=False)
agent = autogpt.agent.Agent(database=database, workspace=workspace)
agent.start(port=port, router=router)

View File

@@ -1,13 +1,14 @@
import asyncio
import os
from fastapi import APIRouter, FastAPI, HTTPException, Response, UploadFile
from fastapi.responses import FileResponse, JSONResponse
from fastapi import APIRouter, FastAPI, Response, UploadFile
from fastapi.responses import FileResponse
from hypercorn.asyncio import serve
from hypercorn.config import Config
from prometheus_fastapi_instrumentator import Instrumentator
from .db import AgentDB
from .errors import NotFoundError
from .forge_log import CustomLogger
from .middlewares import AgentMiddleware
from .routes.agent_protocol import base_router
@@ -50,7 +51,7 @@ class Agent:
config.loglevel = "ERROR"
config.bind = [f"0.0.0.0:{port}"]
LOG.info(f"Agent server starting on {config.bind}")
LOG.info(f"Agent server starting on http://{config.bind[0]}")
asyncio.run(serve(app, config))
async def create_task(self, task_request: TaskRequestBody) -> Task:
@@ -64,10 +65,9 @@ class Agent:
if task_request.additional_input
else None,
)
LOG.info(task.json())
return task
except Exception as e:
return Response(status_code=500, content=str(e))
return task
raise
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
"""
@@ -76,22 +76,18 @@ class Agent:
try:
tasks, pagination = await self.db.list_tasks(page, pageSize)
response = TaskListResponse(tasks=tasks, pagination=pagination)
return Response(content=response.json(), media_type="application/json")
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise
async def get_task(self, task_id: str) -> Task:
"""
Get a task by ID.
"""
if not task_id:
return Response(status_code=400, content="Task ID is required.")
if not isinstance(task_id, str):
return Response(status_code=400, content="Task ID must be a string.")
try:
task = await self.db.get_task(task_id)
except Exception as e:
return Response(status_code=500, content=str(e))
raise
return task
async def list_steps(
@@ -100,16 +96,12 @@ class Agent:
"""
List the IDs of all steps that the task has created.
"""
if not task_id:
return Response(status_code=400, content="Task ID is required.")
if not isinstance(task_id, str):
return Response(status_code=400, content="Task ID must be a string.")
try:
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
response = TaskStepsListResponse(steps=steps, pagination=pagination)
return Response(content=response.json(), media_type="application/json")
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise
async def create_and_execute_step(
self, task_id: str, step_request: StepRequestBody
@@ -144,37 +136,33 @@ class Agent:
steps, steps_pagination = await self.db.list_steps(
task_id, page=1, per_page=100
)
artifacts, artifacts_pagination = await self.db.list_artifacts(
task_id, page=1, per_page=100
)
step = steps[-1]
step.artifacts = artifacts
step.output = "No more steps to run."
# The step is the last step on this page so checking if this is the
# last page is sufficent to know if it is the last step
step.is_last = steps_pagination.current_page == steps_pagination.total_pages
# Find the latest step that has not been completed
step = next((s for s in reversed(steps) if s.status != "completed"), None)
if step is None:
# If all steps have been completed, create a new placeholder step
step = await self.db.create_step(
task_id=task_id,
input="y",
additional_properties=None,
)
step.status = "completed"
step.is_last = True
step.output = "No more steps to run."
step = await self.db.update_step(step)
if isinstance(step.status, Status):
step.status = step.status.value
step.output = "Done some work"
return JSONResponse(content=step.dict(), status_code=200)
return step
async def get_step(self, task_id: str, step_id: str) -> Step:
"""
Get a step by ID.
"""
if not task_id or not step_id:
return Response(
status_code=400, content="Task ID and step ID are required."
)
if not isinstance(task_id, str) or not isinstance(step_id, str):
return Response(
status_code=400, content="Task ID and step ID must be strings."
)
try:
step = await self.db.get_step(task_id, step_id)
return step
except Exception as e:
return Response(status_code=500, content=str(e))
return step
raise
async def list_artifacts(
self, task_id: str, page: int = 1, pageSize: int = 10
@@ -182,18 +170,16 @@ class Agent:
"""
List the artifacts that the task has created.
"""
if not task_id:
return Response(status_code=400, content="Task ID is required.")
if not isinstance(task_id, str):
return Response(status_code=400, content="Task ID must be a string.")
try:
artifacts, pagination = await self.db.list_artifacts(
task_id, page, pageSize
)
response = TaskArtifactsListResponse(
artifacts=artifacts, pagination=pagination
)
return Response(content=response.json(), media_type="application/json")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
response = TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
return Response(content=response.json(), media_type="application/json")
raise
async def create_artifact(
self,
@@ -204,8 +190,6 @@ class Agent:
"""
Create an artifact for the task.
"""
if not file and not uri:
return Response(status_code=400, content="No file or uri provided")
data = None
if not uri:
file_name = file.filename or str(uuid4())
@@ -214,13 +198,13 @@ class Agent:
while contents := file.file.read(1024 * 1024):
data += contents
except Exception as e:
return Response(status_code=500, content=str(e))
raise
else:
try:
data = await load_from_uri(uri, task_id)
file_name = uri.split("/")[-1]
except Exception as e:
return Response(status_code=500, content=str(e))
raise
file_path = os.path.join(task_id / file_name)
self.write(file_path, data)
@@ -241,18 +225,16 @@ class Agent:
"""
try:
artifact = await self.db.get_artifact(task_id, artifact_id)
except Exception as e:
return Response(status_code=500, content=str(e))
try:
retrieved_artifact = await self.load_from_uri(artifact.uri, artifact_id)
except Exception as e:
return Response(status_code=500, content=str(e))
path = artifact.file_name
try:
path = artifact.file_name
with open(path, "wb") as f:
f.write(retrieved_artifact)
except NotFoundError as e:
raise
except FileNotFoundError as e:
raise
except Exception as e:
return Response(status_code=500, content=str(e))
raise
return FileResponse(
# Note: mimetype is guessed in the FileResponse constructor
path=path,

104
autogpt/agent_test.py Normal file
View File

@@ -0,0 +1,104 @@
import pytest
from .agent import Agent
from .db import AgentDB
from .schema import StepRequestBody, Task, TaskListResponse, TaskRequestBody
from .workspace import LocalWorkspace
@pytest.fixture
def agent():
db = AgentDB("sqlite:///test.db")
workspace = LocalWorkspace("./test_workspace")
return Agent(db, workspace)
@pytest.mark.asyncio
async def test_create_task(agent):
task_request = TaskRequestBody(
input="test_input", additional_input="additional_test_input"
)
task: Task = await agent.create_task(task_request)
assert task.input == "test_input"
@pytest.mark.asyncio
async def test_list_tasks(agent):
task_request = TaskRequestBody(
input="test_input", additional_input="additional_test_input"
)
task = await agent.create_task(task_request)
tasks = await agent.list_tasks()
assert isinstance(tasks, TaskListResponse)
@pytest.mark.asyncio
async def test_get_task(agent):
task_request = TaskRequestBody(
input="test_input", additional_input="additional_test_input"
)
task = await agent.create_task(task_request)
retrieved_task = await agent.get_task(task.task_id)
assert retrieved_task.task_id == task.task_id
@pytest.mark.skip
@pytest.mark.asyncio
async def test_create_and_execute_step(agent):
task_request = TaskRequestBody(
input="test_input", additional_input="additional_test_input"
)
task = await agent.create_task(task_request)
step_request = StepRequestBody(
input="step_input", additional_input="additional_step_input"
)
step = await agent.create_and_execute_step(task.task_id, step_request)
assert step.input == "step_input"
assert step.additional_input == "additional_step_input"
@pytest.mark.skip
@pytest.mark.asyncio
async def test_get_step(agent):
task_request = TaskRequestBody(
input="test_input", additional_input="additional_test_input"
)
task = await agent.create_task(task_request)
step_request = StepRequestBody(
input="step_input", additional_input="additional_step_input"
)
step = await agent.create_and_execute_step(task.task_id, step_request)
retrieved_step = await agent.get_step(task.task_id, step.step_id)
assert retrieved_step.step_id == step.step_id
@pytest.mark.skip
@pytest.mark.asyncio
async def test_list_artifacts(agent):
artifacts = await agent.list_artifacts()
assert isinstance(artifacts, list)
@pytest.mark.skip
@pytest.mark.asyncio
async def test_create_artifact(agent):
task_request = TaskRequestBody(
input="test_input", additional_input="additional_test_input"
)
task = await agent.create_task(task_request)
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
artifact = await agent.create_artifact(task.task_id, artifact_request)
assert artifact.uri == "test_uri"
@pytest.mark.skip
@pytest.mark.asyncio
async def test_get_artifact(agent):
task_request = TaskRequestBody(
input="test_input", additional_input="additional_test_input"
)
task = await agent.create_task(task_request)
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
artifact = await agent.create_artifact(task.task_id, artifact_request)
retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
assert retrieved_artifact.artifact_id == artifact.artifact_id

View File

@@ -4,12 +4,16 @@ It uses SQLite as the database and file store backend.
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
"""
import datetime
import math
import uuid
from typing import Dict, List, Optional, Tuple
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, create_engine
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker
from .errors import NotFoundError
from .forge_log import CustomLogger
from .schema import Artifact, Pagination, Status, Step, Task, TaskInput
@@ -20,16 +24,16 @@ class Base(DeclarativeBase):
pass
class DataNotFoundError(Exception):
pass
class TaskModel(Base):
__tablename__ = "tasks"
task_id = Column(Integer, primary_key=True, autoincrement=True)
task_id = Column(String, primary_key=True, index=True)
input = Column(String)
additional_input = Dict
additional_input = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
)
artifacts = relationship("ArtifactModel", back_populates="task")
@@ -37,12 +41,16 @@ class TaskModel(Base):
class StepModel(Base):
__tablename__ = "steps"
step_id = Column(Integer, primary_key=True, autoincrement=True)
task_id = Column(Integer, ForeignKey("tasks.task_id"))
step_id = Column(String, primary_key=True, index=True)
task_id = Column(String, ForeignKey("tasks.task_id"))
name = Column(String)
input = Column(String)
status = Column(String)
is_last = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
)
additional_properties = Column(String)
artifacts = relationship("ArtifactModel", back_populates="step")
@@ -51,12 +59,16 @@ class StepModel(Base):
class ArtifactModel(Base):
__tablename__ = "artifacts"
artifact_id = Column(Integer, primary_key=True, autoincrement=True)
task_id = Column(Integer, ForeignKey("tasks.task_id"))
step_id = Column(Integer, ForeignKey("steps.step_id"))
artifact_id = Column(String, primary_key=True, index=True)
task_id = Column(String, ForeignKey("tasks.task_id"))
step_id = Column(String, ForeignKey("steps.step_id"))
agent_created = Column(Boolean, default=False)
file_name = Column(String)
uri = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
)
step = relationship("StepModel", back_populates="artifacts")
task = relationship("TaskModel", back_populates="artifacts")
@@ -65,18 +77,11 @@ class ArtifactModel(Base):
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
if debug_enabled:
LOG.debug(f"Converting TaskModel to Task for task_id: {task_obj.task_id}")
task_artifacts = [
Artifact(
artifact_id=artifact.artifact_id,
file_name=artifact.file_name,
agent_created=artifact.agent_created,
uri=artifact.uri,
)
for artifact in task_obj.artifacts
if artifact.task_id == task_obj.task_id
]
task_artifacts = [convert_to_artifact(artifact) for artifact in task_obj.artifacts]
return Task(
task_id=task_obj.task_id,
created_at=task_obj.created_at,
modified_at=task_obj.modified_at,
input=task_obj.input,
additional_input=task_obj.additional_input,
artifacts=task_artifacts,
@@ -87,19 +92,14 @@ def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
if debug_enabled:
LOG.debug(f"Converting StepModel to Step for step_id: {step_model.step_id}")
step_artifacts = [
Artifact(
artifact_id=artifact.artifact_id,
file_name=artifact.file_name,
agent_created=artifact.agent_created,
uri=artifact.uri,
)
for artifact in step_model.artifacts
if artifact.step_id == step_model.step_id
convert_to_artifact(artifact) for artifact in step_model.artifacts
]
status = Status.completed if step_model.status == "completed" else Status.created
return Step(
task_id=step_model.task_id,
step_id=step_model.step_id,
created_at=step_model.created_at,
modified_at=step_model.modified_at,
name=step_model.name,
input=step_model.input,
status=status,
@@ -109,6 +109,16 @@ def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
)
def convert_to_artifact(artifact_model: ArtifactModel) -> Artifact:
return Artifact(
artifact_id=artifact_model.artifact_id,
created_at=artifact_model.created_at,
modified_at=artifact_model.modified_at,
agent_created=artifact_model.agent_created,
uri=artifact_model.uri,
)
# sqlite:///{database_name}
class AgentDB:
def __init__(self, database_string, debug_enabled: bool = False) -> None:
@@ -125,19 +135,30 @@ class AgentDB:
) -> Task:
if self.debug_enabled:
LOG.debug("Creating new task")
with self.Session() as session:
new_task = TaskModel(
input=input,
additional_input=additional_input.__root__
if additional_input
else None,
)
session.add(new_task)
session.commit()
session.refresh(new_task)
if self.debug_enabled:
LOG.debug(f"Created new task with task_id: {new_task.task_id}")
return convert_to_task(new_task, self.debug_enabled)
try:
with self.Session() as session:
new_task = TaskModel(
task_id=str(uuid.uuid4()),
input=input,
additional_input=additional_input.__root__
if additional_input
else None,
)
session.add(new_task)
session.commit()
session.refresh(new_task)
if self.debug_enabled:
LOG.debug(f"Created new task with task_id: {new_task.task_id}")
return convert_to_task(new_task, self.debug_enabled)
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while creating task: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while creating task: {e}")
raise
async def create_step(
self,
@@ -149,23 +170,30 @@ class AgentDB:
if self.debug_enabled:
LOG.debug(f"Creating new step for task_id: {task_id}")
try:
session = self.Session()
new_step = StepModel(
task_id=task_id,
name=input,
input=input,
status="created",
is_last=is_last,
additional_properties=additional_properties,
)
session.add(new_step)
session.commit()
session.refresh(new_step)
if self.debug_enabled:
LOG.debug(f"Created new step with step_id: {new_step.step_id}")
with self.Session() as session:
new_step = StepModel(
task_id=task_id,
step_id=str(uuid.uuid4()),
name=input,
input=input,
status="created",
is_last=is_last,
additional_properties=additional_properties,
)
session.add(new_step)
session.commit()
session.refresh(new_step)
if self.debug_enabled:
LOG.debug(f"Created new step with step_id: {new_step.step_id}")
return convert_to_step(new_step, self.debug_enabled)
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while creating step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Error while creating step: {e}")
return convert_to_step(new_step, self.debug_enabled)
LOG.error(f"Unexpected error while creating step: {e}")
raise
async def create_artifact(
self,
@@ -177,71 +205,94 @@ class AgentDB:
) -> Artifact:
if self.debug_enabled:
LOG.debug(f"Creating new artifact for task_id: {task_id}")
session = self.Session()
try:
with self.Session() as session:
if (
existing_artifact := session.query(ArtifactModel)
.filter_by(uri=uri)
.first()
):
session.close()
if self.debug_enabled:
LOG.debug(f"Artifact already exists with uri: {uri}")
return convert_to_artifact(existing_artifact)
if existing_artifact := session.query(ArtifactModel).filter_by(uri=uri).first():
session.close()
if self.debug_enabled:
LOG.debug(f"Artifact already exists with uri: {uri}")
return Artifact(
artifact_id=str(existing_artifact.artifact_id),
file_name=existing_artifact.file_name,
agent_created=existing_artifact.agent_created,
uri=existing_artifact.uri,
)
new_artifact = ArtifactModel(
task_id=task_id,
step_id=step_id,
agent_created=agent_created,
file_name=file_name,
uri=uri,
)
session.add(new_artifact)
session.commit()
session.refresh(new_artifact)
if self.debug_enabled:
LOG.debug(
f"Created new artifact with artifact_id: {new_artifact.artifact_id}"
)
return Artifact(
artifact_id=str(new_artifact.artifact_id),
file_name=new_artifact.file_name,
agent_created=new_artifact.agent_created,
uri=new_artifact.uri,
)
new_artifact = ArtifactModel(
artifact_id=str(uuid.uuid4()),
task_id=task_id,
step_id=step_id,
agent_created=agent_created,
file_name=file_name,
uri=uri,
)
session.add(new_artifact)
session.commit()
session.refresh(new_artifact)
if self.debug_enabled:
LOG.debug(
f"Created new artifact with artifact_id: {new_artifact.artifact_id}"
)
return convert_to_artifact(new_artifact)
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while creating step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while creating step: {e}")
raise
async def get_task(self, task_id: int) -> Task:
"""Get a task by its id"""
if self.debug_enabled:
LOG.debug(f"Getting task with task_id: {task_id}")
session = self.Session()
if task_obj := (
session.query(TaskModel)
.options(joinedload(TaskModel.artifacts))
.filter_by(task_id=task_id)
.first()
):
return convert_to_task(task_obj, self.debug_enabled)
else:
LOG.error(f"Task not found with task_id: {task_id}")
raise DataNotFoundError("Task not found")
try:
with self.Session() as session:
if task_obj := (
session.query(TaskModel)
.options(joinedload(TaskModel.artifacts))
.filter_by(task_id=task_id)
.first()
):
return convert_to_task(task_obj, self.debug_enabled)
else:
LOG.error(f"Task not found with task_id: {task_id}")
raise NotFoundError("Task not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting task: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting task: {e}")
raise
async def get_step(self, task_id: int, step_id: int) -> Step:
if self.debug_enabled:
LOG.debug(f"Getting step with task_id: {task_id} and step_id: {step_id}")
session = self.Session()
if step := (
session.query(StepModel)
.options(joinedload(StepModel.artifacts))
.filter(StepModel.step_id == step_id)
.first()
):
return convert_to_step(step, self.debug_enabled)
try:
with self.Session() as session:
if step := (
session.query(StepModel)
.options(joinedload(StepModel.artifacts))
.filter(StepModel.step_id == step_id)
.first()
):
return convert_to_step(step, self.debug_enabled)
else:
LOG.error(f"Step not found with task_id: {task_id} and step_id: {step_id}")
raise DataNotFoundError("Step not found")
else:
LOG.error(
f"Step not found with task_id: {task_id} and step_id: {step_id}"
)
raise NotFoundError("Step not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting step: {e}")
raise
async def update_step(
self,
@@ -252,108 +303,157 @@ class AgentDB:
) -> Step:
if self.debug_enabled:
LOG.debug(f"Updating step with task_id: {task_id} and step_id: {step_id}")
session = self.Session()
if (
step := session.query(StepModel)
.filter_by(task_id=task_id, step_id=step_id)
.first()
):
step.status = status
step.additional_properties = additional_properties
session.commit()
return await self.get_step(task_id, step_id)
else:
LOG.error(
f"Step not found for update with task_id: {task_id} and step_id: {step_id}"
)
raise DataNotFoundError("Step not found")
try:
with self.Session() as session:
if (
step := session.query(StepModel)
.filter_by(task_id=task_id, step_id=step_id)
.first()
):
step.status = status
step.additional_properties = additional_properties
session.commit()
return await self.get_step(task_id, step_id)
else:
LOG.error(
f"Step not found for update with task_id: {task_id} and step_id: {step_id}"
)
raise NotFoundError("Step not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting step: {e}")
raise
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
async def get_artifact(self, artifact_id: str) -> Artifact:
if self.debug_enabled:
LOG.debug(
f"Getting artifact with task_id: {task_id} and artifact_id: {artifact_id}"
)
session = self.Session()
if artifact_model := (
session.query(ArtifactModel)
.filter_by(task_id=int(task_id), artifact_id=int(artifact_id))
.first()
):
return Artifact(
artifact_id=artifact_model.artifact_id, # Casting to string
file_name=artifact_model.file_name,
agent_created=artifact_model.agent_created,
uri=artifact_model.uri,
)
else:
LOG.error(
f"Artifact not found with task_id: {task_id} and artifact_id: {artifact_id}"
)
raise DataNotFoundError("Artifact not found")
LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}")
try:
with self.Session() as session:
if (
artifact_model := session.query(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.first()
):
return convert_to_artifact(artifact_model)
else:
LOG.error(f"Artifact not found with and artifact_id: {artifact_id}")
raise NotFoundError("Artifact not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting artifact: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting artifact: {e}")
raise
async def list_tasks(
self, page: int = 1, per_page: int = 10
) -> Tuple[List[Task], Pagination]:
if self.debug_enabled:
LOG.debug("Listing tasks")
session = self.Session()
tasks = (
session.query(TaskModel).offset((page - 1) * per_page).limit(per_page).all()
)
total = session.query(TaskModel).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total, total_pages=pages, current_page=page, page_size=per_page
)
return [convert_to_task(task, self.debug_enabled) for task in tasks], pagination
try:
with self.Session() as session:
tasks = (
session.query(TaskModel)
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
if not tasks:
raise NotFoundError("No tasks found")
total = session.query(TaskModel).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total,
total_pages=pages,
current_page=page,
page_size=per_page,
)
return [
convert_to_task(task, self.debug_enabled) for task in tasks
], pagination
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while listing tasks: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while listing tasks: {e}")
raise
async def list_steps(
self, task_id: str, page: int = 1, per_page: int = 10
) -> Tuple[List[Step], Pagination]:
if self.debug_enabled:
LOG.debug(f"Listing steps for task_id: {task_id}")
session = self.Session()
steps = (
session.query(StepModel)
.filter_by(task_id=task_id)
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
total = session.query(StepModel).filter_by(task_id=task_id).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total, total_pages=pages, current_page=page, page_size=per_page
)
return [convert_to_step(step, self.debug_enabled) for step in steps], pagination
try:
with self.Session() as session:
steps = (
session.query(StepModel)
.filter_by(task_id=task_id)
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
if not steps:
raise NotFoundError("No steps found")
total = session.query(StepModel).filter_by(task_id=task_id).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total,
total_pages=pages,
current_page=page,
page_size=per_page,
)
return [
convert_to_step(step, self.debug_enabled) for step in steps
], pagination
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while listing steps: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while listing steps: {e}")
raise
async def list_artifacts(
self, task_id: str, page: int = 1, per_page: int = 10
) -> Tuple[List[Artifact], Pagination]:
if self.debug_enabled:
LOG.debug(f"Listing artifacts for task_id: {task_id}")
with self.Session() as session:
artifacts = (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
total = session.query(ArtifactModel).filter_by(task_id=task_id).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total,
total_pages=pages,
current_page=page,
page_size=per_page,
)
return [
Artifact(
artifact_id=str(artifact.artifact_id),
file_name=artifact.file_name,
agent_created=artifact.agent_created,
uri=artifact.uri,
try:
with self.Session() as session:
artifacts = (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
for artifact in artifacts
], pagination
if not artifacts:
raise NotFoundError("No artifacts found")
total = session.query(ArtifactModel).filter_by(task_id=task_id).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total,
total_pages=pages,
current_page=page,
page_size=per_page,
)
return [
convert_to_artifact(artifact) for artifact in artifacts
], pagination
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while listing artifacts: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while listing artifacts: {e}")
raise

View File

@@ -1,11 +1,23 @@
import os
import sqlite3
from datetime import datetime
import pytest
from autogpt.db import AgentDB, DataNotFoundError
from autogpt.db import (
AgentDB,
ArtifactModel,
StepModel,
TaskModel,
convert_to_artifact,
convert_to_step,
convert_to_task,
)
from autogpt.errors import NotFoundError as DataNotFoundError
from autogpt.schema import *
@pytest.mark.asyncio
def test_table_creation():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
@@ -30,6 +42,143 @@ def test_table_creation():
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_task_schema():
now = datetime.now()
task = Task(
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
input="Write the words you receive to the file 'output.txt'.",
created_at=now,
modified_at=now,
artifacts=[
Artifact(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
agent_created=True,
uri="file:///path/to/artifact",
file_name="main.py",
relative_path="python/code/",
created_at=now,
modified_at=now,
)
],
)
assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
assert task.input == "Write the words you receive to the file 'output.txt'."
assert len(task.artifacts) == 1
assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
@pytest.mark.asyncio
async def test_step_schema():
now = datetime.now()
step = Step(
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e",
created_at=now,
modified_at=now,
name="Write to file",
input="Write the words you receive to the file 'output.txt'.",
status=Status.created,
output="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>",
artifacts=[
Artifact(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
file_name="main.py",
relative_path="python/code/",
created_at=now,
modified_at=now,
agent_created=True,
uri="file:///path/to/artifact",
)
],
is_last=False,
)
assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
assert step.name == "Write to file"
assert step.status == Status.created
assert (
step.output
== "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>"
)
assert len(step.artifacts) == 1
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert step.is_last == False
@pytest.mark.asyncio
async def test_convert_to_task():
now = datetime.now()
task_model = TaskModel(
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
created_at=now,
modified_at=now,
input="Write the words you receive to the file 'output.txt'.",
artifacts=[
ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
created_at=now,
modified_at=now,
uri="file:///path/to/main.py",
agent_created=True,
)
],
)
task = convert_to_task(task_model)
assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
assert task.input == "Write the words you receive to the file 'output.txt'."
assert len(task.artifacts) == 1
assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
@pytest.mark.asyncio
async def test_convert_to_step():
now = datetime.now()
step_model = StepModel(
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e",
created_at=now,
modified_at=now,
name="Write to file",
status="created",
input="Write the words you receive to the file 'output.txt'.",
artifacts=[
ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
created_at=now,
modified_at=now,
uri="file:///path/to/main.py",
agent_created=True,
)
],
is_last=False,
)
step = convert_to_step(step_model)
assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
assert step.name == "Write to file"
assert step.status == Status.created
assert len(step.artifacts) == 1
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert step.is_last == False
@pytest.mark.asyncio
async def test_convert_to_artifact():
now = datetime.now()
artifact_model = ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
created_at=now,
modified_at=now,
uri="file:///path/to/main.py",
agent_created=True,
)
artifact = convert_to_artifact(artifact_model)
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert artifact.uri == "file:///path/to/main.py"
assert artifact.agent_created == True
@pytest.mark.asyncio
async def test_create_task():
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
@@ -52,17 +201,6 @@ async def test_create_and_get_task():
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_create_and_get_task_with_steps():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
await agent_db.create_task("task_input")
task = await agent_db.get_task(1)
assert task.input == "task_input"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_get_task_not_found():
db_name = "sqlite:///test_db.sqlite3"
@@ -87,11 +225,11 @@ async def test_create_and_get_step():
async def test_updating_step():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
await agent_db.create_task("task_input")
await agent_db.create_step(1, "step_name")
await agent_db.update_step(1, 1, "completed")
created_task = await agent_db.create_task("task_input")
created_step = await agent_db.create_step(created_task.task_id, "step_name")
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
step = await agent_db.get_step(1, 1)
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
assert step.status.value == "completed"
os.remove(db_name.split("///")[1])
@@ -117,18 +255,18 @@ async def test_get_artifact():
# Create an artifact
artifact = await db.create_artifact(
task_id=task.task_id,
file_name="sample_file.txt",
uri="file:///path/to/sample_file.txt",
file_name="test_get_artifact_sample_file.txt",
uri="file:///path/to/test_get_artifact_sample_file.txt",
agent_created=True,
step_id=step.step_id,
)
# When: The artifact is fetched by its ID
fetched_artifact = await db.get_artifact(task.task_id, artifact.artifact_id)
fetched_artifact = await db.get_artifact(artifact.artifact_id)
# Then: The fetched artifact matches the original
assert fetched_artifact.artifact_id == artifact.artifact_id
assert fetched_artifact.uri == "file:///path/to/sample_file.txt"
assert fetched_artifact.uri == "file:///path/to/test_get_artifact_sample_file.txt"
os.remove(db_name.split("///")[1])

2
autogpt/errors.py Normal file
View File

@@ -0,0 +1,2 @@
class NotFoundError(Exception):
pass

View File

@@ -22,16 +22,37 @@ the ones that require special attention due to their complexity are:
Developers and contributors should be especially careful when making modifications to these routes to ensure
consistency and correctness in the system's behavior.
"""
import json
from typing import Optional
from fastapi import APIRouter, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from autogpt.errors import *
from autogpt.forge_log import CustomLogger
from autogpt.schema import *
from autogpt.tracing import tracing
base_router = APIRouter()
LOG = CustomLogger(__name__)
@base_router.get("/", tags=["root"])
async def root():
"""
Root endpoint that returns a welcome message.
"""
return Response(content="Welcome to the Auto-GPT Forge")
@base_router.get("/heartbeat", tags=["server"])
async def check_server_status():
"""
Check if the server is running.
"""
return Response(content="Server is running.", status_code=200)
@base_router.get("/", tags=["root"])
async def root():
@@ -75,15 +96,25 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
"input": "Write the word 'Washington' to a .txt file",
"additional_input": "python/code",
"artifacts": [],
"steps": []
}
"""
agent = request["agent"]
if task_request := await agent.create_task(task_request):
return task_request
else:
return Response(content={"error": "Task creation failed"}, status_code=400)
try:
task_request = await agent.create_task(task_request)
return Response(content=task_request.json(), status_code=200)
except NotFoundError:
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
media_type="application/json",
)
@base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse)
@@ -128,7 +159,21 @@ async def list_agent_tasks(
}
"""
agent = request["agent"]
return await agent.list_tasks(page, page_size)
try:
tasks = await agent.list_tasks(page, page_size)
return Response(content=tasks.json(), status_code=200)
except NotFoundError:
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
media_type="application/json",
)
@base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task)
@@ -185,11 +230,21 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
}
"""
agent = request["agent"]
task = await agent.get_task(task_id)
if task:
return task
else:
return Response(content={"error": "Task not found"}, status_code=404)
try:
task = await agent.get_task(task_id)
return Response(content=task.json(), status_code=200)
except NotFoundError:
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
media_type="application/json",
)
@base_router.get(
@@ -236,7 +291,21 @@ async def list_agent_task_steps(
}
"""
agent = request["agent"]
return await agent.list_steps(task_id, page, page_size)
try:
steps = await agent.list_steps(task_id, page, page_size)
return Response(content=steps.json(), status_code=200)
except NotFoundError:
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
media_type="application/json",
)
@base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step)
@@ -284,7 +353,22 @@ async def execute_agent_task_step(
}
"""
agent = request["agent"]
return await agent.create_and_execute_step(task_id, step)
try:
step = await agent.create_and_execute_step(task_id, step)
return Response(content=step.json(), status_code=200)
except NotFoundError:
return Response(
content=json.dumps({"error": f"Task not found {task_id}"}),
status_code=404,
media_type="application/json",
)
except Exception as e:
LOG.exception("Error whilst trying to execute a test")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
media_type="application/json",
)
@base_router.get(
@@ -315,7 +399,21 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
}
"""
agent = request["agent"]
return await agent.get_step(task_id, step_id)
try:
step = await agent.get_step(task_id, step_id)
return Response(content=step.json(), status_code=200)
except NotFoundError:
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
media_type="application/json",
)
@base_router.get(
@@ -362,7 +460,21 @@ async def list_agent_task_artifacts(
}
"""
agent = request["agent"]
return await agent.list_artifacts(task_id, page, page_size)
try:
artifacts = await agent.list_artifacts(task_id, page, page_size)
return Response(content=artifacts.json(), status_code=200)
except NotFoundError:
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
media_type="application/json",
)
@base_router.post(
@@ -394,7 +506,6 @@ async def upload_agent_task_artifacts(
Note:
Either `file` or `uri` must be provided. If both are provided, the behavior depends on
the agent's implementation. If neither is provided, the function will return an error.
Example:
Request:
POST /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/artifacts
@@ -409,7 +520,35 @@ async def upload_agent_task_artifacts(
}
"""
agent = request["agent"]
return await agent.create_artifact(task_id, file, uri)
if file is None and uri is None:
return Response(
content={"error": "Either file or uri must be specified"}, status_code=404
)
if file is not None and uri is not None:
return Response(
content={"error": "Both file and uri cannot be specified at the same time"},
status_code=404,
)
if uri is not None and not uri.startswith(("http://", "https://", "file://")):
return Response(
content={"error": "URI must start with http, https or file"},
status_code=404,
)
try:
artifact = await agent.create_artifact(task_id, file, uri)
return Response(content=artifact.json(), status_code=200)
except NotFoundError:
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
media_type="application/json",
)
@base_router.get(
@@ -438,4 +577,25 @@ async def download_agent_task_artifact(
<file_content_of_artifact>
"""
agent = request["agent"]
return await agent.get_artifact(task_id, artifact_id)
try:
return await agent.get_artifact(task_id, artifact_id)
except NotFoundError:
return Response(
content=json.dumps(
{
"error": f"Artifact not found - task_id: {task_id}, artifact_id: {artifact_id}"
}
),
status_code=404,
media_type="application/json",
)
except Exception:
return Response(
content=json.dumps(
{
"error": f"Internal server error - task_id: {task_id}, artifact_id: {artifact_id}"
}
),
status_code=500,
media_type="application/json",
)

View File

@@ -1,9 +1,10 @@
# generated by fastapi-codegen:
# filename: ../../openapi.yml
# timestamp: 2023-08-24T13:55:59+00:00
# filename: ../../postman/schemas/openapi.yaml
# timestamp: 2023-08-25T10:36:11+00:00
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import List, Optional
@@ -22,6 +23,18 @@ class TaskInput(BaseModel):
class Artifact(BaseModel):
created_at: datetime = Field(
...,
description="The creation datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
modified_at: datetime = Field(
...,
description="The modification datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
artifact_id: str = Field(
...,
description="ID of the artifact.",
@@ -32,8 +45,8 @@ class Artifact(BaseModel):
description="Whether the artifact has been created by the agent.",
example=False,
)
uri: Optional[str] = Field(
None,
uri: str = Field(
...,
description="URI of the artifact.",
example="file://home/bob/workspace/bucket/main.py",
)
@@ -50,7 +63,6 @@ class StepOutput(BaseModel):
class TaskRequestBody(BaseModel):
input: str = Field(
...,
min_length=1,
description="Input prompt for the task.",
example="Write the words you receive to the file 'output.txt'.",
)
@@ -58,6 +70,18 @@ class TaskRequestBody(BaseModel):
class Task(TaskRequestBody):
created_at: datetime = Field(
...,
description="The creation datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
modified_at: datetime = Field(
...,
description="The modification datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
task_id: str = Field(
...,
description="The ID of the task.",
@@ -74,6 +98,9 @@ class Task(TaskRequestBody):
class StepRequestBody(BaseModel):
name: Optional[str] = Field(
None, description="The name of the task step.", example="Write to file"
)
input: str = Field(
..., description="Input prompt for the step.", example="Washington"
)
@@ -87,6 +114,18 @@ class Status(Enum):
class Step(StepRequestBody):
created_at: datetime = Field(
...,
description="The creation datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
modified_at: datetime = Field(
...,
description="The modification datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
task_id: str = Field(
...,
description="The ID of the task this step belongs to.",
@@ -112,8 +151,8 @@ class Step(StepRequestBody):
artifacts: Optional[List[Artifact]] = Field(
[], description="A list of artifacts that the step has produced."
)
is_last: Optional[bool] = Field(
False, description="Whether this is the last step in the task.", example=True
is_last: bool = Field(
..., description="Whether this is the last step in the task.", example=True
)