Refactored forge to a cleaner layout

This commit is contained in:
SwiftyOS
2023-08-28 16:25:53 +02:00
parent 372c73fb33
commit 13c53b650d
20 changed files with 267 additions and 248 deletions

BIN
agent.db

Binary file not shown.

View File

@@ -3,23 +3,23 @@ import os
from dotenv import load_dotenv
load_dotenv()
import autogpt.forge_log
import autogpt.sdk.forge_log
ENABLE_TRACING = os.environ.get("ENABLE_TRACING", "false").lower() == "true"
autogpt.forge_log.setup_logger()
autogpt.sdk.forge_log.setup_logger()
LOG = autogpt.forge_log.CustomLogger(__name__)
LOG = autogpt.sdk.forge_log.CustomLogger(__name__)
if __name__ == "__main__":
"""Runs the agent server"""
# modules are imported here so that logging is setup first
import autogpt.agent
import autogpt.db
import autogpt.sdk.db
from autogpt.benchmark_integration import add_benchmark_routes
from autogpt.workspace import LocalWorkspace
from autogpt.sdk.workspace import LocalWorkspace
router = add_benchmark_routes()
@@ -27,7 +27,7 @@ if __name__ == "__main__":
workspace = LocalWorkspace(os.getenv("AGENT_WORKSPACE"))
port = os.getenv("PORT")
database = autogpt.db.AgentDB(database_name, debug_enabled=False)
agent = autogpt.agent.Agent(database=database, workspace=workspace)
database = autogpt.sdk.db.AgentDB(database_name, debug_enabled=False)
agent = autogpt.agent.AutoGPTAgent(database=database, workspace=workspace)
agent.start(port=port, router=router)

View File

