From 62a8c7ae0433b6b3575c824a8a0885a7352097e9 Mon Sep 17 00:00:00 2001 From: Swifty Date: Fri, 25 Aug 2023 16:03:16 +0200 Subject: [PATCH] Updating to version v0.4 of the Protocol (#21) --- autogpt/__main__.py | 2 +- autogpt/agent.py | 98 +++--- autogpt/agent_test.py | 104 +++++++ autogpt/db.py | 494 +++++++++++++++++++------------ autogpt/db_test.py | 178 +++++++++-- autogpt/errors.py | 2 + autogpt/routes/agent_protocol.py | 196 ++++++++++-- autogpt/schema.py | 53 +++- 8 files changed, 826 insertions(+), 301 deletions(-) create mode 100644 autogpt/agent_test.py create mode 100644 autogpt/errors.py diff --git a/autogpt/__main__.py b/autogpt/__main__.py index 536a18de..7deefcf6 100644 --- a/autogpt/__main__.py +++ b/autogpt/__main__.py @@ -27,7 +27,7 @@ if __name__ == "__main__": workspace = LocalWorkspace(os.getenv("AGENT_WORKSPACE")) port = os.getenv("PORT") - database = autogpt.db.AgentDB(database_name, debug_enabled=True) + database = autogpt.db.AgentDB(database_name, debug_enabled=False) agent = autogpt.agent.Agent(database=database, workspace=workspace) agent.start(port=port, router=router) diff --git a/autogpt/agent.py b/autogpt/agent.py index b6f3d428..8c70470f 100644 --- a/autogpt/agent.py +++ b/autogpt/agent.py @@ -1,13 +1,14 @@ import asyncio import os -from fastapi import APIRouter, FastAPI, HTTPException, Response, UploadFile -from fastapi.responses import FileResponse, JSONResponse +from fastapi import APIRouter, FastAPI, Response, UploadFile +from fastapi.responses import FileResponse from hypercorn.asyncio import serve from hypercorn.config import Config from prometheus_fastapi_instrumentator import Instrumentator from .db import AgentDB +from .errors import NotFoundError from .forge_log import CustomLogger from .middlewares import AgentMiddleware from .routes.agent_protocol import base_router @@ -50,7 +51,7 @@ class Agent: config.loglevel = "ERROR" config.bind = [f"0.0.0.0:{port}"] - LOG.info(f"Agent server starting on {config.bind}") + LOG.info(f"Agent server starting on http://{config.bind[0]}") asyncio.run(serve(app, config)) async def create_task(self, task_request: TaskRequestBody) -> Task: @@ -64,10 +65,9 @@ class Agent: if task_request.additional_input else None, ) - LOG.info(task.json()) + return task except Exception as e: - return Response(status_code=500, content=str(e)) - return task + raise async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse: """ @@ -76,22 +76,18 @@ class Agent: try: tasks, pagination = await self.db.list_tasks(page, pageSize) response = TaskListResponse(tasks=tasks, pagination=pagination) - return Response(content=response.json(), media_type="application/json") + return response except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + raise async def get_task(self, task_id: str) -> Task: """ Get a task by ID. """ - if not task_id: - return Response(status_code=400, content="Task ID is required.") - if not isinstance(task_id, str): - return Response(status_code=400, content="Task ID must be a string.") try: task = await self.db.get_task(task_id) except Exception as e: - return Response(status_code=500, content=str(e)) + raise return task async def list_steps( @@ -100,16 +96,12 @@ class Agent: """ List the IDs of all steps that the task has created. """ - if not task_id: - return Response(status_code=400, content="Task ID is required.") - if not isinstance(task_id, str): - return Response(status_code=400, content="Task ID must be a string.") try: steps, pagination = await self.db.list_steps(task_id, page, pageSize) response = TaskStepsListResponse(steps=steps, pagination=pagination) - return Response(content=response.json(), media_type="application/json") + return response except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + raise async def create_and_execute_step( self, task_id: str, step_request: StepRequestBody @@ -144,37 +136,33 @@ class Agent: steps, steps_pagination = await self.db.list_steps( task_id, page=1, per_page=100 ) - artifacts, artifacts_pagination = await self.db.list_artifacts( - task_id, page=1, per_page=100 - ) - step = steps[-1] - step.artifacts = artifacts - step.output = "No more steps to run." - # The step is the last step on this page so checking if this is the - # last page is sufficent to know if it is the last step - step.is_last = steps_pagination.current_page == steps_pagination.total_pages + # Find the latest step that has not been completed + step = next((s for s in reversed(steps) if s.status != "completed"), None) + if step is None: + # If all steps have been completed, create a new placeholder step + step = await self.db.create_step( + task_id=task_id, + input="y", + additional_properties=None, + ) + step.status = "completed" + step.is_last = True + step.output = "No more steps to run." + step = await self.db.update_step(step) if isinstance(step.status, Status): step.status = step.status.value step.output = "Done some work" - return JSONResponse(content=step.dict(), status_code=200) + return step async def get_step(self, task_id: str, step_id: str) -> Step: """ Get a step by ID. """ - if not task_id or not step_id: - return Response( - status_code=400, content="Task ID and step ID are required." - ) - if not isinstance(task_id, str) or not isinstance(step_id, str): - return Response( - status_code=400, content="Task ID and step ID must be strings." - ) try: step = await self.db.get_step(task_id, step_id) + return step except Exception as e: - return Response(status_code=500, content=str(e)) - return step + raise async def list_artifacts( self, task_id: str, page: int = 1, pageSize: int = 10 @@ -182,18 +170,16 @@ class Agent: """ List the artifacts that the task has created. """ - if not task_id: - return Response(status_code=400, content="Task ID is required.") - if not isinstance(task_id, str): - return Response(status_code=400, content="Task ID must be a string.") try: artifacts, pagination = await self.db.list_artifacts( task_id, page, pageSize ) + response = TaskArtifactsListResponse( + artifacts=artifacts, pagination=pagination + ) + return Response(content=response.json(), media_type="application/json") except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - response = TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination) - return Response(content=response.json(), media_type="application/json") + raise async def create_artifact( self, @@ -204,8 +190,6 @@ class Agent: """ Create an artifact for the task. """ - if not file and not uri: - return Response(status_code=400, content="No file or uri provided") data = None if not uri: file_name = file.filename or str(uuid4()) @@ -214,13 +198,13 @@ class Agent: while contents := file.file.read(1024 * 1024): data += contents except Exception as e: - return Response(status_code=500, content=str(e)) + raise else: try: data = await load_from_uri(uri, task_id) file_name = uri.split("/")[-1] except Exception as e: - return Response(status_code=500, content=str(e)) + raise file_path = os.path.join(task_id / file_name) self.write(file_path, data) @@ -241,18 +225,16 @@ class Agent: """ try: artifact = await self.db.get_artifact(task_id, artifact_id) - except Exception as e: - return Response(status_code=500, content=str(e)) - try: retrieved_artifact = await self.load_from_uri(artifact.uri, artifact_id) - except Exception as e: - return Response(status_code=500, content=str(e)) - path = artifact.file_name - try: + path = artifact.file_name with open(path, "wb") as f: f.write(retrieved_artifact) + except NotFoundError as e: + raise + except FileNotFoundError as e: + raise except Exception as e: - return Response(status_code=500, content=str(e)) + raise return FileResponse( # Note: mimetype is guessed in the FileResponse constructor path=path, diff --git a/autogpt/agent_test.py b/autogpt/agent_test.py new file mode 100644 index 00000000..562f20bf --- /dev/null +++ b/autogpt/agent_test.py @@ -0,0 +1,104 @@ +import pytest + +from .agent import Agent +from .db import AgentDB +from .schema import StepRequestBody, Task, TaskListResponse, TaskRequestBody +from .workspace import LocalWorkspace + + +@pytest.fixture +def agent(): + db = AgentDB("sqlite:///test.db") + workspace = LocalWorkspace("./test_workspace") + return Agent(db, workspace) + + +@pytest.mark.asyncio +async def test_create_task(agent): + task_request = TaskRequestBody( + input="test_input", additional_input="additional_test_input" + ) + task: Task = await agent.create_task(task_request) + assert task.input == "test_input" + + +@pytest.mark.asyncio +async def test_list_tasks(agent): + task_request = TaskRequestBody( + input="test_input", additional_input="additional_test_input" + ) + task = await agent.create_task(task_request) + tasks = await agent.list_tasks() + assert isinstance(tasks, TaskListResponse) + + +@pytest.mark.asyncio +async def test_get_task(agent): + task_request = TaskRequestBody( + input="test_input", additional_input="additional_test_input" + ) + task = await agent.create_task(task_request) + retrieved_task = await agent.get_task(task.task_id) + assert retrieved_task.task_id == task.task_id + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_create_and_execute_step(agent): + task_request = TaskRequestBody( + input="test_input", additional_input="additional_test_input" + ) + task = await agent.create_task(task_request) + step_request = StepRequestBody( + input="step_input", additional_input="additional_step_input" + ) + step = await agent.create_and_execute_step(task.task_id, step_request) + assert step.input == "step_input" + assert step.additional_input == "additional_step_input" + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_get_step(agent): + task_request = TaskRequestBody( + input="test_input", additional_input="additional_test_input" + ) + task = await agent.create_task(task_request) + step_request = StepRequestBody( + input="step_input", additional_input="additional_step_input" + ) + step = await agent.create_and_execute_step(task.task_id, step_request) + retrieved_step = await agent.get_step(task.task_id, step.step_id) + assert retrieved_step.step_id == step.step_id + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_list_artifacts(agent): + artifacts = await agent.list_artifacts() + assert isinstance(artifacts, list) + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_create_artifact(agent): + task_request = TaskRequestBody( + input="test_input", additional_input="additional_test_input" + ) + task = await agent.create_task(task_request) + artifact_request = ArtifactRequestBody(file=None, uri="test_uri") + artifact = await agent.create_artifact(task.task_id, artifact_request) + assert artifact.uri == "test_uri" + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_get_artifact(agent): + task_request = TaskRequestBody( + input="test_input", additional_input="additional_test_input" + ) + task = await agent.create_task(task_request) + artifact_request = ArtifactRequestBody(file=None, uri="test_uri") + artifact = await agent.create_artifact(task.task_id, artifact_request) + retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id) + assert retrieved_artifact.artifact_id == artifact.artifact_id diff --git a/autogpt/db.py b/autogpt/db.py index bbd9cfc5..76d93ca4 100644 --- a/autogpt/db.py +++ b/autogpt/db.py @@ -4,12 +4,16 @@ It uses SQLite as the database and file store backend. IT IS NOT ADVISED TO USE THIS IN PRODUCTION! """ +import datetime import math +import uuid from typing import Dict, List, Optional, Tuple -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, create_engine +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, create_engine +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker +from .errors import NotFoundError from .forge_log import CustomLogger from .schema import Artifact, Pagination, Status, Step, Task, TaskInput @@ -20,16 +24,16 @@ class Base(DeclarativeBase): pass -class DataNotFoundError(Exception): - pass - - class TaskModel(Base): __tablename__ = "tasks" - task_id = Column(Integer, primary_key=True, autoincrement=True) + task_id = Column(String, primary_key=True, index=True) input = Column(String) - additional_input = Dict + additional_input = Column(String) + created_at = Column(DateTime, default=datetime.datetime.utcnow) + modified_at = Column( + DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow + ) artifacts = relationship("ArtifactModel", back_populates="task") @@ -37,12 +41,16 @@ class TaskModel(Base): class StepModel(Base): __tablename__ = "steps" - step_id = Column(Integer, primary_key=True, autoincrement=True) - task_id = Column(Integer, ForeignKey("tasks.task_id")) + step_id = Column(String, primary_key=True, index=True) + task_id = Column(String, ForeignKey("tasks.task_id")) name = Column(String) input = Column(String) status = Column(String) is_last = Column(Boolean, default=False) + created_at = Column(DateTime, default=datetime.datetime.utcnow) + modified_at = Column( + DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow + ) additional_properties = Column(String) artifacts = relationship("ArtifactModel", back_populates="step") @@ -51,12 +59,16 @@ class StepModel(Base): class ArtifactModel(Base): __tablename__ = "artifacts" - artifact_id = Column(Integer, primary_key=True, autoincrement=True) - task_id = Column(Integer, ForeignKey("tasks.task_id")) - step_id = Column(Integer, ForeignKey("steps.step_id")) + artifact_id = Column(String, primary_key=True, index=True) + task_id = Column(String, ForeignKey("tasks.task_id")) + step_id = Column(String, ForeignKey("steps.step_id")) agent_created = Column(Boolean, default=False) file_name = Column(String) uri = Column(String) + created_at = Column(DateTime, default=datetime.datetime.utcnow) + modified_at = Column( + DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow + ) step = relationship("StepModel", back_populates="artifacts") task = relationship("TaskModel", back_populates="artifacts") @@ -65,18 +77,11 @@ class ArtifactModel(Base): def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task: if debug_enabled: LOG.debug(f"Converting TaskModel to Task for task_id: {task_obj.task_id}") - task_artifacts = [ - Artifact( - artifact_id=artifact.artifact_id, - file_name=artifact.file_name, - agent_created=artifact.agent_created, - uri=artifact.uri, - ) - for artifact in task_obj.artifacts - if artifact.task_id == task_obj.task_id - ] + task_artifacts = [convert_to_artifact(artifact) for artifact in task_obj.artifacts] return Task( task_id=task_obj.task_id, + created_at=task_obj.created_at, + modified_at=task_obj.modified_at, input=task_obj.input, additional_input=task_obj.additional_input, artifacts=task_artifacts, @@ -87,19 +92,14 @@ def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step: if debug_enabled: LOG.debug(f"Converting StepModel to Step for step_id: {step_model.step_id}") step_artifacts = [ - Artifact( - artifact_id=artifact.artifact_id, - file_name=artifact.file_name, - agent_created=artifact.agent_created, - uri=artifact.uri, - ) - for artifact in step_model.artifacts - if artifact.step_id == step_model.step_id + convert_to_artifact(artifact) for artifact in step_model.artifacts ] status = Status.completed if step_model.status == "completed" else Status.created return Step( task_id=step_model.task_id, step_id=step_model.step_id, + created_at=step_model.created_at, + modified_at=step_model.modified_at, name=step_model.name, input=step_model.input, status=status, @@ -109,6 +109,16 @@ def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step: ) +def convert_to_artifact(artifact_model: ArtifactModel) -> Artifact: + return Artifact( + artifact_id=artifact_model.artifact_id, + created_at=artifact_model.created_at, + modified_at=artifact_model.modified_at, + agent_created=artifact_model.agent_created, + uri=artifact_model.uri, + ) + + # sqlite:///{database_name} class AgentDB: def __init__(self, database_string, debug_enabled: bool = False) -> None: @@ -125,19 +135,30 @@ class AgentDB: ) -> Task: if self.debug_enabled: LOG.debug("Creating new task") - with self.Session() as session: - new_task = TaskModel( - input=input, - additional_input=additional_input.__root__ - if additional_input - else None, - ) - session.add(new_task) - session.commit() - session.refresh(new_task) - if self.debug_enabled: - LOG.debug(f"Created new task with task_id: {new_task.task_id}") - return convert_to_task(new_task, self.debug_enabled) + + try: + with self.Session() as session: + new_task = TaskModel( + task_id=str(uuid.uuid4()), + input=input, + additional_input=additional_input.__root__ + if additional_input + else None, + ) + session.add(new_task) + session.commit() + session.refresh(new_task) + if self.debug_enabled: + LOG.debug(f"Created new task with task_id: {new_task.task_id}") + return convert_to_task(new_task, self.debug_enabled) + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while creating task: {e}") + raise + except NotFoundError as e: + raise + except Exception as e: + LOG.error(f"Unexpected error while creating task: {e}") + raise async def create_step( self, @@ -149,23 +170,30 @@ class AgentDB: if self.debug_enabled: LOG.debug(f"Creating new step for task_id: {task_id}") try: - session = self.Session() - new_step = StepModel( - task_id=task_id, - name=input, - input=input, - status="created", - is_last=is_last, - additional_properties=additional_properties, - ) - session.add(new_step) - session.commit() - session.refresh(new_step) - if self.debug_enabled: - LOG.debug(f"Created new step with step_id: {new_step.step_id}") + with self.Session() as session: + new_step = StepModel( + task_id=task_id, + step_id=str(uuid.uuid4()), + name=input, + input=input, + status="created", + is_last=is_last, + additional_properties=additional_properties, + ) + session.add(new_step) + session.commit() + session.refresh(new_step) + if self.debug_enabled: + LOG.debug(f"Created new step with step_id: {new_step.step_id}") + return convert_to_step(new_step, self.debug_enabled) + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while creating step: {e}") + raise + except NotFoundError as e: + raise except Exception as e: - LOG.error(f"Error while creating step: {e}") - return convert_to_step(new_step, self.debug_enabled) + LOG.error(f"Unexpected error while creating step: {e}") + raise async def create_artifact( self, @@ -177,71 +205,94 @@ class AgentDB: ) -> Artifact: if self.debug_enabled: LOG.debug(f"Creating new artifact for task_id: {task_id}") - session = self.Session() + try: + with self.Session() as session: + if ( + existing_artifact := session.query(ArtifactModel) + .filter_by(uri=uri) + .first() + ): + session.close() + if self.debug_enabled: + LOG.debug(f"Artifact already exists with uri: {uri}") + return convert_to_artifact(existing_artifact) - if existing_artifact := session.query(ArtifactModel).filter_by(uri=uri).first(): - session.close() - if self.debug_enabled: - LOG.debug(f"Artifact already exists with uri: {uri}") - return Artifact( - artifact_id=str(existing_artifact.artifact_id), - file_name=existing_artifact.file_name, - agent_created=existing_artifact.agent_created, - uri=existing_artifact.uri, - ) - - new_artifact = ArtifactModel( - task_id=task_id, - step_id=step_id, - agent_created=agent_created, - file_name=file_name, - uri=uri, - ) - session.add(new_artifact) - session.commit() - session.refresh(new_artifact) - if self.debug_enabled: - LOG.debug( - f"Created new artifact with artifact_id: {new_artifact.artifact_id}" - ) - return Artifact( - artifact_id=str(new_artifact.artifact_id), - file_name=new_artifact.file_name, - agent_created=new_artifact.agent_created, - uri=new_artifact.uri, - ) + new_artifact = ArtifactModel( + artifact_id=str(uuid.uuid4()), + task_id=task_id, + step_id=step_id, + agent_created=agent_created, + file_name=file_name, + uri=uri, + ) + session.add(new_artifact) + session.commit() + session.refresh(new_artifact) + if self.debug_enabled: + LOG.debug( + f"Created new artifact with artifact_id: {new_artifact.artifact_id}" + ) + return convert_to_artifact(new_artifact) + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while creating step: {e}") + raise + except NotFoundError as e: + raise + except Exception as e: + LOG.error(f"Unexpected error while creating step: {e}") + raise async def get_task(self, task_id: int) -> Task: """Get a task by its id""" if self.debug_enabled: LOG.debug(f"Getting task with task_id: {task_id}") - session = self.Session() - if task_obj := ( - session.query(TaskModel) - .options(joinedload(TaskModel.artifacts)) - .filter_by(task_id=task_id) - .first() - ): - return convert_to_task(task_obj, self.debug_enabled) - else: - LOG.error(f"Task not found with task_id: {task_id}") - raise DataNotFoundError("Task not found") + try: + with self.Session() as session: + if task_obj := ( + session.query(TaskModel) + .options(joinedload(TaskModel.artifacts)) + .filter_by(task_id=task_id) + .first() + ): + return convert_to_task(task_obj, self.debug_enabled) + else: + LOG.error(f"Task not found with task_id: {task_id}") + raise NotFoundError("Task not found") + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while getting task: {e}") + raise + except NotFoundError as e: + raise + except Exception as e: + LOG.error(f"Unexpected error while getting task: {e}") + raise async def get_step(self, task_id: int, step_id: int) -> Step: if self.debug_enabled: LOG.debug(f"Getting step with task_id: {task_id} and step_id: {step_id}") - session = self.Session() - if step := ( - session.query(StepModel) - .options(joinedload(StepModel.artifacts)) - .filter(StepModel.step_id == step_id) - .first() - ): - return convert_to_step(step, self.debug_enabled) + try: + with self.Session() as session: + if step := ( + session.query(StepModel) + .options(joinedload(StepModel.artifacts)) + .filter(StepModel.step_id == step_id) + .first() + ): + return convert_to_step(step, self.debug_enabled) - else: - LOG.error(f"Step not found with task_id: {task_id} and step_id: {step_id}") - raise DataNotFoundError("Step not found") + else: + LOG.error( + f"Step not found with task_id: {task_id} and step_id: {step_id}" + ) + raise NotFoundError("Step not found") + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while getting step: {e}") + raise + except NotFoundError as e: + raise + except Exception as e: + LOG.error(f"Unexpected error while getting step: {e}") + raise async def update_step( self, @@ -252,108 +303,157 @@ class AgentDB: ) -> Step: if self.debug_enabled: LOG.debug(f"Updating step with task_id: {task_id} and step_id: {step_id}") - session = self.Session() - if ( - step := session.query(StepModel) - .filter_by(task_id=task_id, step_id=step_id) - .first() - ): - step.status = status - step.additional_properties = additional_properties - session.commit() - return await self.get_step(task_id, step_id) - else: - LOG.error( - f"Step not found for update with task_id: {task_id} and step_id: {step_id}" - ) - raise DataNotFoundError("Step not found") + try: + with self.Session() as session: + if ( + step := session.query(StepModel) + .filter_by(task_id=task_id, step_id=step_id) + .first() + ): + step.status = status + step.additional_properties = additional_properties + session.commit() + return await self.get_step(task_id, step_id) + else: + LOG.error( + f"Step not found for update with task_id: {task_id} and step_id: {step_id}" + ) + raise NotFoundError("Step not found") + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while getting step: {e}") + raise + except NotFoundError as e: + raise + except Exception as e: + LOG.error(f"Unexpected error while getting step: {e}") + raise - async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: + async def get_artifact(self, artifact_id: str) -> Artifact: if self.debug_enabled: - LOG.debug( - f"Getting artifact with task_id: {task_id} and artifact_id: {artifact_id}" - ) - session = self.Session() - if artifact_model := ( - session.query(ArtifactModel) - .filter_by(task_id=int(task_id), artifact_id=int(artifact_id)) - .first() - ): - return Artifact( - artifact_id=artifact_model.artifact_id, # Casting to string - file_name=artifact_model.file_name, - agent_created=artifact_model.agent_created, - uri=artifact_model.uri, - ) - else: - LOG.error( - f"Artifact not found with task_id: {task_id} and artifact_id: {artifact_id}" - ) - raise DataNotFoundError("Artifact not found") + LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}") + try: + with self.Session() as session: + if ( + artifact_model := session.query(ArtifactModel) + .filter_by(artifact_id=artifact_id) + .first() + ): + return convert_to_artifact(artifact_model) + else: + LOG.error(f"Artifact not found with and artifact_id: {artifact_id}") + raise NotFoundError("Artifact not found") + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while getting artifact: {e}") + raise + except NotFoundError as e: + raise + except Exception as e: + LOG.error(f"Unexpected error while getting artifact: {e}") + raise async def list_tasks( self, page: int = 1, per_page: int = 10 ) -> Tuple[List[Task], Pagination]: if self.debug_enabled: LOG.debug("Listing tasks") - session = self.Session() - tasks = ( - session.query(TaskModel).offset((page - 1) * per_page).limit(per_page).all() - ) - total = session.query(TaskModel).count() - pages = math.ceil(total / per_page) - pagination = Pagination( - total_items=total, total_pages=pages, current_page=page, page_size=per_page - ) - return [convert_to_task(task, self.debug_enabled) for task in tasks], pagination + try: + with self.Session() as session: + tasks = ( + session.query(TaskModel) + .offset((page - 1) * per_page) + .limit(per_page) + .all() + ) + if not tasks: + raise NotFoundError("No tasks found") + total = session.query(TaskModel).count() + pages = math.ceil(total / per_page) + pagination = Pagination( + total_items=total, + total_pages=pages, + current_page=page, + page_size=per_page, + ) + return [ + convert_to_task(task, self.debug_enabled) for task in tasks + ], pagination + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while listing tasks: {e}") + raise + except NotFoundError as e: + raise + except Exception as e: + LOG.error(f"Unexpected error while listing tasks: {e}") + raise async def list_steps( self, task_id: str, page: int = 1, per_page: int = 10 ) -> Tuple[List[Step], Pagination]: if self.debug_enabled: LOG.debug(f"Listing steps for task_id: {task_id}") - session = self.Session() - steps = ( - session.query(StepModel) - .filter_by(task_id=task_id) - .offset((page - 1) * per_page) - .limit(per_page) - .all() - ) - total = session.query(StepModel).filter_by(task_id=task_id).count() - pages = math.ceil(total / per_page) - pagination = Pagination( - total_items=total, total_pages=pages, current_page=page, page_size=per_page - ) - return [convert_to_step(step, self.debug_enabled) for step in steps], pagination + try: + with self.Session() as session: + steps = ( + session.query(StepModel) + .filter_by(task_id=task_id) + .offset((page - 1) * per_page) + .limit(per_page) + .all() + ) + if not steps: + raise NotFoundError("No steps found") + total = session.query(StepModel).filter_by(task_id=task_id).count() + pages = math.ceil(total / per_page) + pagination = Pagination( + total_items=total, + total_pages=pages, + current_page=page, + page_size=per_page, + ) + return [ + convert_to_step(step, self.debug_enabled) for step in steps + ], pagination + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while listing steps: {e}") + raise + except NotFoundError as e: + raise + except Exception as e: + LOG.error(f"Unexpected error while listing steps: {e}") + raise async def list_artifacts( self, task_id: str, page: int = 1, per_page: int = 10 ) -> Tuple[List[Artifact], Pagination]: if self.debug_enabled: LOG.debug(f"Listing artifacts for task_id: {task_id}") - with self.Session() as session: - artifacts = ( - session.query(ArtifactModel) - .filter_by(task_id=task_id) - .offset((page - 1) * per_page) - .limit(per_page) - .all() - ) - total = session.query(ArtifactModel).filter_by(task_id=task_id).count() - pages = math.ceil(total / per_page) - pagination = Pagination( - total_items=total, - total_pages=pages, - current_page=page, - page_size=per_page, - ) - return [ - Artifact( - artifact_id=str(artifact.artifact_id), - file_name=artifact.file_name, - agent_created=artifact.agent_created, - uri=artifact.uri, + try: + with self.Session() as session: + artifacts = ( + session.query(ArtifactModel) + .filter_by(task_id=task_id) + .offset((page - 1) * per_page) + .limit(per_page) + .all() ) - for artifact in artifacts - ], pagination + if not artifacts: + raise NotFoundError("No artifacts found") + total = session.query(ArtifactModel).filter_by(task_id=task_id).count() + pages = math.ceil(total / per_page) + pagination = Pagination( + total_items=total, + total_pages=pages, + current_page=page, + page_size=per_page, + ) + return [ + convert_to_artifact(artifact) for artifact in artifacts + ], pagination + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while listing artifacts: {e}") + raise + except NotFoundError as e: + raise + except Exception as e: + LOG.error(f"Unexpected error while listing artifacts: {e}") + raise diff --git a/autogpt/db_test.py b/autogpt/db_test.py index 19de697c..6c59c2a0 100644 --- a/autogpt/db_test.py +++ b/autogpt/db_test.py @@ -1,11 +1,23 @@ import os import sqlite3 +from datetime import datetime import pytest -from autogpt.db import AgentDB, DataNotFoundError +from autogpt.db import ( + AgentDB, + ArtifactModel, + StepModel, + TaskModel, + convert_to_artifact, + convert_to_step, + convert_to_task, +) +from autogpt.errors import NotFoundError as DataNotFoundError +from autogpt.schema import * +@pytest.mark.asyncio def test_table_creation(): db_name = "sqlite:///test_db.sqlite3" agent_db = AgentDB(db_name) @@ -30,6 +42,143 @@ def test_table_creation(): os.remove(db_name.split("///")[1]) +@pytest.mark.asyncio +async def test_task_schema(): + now = datetime.now() + task = Task( + task_id="50da533e-3904-4401-8a07-c49adf88b5eb", + input="Write the words you receive to the file 'output.txt'.", + created_at=now, + modified_at=now, + artifacts=[ + Artifact( + artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", + agent_created=True, + uri="file:///path/to/artifact", + file_name="main.py", + relative_path="python/code/", + created_at=now, + modified_at=now, + ) + ], + ) + assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb" + assert task.input == "Write the words you receive to the file 'output.txt'." + assert len(task.artifacts) == 1 + assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" + + +@pytest.mark.asyncio +async def test_step_schema(): + now = datetime.now() + step = Step( + task_id="50da533e-3904-4401-8a07-c49adf88b5eb", + step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e", + created_at=now, + modified_at=now, + name="Write to file", + input="Write the words you receive to the file 'output.txt'.", + status=Status.created, + output="I am going to use the write_to_file command and write Washington to a file called output.txt ", + artifacts=[ + Artifact( + artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", + file_name="main.py", + relative_path="python/code/", + created_at=now, + modified_at=now, + agent_created=True, + uri="file:///path/to/artifact", + ) + ], + is_last=False, + ) + assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb" + assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e" + assert step.name == "Write to file" + assert step.status == Status.created + assert ( + step.output + == "I am going to use the write_to_file command and write Washington to a file called output.txt " + ) + assert len(step.artifacts) == 1 + assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" + assert step.is_last == False + + +@pytest.mark.asyncio +async def test_convert_to_task(): + now = datetime.now() + task_model = TaskModel( + task_id="50da533e-3904-4401-8a07-c49adf88b5eb", + created_at=now, + modified_at=now, + input="Write the words you receive to the file 'output.txt'.", + artifacts=[ + ArtifactModel( + artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", + created_at=now, + modified_at=now, + uri="file:///path/to/main.py", + agent_created=True, + ) + ], + ) + task = convert_to_task(task_model) + assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb" + assert task.input == "Write the words you receive to the file 'output.txt'." + assert len(task.artifacts) == 1 + assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" + + +@pytest.mark.asyncio +async def test_convert_to_step(): + now = datetime.now() + step_model = StepModel( + task_id="50da533e-3904-4401-8a07-c49adf88b5eb", + step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e", + created_at=now, + modified_at=now, + name="Write to file", + status="created", + input="Write the words you receive to the file 'output.txt'.", + artifacts=[ + ArtifactModel( + artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", + created_at=now, + modified_at=now, + uri="file:///path/to/main.py", + agent_created=True, + ) + ], + is_last=False, + ) + step = convert_to_step(step_model) + assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb" + assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e" + assert step.name == "Write to file" + assert step.status == Status.created + assert len(step.artifacts) == 1 + assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" + assert step.is_last == False + + +@pytest.mark.asyncio +async def test_convert_to_artifact(): + now = datetime.now() + artifact_model = ArtifactModel( + artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", + created_at=now, + modified_at=now, + uri="file:///path/to/main.py", + agent_created=True, + ) + artifact = convert_to_artifact(artifact_model) + assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" + assert artifact.uri == "file:///path/to/main.py" + assert artifact.agent_created == True + + @pytest.mark.asyncio async def test_create_task(): # Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround @@ -52,17 +201,6 @@ async def test_create_and_get_task(): os.remove(db_name.split("///")[1]) -@pytest.mark.asyncio -async def test_create_and_get_task_with_steps(): - db_name = "sqlite:///test_db.sqlite3" - agent_db = AgentDB(db_name) - await agent_db.create_task("task_input") - task = await agent_db.get_task(1) - - assert task.input == "task_input" - os.remove(db_name.split("///")[1]) - - @pytest.mark.asyncio async def test_get_task_not_found(): db_name = "sqlite:///test_db.sqlite3" @@ -87,11 +225,11 @@ async def test_create_and_get_step(): async def test_updating_step(): db_name = "sqlite:///test_db.sqlite3" agent_db = AgentDB(db_name) - await agent_db.create_task("task_input") - await agent_db.create_step(1, "step_name") - await agent_db.update_step(1, 1, "completed") + created_task = await agent_db.create_task("task_input") + created_step = await agent_db.create_step(created_task.task_id, "step_name") + await agent_db.update_step(created_task.task_id, created_step.step_id, "completed") - step = await agent_db.get_step(1, 1) + step = await agent_db.get_step(created_task.task_id, created_step.step_id) assert step.status.value == "completed" os.remove(db_name.split("///")[1]) @@ -117,18 +255,18 @@ async def test_get_artifact(): # Create an artifact artifact = await db.create_artifact( task_id=task.task_id, - file_name="sample_file.txt", - uri="file:///path/to/sample_file.txt", + file_name="test_get_artifact_sample_file.txt", + uri="file:///path/to/test_get_artifact_sample_file.txt", agent_created=True, step_id=step.step_id, ) # When: The artifact is fetched by its ID - fetched_artifact = await db.get_artifact(task.task_id, artifact.artifact_id) + fetched_artifact = await db.get_artifact(artifact.artifact_id) # Then: The fetched artifact matches the original assert fetched_artifact.artifact_id == artifact.artifact_id - assert fetched_artifact.uri == "file:///path/to/sample_file.txt" + assert fetched_artifact.uri == "file:///path/to/test_get_artifact_sample_file.txt" os.remove(db_name.split("///")[1]) diff --git a/autogpt/errors.py b/autogpt/errors.py new file mode 100644 index 00000000..c901a5c9 --- /dev/null +++ b/autogpt/errors.py @@ -0,0 +1,2 @@ +class NotFoundError(Exception): + pass diff --git a/autogpt/routes/agent_protocol.py b/autogpt/routes/agent_protocol.py index ff86282b..28a5bfbf 100644 --- a/autogpt/routes/agent_protocol.py +++ b/autogpt/routes/agent_protocol.py @@ -22,16 +22,37 @@ the ones that require special attention due to their complexity are: Developers and contributors should be especially careful when making modifications to these routes to ensure consistency and correctness in the system's behavior. """ +import json from typing import Optional from fastapi import APIRouter, Query, Request, Response, UploadFile from fastapi.responses import FileResponse +from autogpt.errors import * +from autogpt.forge_log import CustomLogger from autogpt.schema import * from autogpt.tracing import tracing base_router = APIRouter() +LOG = CustomLogger(__name__) + + +@base_router.get("/", tags=["root"]) +async def root(): + """ + Root endpoint that returns a welcome message. + """ + return Response(content="Welcome to the Auto-GPT Forge") + + +@base_router.get("/heartbeat", tags=["server"]) +async def check_server_status(): + """ + Check if the server is running. + """ + return Response(content="Server is running.", status_code=200) + @base_router.get("/", tags=["root"]) async def root(): @@ -75,15 +96,25 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) -> "input": "Write the word 'Washington' to a .txt file", "additional_input": "python/code", "artifacts": [], - "steps": [] } """ agent = request["agent"] - if task_request := await agent.create_task(task_request): - return task_request - else: - return Response(content={"error": "Task creation failed"}, status_code=400) + try: + task_request = await agent.create_task(task_request) + return Response(content=task_request.json(), status_code=200) + except NotFoundError: + return Response( + content=json.dumps({"error": "Task not found"}), + status_code=404, + media_type="application/json", + ) + except Exception: + return Response( + content=json.dumps({"error": "Internal server error"}), + status_code=500, + media_type="application/json", + ) @base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse) @@ -128,7 +159,21 @@ async def list_agent_tasks( } """ agent = request["agent"] - return await agent.list_tasks(page, page_size) + try: + tasks = await agent.list_tasks(page, page_size) + return Response(content=tasks.json(), status_code=200) + except NotFoundError: + return Response( + content=json.dumps({"error": "Task not found"}), + status_code=404, + media_type="application/json", + ) + except Exception: + return Response( + content=json.dumps({"error": "Internal server error"}), + status_code=500, + media_type="application/json", + ) @base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task) @@ -185,11 +230,21 @@ async def get_agent_task(request: Request, task_id: str) -> Task: } """ agent = request["agent"] - task = await agent.get_task(task_id) - if task: - return task - else: - return Response(content={"error": "Task not found"}, status_code=404) + try: + task = await agent.get_task(task_id) + return Response(content=task.json(), status_code=200) + except NotFoundError: + return Response( + content=json.dumps({"error": "Task not found"}), + status_code=404, + media_type="application/json", + ) + except Exception: + return Response( + content=json.dumps({"error": "Internal server error"}), + status_code=500, + media_type="application/json", + ) @base_router.get( @@ -236,7 +291,21 @@ async def list_agent_task_steps( } """ agent = request["agent"] - return await agent.list_steps(task_id, page, page_size) + try: + steps = await agent.list_steps(task_id, page, page_size) + return Response(content=steps.json(), status_code=200) + except NotFoundError: + return Response( + content=json.dumps({"error": "Task not found"}), + status_code=404, + media_type="application/json", + ) + except Exception: + return Response( + content=json.dumps({"error": "Internal server error"}), + status_code=500, + media_type="application/json", + ) @base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step) @@ -284,7 +353,22 @@ async def execute_agent_task_step( } """ agent = request["agent"] - return await agent.create_and_execute_step(task_id, step) + try: + step = await agent.create_and_execute_step(task_id, step) + return Response(content=step.json(), status_code=200) + except NotFoundError: + return Response( + content=json.dumps({"error": f"Task not found {task_id}"}), + status_code=404, + media_type="application/json", + ) + except Exception as e: + LOG.exception("Error whilst trying to execute a test") + return Response( + content=json.dumps({"error": "Internal server error"}), + status_code=500, + media_type="application/json", + ) @base_router.get( @@ -315,7 +399,21 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S } """ agent = request["agent"] - return await agent.get_step(task_id, step_id) + try: + step = await agent.get_step(task_id, step_id) + return Response(content=step.json(), status_code=200) + except NotFoundError: + return Response( + content=json.dumps({"error": "Task not found"}), + status_code=404, + media_type="application/json", + ) + except Exception: + return Response( + content=json.dumps({"error": "Internal server error"}), + status_code=500, + media_type="application/json", + ) @base_router.get( @@ -362,7 +460,21 @@ async def list_agent_task_artifacts( } """ agent = request["agent"] - return await agent.list_artifacts(task_id, page, page_size) + try: + artifacts = await agent.list_artifacts(task_id, page, page_size) + return Response(content=artifacts.json(), status_code=200) + except NotFoundError: + return Response( + content=json.dumps({"error": "Task not found"}), + status_code=404, + media_type="application/json", + ) + except Exception: + return Response( + content=json.dumps({"error": "Internal server error"}), + status_code=500, + media_type="application/json", + ) @base_router.post( @@ -394,7 +506,6 @@ async def upload_agent_task_artifacts( Note: Either `file` or `uri` must be provided. If both are provided, the behavior depends on the agent's implementation. If neither is provided, the function will return an error. - Example: Request: POST /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/artifacts @@ -409,7 +520,35 @@ async def upload_agent_task_artifacts( } """ agent = request["agent"] - return await agent.create_artifact(task_id, file, uri) + if file is None and uri is None: + return Response( + content={"error": "Either file or uri must be specified"}, status_code=404 + ) + if file is not None and uri is not None: + return Response( + content={"error": "Both file and uri cannot be specified at the same time"}, + status_code=404, + ) + if uri is not None and not uri.startswith(("http://", "https://", "file://")): + return Response( + content={"error": "URI must start with http, https or file"}, + status_code=404, + ) + try: + artifact = await agent.create_artifact(task_id, file, uri) + return Response(content=artifact.json(), status_code=200) + except NotFoundError: + return Response( + content=json.dumps({"error": "Task not found"}), + status_code=404, + media_type="application/json", + ) + except Exception: + return Response( + content=json.dumps({"error": "Internal server error"}), + status_code=500, + media_type="application/json", + ) @base_router.get( @@ -438,4 +577,25 @@ async def download_agent_task_artifact( """ agent = request["agent"] - return await agent.get_artifact(task_id, artifact_id) + try: + return await agent.get_artifact(task_id, artifact_id) + except NotFoundError: + return Response( + content=json.dumps( + { + "error": f"Artifact not found - task_id: {task_id}, artifact_id: {artifact_id}" + } + ), + status_code=404, + media_type="application/json", + ) + except Exception: + return Response( + content=json.dumps( + { + "error": f"Internal server error - task_id: {task_id}, artifact_id: {artifact_id}" + } + ), + status_code=500, + media_type="application/json", + ) diff --git a/autogpt/schema.py b/autogpt/schema.py index b69438b6..12aa3664 100644 --- a/autogpt/schema.py +++ b/autogpt/schema.py @@ -1,9 +1,10 @@ # generated by fastapi-codegen: -# filename: ../../openapi.yml -# timestamp: 2023-08-24T13:55:59+00:00 +# filename: ../../postman/schemas/openapi.yaml +# timestamp: 2023-08-25T10:36:11+00:00 from __future__ import annotations +from datetime import datetime from enum import Enum from typing import List, Optional @@ -22,6 +23,18 @@ class TaskInput(BaseModel): class Artifact(BaseModel): + created_at: datetime = Field( + ..., + description="The creation datetime of the task.", + example="2023-01-01T00:00:00Z", + json_encoders={datetime: lambda v: v.isoformat()}, + ) + modified_at: datetime = Field( + ..., + description="The modification datetime of the task.", + example="2023-01-01T00:00:00Z", + json_encoders={datetime: lambda v: v.isoformat()}, + ) artifact_id: str = Field( ..., description="ID of the artifact.", @@ -32,8 +45,8 @@ class Artifact(BaseModel): description="Whether the artifact has been created by the agent.", example=False, ) - uri: Optional[str] = Field( - None, + uri: str = Field( + ..., description="URI of the artifact.", example="file://home/bob/workspace/bucket/main.py", ) @@ -50,7 +63,6 @@ class StepOutput(BaseModel): class TaskRequestBody(BaseModel): input: str = Field( ..., - min_length=1, description="Input prompt for the task.", example="Write the words you receive to the file 'output.txt'.", ) @@ -58,6 +70,18 @@ class TaskRequestBody(BaseModel): class Task(TaskRequestBody): + created_at: datetime = Field( + ..., + description="The creation datetime of the task.", + example="2023-01-01T00:00:00Z", + json_encoders={datetime: lambda v: v.isoformat()}, + ) + modified_at: datetime = Field( + ..., + description="The modification datetime of the task.", + example="2023-01-01T00:00:00Z", + json_encoders={datetime: lambda v: v.isoformat()}, + ) task_id: str = Field( ..., description="The ID of the task.", @@ -74,6 +98,9 @@ class Task(TaskRequestBody): class StepRequestBody(BaseModel): + name: Optional[str] = Field( + None, description="The name of the task step.", example="Write to file" + ) input: str = Field( ..., description="Input prompt for the step.", example="Washington" ) @@ -87,6 +114,18 @@ class Status(Enum): class Step(StepRequestBody): + created_at: datetime = Field( + ..., + description="The creation datetime of the task.", + example="2023-01-01T00:00:00Z", + json_encoders={datetime: lambda v: v.isoformat()}, + ) + modified_at: datetime = Field( + ..., + description="The modification datetime of the task.", + example="2023-01-01T00:00:00Z", + json_encoders={datetime: lambda v: v.isoformat()}, + ) task_id: str = Field( ..., description="The ID of the task this step belongs to.", @@ -112,8 +151,8 @@ class Step(StepRequestBody): artifacts: Optional[List[Artifact]] = Field( [], description="A list of artifacts that the step has produced." ) - is_last: Optional[bool] = Field( - False, description="Whether this is the last step in the task.", example=True + is_last: bool = Field( + ..., description="Whether this is the last step in the task.", example=True )