mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-07 08:14:25 +01:00
Updating to version v0.4 of the Protocol (#21)
This commit is contained in:
@@ -27,7 +27,7 @@ if __name__ == "__main__":
|
||||
workspace = LocalWorkspace(os.getenv("AGENT_WORKSPACE"))
|
||||
port = os.getenv("PORT")
|
||||
|
||||
database = autogpt.db.AgentDB(database_name, debug_enabled=True)
|
||||
database = autogpt.db.AgentDB(database_name, debug_enabled=False)
|
||||
agent = autogpt.agent.Agent(database=database, workspace=workspace)
|
||||
|
||||
agent.start(port=port, router=router)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Response, UploadFile
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi import APIRouter, FastAPI, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from hypercorn.asyncio import serve
|
||||
from hypercorn.config import Config
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
from .db import AgentDB
|
||||
from .errors import NotFoundError
|
||||
from .forge_log import CustomLogger
|
||||
from .middlewares import AgentMiddleware
|
||||
from .routes.agent_protocol import base_router
|
||||
@@ -50,7 +51,7 @@ class Agent:
|
||||
config.loglevel = "ERROR"
|
||||
config.bind = [f"0.0.0.0:{port}"]
|
||||
|
||||
LOG.info(f"Agent server starting on {config.bind}")
|
||||
LOG.info(f"Agent server starting on http://{config.bind[0]}")
|
||||
asyncio.run(serve(app, config))
|
||||
|
||||
async def create_task(self, task_request: TaskRequestBody) -> Task:
|
||||
@@ -64,10 +65,9 @@ class Agent:
|
||||
if task_request.additional_input
|
||||
else None,
|
||||
)
|
||||
LOG.info(task.json())
|
||||
return task
|
||||
except Exception as e:
|
||||
return Response(status_code=500, content=str(e))
|
||||
return task
|
||||
raise
|
||||
|
||||
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
|
||||
"""
|
||||
@@ -76,22 +76,18 @@ class Agent:
|
||||
try:
|
||||
tasks, pagination = await self.db.list_tasks(page, pageSize)
|
||||
response = TaskListResponse(tasks=tasks, pagination=pagination)
|
||||
return Response(content=response.json(), media_type="application/json")
|
||||
return response
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise
|
||||
|
||||
async def get_task(self, task_id: str) -> Task:
|
||||
"""
|
||||
Get a task by ID.
|
||||
"""
|
||||
if not task_id:
|
||||
return Response(status_code=400, content="Task ID is required.")
|
||||
if not isinstance(task_id, str):
|
||||
return Response(status_code=400, content="Task ID must be a string.")
|
||||
try:
|
||||
task = await self.db.get_task(task_id)
|
||||
except Exception as e:
|
||||
return Response(status_code=500, content=str(e))
|
||||
raise
|
||||
return task
|
||||
|
||||
async def list_steps(
|
||||
@@ -100,16 +96,12 @@ class Agent:
|
||||
"""
|
||||
List the IDs of all steps that the task has created.
|
||||
"""
|
||||
if not task_id:
|
||||
return Response(status_code=400, content="Task ID is required.")
|
||||
if not isinstance(task_id, str):
|
||||
return Response(status_code=400, content="Task ID must be a string.")
|
||||
try:
|
||||
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
|
||||
response = TaskStepsListResponse(steps=steps, pagination=pagination)
|
||||
return Response(content=response.json(), media_type="application/json")
|
||||
return response
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise
|
||||
|
||||
async def create_and_execute_step(
|
||||
self, task_id: str, step_request: StepRequestBody
|
||||
@@ -144,37 +136,33 @@ class Agent:
|
||||
steps, steps_pagination = await self.db.list_steps(
|
||||
task_id, page=1, per_page=100
|
||||
)
|
||||
artifacts, artifacts_pagination = await self.db.list_artifacts(
|
||||
task_id, page=1, per_page=100
|
||||
)
|
||||
step = steps[-1]
|
||||
step.artifacts = artifacts
|
||||
step.output = "No more steps to run."
|
||||
# The step is the last step on this page so checking if this is the
|
||||
# last page is sufficent to know if it is the last step
|
||||
step.is_last = steps_pagination.current_page == steps_pagination.total_pages
|
||||
# Find the latest step that has not been completed
|
||||
step = next((s for s in reversed(steps) if s.status != "completed"), None)
|
||||
if step is None:
|
||||
# If all steps have been completed, create a new placeholder step
|
||||
step = await self.db.create_step(
|
||||
task_id=task_id,
|
||||
input="y",
|
||||
additional_properties=None,
|
||||
)
|
||||
step.status = "completed"
|
||||
step.is_last = True
|
||||
step.output = "No more steps to run."
|
||||
step = await self.db.update_step(step)
|
||||
if isinstance(step.status, Status):
|
||||
step.status = step.status.value
|
||||
step.output = "Done some work"
|
||||
return JSONResponse(content=step.dict(), status_code=200)
|
||||
return step
|
||||
|
||||
async def get_step(self, task_id: str, step_id: str) -> Step:
|
||||
"""
|
||||
Get a step by ID.
|
||||
"""
|
||||
if not task_id or not step_id:
|
||||
return Response(
|
||||
status_code=400, content="Task ID and step ID are required."
|
||||
)
|
||||
if not isinstance(task_id, str) or not isinstance(step_id, str):
|
||||
return Response(
|
||||
status_code=400, content="Task ID and step ID must be strings."
|
||||
)
|
||||
try:
|
||||
step = await self.db.get_step(task_id, step_id)
|
||||
return step
|
||||
except Exception as e:
|
||||
return Response(status_code=500, content=str(e))
|
||||
return step
|
||||
raise
|
||||
|
||||
async def list_artifacts(
|
||||
self, task_id: str, page: int = 1, pageSize: int = 10
|
||||
@@ -182,18 +170,16 @@ class Agent:
|
||||
"""
|
||||
List the artifacts that the task has created.
|
||||
"""
|
||||
if not task_id:
|
||||
return Response(status_code=400, content="Task ID is required.")
|
||||
if not isinstance(task_id, str):
|
||||
return Response(status_code=400, content="Task ID must be a string.")
|
||||
try:
|
||||
artifacts, pagination = await self.db.list_artifacts(
|
||||
task_id, page, pageSize
|
||||
)
|
||||
response = TaskArtifactsListResponse(
|
||||
artifacts=artifacts, pagination=pagination
|
||||
)
|
||||
return Response(content=response.json(), media_type="application/json")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
response = TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
|
||||
return Response(content=response.json(), media_type="application/json")
|
||||
raise
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
@@ -204,8 +190,6 @@ class Agent:
|
||||
"""
|
||||
Create an artifact for the task.
|
||||
"""
|
||||
if not file and not uri:
|
||||
return Response(status_code=400, content="No file or uri provided")
|
||||
data = None
|
||||
if not uri:
|
||||
file_name = file.filename or str(uuid4())
|
||||
@@ -214,13 +198,13 @@ class Agent:
|
||||
while contents := file.file.read(1024 * 1024):
|
||||
data += contents
|
||||
except Exception as e:
|
||||
return Response(status_code=500, content=str(e))
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
data = await load_from_uri(uri, task_id)
|
||||
file_name = uri.split("/")[-1]
|
||||
except Exception as e:
|
||||
return Response(status_code=500, content=str(e))
|
||||
raise
|
||||
|
||||
file_path = os.path.join(task_id / file_name)
|
||||
self.write(file_path, data)
|
||||
@@ -241,18 +225,16 @@ class Agent:
|
||||
"""
|
||||
try:
|
||||
artifact = await self.db.get_artifact(task_id, artifact_id)
|
||||
except Exception as e:
|
||||
return Response(status_code=500, content=str(e))
|
||||
try:
|
||||
retrieved_artifact = await self.load_from_uri(artifact.uri, artifact_id)
|
||||
except Exception as e:
|
||||
return Response(status_code=500, content=str(e))
|
||||
path = artifact.file_name
|
||||
try:
|
||||
path = artifact.file_name
|
||||
with open(path, "wb") as f:
|
||||
f.write(retrieved_artifact)
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except FileNotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
return Response(status_code=500, content=str(e))
|
||||
raise
|
||||
return FileResponse(
|
||||
# Note: mimetype is guessed in the FileResponse constructor
|
||||
path=path,
|
||||
|
||||
104
autogpt/agent_test.py
Normal file
104
autogpt/agent_test.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
|
||||
from .agent import Agent
|
||||
from .db import AgentDB
|
||||
from .schema import StepRequestBody, Task, TaskListResponse, TaskRequestBody
|
||||
from .workspace import LocalWorkspace
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
db = AgentDB("sqlite:///test.db")
|
||||
workspace = LocalWorkspace("./test_workspace")
|
||||
return Agent(db, workspace)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task(agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input="additional_test_input"
|
||||
)
|
||||
task: Task = await agent.create_task(task_request)
|
||||
assert task.input == "test_input"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks(agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input="additional_test_input"
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
tasks = await agent.list_tasks()
|
||||
assert isinstance(tasks, TaskListResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task(agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input="additional_test_input"
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
retrieved_task = await agent.get_task(task.task_id)
|
||||
assert retrieved_task.task_id == task.task_id
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_execute_step(agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input="additional_test_input"
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
step_request = StepRequestBody(
|
||||
input="step_input", additional_input="additional_step_input"
|
||||
)
|
||||
step = await agent.create_and_execute_step(task.task_id, step_request)
|
||||
assert step.input == "step_input"
|
||||
assert step.additional_input == "additional_step_input"
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_step(agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input="additional_test_input"
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
step_request = StepRequestBody(
|
||||
input="step_input", additional_input="additional_step_input"
|
||||
)
|
||||
step = await agent.create_and_execute_step(task.task_id, step_request)
|
||||
retrieved_step = await agent.get_step(task.task_id, step.step_id)
|
||||
assert retrieved_step.step_id == step.step_id
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_artifacts(agent):
|
||||
artifacts = await agent.list_artifacts()
|
||||
assert isinstance(artifacts, list)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_artifact(agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input="additional_test_input"
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
|
||||
artifact = await agent.create_artifact(task.task_id, artifact_request)
|
||||
assert artifact.uri == "test_uri"
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_artifact(agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input="additional_test_input"
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
|
||||
artifact = await agent.create_artifact(task.task_id, artifact_request)
|
||||
retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
|
||||
assert retrieved_artifact.artifact_id == artifact.artifact_id
|
||||
494
autogpt/db.py
494
autogpt/db.py
@@ -4,12 +4,16 @@ It uses SQLite as the database and file store backend.
|
||||
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import math
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, create_engine
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, create_engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker
|
||||
|
||||
from .errors import NotFoundError
|
||||
from .forge_log import CustomLogger
|
||||
from .schema import Artifact, Pagination, Status, Step, Task, TaskInput
|
||||
|
||||
@@ -20,16 +24,16 @@ class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class DataNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TaskModel(Base):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
task_id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
task_id = Column(String, primary_key=True, index=True)
|
||||
input = Column(String)
|
||||
additional_input = Dict
|
||||
additional_input = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
modified_at = Column(
|
||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
)
|
||||
|
||||
artifacts = relationship("ArtifactModel", back_populates="task")
|
||||
|
||||
@@ -37,12 +41,16 @@ class TaskModel(Base):
|
||||
class StepModel(Base):
|
||||
__tablename__ = "steps"
|
||||
|
||||
step_id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
task_id = Column(Integer, ForeignKey("tasks.task_id"))
|
||||
step_id = Column(String, primary_key=True, index=True)
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
name = Column(String)
|
||||
input = Column(String)
|
||||
status = Column(String)
|
||||
is_last = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
modified_at = Column(
|
||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
)
|
||||
|
||||
additional_properties = Column(String)
|
||||
artifacts = relationship("ArtifactModel", back_populates="step")
|
||||
@@ -51,12 +59,16 @@ class StepModel(Base):
|
||||
class ArtifactModel(Base):
|
||||
__tablename__ = "artifacts"
|
||||
|
||||
artifact_id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
task_id = Column(Integer, ForeignKey("tasks.task_id"))
|
||||
step_id = Column(Integer, ForeignKey("steps.step_id"))
|
||||
artifact_id = Column(String, primary_key=True, index=True)
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
step_id = Column(String, ForeignKey("steps.step_id"))
|
||||
agent_created = Column(Boolean, default=False)
|
||||
file_name = Column(String)
|
||||
uri = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
modified_at = Column(
|
||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
)
|
||||
|
||||
step = relationship("StepModel", back_populates="artifacts")
|
||||
task = relationship("TaskModel", back_populates="artifacts")
|
||||
@@ -65,18 +77,11 @@ class ArtifactModel(Base):
|
||||
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
|
||||
if debug_enabled:
|
||||
LOG.debug(f"Converting TaskModel to Task for task_id: {task_obj.task_id}")
|
||||
task_artifacts = [
|
||||
Artifact(
|
||||
artifact_id=artifact.artifact_id,
|
||||
file_name=artifact.file_name,
|
||||
agent_created=artifact.agent_created,
|
||||
uri=artifact.uri,
|
||||
)
|
||||
for artifact in task_obj.artifacts
|
||||
if artifact.task_id == task_obj.task_id
|
||||
]
|
||||
task_artifacts = [convert_to_artifact(artifact) for artifact in task_obj.artifacts]
|
||||
return Task(
|
||||
task_id=task_obj.task_id,
|
||||
created_at=task_obj.created_at,
|
||||
modified_at=task_obj.modified_at,
|
||||
input=task_obj.input,
|
||||
additional_input=task_obj.additional_input,
|
||||
artifacts=task_artifacts,
|
||||
@@ -87,19 +92,14 @@ def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
|
||||
if debug_enabled:
|
||||
LOG.debug(f"Converting StepModel to Step for step_id: {step_model.step_id}")
|
||||
step_artifacts = [
|
||||
Artifact(
|
||||
artifact_id=artifact.artifact_id,
|
||||
file_name=artifact.file_name,
|
||||
agent_created=artifact.agent_created,
|
||||
uri=artifact.uri,
|
||||
)
|
||||
for artifact in step_model.artifacts
|
||||
if artifact.step_id == step_model.step_id
|
||||
convert_to_artifact(artifact) for artifact in step_model.artifacts
|
||||
]
|
||||
status = Status.completed if step_model.status == "completed" else Status.created
|
||||
return Step(
|
||||
task_id=step_model.task_id,
|
||||
step_id=step_model.step_id,
|
||||
created_at=step_model.created_at,
|
||||
modified_at=step_model.modified_at,
|
||||
name=step_model.name,
|
||||
input=step_model.input,
|
||||
status=status,
|
||||
@@ -109,6 +109,16 @@ def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
|
||||
)
|
||||
|
||||
|
||||
def convert_to_artifact(artifact_model: ArtifactModel) -> Artifact:
|
||||
return Artifact(
|
||||
artifact_id=artifact_model.artifact_id,
|
||||
created_at=artifact_model.created_at,
|
||||
modified_at=artifact_model.modified_at,
|
||||
agent_created=artifact_model.agent_created,
|
||||
uri=artifact_model.uri,
|
||||
)
|
||||
|
||||
|
||||
# sqlite:///{database_name}
|
||||
class AgentDB:
|
||||
def __init__(self, database_string, debug_enabled: bool = False) -> None:
|
||||
@@ -125,19 +135,30 @@ class AgentDB:
|
||||
) -> Task:
|
||||
if self.debug_enabled:
|
||||
LOG.debug("Creating new task")
|
||||
with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
input=input,
|
||||
additional_input=additional_input.__root__
|
||||
if additional_input
|
||||
else None,
|
||||
)
|
||||
session.add(new_task)
|
||||
session.commit()
|
||||
session.refresh(new_task)
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Created new task with task_id: {new_task.task_id}")
|
||||
return convert_to_task(new_task, self.debug_enabled)
|
||||
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
task_id=str(uuid.uuid4()),
|
||||
input=input,
|
||||
additional_input=additional_input.__root__
|
||||
if additional_input
|
||||
else None,
|
||||
)
|
||||
session.add(new_task)
|
||||
session.commit()
|
||||
session.refresh(new_task)
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Created new task with task_id: {new_task.task_id}")
|
||||
return convert_to_task(new_task, self.debug_enabled)
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while creating task: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Unexpected error while creating task: {e}")
|
||||
raise
|
||||
|
||||
async def create_step(
|
||||
self,
|
||||
@@ -149,23 +170,30 @@ class AgentDB:
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Creating new step for task_id: {task_id}")
|
||||
try:
|
||||
session = self.Session()
|
||||
new_step = StepModel(
|
||||
task_id=task_id,
|
||||
name=input,
|
||||
input=input,
|
||||
status="created",
|
||||
is_last=is_last,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
session.add(new_step)
|
||||
session.commit()
|
||||
session.refresh(new_step)
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Created new step with step_id: {new_step.step_id}")
|
||||
with self.Session() as session:
|
||||
new_step = StepModel(
|
||||
task_id=task_id,
|
||||
step_id=str(uuid.uuid4()),
|
||||
name=input,
|
||||
input=input,
|
||||
status="created",
|
||||
is_last=is_last,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
session.add(new_step)
|
||||
session.commit()
|
||||
session.refresh(new_step)
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Created new step with step_id: {new_step.step_id}")
|
||||
return convert_to_step(new_step, self.debug_enabled)
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while creating step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Error while creating step: {e}")
|
||||
return convert_to_step(new_step, self.debug_enabled)
|
||||
LOG.error(f"Unexpected error while creating step: {e}")
|
||||
raise
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
@@ -177,71 +205,94 @@ class AgentDB:
|
||||
) -> Artifact:
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Creating new artifact for task_id: {task_id}")
|
||||
session = self.Session()
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
existing_artifact := session.query(ArtifactModel)
|
||||
.filter_by(uri=uri)
|
||||
.first()
|
||||
):
|
||||
session.close()
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Artifact already exists with uri: {uri}")
|
||||
return convert_to_artifact(existing_artifact)
|
||||
|
||||
if existing_artifact := session.query(ArtifactModel).filter_by(uri=uri).first():
|
||||
session.close()
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Artifact already exists with uri: {uri}")
|
||||
return Artifact(
|
||||
artifact_id=str(existing_artifact.artifact_id),
|
||||
file_name=existing_artifact.file_name,
|
||||
agent_created=existing_artifact.agent_created,
|
||||
uri=existing_artifact.uri,
|
||||
)
|
||||
|
||||
new_artifact = ArtifactModel(
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
agent_created=agent_created,
|
||||
file_name=file_name,
|
||||
uri=uri,
|
||||
)
|
||||
session.add(new_artifact)
|
||||
session.commit()
|
||||
session.refresh(new_artifact)
|
||||
if self.debug_enabled:
|
||||
LOG.debug(
|
||||
f"Created new artifact with artifact_id: {new_artifact.artifact_id}"
|
||||
)
|
||||
return Artifact(
|
||||
artifact_id=str(new_artifact.artifact_id),
|
||||
file_name=new_artifact.file_name,
|
||||
agent_created=new_artifact.agent_created,
|
||||
uri=new_artifact.uri,
|
||||
)
|
||||
new_artifact = ArtifactModel(
|
||||
artifact_id=str(uuid.uuid4()),
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
agent_created=agent_created,
|
||||
file_name=file_name,
|
||||
uri=uri,
|
||||
)
|
||||
session.add(new_artifact)
|
||||
session.commit()
|
||||
session.refresh(new_artifact)
|
||||
if self.debug_enabled:
|
||||
LOG.debug(
|
||||
f"Created new artifact with artifact_id: {new_artifact.artifact_id}"
|
||||
)
|
||||
return convert_to_artifact(new_artifact)
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while creating step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Unexpected error while creating step: {e}")
|
||||
raise
|
||||
|
||||
async def get_task(self, task_id: int) -> Task:
|
||||
"""Get a task by its id"""
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Getting task with task_id: {task_id}")
|
||||
session = self.Session()
|
||||
if task_obj := (
|
||||
session.query(TaskModel)
|
||||
.options(joinedload(TaskModel.artifacts))
|
||||
.filter_by(task_id=task_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_task(task_obj, self.debug_enabled)
|
||||
else:
|
||||
LOG.error(f"Task not found with task_id: {task_id}")
|
||||
raise DataNotFoundError("Task not found")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if task_obj := (
|
||||
session.query(TaskModel)
|
||||
.options(joinedload(TaskModel.artifacts))
|
||||
.filter_by(task_id=task_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_task(task_obj, self.debug_enabled)
|
||||
else:
|
||||
LOG.error(f"Task not found with task_id: {task_id}")
|
||||
raise NotFoundError("Task not found")
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while getting task: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Unexpected error while getting task: {e}")
|
||||
raise
|
||||
|
||||
async def get_step(self, task_id: int, step_id: int) -> Step:
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Getting step with task_id: {task_id} and step_id: {step_id}")
|
||||
session = self.Session()
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.options(joinedload(StepModel.artifacts))
|
||||
.filter(StepModel.step_id == step_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_step(step, self.debug_enabled)
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.options(joinedload(StepModel.artifacts))
|
||||
.filter(StepModel.step_id == step_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_step(step, self.debug_enabled)
|
||||
|
||||
else:
|
||||
LOG.error(f"Step not found with task_id: {task_id} and step_id: {step_id}")
|
||||
raise DataNotFoundError("Step not found")
|
||||
else:
|
||||
LOG.error(
|
||||
f"Step not found with task_id: {task_id} and step_id: {step_id}"
|
||||
)
|
||||
raise NotFoundError("Step not found")
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while getting step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Unexpected error while getting step: {e}")
|
||||
raise
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
@@ -252,108 +303,157 @@ class AgentDB:
|
||||
) -> Step:
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Updating step with task_id: {task_id} and step_id: {step_id}")
|
||||
session = self.Session()
|
||||
if (
|
||||
step := session.query(StepModel)
|
||||
.filter_by(task_id=task_id, step_id=step_id)
|
||||
.first()
|
||||
):
|
||||
step.status = status
|
||||
step.additional_properties = additional_properties
|
||||
session.commit()
|
||||
return await self.get_step(task_id, step_id)
|
||||
else:
|
||||
LOG.error(
|
||||
f"Step not found for update with task_id: {task_id} and step_id: {step_id}"
|
||||
)
|
||||
raise DataNotFoundError("Step not found")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
step := session.query(StepModel)
|
||||
.filter_by(task_id=task_id, step_id=step_id)
|
||||
.first()
|
||||
):
|
||||
step.status = status
|
||||
step.additional_properties = additional_properties
|
||||
session.commit()
|
||||
return await self.get_step(task_id, step_id)
|
||||
else:
|
||||
LOG.error(
|
||||
f"Step not found for update with task_id: {task_id} and step_id: {step_id}"
|
||||
)
|
||||
raise NotFoundError("Step not found")
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while getting step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Unexpected error while getting step: {e}")
|
||||
raise
|
||||
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||
async def get_artifact(self, artifact_id: str) -> Artifact:
|
||||
if self.debug_enabled:
|
||||
LOG.debug(
|
||||
f"Getting artifact with task_id: {task_id} and artifact_id: {artifact_id}"
|
||||
)
|
||||
session = self.Session()
|
||||
if artifact_model := (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=int(task_id), artifact_id=int(artifact_id))
|
||||
.first()
|
||||
):
|
||||
return Artifact(
|
||||
artifact_id=artifact_model.artifact_id, # Casting to string
|
||||
file_name=artifact_model.file_name,
|
||||
agent_created=artifact_model.agent_created,
|
||||
uri=artifact_model.uri,
|
||||
)
|
||||
else:
|
||||
LOG.error(
|
||||
f"Artifact not found with task_id: {task_id} and artifact_id: {artifact_id}"
|
||||
)
|
||||
raise DataNotFoundError("Artifact not found")
|
||||
LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
artifact_model := session.query(ArtifactModel)
|
||||
.filter_by(artifact_id=artifact_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_artifact(artifact_model)
|
||||
else:
|
||||
LOG.error(f"Artifact not found with and artifact_id: {artifact_id}")
|
||||
raise NotFoundError("Artifact not found")
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while getting artifact: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Unexpected error while getting artifact: {e}")
|
||||
raise
|
||||
|
||||
async def list_tasks(
|
||||
self, page: int = 1, per_page: int = 10
|
||||
) -> Tuple[List[Task], Pagination]:
|
||||
if self.debug_enabled:
|
||||
LOG.debug("Listing tasks")
|
||||
session = self.Session()
|
||||
tasks = (
|
||||
session.query(TaskModel).offset((page - 1) * per_page).limit(per_page).all()
|
||||
)
|
||||
total = session.query(TaskModel).count()
|
||||
pages = math.ceil(total / per_page)
|
||||
pagination = Pagination(
|
||||
total_items=total, total_pages=pages, current_page=page, page_size=per_page
|
||||
)
|
||||
return [convert_to_task(task, self.debug_enabled) for task in tasks], pagination
|
||||
try:
|
||||
with self.Session() as session:
|
||||
tasks = (
|
||||
session.query(TaskModel)
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
if not tasks:
|
||||
raise NotFoundError("No tasks found")
|
||||
total = session.query(TaskModel).count()
|
||||
pages = math.ceil(total / per_page)
|
||||
pagination = Pagination(
|
||||
total_items=total,
|
||||
total_pages=pages,
|
||||
current_page=page,
|
||||
page_size=per_page,
|
||||
)
|
||||
return [
|
||||
convert_to_task(task, self.debug_enabled) for task in tasks
|
||||
], pagination
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while listing tasks: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Unexpected error while listing tasks: {e}")
|
||||
raise
|
||||
|
||||
async def list_steps(
|
||||
self, task_id: str, page: int = 1, per_page: int = 10
|
||||
) -> Tuple[List[Step], Pagination]:
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Listing steps for task_id: {task_id}")
|
||||
session = self.Session()
|
||||
steps = (
|
||||
session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
total = session.query(StepModel).filter_by(task_id=task_id).count()
|
||||
pages = math.ceil(total / per_page)
|
||||
pagination = Pagination(
|
||||
total_items=total, total_pages=pages, current_page=page, page_size=per_page
|
||||
)
|
||||
return [convert_to_step(step, self.debug_enabled) for step in steps], pagination
|
||||
try:
|
||||
with self.Session() as session:
|
||||
steps = (
|
||||
session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
if not steps:
|
||||
raise NotFoundError("No steps found")
|
||||
total = session.query(StepModel).filter_by(task_id=task_id).count()
|
||||
pages = math.ceil(total / per_page)
|
||||
pagination = Pagination(
|
||||
total_items=total,
|
||||
total_pages=pages,
|
||||
current_page=page,
|
||||
page_size=per_page,
|
||||
)
|
||||
return [
|
||||
convert_to_step(step, self.debug_enabled) for step in steps
|
||||
], pagination
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while listing steps: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Unexpected error while listing steps: {e}")
|
||||
raise
|
||||
|
||||
async def list_artifacts(
|
||||
self, task_id: str, page: int = 1, per_page: int = 10
|
||||
) -> Tuple[List[Artifact], Pagination]:
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Listing artifacts for task_id: {task_id}")
|
||||
with self.Session() as session:
|
||||
artifacts = (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
total = session.query(ArtifactModel).filter_by(task_id=task_id).count()
|
||||
pages = math.ceil(total / per_page)
|
||||
pagination = Pagination(
|
||||
total_items=total,
|
||||
total_pages=pages,
|
||||
current_page=page,
|
||||
page_size=per_page,
|
||||
)
|
||||
return [
|
||||
Artifact(
|
||||
artifact_id=str(artifact.artifact_id),
|
||||
file_name=artifact.file_name,
|
||||
agent_created=artifact.agent_created,
|
||||
uri=artifact.uri,
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifacts = (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
for artifact in artifacts
|
||||
], pagination
|
||||
if not artifacts:
|
||||
raise NotFoundError("No artifacts found")
|
||||
total = session.query(ArtifactModel).filter_by(task_id=task_id).count()
|
||||
pages = math.ceil(total / per_page)
|
||||
pagination = Pagination(
|
||||
total_items=total,
|
||||
total_pages=pages,
|
||||
current_page=page,
|
||||
page_size=per_page,
|
||||
)
|
||||
return [
|
||||
convert_to_artifact(artifact) for artifact in artifacts
|
||||
], pagination
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while listing artifacts: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Unexpected error while listing artifacts: {e}")
|
||||
raise
|
||||
|
||||
@@ -1,11 +1,23 @@
|
||||
import os
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt.db import AgentDB, DataNotFoundError
|
||||
from autogpt.db import (
|
||||
AgentDB,
|
||||
ArtifactModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
convert_to_artifact,
|
||||
convert_to_step,
|
||||
convert_to_task,
|
||||
)
|
||||
from autogpt.errors import NotFoundError as DataNotFoundError
|
||||
from autogpt.schema import *
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_table_creation():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
@@ -30,6 +42,143 @@ def test_table_creation():
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_schema():
|
||||
now = datetime.now()
|
||||
task = Task(
|
||||
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
agent_created=True,
|
||||
uri="file:///path/to/artifact",
|
||||
file_name="main.py",
|
||||
relative_path="python/code/",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
],
|
||||
)
|
||||
assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
||||
assert task.input == "Write the words you receive to the file 'output.txt'."
|
||||
assert len(task.artifacts) == 1
|
||||
assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_schema():
|
||||
now = datetime.now()
|
||||
step = Step(
|
||||
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
name="Write to file",
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
status=Status.created,
|
||||
output="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>",
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
file_name="main.py",
|
||||
relative_path="python/code/",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
agent_created=True,
|
||||
uri="file:///path/to/artifact",
|
||||
)
|
||||
],
|
||||
is_last=False,
|
||||
)
|
||||
assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
||||
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
|
||||
assert step.name == "Write to file"
|
||||
assert step.status == Status.created
|
||||
assert (
|
||||
step.output
|
||||
== "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
||||
)
|
||||
assert len(step.artifacts) == 1
|
||||
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert step.is_last == False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_task():
|
||||
now = datetime.now()
|
||||
task_model = TaskModel(
|
||||
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
artifacts=[
|
||||
ArtifactModel(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
uri="file:///path/to/main.py",
|
||||
agent_created=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
task = convert_to_task(task_model)
|
||||
assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
||||
assert task.input == "Write the words you receive to the file 'output.txt'."
|
||||
assert len(task.artifacts) == 1
|
||||
assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_step():
|
||||
now = datetime.now()
|
||||
step_model = StepModel(
|
||||
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
name="Write to file",
|
||||
status="created",
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
artifacts=[
|
||||
ArtifactModel(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
uri="file:///path/to/main.py",
|
||||
agent_created=True,
|
||||
)
|
||||
],
|
||||
is_last=False,
|
||||
)
|
||||
step = convert_to_step(step_model)
|
||||
assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
||||
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
|
||||
assert step.name == "Write to file"
|
||||
assert step.status == Status.created
|
||||
assert len(step.artifacts) == 1
|
||||
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert step.is_last == False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_artifact():
|
||||
now = datetime.now()
|
||||
artifact_model = ArtifactModel(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
uri="file:///path/to/main.py",
|
||||
agent_created=True,
|
||||
)
|
||||
artifact = convert_to_artifact(artifact_model)
|
||||
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert artifact.uri == "file:///path/to/main.py"
|
||||
assert artifact.agent_created == True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task():
|
||||
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
|
||||
@@ -52,17 +201,6 @@ async def test_create_and_get_task():
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_task_with_steps():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
await agent_db.create_task("task_input")
|
||||
task = await agent_db.get_task(1)
|
||||
|
||||
assert task.input == "task_input"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_not_found():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
@@ -87,11 +225,11 @@ async def test_create_and_get_step():
|
||||
async def test_updating_step():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
await agent_db.create_task("task_input")
|
||||
await agent_db.create_step(1, "step_name")
|
||||
await agent_db.update_step(1, 1, "completed")
|
||||
created_task = await agent_db.create_task("task_input")
|
||||
created_step = await agent_db.create_step(created_task.task_id, "step_name")
|
||||
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
|
||||
|
||||
step = await agent_db.get_step(1, 1)
|
||||
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
|
||||
assert step.status.value == "completed"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
@@ -117,18 +255,18 @@ async def test_get_artifact():
|
||||
# Create an artifact
|
||||
artifact = await db.create_artifact(
|
||||
task_id=task.task_id,
|
||||
file_name="sample_file.txt",
|
||||
uri="file:///path/to/sample_file.txt",
|
||||
file_name="test_get_artifact_sample_file.txt",
|
||||
uri="file:///path/to/test_get_artifact_sample_file.txt",
|
||||
agent_created=True,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
# When: The artifact is fetched by its ID
|
||||
fetched_artifact = await db.get_artifact(task.task_id, artifact.artifact_id)
|
||||
fetched_artifact = await db.get_artifact(artifact.artifact_id)
|
||||
|
||||
# Then: The fetched artifact matches the original
|
||||
assert fetched_artifact.artifact_id == artifact.artifact_id
|
||||
assert fetched_artifact.uri == "file:///path/to/sample_file.txt"
|
||||
assert fetched_artifact.uri == "file:///path/to/test_get_artifact_sample_file.txt"
|
||||
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
2
autogpt/errors.py
Normal file
2
autogpt/errors.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class NotFoundError(Exception):
|
||||
pass
|
||||
@@ -22,16 +22,37 @@ the ones that require special attention due to their complexity are:
|
||||
Developers and contributors should be especially careful when making modifications to these routes to ensure
|
||||
consistency and correctness in the system's behavior.
|
||||
"""
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from autogpt.errors import *
|
||||
from autogpt.forge_log import CustomLogger
|
||||
from autogpt.schema import *
|
||||
from autogpt.tracing import tracing
|
||||
|
||||
base_router = APIRouter()
|
||||
|
||||
LOG = CustomLogger(__name__)
|
||||
|
||||
|
||||
@base_router.get("/", tags=["root"])
|
||||
async def root():
|
||||
"""
|
||||
Root endpoint that returns a welcome message.
|
||||
"""
|
||||
return Response(content="Welcome to the Auto-GPT Forge")
|
||||
|
||||
|
||||
@base_router.get("/heartbeat", tags=["server"])
|
||||
async def check_server_status():
|
||||
"""
|
||||
Check if the server is running.
|
||||
"""
|
||||
return Response(content="Server is running.", status_code=200)
|
||||
|
||||
|
||||
@base_router.get("/", tags=["root"])
|
||||
async def root():
|
||||
@@ -75,15 +96,25 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
|
||||
"input": "Write the word 'Washington' to a .txt file",
|
||||
"additional_input": "python/code",
|
||||
"artifacts": [],
|
||||
"steps": []
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
|
||||
if task_request := await agent.create_task(task_request):
|
||||
return task_request
|
||||
else:
|
||||
return Response(content={"error": "Task creation failed"}, status_code=400)
|
||||
try:
|
||||
task_request = await agent.create_task(task_request)
|
||||
return Response(content=task_request.json(), status_code=200)
|
||||
except NotFoundError:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse)
|
||||
@@ -128,7 +159,21 @@ async def list_agent_tasks(
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
return await agent.list_tasks(page, page_size)
|
||||
try:
|
||||
tasks = await agent.list_tasks(page, page_size)
|
||||
return Response(content=tasks.json(), status_code=200)
|
||||
except NotFoundError:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task)
|
||||
@@ -185,11 +230,21 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
task = await agent.get_task(task_id)
|
||||
if task:
|
||||
return task
|
||||
else:
|
||||
return Response(content={"error": "Task not found"}, status_code=404)
|
||||
try:
|
||||
task = await agent.get_task(task_id)
|
||||
return Response(content=task.json(), status_code=200)
|
||||
except NotFoundError:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@@ -236,7 +291,21 @@ async def list_agent_task_steps(
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
return await agent.list_steps(task_id, page, page_size)
|
||||
try:
|
||||
steps = await agent.list_steps(task_id, page, page_size)
|
||||
return Response(content=steps.json(), status_code=200)
|
||||
except NotFoundError:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step)
|
||||
@@ -284,7 +353,22 @@ async def execute_agent_task_step(
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
return await agent.create_and_execute_step(task_id, step)
|
||||
try:
|
||||
step = await agent.create_and_execute_step(task_id, step)
|
||||
return Response(content=step.json(), status_code=200)
|
||||
except NotFoundError:
|
||||
return Response(
|
||||
content=json.dumps({"error": f"Task not found {task_id}"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.exception("Error whilst trying to execute a test")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@@ -315,7 +399,21 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
return await agent.get_step(task_id, step_id)
|
||||
try:
|
||||
step = await agent.get_step(task_id, step_id)
|
||||
return Response(content=step.json(), status_code=200)
|
||||
except NotFoundError:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@@ -362,7 +460,21 @@ async def list_agent_task_artifacts(
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
return await agent.list_artifacts(task_id, page, page_size)
|
||||
try:
|
||||
artifacts = await agent.list_artifacts(task_id, page, page_size)
|
||||
return Response(content=artifacts.json(), status_code=200)
|
||||
except NotFoundError:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@base_router.post(
|
||||
@@ -394,7 +506,6 @@ async def upload_agent_task_artifacts(
|
||||
Note:
|
||||
Either `file` or `uri` must be provided. If both are provided, the behavior depends on
|
||||
the agent's implementation. If neither is provided, the function will return an error.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
POST /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/artifacts
|
||||
@@ -409,7 +520,35 @@ async def upload_agent_task_artifacts(
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
return await agent.create_artifact(task_id, file, uri)
|
||||
if file is None and uri is None:
|
||||
return Response(
|
||||
content={"error": "Either file or uri must be specified"}, status_code=404
|
||||
)
|
||||
if file is not None and uri is not None:
|
||||
return Response(
|
||||
content={"error": "Both file and uri cannot be specified at the same time"},
|
||||
status_code=404,
|
||||
)
|
||||
if uri is not None and not uri.startswith(("http://", "https://", "file://")):
|
||||
return Response(
|
||||
content={"error": "URI must start with http, https or file"},
|
||||
status_code=404,
|
||||
)
|
||||
try:
|
||||
artifact = await agent.create_artifact(task_id, file, uri)
|
||||
return Response(content=artifact.json(), status_code=200)
|
||||
except NotFoundError:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@@ -438,4 +577,25 @@ async def download_agent_task_artifact(
|
||||
<file_content_of_artifact>
|
||||
"""
|
||||
agent = request["agent"]
|
||||
return await agent.get_artifact(task_id, artifact_id)
|
||||
try:
|
||||
return await agent.get_artifact(task_id, artifact_id)
|
||||
except NotFoundError:
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": f"Artifact not found - task_id: {task_id}, artifact_id: {artifact_id}"
|
||||
}
|
||||
),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": f"Internal server error - task_id: {task_id}, artifact_id: {artifact_id}"
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# generated by fastapi-codegen:
|
||||
# filename: ../../openapi.yml
|
||||
# timestamp: 2023-08-24T13:55:59+00:00
|
||||
# filename: ../../postman/schemas/openapi.yaml
|
||||
# timestamp: 2023-08-25T10:36:11+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -22,6 +23,18 @@ class TaskInput(BaseModel):
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="The creation datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
modified_at: datetime = Field(
|
||||
...,
|
||||
description="The modification datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
artifact_id: str = Field(
|
||||
...,
|
||||
description="ID of the artifact.",
|
||||
@@ -32,8 +45,8 @@ class Artifact(BaseModel):
|
||||
description="Whether the artifact has been created by the agent.",
|
||||
example=False,
|
||||
)
|
||||
uri: Optional[str] = Field(
|
||||
None,
|
||||
uri: str = Field(
|
||||
...,
|
||||
description="URI of the artifact.",
|
||||
example="file://home/bob/workspace/bucket/main.py",
|
||||
)
|
||||
@@ -50,7 +63,6 @@ class StepOutput(BaseModel):
|
||||
class TaskRequestBody(BaseModel):
|
||||
input: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Input prompt for the task.",
|
||||
example="Write the words you receive to the file 'output.txt'.",
|
||||
)
|
||||
@@ -58,6 +70,18 @@ class TaskRequestBody(BaseModel):
|
||||
|
||||
|
||||
class Task(TaskRequestBody):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="The creation datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
modified_at: datetime = Field(
|
||||
...,
|
||||
description="The modification datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
task_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task.",
|
||||
@@ -74,6 +98,9 @@ class Task(TaskRequestBody):
|
||||
|
||||
|
||||
class StepRequestBody(BaseModel):
|
||||
name: Optional[str] = Field(
|
||||
None, description="The name of the task step.", example="Write to file"
|
||||
)
|
||||
input: str = Field(
|
||||
..., description="Input prompt for the step.", example="Washington"
|
||||
)
|
||||
@@ -87,6 +114,18 @@ class Status(Enum):
|
||||
|
||||
|
||||
class Step(StepRequestBody):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="The creation datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
modified_at: datetime = Field(
|
||||
...,
|
||||
description="The modification datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
task_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task this step belongs to.",
|
||||
@@ -112,8 +151,8 @@ class Step(StepRequestBody):
|
||||
artifacts: Optional[List[Artifact]] = Field(
|
||||
[], description="A list of artifacts that the step has produced."
|
||||
)
|
||||
is_last: Optional[bool] = Field(
|
||||
False, description="Whether this is the last step in the task.", example=True
|
||||
is_last: bool = Field(
|
||||
..., description="Whether this is the last step in the task.", example=True
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user