@@ -1,238 +1,12 @@
import asyncio
import os
from fastapi import APIRouter, FastAPI, Response, UploadFile
from fastapi.responses import FileResponse
from hypercorn.asyncio import serve
from hypercorn.config import Config
from prometheus_fastapi_instrumentator import Instrumentator
from .db import AgentDB
from .errors import NotFoundError
from .forge_log import CustomLogger
from .middlewares import AgentMiddleware
from .routes.agent_protocol import base_router
from .schema import *
from .tracing import setup_tracing
from .utils import run
from .workspace import Workspace, load_from_uri
LOG = CustomLogger(__name__)
import autogpt.sdk.agent
from autogpt.sdk.schema import Step, StepRequestBody
class Agent:
def __init__(self, database: AgentDB, workspace: Workspace):
self.db = database
self.workspace = workspace
def start(self, port: int = 8000, router: APIRouter = base_router):
"""
Start the agent server.
"""
config = Config()
config.bind = [f"localhost:{port}"]
app = FastAPI(
title="Auto-GPT Forge",
description="Modified version of The Agent Protocol.",
version="v0.4",
)
# Add Prometheus metrics to the agent
# https://github.com/trallnag/prometheus-fastapi-instrumentator
instrumentator = Instrumentator().instrument(app)
@app.on_event("startup")
async def _startup():
instrumentator.expose(app)
app.include_router(router)
app.add_middleware(AgentMiddleware, agent=self)
setup_tracing(app)
config.loglevel = "ERROR"
config.bind = [f"0.0.0.0:{port}"]
LOG.info(f"Agent server starting on http://{config.bind[0]}")
asyncio.run(serve(app, config))
async def create_task(self, task_request: TaskRequestBody) -> Task:
"""
Create a task for the agent.
"""
try:
task = await self.db.create_task(
input=task_request.input,
additional_input=task_request.additional_input,
)
return task
except Exception as e:
raise
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
"""
List all tasks that the agent has created.
"""
try:
tasks, pagination = await self.db.list_tasks(page, pageSize)
response = TaskListResponse(tasks=tasks, pagination=pagination)
return response
except Exception as e:
raise
async def get_task(self, task_id: str) -> Task:
"""
Get a task by ID.
"""
try:
task = await self.db.get_task(task_id)
except Exception as e:
raise
return task
async def list_steps(
self, task_id: str, page: int = 1, pageSize: int = 10
) -> TaskStepsListResponse:
"""
List the IDs of all steps that the task has created.
"""
try:
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
response = TaskStepsListResponse(steps=steps, pagination=pagination)
return response
except Exception as e:
raise
class AutoGPTAgent(autogpt.sdk.agent.Agent):
async def create_and_execute_step(
self, task_id: str, step_request: StepRequestBody
) -> Step:
"""
Create a step for the task.
Create a step for the task and execute it.
"""
if step_request.input != "y":
step = await self.db.create_step(
task_id=task_id,
input=step_request,
additional_input=step_request.additional_input,
)
# utils.run
artifacts = run(step.input)
for artifact in artifacts:
art = await self.db.create_artifact(
task_id=step.task_id,
file_name=artifact["file_name"],
uri=artifact["uri"],
agent_created=True,
step_id=step.step_id,
)
assert isinstance(
art, Artifact
), f"Artifact not instance of Artifact {type(art)}"
step.artifacts.append(art)
step.status = "completed"
else:
steps, steps_pagination = await self.db.list_steps(
task_id, page=1, per_page=100
)
# Find the latest step that has not been completed
step = next((s for s in reversed(steps) if s.status != "completed"), None)
if step is None:
# If all steps have been completed, create a new placeholder step
step = await self.db.create_step(
task_id=task_id,
input="y",
additional_input={},
)
step.status = "completed"
step.is_last = True
step.output = "No more steps to run."
step = await self.db.update_step(step)
if isinstance(step.status, Status):
step.status = step.status.value
step.output = "Done some work"
return step
async def get_step(self, task_id: str, step_id: str) -> Step:
"""
Get a step by ID.
"""
try:
step = await self.db.get_step(task_id, step_id)
return step
except Exception as e:
raise
async def list_artifacts(
self, task_id: str, page: int = 1, pageSize: int = 10
) -> TaskArtifactsListResponse:
"""
List the artifacts that the task has created.
"""
try:
artifacts, pagination = await self.db.list_artifacts(
task_id, page, pageSize
)
response = TaskArtifactsListResponse(
artifacts=artifacts, pagination=pagination
)
return Response(content=response.json(), media_type="application/json")
except Exception as e:
raise
async def create_artifact(
self,
task_id: str,
file: UploadFile | None = None,
uri: str | None = None,
) -> Artifact:
"""
Create an artifact for the task.
"""
data = None
if not uri:
file_name = file.filename or str(uuid4())
try:
data = b""
while contents := file.file.read(1024 * 1024):
data += contents
except Exception as e:
raise
else:
try:
data = await load_from_uri(uri, task_id)
file_name = uri.split("/")[-1]
except Exception as e:
raise
file_path = os.path.join(task_id / file_name)
self.write(file_path, data)
self.db.save_artifact(task_id, artifact)
artifact = await self.create_artifact(
task_id=task_id,
file_name=file_name,
uri=f"file://{file_path}",
agent_created=False,
)
return artifact
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
"""
Get an artifact by ID.
"""
try:
artifact = await self.db.get_artifact(artifact_id)
retrieved_artifact = await self.load_from_uri(artifact.uri, artifact_id)
path = artifact.file_name
with open(path, "wb") as f:
f.write(retrieved_artifact)
except NotFoundError as e:
raise
except FileNotFoundError as e:
raise
except Exception as e:
raise
return FileResponse(
# Note: mimetype is guessed in the FileResponse constructor
path=path,
filename=artifact.file_name,
)
return await super().create_and_execute_step(task_id, step_request)

View File

@@ -5,7 +5,7 @@ from fastapi import (
)
from fastapi.responses import FileResponse
from autogpt.routes.agent_protocol import base_router
from autogpt.sdk.routes.agent_protocol import base_router
def add_benchmark_routes():

238
autogpt/sdk/agent.py Normal file
View File

