mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-09 09:14:19 +01:00
Updated Artifact Handling to schema 0.4 (#23)
This commit is contained in:
@@ -27,7 +27,7 @@ if __name__ == "__main__":
|
||||
workspace = LocalWorkspace(os.getenv("AGENT_WORKSPACE"))
|
||||
port = os.getenv("PORT")
|
||||
|
||||
database = autogpt.sdk.db.AgentDB(database_name, debug_enabled=False)
|
||||
database = autogpt.sdk.db.AgentDB(database_name, debug_enabled=True)
|
||||
agent = autogpt.agent.AutoGPTAgent(database=database, workspace=workspace)
|
||||
|
||||
agent.start(port=port, router=router)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
@@ -15,7 +16,7 @@ from .routes.agent_protocol import base_router
|
||||
from .schema import *
|
||||
from .tracing import setup_tracing
|
||||
from .utils import run
|
||||
from .workspace import Workspace, load_from_uri
|
||||
from .workspace import Workspace
|
||||
|
||||
LOG = CustomLogger(__name__)
|
||||
|
||||
@@ -178,41 +179,33 @@ class Agent:
|
||||
raise
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
file: UploadFile | None = None,
|
||||
uri: str | None = None,
|
||||
self, task_id: str, file: UploadFile, relative_path: str
|
||||
) -> Artifact:
|
||||
"""
|
||||
Create an artifact for the task.
|
||||
"""
|
||||
data = None
|
||||
if not uri:
|
||||
file_name = file.filename or str(uuid4())
|
||||
try:
|
||||
data = b""
|
||||
while contents := file.file.read(1024 * 1024):
|
||||
data += contents
|
||||
except Exception as e:
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
data = await load_from_uri(uri, task_id)
|
||||
file_name = uri.split("/")[-1]
|
||||
except Exception as e:
|
||||
raise
|
||||
file_name = file.filename or str(uuid4())
|
||||
try:
|
||||
data = b""
|
||||
while contents := file.file.read(1024 * 1024):
|
||||
data += contents
|
||||
# Check if relative path ends with filename
|
||||
if relative_path.endswith(file_name):
|
||||
file_path = relative_path
|
||||
else:
|
||||
file_path = os.path.join(relative_path, file_name)
|
||||
|
||||
file_path = os.path.join(task_id / file_name)
|
||||
self.write(file_path, data)
|
||||
self.db.save_artifact(task_id, artifact)
|
||||
|
||||
artifact = await self.create_artifact(
|
||||
task_id=task_id,
|
||||
file_name=file_name,
|
||||
uri=f"file://{file_path}",
|
||||
agent_created=False,
|
||||
)
|
||||
self.workspace.write(task_id, file_path, data)
|
||||
|
||||
artifact = await self.db.create_artifact(
|
||||
task_id=task_id,
|
||||
file_name=file_name,
|
||||
relative_path=relative_path,
|
||||
agent_created=False,
|
||||
)
|
||||
except Exception as e:
|
||||
raise
|
||||
return artifact
|
||||
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||
@@ -221,7 +214,8 @@ class Agent:
|
||||
"""
|
||||
try:
|
||||
artifact = await self.db.get_artifact(artifact_id)
|
||||
retrieved_artifact = await self.load_from_uri(artifact.uri, artifact_id)
|
||||
file_path = os.path.join(artifact.relative_path, artifact.file_name)
|
||||
retrieved_artifact = self.workspace.read(task_id=task_id, path=file_path)
|
||||
path = artifact.file_name
|
||||
with open(path, "wb") as f:
|
||||
f.write(retrieved_artifact)
|
||||
|
||||
@@ -23,7 +23,7 @@ from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmak
|
||||
|
||||
from .errors import NotFoundError
|
||||
from .forge_log import CustomLogger
|
||||
from .schema import Artifact, Pagination, Status, Step, Task, TaskInput
|
||||
from .schema import Artifact, Pagination, Status, Step, StepRequestBody, Task, TaskInput
|
||||
|
||||
LOG = CustomLogger(__name__)
|
||||
|
||||
@@ -72,7 +72,7 @@ class ArtifactModel(Base):
|
||||
step_id = Column(String, ForeignKey("steps.step_id"))
|
||||
agent_created = Column(Boolean, default=False)
|
||||
file_name = Column(String)
|
||||
uri = Column(String)
|
||||
relative_path = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
modified_at = Column(
|
||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
@@ -123,7 +123,8 @@ def convert_to_artifact(artifact_model: ArtifactModel) -> Artifact:
|
||||
created_at=artifact_model.created_at,
|
||||
modified_at=artifact_model.modified_at,
|
||||
agent_created=artifact_model.agent_created,
|
||||
uri=artifact_model.uri,
|
||||
relative_path=artifact_model.relative_path,
|
||||
file_name=artifact_model.file_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -149,7 +150,9 @@ class AgentDB:
|
||||
new_task = TaskModel(
|
||||
task_id=str(uuid.uuid4()),
|
||||
input=input,
|
||||
additional_input=additional_input,
|
||||
additional_input=additional_input.json()
|
||||
if additional_input
|
||||
else {},
|
||||
)
|
||||
session.add(new_task)
|
||||
session.commit()
|
||||
@@ -169,7 +172,7 @@ class AgentDB:
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: str,
|
||||
input: StepRequestBody,
|
||||
is_last: bool = False,
|
||||
additional_input: Optional[Dict[str, Any]] = {},
|
||||
) -> Step:
|
||||
@@ -180,7 +183,7 @@ class AgentDB:
|
||||
new_step = StepModel(
|
||||
task_id=task_id,
|
||||
step_id=str(uuid.uuid4()),
|
||||
name=input.name,
|
||||
name=input.input,
|
||||
input=input.input,
|
||||
status="created",
|
||||
is_last=is_last,
|
||||
@@ -205,7 +208,7 @@ class AgentDB:
|
||||
self,
|
||||
task_id: str,
|
||||
file_name: str,
|
||||
uri: str,
|
||||
relative_path: str,
|
||||
agent_created: bool = False,
|
||||
step_id: str | None = None,
|
||||
) -> Artifact:
|
||||
@@ -215,12 +218,14 @@ class AgentDB:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
existing_artifact := session.query(ArtifactModel)
|
||||
.filter_by(uri=uri)
|
||||
.filter_by(relative_path=relative_path)
|
||||
.first()
|
||||
):
|
||||
session.close()
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Artifact already exists with uri: {uri}")
|
||||
LOG.debug(
|
||||
f"Artifact already exists with relative_path: {relative_path}"
|
||||
)
|
||||
return convert_to_artifact(existing_artifact)
|
||||
|
||||
new_artifact = ArtifactModel(
|
||||
@@ -229,7 +234,7 @@ class AgentDB:
|
||||
step_id=step_id,
|
||||
agent_created=agent_created,
|
||||
file_name=file_name,
|
||||
uri=uri,
|
||||
relative_path=relative_path,
|
||||
)
|
||||
session.add(new_artifact)
|
||||
session.commit()
|
||||
|
||||
@@ -54,7 +54,6 @@ async def test_task_schema():
|
||||
Artifact(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
agent_created=True,
|
||||
uri="file:///path/to/artifact",
|
||||
file_name="main.py",
|
||||
relative_path="python/code/",
|
||||
created_at=now,
|
||||
@@ -88,7 +87,6 @@ async def test_step_schema():
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
agent_created=True,
|
||||
uri="file:///path/to/artifact",
|
||||
)
|
||||
],
|
||||
is_last=False,
|
||||
@@ -119,8 +117,9 @@ async def test_convert_to_task():
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
uri="file:///path/to/main.py",
|
||||
relative_path="file:///path/to/main.py",
|
||||
agent_created=True,
|
||||
file_name="main.py",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -147,8 +146,9 @@ async def test_convert_to_step():
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
uri="file:///path/to/main.py",
|
||||
relative_path="file:///path/to/main.py",
|
||||
agent_created=True,
|
||||
file_name="main.py",
|
||||
)
|
||||
],
|
||||
is_last=False,
|
||||
@@ -170,12 +170,13 @@ async def test_convert_to_artifact():
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
uri="file:///path/to/main.py",
|
||||
relative_path="file:///path/to/main.py",
|
||||
agent_created=True,
|
||||
file_name="main.py",
|
||||
)
|
||||
artifact = convert_to_artifact(artifact_model)
|
||||
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert artifact.uri == "file:///path/to/main.py"
|
||||
assert artifact.relative_path == "file:///path/to/main.py"
|
||||
assert artifact.agent_created == True
|
||||
|
||||
|
||||
@@ -210,25 +211,27 @@ async def test_get_task_not_found():
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_step():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
task = await agent_db.create_task("task_input")
|
||||
step = await agent_db.create_step(task.task_id, "step_name")
|
||||
step_input = StepInput(type="python/code")
|
||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
step = await agent_db.create_step(task.task_id, request)
|
||||
step = await agent_db.get_step(task.task_id, step.step_id)
|
||||
assert step.name == "step_name"
|
||||
assert step.input == "test_input debug"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_updating_step():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
created_task = await agent_db.create_task("task_input")
|
||||
created_step = await agent_db.create_step(created_task.task_id, "step_name")
|
||||
step_input = StepInput(type="python/code")
|
||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
created_step = await agent_db.create_step(created_task.task_id, request)
|
||||
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
|
||||
|
||||
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
|
||||
@@ -245,7 +248,6 @@ async def test_get_step_not_found():
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_artifact():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
@@ -253,13 +255,16 @@ async def test_get_artifact():
|
||||
|
||||
# Given: A task and its corresponding artifact
|
||||
task = await db.create_task("test_input debug")
|
||||
step = await db.create_step(task.task_id, "step_name")
|
||||
step_input = StepInput(type="python/code")
|
||||
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
|
||||
step = await db.create_step(task.task_id, requst)
|
||||
|
||||
# Create an artifact
|
||||
artifact = await db.create_artifact(
|
||||
task_id=task.task_id,
|
||||
file_name="test_get_artifact_sample_file.txt",
|
||||
uri="file:///path/to/test_get_artifact_sample_file.txt",
|
||||
relative_path="file:///path/to/test_get_artifact_sample_file.txt",
|
||||
agent_created=True,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
@@ -269,7 +274,10 @@ async def test_get_artifact():
|
||||
|
||||
# Then: The fetched artifact matches the original
|
||||
assert fetched_artifact.artifact_id == artifact.artifact_id
|
||||
assert fetched_artifact.uri == "file:///path/to/test_get_artifact_sample_file.txt"
|
||||
assert (
|
||||
fetched_artifact.relative_path
|
||||
== "file:///path/to/test_get_artifact_sample_file.txt"
|
||||
)
|
||||
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
@@ -293,16 +301,19 @@ async def test_list_tasks():
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_steps():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
|
||||
step_input = StepInput(type="python/code")
|
||||
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
|
||||
# Given: A task and multiple steps for that task
|
||||
task = await db.create_task("test_input")
|
||||
step1 = await db.create_step(task.task_id, "step_1")
|
||||
step2 = await db.create_step(task.task_id, "step_2")
|
||||
step1 = await db.create_step(task.task_id, requst)
|
||||
requst = StepRequestBody(input="step two", additional_input=step_input)
|
||||
step2 = await db.create_step(task.task_id, requst)
|
||||
|
||||
# When: All steps for the task are fetched
|
||||
fetched_steps, pagination = await db.list_steps(task.task_id)
|
||||
|
||||
@@ -185,7 +185,7 @@ logging_config: dict = dict(
|
||||
},
|
||||
root={
|
||||
"handlers": ["h"],
|
||||
"level": logging.WARNING,
|
||||
"level": logging.DEBUG,
|
||||
},
|
||||
loggers={
|
||||
"autogpt": {
|
||||
|
||||
@@ -107,13 +107,8 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(f"Error whilst trying to create a task: {task_request}")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
@@ -171,12 +166,14 @@ async def list_agent_tasks(
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
LOG.exception("Error whilst trying to list tasks")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
content=json.dumps({"error": "Tasks not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception("Error whilst trying to list tasks")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
@@ -246,12 +243,14 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
LOG.exception(f"Error whilst trying to get task: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(f"Error whilst trying to get task: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
@@ -311,12 +310,14 @@ async def list_agent_task_steps(
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
LOG.exception("Error whilst trying to list steps")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
content=json.dumps({"error": "Steps not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception("Error whilst trying to list steps")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
@@ -377,13 +378,14 @@ async def execute_agent_task_step(
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
LOG.exception(f"Error whilst trying to execute a task step: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": f"Task not found {task_id}"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.exception("Error whilst trying to execute a test")
|
||||
LOG.exception(f"Error whilst trying to execute a task step: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
@@ -423,12 +425,14 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
|
||||
step = await agent.get_step(task_id, step_id)
|
||||
return Response(content=step.json(), status_code=200)
|
||||
except NotFoundError:
|
||||
LOG.exception(f"Error whilst trying to get step: {step_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
content=json.dumps({"error": "Step not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(f"Error whilst trying to get step: {step_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
@@ -484,12 +488,14 @@ async def list_agent_task_artifacts(
|
||||
artifacts = await agent.list_artifacts(task_id, page, page_size)
|
||||
return artifacts
|
||||
except NotFoundError:
|
||||
LOG.exception("Error whilst trying to list artifacts")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
content=json.dumps({"error": "Artifacts not found for task_id"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception("Error whilst trying to list artifacts")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
@@ -502,36 +508,25 @@ async def list_agent_task_artifacts(
|
||||
)
|
||||
@tracing("Uploading task artifact")
|
||||
async def upload_agent_task_artifacts(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
file: UploadFile | None = None,
|
||||
uri: str | None = None,
|
||||
request: Request, task_id: str, file: UploadFile, relative_path: str
|
||||
) -> Artifact:
|
||||
"""
|
||||
Uploads an artifact for a specific task using either a provided file or a URI.
|
||||
At least one of the parameters, `file` or `uri`, must be specified. The `uri` can point to
|
||||
cloud storage resources such as S3, GCS, etc., or to other resources like FTP or HTTP.
|
||||
|
||||
To check the supported URI types for the agent, use the `/agent/artifacts/uris` endpoint.
|
||||
Uploads an artifact for a specific task using a provided file.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
task_id (str): The ID of the task.
|
||||
file (UploadFile, optional): The uploaded file. Defaults to None.
|
||||
uri (str, optional): The URI pointing to the resource. Defaults to None.
|
||||
artifact_upload (ArtifactUpload): The uploaded file and its relative path.
|
||||
|
||||
Returns:
|
||||
Artifact: Details of the uploaded artifact.
|
||||
|
||||
Note:
|
||||
Either `file` or `uri` must be provided. If both are provided, the behavior depends on
|
||||
the agent's implementation. If neither is provided, the function will return an error.
|
||||
The `file` must be provided. If it is not provided, the function will return an error.
|
||||
Example:
|
||||
Request:
|
||||
POST /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/artifacts
|
||||
File: <uploaded_file>
|
||||
OR
|
||||
URI: "s3://path/to/artifact"
|
||||
|
||||
Response:
|
||||
{
|
||||
@@ -540,40 +535,22 @@ async def upload_agent_task_artifacts(
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
if file is None and uri is None:
|
||||
|
||||
if file is None:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Either file or uri must be specified"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
if file is not None and uri is not None:
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{"error": "Both file and uri cannot be specified at the same time"}
|
||||
),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
if uri is not None and not uri.startswith(("http://", "https://", "file://")):
|
||||
return Response(
|
||||
content=json.dumps({"error": "URI must start with http, https or file"}),
|
||||
content=json.dumps({"error": "File must be specified"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
try:
|
||||
artifact = await agent.create_artifact(task_id, file, uri)
|
||||
artifact = await agent.create_artifact(task_id, file, relative_path)
|
||||
return Response(
|
||||
content=artifact.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(f"Error whilst trying to upload artifact: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Internal server error"}),
|
||||
status_code=500,
|
||||
@@ -610,6 +587,7 @@ async def download_agent_task_artifact(
|
||||
try:
|
||||
return await agent.get_artifact(task_id, artifact_id)
|
||||
except NotFoundError:
|
||||
LOG.exception(f"Error whilst trying to download artifact: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
@@ -620,6 +598,7 @@ async def download_agent_task_artifact(
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(f"Error whilst trying to download artifact: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
|
||||
@@ -8,9 +8,19 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ArtifactUpload(BaseModel):
|
||||
file: str = Field(..., description="File to upload.", format="binary")
|
||||
relative_path: str = Field(
|
||||
...,
|
||||
description="Relative path of the artifact in the agent's workspace.",
|
||||
example="python/code",
|
||||
)
|
||||
|
||||
|
||||
class Pagination(BaseModel):
|
||||
total_items: int = Field(..., description="Total number of items.", example=42)
|
||||
total_pages: int = Field(..., description="Total number of pages.", example=97)
|
||||
@@ -45,10 +55,15 @@ class Artifact(BaseModel):
|
||||
description="Whether the artifact has been created by the agent.",
|
||||
example=False,
|
||||
)
|
||||
uri: str = Field(
|
||||
relative_path: str = Field(
|
||||
...,
|
||||
description="URI of the artifact.",
|
||||
example="file://home/bob/workspace/bucket/main.py",
|
||||
description="Relative path of the artifact in the agents workspace.",
|
||||
example="/my_folder/my_other_folder/",
|
||||
)
|
||||
file_name: str = Field(
|
||||
...,
|
||||
description="Filename of the artifact.",
|
||||
example="main.py",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -42,18 +42,18 @@ class LocalWorkspace(Workspace):
|
||||
def _resolve_path(self, task_id: str, path: str) -> Path:
|
||||
abs_path = (self.base_path / task_id / path).resolve()
|
||||
if not str(abs_path).startswith(str(self.base_path)):
|
||||
print("Error")
|
||||
raise ValueError("Directory traversal is not allowed!")
|
||||
(self.base_path / task_id).mkdir(parents=True, exist_ok=True)
|
||||
abs_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return abs_path
|
||||
|
||||
def read(self, task_id: str, path: str) -> bytes:
|
||||
path = self.base_path / task_id / path
|
||||
with open(self._resolve_path(task_id, path), "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def write(self, task_id: str, path: str, data: bytes) -> None:
|
||||
path = self.base_path / task_id / path
|
||||
with open(self._resolve_path(task_id, path), "wb") as f:
|
||||
file_path = self._resolve_path(task_id, path)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
def delete(
|
||||
@@ -77,28 +77,3 @@ class LocalWorkspace(Workspace):
|
||||
path = self.base_path / task_id / path
|
||||
base = self._resolve_path(task_id, path)
|
||||
return [str(p.relative_to(self.base_path / task_id)) for p in base.iterdir()]
|
||||
|
||||
|
||||
async def load_from_uri(self, uri: str, task_id: str, workspace: Workspace) -> bytes:
|
||||
"""
|
||||
Load file from given URI and return its bytes.
|
||||
"""
|
||||
file_path = None
|
||||
try:
|
||||
if uri.startswith("file://"):
|
||||
file_path = uri.split("file://")[1]
|
||||
if not workspace.exists(task_id, file_path):
|
||||
return Response(status_code=500, content="File not found")
|
||||
return workspace.read(task_id, file_path)
|
||||
elif uri.startswith("http://") or uri.startswith("https://"):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(uri) as resp:
|
||||
if resp.status != 200:
|
||||
return Response(
|
||||
status_code=500, content="Unable to load from URL"
|
||||
)
|
||||
return await resp.read()
|
||||
else:
|
||||
return Response(status_code=500, content="Loading from unsupported uri")
|
||||
except Exception as e:
|
||||
return Response(status_code=500, content=str(e))
|
||||
|
||||
44
test_artifacts.py
Normal file
44
test_artifacts.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
# Define the base URL of the API
|
||||
base_url = "http://localhost:8000" # Replace with your actual API base URL
|
||||
|
||||
# Create a new task
|
||||
task_request = {
|
||||
"input": "Write the words you receive to the file 'output.txt'.",
|
||||
"additional_input": {"type": "python/code"},
|
||||
}
|
||||
response = requests.post(f"{base_url}/agent/tasks", json=task_request)
|
||||
task = response.json()
|
||||
print(f"Created task: {task}")
|
||||
|
||||
# Upload a file as an artifact for the task
|
||||
task_id = task["task_id"]
|
||||
test_file_content = "This is a test file for testing."
|
||||
relative_path = "./relative/path/to/your/file" # Add your relative path here
|
||||
file_path = "test_file.txt"
|
||||
with open(file_path, "w") as f:
|
||||
f.write(test_file_content)
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"file": f}
|
||||
data = {"relative_path": relative_path}
|
||||
|
||||
response = requests.post(
|
||||
f"{base_url}/agent/tasks/{task_id}/artifacts?relative_path={relative_path}",
|
||||
files=files,
|
||||
)
|
||||
artifact = response.json()
|
||||
|
||||
print(f"Uploaded artifact: {response.text}")
|
||||
|
||||
# Download the artifact
|
||||
artifact_id = artifact["artifact_id"]
|
||||
response = requests.get(f"{base_url}/agent/tasks/{task_id}/artifacts/{artifact_id}")
|
||||
if response.status_code == 200:
|
||||
with open("downloaded_file.txt", "wb") as f:
|
||||
f.write(response.content)
|
||||
print("Downloaded artifact.")
|
||||
else:
|
||||
print(f"Error downloading artifact: {response.content}")
|
||||
23
test_local_llm.py
Normal file
23
test_local_llm.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import openai
|
||||
|
||||
openai.api_base = "http://localhost:4891/v1"
|
||||
|
||||
openai.api_key = "not needed for a local LLM"
|
||||
|
||||
|
||||
model = "ggml-llama-2-13b-chat.ggmlv3.q4_0.bin"
|
||||
prompt = "Who is Michael Jordan?"
|
||||
response = openai.Completion.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=50,
|
||||
temperature=0.28,
|
||||
top_p=0.95,
|
||||
n=1,
|
||||
echo=True,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response["choices"][0]["text"]) > len(prompt)
|
||||
print(f"Model: {response['model']}")
|
||||
print(f"Usage: {response['usage']}")
|
||||
print(f"Answer: {response['choices'][0]['text']}")
|
||||
Reference in New Issue
Block a user