Files
Auto-GPT/autogpt/agent.py
merwanehamadi d3e86622b7 Fix Forge (#22)
Signed-off-by: Merwane Hamadi <merwanehamadi@gmail.com>
2023-08-26 08:20:12 -07:00

239 lines
7.6 KiB
Python

import asyncio
import os
from fastapi import APIRouter, FastAPI, Response, UploadFile
from fastapi.responses import FileResponse
from hypercorn.asyncio import serve
from hypercorn.config import Config
from prometheus_fastapi_instrumentator import Instrumentator
from .db import AgentDB
from .errors import NotFoundError
from .forge_log import CustomLogger
from .middlewares import AgentMiddleware
from .routes.agent_protocol import base_router
from .schema import *
from .tracing import setup_tracing
from .utils import run
from .workspace import Workspace, load_from_uri
LOG = CustomLogger(__name__)
class Agent:
def __init__(self, database: AgentDB, workspace: Workspace):
self.db = database
self.workspace = workspace
def start(self, port: int = 8000, router: APIRouter = base_router):
"""
Start the agent server.
"""
config = Config()
config.bind = [f"localhost:{port}"]
app = FastAPI(
title="Auto-GPT Forge",
description="Modified version of The Agent Protocol.",
version="v0.4",
)
# Add Prometheus metrics to the agent
# https://github.com/trallnag/prometheus-fastapi-instrumentator
instrumentator = Instrumentator().instrument(app)
@app.on_event("startup")
async def _startup():
instrumentator.expose(app)
app.include_router(router)
app.add_middleware(AgentMiddleware, agent=self)
setup_tracing(app)
config.loglevel = "ERROR"
config.bind = [f"0.0.0.0:{port}"]
LOG.info(f"Agent server starting on http://{config.bind[0]}")
asyncio.run(serve(app, config))
async def create_task(self, task_request: TaskRequestBody) -> Task:
"""
Create a task for the agent.
"""
try:
task = await self.db.create_task(
input=task_request.input,
additional_input=task_request.additional_input,
)
return task
except Exception as e:
raise
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
"""
List all tasks that the agent has created.
"""
try:
tasks, pagination = await self.db.list_tasks(page, pageSize)
response = TaskListResponse(tasks=tasks, pagination=pagination)
return response
except Exception as e:
raise
async def get_task(self, task_id: str) -> Task:
"""
Get a task by ID.
"""
try:
task = await self.db.get_task(task_id)
except Exception as e:
raise
return task
async def list_steps(
self, task_id: str, page: int = 1, pageSize: int = 10
) -> TaskStepsListResponse:
"""
List the IDs of all steps that the task has created.
"""
try:
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
response = TaskStepsListResponse(steps=steps, pagination=pagination)
return response
except Exception as e:
raise
async def create_and_execute_step(
self, task_id: str, step_request: StepRequestBody
) -> Step:
"""
Create a step for the task.
"""
if step_request.input != "y":
step = await self.db.create_step(
task_id=task_id,
input=step_request,
additional_input=step_request.additional_input,
)
# utils.run
artifacts = run(step.input)
for artifact in artifacts:
art = await self.db.create_artifact(
task_id=step.task_id,
file_name=artifact["file_name"],
uri=artifact["uri"],
agent_created=True,
step_id=step.step_id,
)
assert isinstance(
art, Artifact
), f"Artifact not instance of Artifact {type(art)}"
step.artifacts.append(art)
step.status = "completed"
else:
steps, steps_pagination = await self.db.list_steps(
task_id, page=1, per_page=100
)
# Find the latest step that has not been completed
step = next((s for s in reversed(steps) if s.status != "completed"), None)
if step is None:
# If all steps have been completed, create a new placeholder step
step = await self.db.create_step(
task_id=task_id,
input="y",
additional_input={},
)
step.status = "completed"
step.is_last = True
step.output = "No more steps to run."
step = await self.db.update_step(step)
if isinstance(step.status, Status):
step.status = step.status.value
step.output = "Done some work"
return step
async def get_step(self, task_id: str, step_id: str) -> Step:
"""
Get a step by ID.
"""
try:
step = await self.db.get_step(task_id, step_id)
return step
except Exception as e:
raise
async def list_artifacts(
self, task_id: str, page: int = 1, pageSize: int = 10
) -> TaskArtifactsListResponse:
"""
List the artifacts that the task has created.
"""
try:
artifacts, pagination = await self.db.list_artifacts(
task_id, page, pageSize
)
response = TaskArtifactsListResponse(
artifacts=artifacts, pagination=pagination
)
return Response(content=response.json(), media_type="application/json")
except Exception as e:
raise
async def create_artifact(
self,
task_id: str,
file: UploadFile | None = None,
uri: str | None = None,
) -> Artifact:
"""
Create an artifact for the task.
"""
data = None
if not uri:
file_name = file.filename or str(uuid4())
try:
data = b""
while contents := file.file.read(1024 * 1024):
data += contents
except Exception as e:
raise
else:
try:
data = await load_from_uri(uri, task_id)
file_name = uri.split("/")[-1]
except Exception as e:
raise
file_path = os.path.join(task_id / file_name)
self.write(file_path, data)
self.db.save_artifact(task_id, artifact)
artifact = await self.create_artifact(
task_id=task_id,
file_name=file_name,
uri=f"file://{file_path}",
agent_created=False,
)
return artifact
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
"""
Get an artifact by ID.
"""
try:
artifact = await self.db.get_artifact(artifact_id)
retrieved_artifact = await self.load_from_uri(artifact.uri, artifact_id)
path = artifact.file_name
with open(path, "wb") as f:
f.write(retrieved_artifact)
except NotFoundError as e:
raise
except FileNotFoundError as e:
raise
except Exception as e:
raise
return FileResponse(
# Note: mimetype is guessed in the FileResponse constructor
path=path,
filename=artifact.file_name,
)