@@ -0,0 +1,238 @@
import asyncio
import os
from fastapi import APIRouter, FastAPI, Response, UploadFile
from fastapi.responses import FileResponse
from hypercorn.asyncio import serve
from hypercorn.config import Config
from prometheus_fastapi_instrumentator import Instrumentator
from .db import AgentDB
from .errors import NotFoundError
from .forge_log import CustomLogger
from .middlewares import AgentMiddleware
from .routes.agent_protocol import base_router
from .schema import *
from .tracing import setup_tracing
from .utils import run
from .workspace import Workspace, load_from_uri
LOG = CustomLogger(__name__)
class Agent:
def __init__(self, database: AgentDB, workspace: Workspace):
self.db = database
self.workspace = workspace
def start(self, port: int = 8000, router: APIRouter = base_router):
"""
Start the agent server.
"""
config = Config()
config.bind = [f"localhost:{port}"]
app = FastAPI(
title="Auto-GPT Forge",
description="Modified version of The Agent Protocol.",
version="v0.4",
)
# Add Prometheus metrics to the agent
# https://github.com/trallnag/prometheus-fastapi-instrumentator
instrumentator = Instrumentator().instrument(app)
@app.on_event("startup")
async def _startup():
instrumentator.expose(app)
app.include_router(router)
app.add_middleware(AgentMiddleware, agent=self)
setup_tracing(app)
config.loglevel = "ERROR"
config.bind = [f"0.0.0.0:{port}"]
LOG.info(f"Agent server starting on http://{config.bind[0]}")
asyncio.run(serve(app, config))
async def create_task(self, task_request: TaskRequestBody) -> Task:
"""
Create a task for the agent.
"""
try:
task = await self.db.create_task(
input=task_request.input,
additional_input=task_request.additional_input,
)
return task
except Exception as e:
raise
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
"""
List all tasks that the agent has created.
"""
try:
tasks, pagination = await self.db.list_tasks(page, pageSize)
response = TaskListResponse(tasks=tasks, pagination=pagination)
return response
except Exception as e:
raise
async def get_task(self, task_id: str) -> Task:
"""
Get a task by ID.
"""
try:
task = await self.db.get_task(task_id)
except Exception as e:
raise
return task
async def list_steps(
self, task_id: str, page: int = 1, pageSize: int = 10
) -> TaskStepsListResponse:
"""
List the IDs of all steps that the task has created.
"""
try:
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
response = TaskStepsListResponse(steps=steps, pagination=pagination)
return response
except Exception as e:
raise
async def create_and_execute_step(
self, task_id: str, step_request: StepRequestBody
) -> Step:
"""
Create a step for the task.
"""
if step_request.input != "y":
step = await self.db.create_step(
task_id=task_id,
input=step_request,
additional_input=step_request.additional_input,
)
# utils.run
artifacts = run(step.input)
for artifact in artifacts:
art = await self.db.create_artifact(
task_id=step.task_id,
file_name=artifact["file_name"],
uri=artifact["uri"],
agent_created=True,
step_id=step.step_id,
)
assert isinstance(
art, Artifact
), f"Artifact not instance of Artifact {type(art)}"
step.artifacts.append(art)
step.status = "completed"
else:
steps, steps_pagination = await self.db.list_steps(
task_id, page=1, per_page=100
)
# Find the latest step that has not been completed
step = next((s for s in reversed(steps) if s.status != "completed"), None)
if step is None:
# If all steps have been completed, create a new placeholder step
step = await self.db.create_step(
task_id=task_id,
input="y",
additional_input={},
)
step.status = "completed"
step.is_last = True
step.output = "No more steps to run."
step = await self.db.update_step(step)
if isinstance(step.status, Status):
step.status = step.status.value
step.output = "Done some work"
return step
async def get_step(self, task_id: str, step_id: str) -> Step:
"""
Get a step by ID.
"""
try:
step = await self.db.get_step(task_id, step_id)
return step
except Exception as e:
raise
async def list_artifacts(
self, task_id: str, page: int = 1, pageSize: int = 10
) -> TaskArtifactsListResponse:
"""
List the artifacts that the task has created.
"""
try:
artifacts, pagination = await self.db.list_artifacts(
task_id, page, pageSize
)
response = TaskArtifactsListResponse(
artifacts=artifacts, pagination=pagination
)
return Response(content=response.json(), media_type="application/json")
except Exception as e:
raise
async def create_artifact(
self,
task_id: str,
file: UploadFile | None = None,
uri: str | None = None,
) -> Artifact:
"""
Create an artifact for the task.
"""
data = None
if not uri:
file_name = file.filename or str(uuid4())
try:
data = b""
while contents := file.file.read(1024 * 1024):
data += contents
except Exception as e:
raise
else:
try:
data = await load_from_uri(uri, task_id)
file_name = uri.split("/")[-1]
except Exception as e:
raise
file_path = os.path.join(task_id / file_name)
self.write(file_path, data)
self.db.save_artifact(task_id, artifact)
artifact = await self.create_artifact(
task_id=task_id,
file_name=file_name,
uri=f"file://{file_path}",
agent_created=False,
)
return artifact
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
"""
Get an artifact by ID.
"""
try:
artifact = await self.db.get_artifact(artifact_id)
retrieved_artifact = await self.load_from_uri(artifact.uri, artifact_id)
path = artifact.file_name
with open(path, "wb") as f:
f.write(retrieved_artifact)
except NotFoundError as e:
raise
except FileNotFoundError as e:
raise
except Exception as e:
raise
return FileResponse(
# Note: mimetype is guessed in the FileResponse constructor
path=path,
filename=artifact.file_name,
)

