Updated Artifact Handling to schema 0.4 (#23)

This commit is contained in:
Swifty
2023-08-30 11:36:57 +02:00
committed by GitHub
parent 13c53b650d
commit 77a726dd79
10 changed files with 187 additions and 141 deletions

View File

@@ -27,7 +27,7 @@ if __name__ == "__main__":
workspace = LocalWorkspace(os.getenv("AGENT_WORKSPACE"))
port = os.getenv("PORT")
database = autogpt.sdk.db.AgentDB(database_name, debug_enabled=False)
database = autogpt.sdk.db.AgentDB(database_name, debug_enabled=True)
agent = autogpt.agent.AutoGPTAgent(database=database, workspace=workspace)
agent.start(port=port, router=router)

View File

@@ -1,5 +1,6 @@
import asyncio
import os
from uuid import uuid4
from fastapi import APIRouter, FastAPI, Response, UploadFile
from fastapi.responses import FileResponse
@@ -15,7 +16,7 @@ from .routes.agent_protocol import base_router
from .schema import *
from .tracing import setup_tracing
from .utils import run
from .workspace import Workspace, load_from_uri
from .workspace import Workspace
LOG = CustomLogger(__name__)
@@ -178,41 +179,33 @@ class Agent:
raise
async def create_artifact(
self,
task_id: str,
file: UploadFile | None = None,
uri: str | None = None,
self, task_id: str, file: UploadFile, relative_path: str
) -> Artifact:
"""
Create an artifact for the task.
"""
data = None
if not uri:
file_name = file.filename or str(uuid4())
try:
data = b""
while contents := file.file.read(1024 * 1024):
data += contents
except Exception as e:
raise
else:
try:
data = await load_from_uri(uri, task_id)
file_name = uri.split("/")[-1]
except Exception as e:
raise
file_name = file.filename or str(uuid4())
try:
data = b""
while contents := file.file.read(1024 * 1024):
data += contents
# Check if relative path ends with filename
if relative_path.endswith(file_name):
file_path = relative_path
else:
file_path = os.path.join(relative_path, file_name)
file_path = os.path.join(task_id / file_name)
self.write(file_path, data)
self.db.save_artifact(task_id, artifact)
artifact = await self.create_artifact(
task_id=task_id,
file_name=file_name,
uri=f"file://{file_path}",
agent_created=False,
)
self.workspace.write(task_id, file_path, data)
artifact = await self.db.create_artifact(
task_id=task_id,
file_name=file_name,
relative_path=relative_path,
agent_created=False,
)
except Exception as e:
raise
return artifact
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
@@ -221,7 +214,8 @@ class Agent:
"""
try:
artifact = await self.db.get_artifact(artifact_id)
retrieved_artifact = await self.load_from_uri(artifact.uri, artifact_id)
file_path = os.path.join(artifact.relative_path, artifact.file_name)
retrieved_artifact = self.workspace.read(task_id=task_id, path=file_path)
path = artifact.file_name
with open(path, "wb") as f:
f.write(retrieved_artifact)

View File

@@ -23,7 +23,7 @@ from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmak
from .errors import NotFoundError
from .forge_log import CustomLogger
from .schema import Artifact, Pagination, Status, Step, Task, TaskInput
from .schema import Artifact, Pagination, Status, Step, StepRequestBody, Task, TaskInput
LOG = CustomLogger(__name__)
@@ -72,7 +72,7 @@ class ArtifactModel(Base):
step_id = Column(String, ForeignKey("steps.step_id"))
agent_created = Column(Boolean, default=False)
file_name = Column(String)
uri = Column(String)
relative_path = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
@@ -123,7 +123,8 @@ def convert_to_artifact(artifact_model: ArtifactModel) -> Artifact:
created_at=artifact_model.created_at,
modified_at=artifact_model.modified_at,
agent_created=artifact_model.agent_created,
uri=artifact_model.uri,
relative_path=artifact_model.relative_path,
file_name=artifact_model.file_name,
)
@@ -149,7 +150,9 @@ class AgentDB:
new_task = TaskModel(
task_id=str(uuid.uuid4()),
input=input,
additional_input=additional_input,
additional_input=additional_input.json()
if additional_input
else {},
)
session.add(new_task)
session.commit()
@@ -169,7 +172,7 @@ class AgentDB:
async def create_step(
self,
task_id: str,
input: str,
input: StepRequestBody,
is_last: bool = False,
additional_input: Optional[Dict[str, Any]] = {},
) -> Step:
@@ -180,7 +183,7 @@ class AgentDB:
new_step = StepModel(
task_id=task_id,
step_id=str(uuid.uuid4()),
name=input.name,
name=input.input,
input=input.input,
status="created",
is_last=is_last,
@@ -205,7 +208,7 @@ class AgentDB:
self,
task_id: str,
file_name: str,
uri: str,
relative_path: str,
agent_created: bool = False,
step_id: str | None = None,
) -> Artifact:
@@ -215,12 +218,14 @@ class AgentDB:
with self.Session() as session:
if (
existing_artifact := session.query(ArtifactModel)
.filter_by(uri=uri)
.filter_by(relative_path=relative_path)
.first()
):
session.close()
if self.debug_enabled:
LOG.debug(f"Artifact already exists with uri: {uri}")
LOG.debug(
f"Artifact already exists with relative_path: {relative_path}"
)
return convert_to_artifact(existing_artifact)
new_artifact = ArtifactModel(
@@ -229,7 +234,7 @@ class AgentDB:
step_id=step_id,
agent_created=agent_created,
file_name=file_name,
uri=uri,
relative_path=relative_path,
)
session.add(new_artifact)
session.commit()

View File

@@ -54,7 +54,6 @@ async def test_task_schema():
Artifact(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
agent_created=True,
uri="file:///path/to/artifact",
file_name="main.py",
relative_path="python/code/",
created_at=now,
@@ -88,7 +87,6 @@ async def test_step_schema():
created_at=now,
modified_at=now,
agent_created=True,
uri="file:///path/to/artifact",
)
],
is_last=False,
@@ -119,8 +117,9 @@ async def test_convert_to_task():
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
created_at=now,
modified_at=now,
uri="file:///path/to/main.py",
relative_path="file:///path/to/main.py",
agent_created=True,
file_name="main.py",
)
],
)
@@ -147,8 +146,9 @@ async def test_convert_to_step():
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
created_at=now,
modified_at=now,
uri="file:///path/to/main.py",
relative_path="file:///path/to/main.py",
agent_created=True,
file_name="main.py",
)
],
is_last=False,
@@ -170,12 +170,13 @@ async def test_convert_to_artifact():
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
created_at=now,
modified_at=now,
uri="file:///path/to/main.py",
relative_path="file:///path/to/main.py",
agent_created=True,
file_name="main.py",
)
artifact = convert_to_artifact(artifact_model)
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert artifact.uri == "file:///path/to/main.py"
assert artifact.relative_path == "file:///path/to/main.py"
assert artifact.agent_created == True
@@ -210,25 +211,27 @@ async def test_get_task_not_found():
os.remove(db_name.split("///")[1])
@pytest.mark.skip
@pytest.mark.asyncio
async def test_create_and_get_step():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
task = await agent_db.create_task("task_input")
step = await agent_db.create_step(task.task_id, "step_name")
step_input = StepInput(type="python/code")
request = StepRequestBody(input="test_input debug", additional_input=step_input)
step = await agent_db.create_step(task.task_id, request)
step = await agent_db.get_step(task.task_id, step.step_id)
assert step.name == "step_name"
assert step.input == "test_input debug"
os.remove(db_name.split("///")[1])
@pytest.mark.skip
@pytest.mark.asyncio
async def test_updating_step():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
created_task = await agent_db.create_task("task_input")
created_step = await agent_db.create_step(created_task.task_id, "step_name")
step_input = StepInput(type="python/code")
request = StepRequestBody(input="test_input debug", additional_input=step_input)
created_step = await agent_db.create_step(created_task.task_id, request)
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
@@ -245,7 +248,6 @@ async def test_get_step_not_found():
os.remove(db_name.split("///")[1])
@pytest.mark.skip
@pytest.mark.asyncio
async def test_get_artifact():
db_name = "sqlite:///test_db.sqlite3"
@@ -253,13 +255,16 @@ async def test_get_artifact():
# Given: A task and its corresponding artifact
task = await db.create_task("test_input debug")
step = await db.create_step(task.task_id, "step_name")
step_input = StepInput(type="python/code")
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
step = await db.create_step(task.task_id, requst)
# Create an artifact
artifact = await db.create_artifact(
task_id=task.task_id,
file_name="test_get_artifact_sample_file.txt",
uri="file:///path/to/test_get_artifact_sample_file.txt",
relative_path="file:///path/to/test_get_artifact_sample_file.txt",
agent_created=True,
step_id=step.step_id,
)
@@ -269,7 +274,10 @@ async def test_get_artifact():
# Then: The fetched artifact matches the original
assert fetched_artifact.artifact_id == artifact.artifact_id
assert fetched_artifact.uri == "file:///path/to/test_get_artifact_sample_file.txt"
assert (
fetched_artifact.relative_path
== "file:///path/to/test_get_artifact_sample_file.txt"
)
os.remove(db_name.split("///")[1])
@@ -293,16 +301,19 @@ async def test_list_tasks():
os.remove(db_name.split("///")[1])
@pytest.mark.skip
@pytest.mark.asyncio
async def test_list_steps():
db_name = "sqlite:///test_db.sqlite3"
db = AgentDB(db_name)
step_input = StepInput(type="python/code")
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
# Given: A task and multiple steps for that task
task = await db.create_task("test_input")
step1 = await db.create_step(task.task_id, "step_1")
step2 = await db.create_step(task.task_id, "step_2")
step1 = await db.create_step(task.task_id, requst)
requst = StepRequestBody(input="step two", additional_input=step_input)
step2 = await db.create_step(task.task_id, requst)
# When: All steps for the task are fetched
fetched_steps, pagination = await db.list_steps(task.task_id)

View File

@@ -185,7 +185,7 @@ logging_config: dict = dict(
},
root={
"handlers": ["h"],
"level": logging.WARNING,
"level": logging.DEBUG,
},
loggers={
"autogpt": {

View File

@@ -107,13 +107,8 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
status_code=200,
media_type="application/json",
)
except NotFoundError:
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception(f"Error whilst trying to create a task: {task_request}")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
@@ -171,12 +166,14 @@ async def list_agent_tasks(
media_type="application/json",
)
except NotFoundError:
LOG.exception("Error whilst trying to list tasks")
return Response(
content=json.dumps({"error": "Task not found"}),
content=json.dumps({"error": "Tasks not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception("Error whilst trying to list tasks")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
@@ -246,12 +243,14 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
media_type="application/json",
)
except NotFoundError:
LOG.exception(f"Error whilst trying to get task: {task_id}")
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception(f"Error whilst trying to get task: {task_id}")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
@@ -311,12 +310,14 @@ async def list_agent_task_steps(
media_type="application/json",
)
except NotFoundError:
LOG.exception("Error whilst trying to list steps")
return Response(
content=json.dumps({"error": "Task not found"}),
content=json.dumps({"error": "Steps not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception("Error whilst trying to list steps")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
@@ -377,13 +378,14 @@ async def execute_agent_task_step(
media_type="application/json",
)
except NotFoundError:
LOG.exception(f"Error whilst trying to execute a task step: {task_id}")
return Response(
content=json.dumps({"error": f"Task not found {task_id}"}),
status_code=404,
media_type="application/json",
)
except Exception as e:
LOG.exception("Error whilst trying to execute a test")
LOG.exception(f"Error whilst trying to execute a task step: {task_id}")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
@@ -423,12 +425,14 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
step = await agent.get_step(task_id, step_id)
return Response(content=step.json(), status_code=200)
except NotFoundError:
LOG.exception(f"Error whilst trying to get step: {step_id}")
return Response(
content=json.dumps({"error": "Task not found"}),
content=json.dumps({"error": "Step not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception(f"Error whilst trying to get step: {step_id}")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
@@ -484,12 +488,14 @@ async def list_agent_task_artifacts(
artifacts = await agent.list_artifacts(task_id, page, page_size)
return artifacts
except NotFoundError:
LOG.exception("Error whilst trying to list artifacts")
return Response(
content=json.dumps({"error": "Task not found"}),
content=json.dumps({"error": "Artifacts not found for task_id"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception("Error whilst trying to list artifacts")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
@@ -502,36 +508,25 @@ async def list_agent_task_artifacts(
)
@tracing("Uploading task artifact")
async def upload_agent_task_artifacts(
request: Request,
task_id: str,
file: UploadFile | None = None,
uri: str | None = None,
request: Request, task_id: str, file: UploadFile, relative_path: str
) -> Artifact:
"""
Uploads an artifact for a specific task using either a provided file or a URI.
At least one of the parameters, `file` or `uri`, must be specified. The `uri` can point to
cloud storage resources such as S3, GCS, etc., or to other resources like FTP or HTTP.
To check the supported URI types for the agent, use the `/agent/artifacts/uris` endpoint.
Uploads an artifact for a specific task using a provided file.
Args:
request (Request): FastAPI request object.
task_id (str): The ID of the task.
file (UploadFile, optional): The uploaded file. Defaults to None.
uri (str, optional): The URI pointing to the resource. Defaults to None.
artifact_upload (ArtifactUpload): The uploaded file and its relative path.
Returns:
Artifact: Details of the uploaded artifact.
Note:
Either `file` or `uri` must be provided. If both are provided, the behavior depends on
the agent's implementation. If neither is provided, the function will return an error.
The `file` must be provided. If it is not provided, the function will return an error.
Example:
Request:
POST /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/artifacts
File: <uploaded_file>
OR
URI: "s3://path/to/artifact"
Response:
{
@@ -540,40 +535,22 @@ async def upload_agent_task_artifacts(
}
"""
agent = request["agent"]
if file is None and uri is None:
if file is None:
return Response(
content=json.dumps({"error": "Either file or uri must be specified"}),
status_code=404,
media_type="application/json",
)
if file is not None and uri is not None:
return Response(
content=json.dumps(
{"error": "Both file and uri cannot be specified at the same time"}
),
status_code=404,
media_type="application/json",
)
if uri is not None and not uri.startswith(("http://", "https://", "file://")):
return Response(
content=json.dumps({"error": "URI must start with http, https or file"}),
content=json.dumps({"error": "File must be specified"}),
status_code=404,
media_type="application/json",
)
try:
artifact = await agent.create_artifact(task_id, file, uri)
artifact = await agent.create_artifact(task_id, file, relative_path)
return Response(
content=artifact.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception(f"Error whilst trying to upload artifact: {task_id}")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
@@ -610,6 +587,7 @@ async def download_agent_task_artifact(
try:
return await agent.get_artifact(task_id, artifact_id)
except NotFoundError:
LOG.exception(f"Error whilst trying to download artifact: {task_id}")
return Response(
content=json.dumps(
{
@@ -620,6 +598,7 @@ async def download_agent_task_artifact(
media_type="application/json",
)
except Exception:
LOG.exception(f"Error whilst trying to download artifact: {task_id}")
return Response(
content=json.dumps(
{

View File

@@ -8,9 +8,19 @@ from datetime import datetime
from enum import Enum
from typing import List, Optional
from fastapi import UploadFile
from pydantic import BaseModel, Field
class ArtifactUpload(BaseModel):
file: str = Field(..., description="File to upload.", format="binary")
relative_path: str = Field(
...,
description="Relative path of the artifact in the agent's workspace.",
example="python/code",
)
class Pagination(BaseModel):
total_items: int = Field(..., description="Total number of items.", example=42)
total_pages: int = Field(..., description="Total number of pages.", example=97)
@@ -45,10 +55,15 @@ class Artifact(BaseModel):
description="Whether the artifact has been created by the agent.",
example=False,
)
uri: str = Field(
relative_path: str = Field(
...,
description="URI of the artifact.",
example="file://home/bob/workspace/bucket/main.py",
description="Relative path of the artifact in the agents workspace.",
example="/my_folder/my_other_folder/",
)
file_name: str = Field(
...,
description="Filename of the artifact.",
example="main.py",
)

View File

@@ -42,18 +42,18 @@ class LocalWorkspace(Workspace):
def _resolve_path(self, task_id: str, path: str) -> Path:
abs_path = (self.base_path / task_id / path).resolve()
if not str(abs_path).startswith(str(self.base_path)):
print("Error")
raise ValueError("Directory traversal is not allowed!")
(self.base_path / task_id).mkdir(parents=True, exist_ok=True)
abs_path.parent.mkdir(parents=True, exist_ok=True)
return abs_path
def read(self, task_id: str, path: str) -> bytes:
path = self.base_path / task_id / path
with open(self._resolve_path(task_id, path), "rb") as f:
return f.read()
def write(self, task_id: str, path: str, data: bytes) -> None:
path = self.base_path / task_id / path
with open(self._resolve_path(task_id, path), "wb") as f:
file_path = self._resolve_path(task_id, path)
with open(file_path, "wb") as f:
f.write(data)
def delete(
@@ -77,28 +77,3 @@ class LocalWorkspace(Workspace):
path = self.base_path / task_id / path
base = self._resolve_path(task_id, path)
return [str(p.relative_to(self.base_path / task_id)) for p in base.iterdir()]
async def load_from_uri(self, uri: str, task_id: str, workspace: Workspace) -> bytes:
"""
Load file from given URI and return its bytes.
"""
file_path = None
try:
if uri.startswith("file://"):
file_path = uri.split("file://")[1]
if not workspace.exists(task_id, file_path):
return Response(status_code=500, content="File not found")
return workspace.read(task_id, file_path)
elif uri.startswith("http://") or uri.startswith("https://"):
async with aiohttp.ClientSession() as session:
async with session.get(uri) as resp:
if resp.status != 200:
return Response(
status_code=500, content="Unable to load from URL"
)
return await resp.read()
else:
return Response(status_code=500, content="Loading from unsupported uri")
except Exception as e:
return Response(status_code=500, content=str(e))

44
test_artifacts.py Normal file
View File

@@ -0,0 +1,44 @@
import json
import requests
# Define the base URL of the API
base_url = "http://localhost:8000" # Replace with your actual API base URL
# Create a new task
task_request = {
"input": "Write the words you receive to the file 'output.txt'.",
"additional_input": {"type": "python/code"},
}
response = requests.post(f"{base_url}/agent/tasks", json=task_request)
task = response.json()
print(f"Created task: {task}")
# Upload a file as an artifact for the task
task_id = task["task_id"]
test_file_content = "This is a test file for testing."
relative_path = "./relative/path/to/your/file" # Add your relative path here
file_path = "test_file.txt"
with open(file_path, "w") as f:
f.write(test_file_content)
with open(file_path, "rb") as f:
files = {"file": f}
data = {"relative_path": relative_path}
response = requests.post(
f"{base_url}/agent/tasks/{task_id}/artifacts?relative_path={relative_path}",
files=files,
)
artifact = response.json()
print(f"Uploaded artifact: {response.text}")
# Download the artifact
artifact_id = artifact["artifact_id"]
response = requests.get(f"{base_url}/agent/tasks/{task_id}/artifacts/{artifact_id}")
if response.status_code == 200:
with open("downloaded_file.txt", "wb") as f:
f.write(response.content)
print("Downloaded artifact.")
else:
print(f"Error downloading artifact: {response.content}")

23
test_local_llm.py Normal file
View File

@@ -0,0 +1,23 @@
import openai
openai.api_base = "http://localhost:4891/v1"
openai.api_key = "not needed for a local LLM"
model = "ggml-llama-2-13b-chat.ggmlv3.q4_0.bin"
prompt = "Who is Michael Jordan?"
response = openai.Completion.create(
model=model,
prompt=prompt,
max_tokens=50,
temperature=0.28,
top_p=0.95,
n=1,
echo=True,
stream=False,
)
assert len(response["choices"][0]["text"]) > len(prompt)
print(f"Model: {response['model']}")
print(f"Usage: {response['usage']}")
print(f"Answer: {response['choices'][0]['text']}")