View File

@@ -13,6 +13,7 @@ def agent():
return Agent(db, workspace)
@pytest.mark.skip
@pytest.mark.asyncio
async def test_create_task(agent):
task_request = TaskRequestBody(
@@ -22,6 +23,7 @@ async def test_create_task(agent):
assert task.input == "test_input"
@pytest.mark.skip
@pytest.mark.asyncio
async def test_list_tasks(agent):
task_request = TaskRequestBody(
@@ -32,6 +34,7 @@ async def test_list_tasks(agent):
assert isinstance(tasks, TaskListResponse)
@pytest.mark.skip
@pytest.mark.asyncio
async def test_get_task(agent):
task_request = TaskRequestBody(

View File

@@ -4,7 +4,7 @@ from datetime import datetime
import pytest
from autogpt.db import (
from autogpt.sdk.db import (
AgentDB,
ArtifactModel,
StepModel,
@@ -13,8 +13,8 @@ from autogpt.db import (
convert_to_step,
convert_to_task,
)
from autogpt.errors import NotFoundError as DataNotFoundError
from autogpt.schema import *
from autogpt.sdk.errors import NotFoundError as DataNotFoundError
from autogpt.sdk.schema import *
@pytest.mark.asyncio
@@ -210,6 +210,7 @@ async def test_get_task_not_found():
os.remove(db_name.split("///")[1])
@pytest.mark.skip
@pytest.mark.asyncio
async def test_create_and_get_step():
db_name = "sqlite:///test_db.sqlite3"
@@ -221,6 +222,7 @@ async def test_create_and_get_step():
os.remove(db_name.split("///")[1])
@pytest.mark.skip
@pytest.mark.asyncio
async def test_updating_step():
db_name = "sqlite:///test_db.sqlite3"
@@ -243,6 +245,7 @@ async def test_get_step_not_found():
os.remove(db_name.split("///")[1])
@pytest.mark.skip
@pytest.mark.asyncio
async def test_get_artifact():
db_name = "sqlite:///test_db.sqlite3"
@@ -290,6 +293,7 @@ async def test_list_tasks():
os.remove(db_name.split("///")[1])
@pytest.mark.skip
@pytest.mark.asyncio
async def test_list_steps():
db_name = "sqlite:///test_db.sqlite3"

View File

View File

@@ -28,10 +28,10 @@ from typing import Optional
from fastapi import APIRouter, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from autogpt.errors import *
from autogpt.forge_log import CustomLogger
from autogpt.schema import *
from autogpt.tracing import tracing
from autogpt.sdk.errors import *
from autogpt.sdk.forge_log import CustomLogger
from autogpt.sdk.schema import *
from autogpt.sdk.tracing import tracing
base_router = APIRouter()

View File

@@ -3,7 +3,7 @@ from functools import wraps
from dotenv import load_dotenv
from autogpt.forge_log import CustomLogger
from autogpt.sdk.forge_log import CustomLogger
load_dotenv()
@@ -47,7 +47,7 @@ if ENABLE_TRACING:
from opentelemetry.trace import NonRecordingSpan
from pydantic import BaseModel
from autogpt.schema import Task
from autogpt.sdk.schema import Task
tasks_context_db = {}