mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-10 09:44:26 +01:00
Added sqlite database
This commit is contained in:
@@ -22,10 +22,10 @@ repos:
|
||||
- id: black
|
||||
language_version: python3.10
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: 'v1.3.0'
|
||||
hooks:
|
||||
- id: mypy
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: 'v1.3.0'
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -34,9 +34,9 @@ repos:
|
||||
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests
|
||||
language: python
|
||||
types: [ python ]
|
||||
# - id: pytest-check
|
||||
# name: pytest-check
|
||||
# entry: pytest --cov=autogpt --without-integration --without-slow-integration
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
# always_run: true
|
||||
- id: pytest-check
|
||||
name: pytest-check
|
||||
entry: pytest
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
|
||||
@@ -1,16 +1,6 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
from agent_protocol_client import Configuration
|
||||
|
||||
import agbenchmark.e2b_boilerplate
|
||||
|
||||
configuration = Configuration(host="http://localhost:8915")
|
||||
|
||||
|
||||
async def run_specific_agent(task: str) -> None:
|
||||
"""Runs the agent for benchmarking"""
|
||||
# Start the agent Server
|
||||
if __name__ == "__main__":
|
||||
command = [
|
||||
"poetry",
|
||||
"run",
|
||||
@@ -18,19 +8,4 @@ async def run_specific_agent(task: str) -> None:
|
||||
"-m",
|
||||
"autogpt",
|
||||
]
|
||||
agent_process = agbenchmark.e2b_boilerplate.start_agent_server(
|
||||
command, "localhost", 8915
|
||||
)
|
||||
|
||||
# Send the task to the agent
|
||||
await agbenchmark.e2b_boilerplate.task_agent(task)
|
||||
|
||||
agent_process.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python script.py <task>")
|
||||
sys.exit(1)
|
||||
task = sys.argv[-1]
|
||||
asyncio.run(run_specific_agent(task))
|
||||
subprocess.run(command)
|
||||
|
||||
@@ -1,5 +1 @@
|
||||
{
|
||||
"workspace": "agbenchmark/workspace",
|
||||
"entry_path": "agbenchmark.benchmarks"
|
||||
}
|
||||
|
||||
{"workspace": "agbenchmark/workspace", "entry_path": "agbenchmark.benchmarks", "api_mode": "True", "host": "http://localhost:8000"}
|
||||
@@ -1,64 +0,0 @@
|
||||
import subprocess
|
||||
import time
|
||||
import typing
|
||||
|
||||
import agent_protocol_client as apc
|
||||
import requests
|
||||
|
||||
configuration = apc.Configuration(host="http://localhost:8915")
|
||||
|
||||
|
||||
def start_agent_server(
|
||||
command: typing.List[str], host: str, port: int
|
||||
) -> subprocess.Popen:
|
||||
"""Boilerplate code to start the agent server and wait for it to be ready"""
|
||||
agent_process = subprocess.Popen(command, text=True)
|
||||
|
||||
print("Agent server started")
|
||||
server_ready = False
|
||||
attempts = 0
|
||||
while not server_ready and attempts < 5:
|
||||
try:
|
||||
response = requests.get(f"http://{host}:{port}/hb")
|
||||
if response.status_code == 200:
|
||||
server_ready = True
|
||||
except Exception as e:
|
||||
print(f"Unable to connect to server: {e}")
|
||||
attempts += 1
|
||||
time.sleep(0.5)
|
||||
|
||||
if not server_ready:
|
||||
agent_process.terminate()
|
||||
print("Agent server failed to start")
|
||||
return agent_process
|
||||
|
||||
|
||||
async def task_agent(task: str) -> None:
|
||||
try:
|
||||
async with apc.ApiClient(configuration) as api_client:
|
||||
# Create an instance of the API class
|
||||
api_instance = apc.AgentApi(api_client)
|
||||
task_request_body = apc.TaskRequestBody(input=task)
|
||||
|
||||
print("sending task to agent")
|
||||
response = await api_instance.create_agent_task(
|
||||
task_request_body=task_request_body
|
||||
)
|
||||
print("The response of AgentApi->create_agent_task:\n")
|
||||
print(response)
|
||||
|
||||
task_id = response.task_id
|
||||
i = 1
|
||||
|
||||
while (
|
||||
step := await api_instance.execute_agent_task_step(
|
||||
task_id=task_id, step_request_body=apc.StepRequestBody(input=i)
|
||||
)
|
||||
) and step.is_last is False:
|
||||
print("The response of AgentApi->execute_agent_task_step:\n")
|
||||
print(step)
|
||||
i += 1
|
||||
|
||||
print("Agent finished its work!")
|
||||
except Exception as e:
|
||||
print(f"Exception whilst attempting to run agent task: {e}")
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"command": "agbenchmark start",
|
||||
"benchmark_git_commit_sha": null,
|
||||
"agent_git_commit_sha": "git@github.com:Significant-Gravitas/Auto-GPT-Forge/tree/9094137d25022d896c3fddb337f6593ec0244b8a",
|
||||
"completion_time": "2023-08-07-16:10",
|
||||
"benchmark_start_time": "2023-08-07-16:10",
|
||||
"metrics": {
|
||||
"run_time": "0.17 seconds",
|
||||
"highest_difficulty": "No successful tests"
|
||||
},
|
||||
"tests": {},
|
||||
"config": {
|
||||
"workspace": "agbenchmark/workspace",
|
||||
"entry_path": "agbenchmark.benchmarks"
|
||||
}
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
{}
|
||||
@@ -1 +0,0 @@
|
||||
{}
|
||||
@@ -1,9 +1,19 @@
|
||||
import os
|
||||
|
||||
from agent_protocol import Agent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import autogpt.agent
|
||||
import autogpt.db
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Runs the agent server"""
|
||||
load_dotenv()
|
||||
autogpt.agent.start_agent(port=8915)
|
||||
database_name = os.getenv("DATABASE_NAME")
|
||||
port = os.getenv("PORT")
|
||||
auto_gpt = autogpt.agent.AutoGPT()
|
||||
|
||||
database = autogpt.db.AgentDB(database_name)
|
||||
agent = Agent.setup_agent(auto_gpt.task_handler, auto_gpt.step_handler)
|
||||
agent.db = database
|
||||
agent.start(port=port)
|
||||
|
||||
@@ -1,44 +1,21 @@
|
||||
from agent_protocol import Agent, Step, Task, router
|
||||
from fastapi import APIRouter, Response
|
||||
from agent_protocol import Agent, Step, Task
|
||||
|
||||
import autogpt.utils
|
||||
|
||||
##################################################
|
||||
##################################################
|
||||
# E2b boilerplate code
|
||||
# We need to add heartbeat endpoint to agent_protocol
|
||||
# so we can detect when the agent server is ready
|
||||
e2b_extension_router = APIRouter()
|
||||
|
||||
class AutoGPT:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@e2b_extension_router.get("/hb")
|
||||
async def hello():
|
||||
return Response("Agent running")
|
||||
async def task_handler(self, task: Task) -> None:
|
||||
print(f"task: {task.input}")
|
||||
autogpt.utils.run(task.input)
|
||||
await Agent.db.create_step(task.task_id, task.input)
|
||||
|
||||
|
||||
e2b_extension_router.include_router(router)
|
||||
|
||||
|
||||
def start_agent(port: int):
|
||||
Agent.setup_agent(task_handler, step_handler).start(
|
||||
port=port, router=e2b_extension_router
|
||||
)
|
||||
|
||||
|
||||
###################################################
|
||||
###################################################
|
||||
|
||||
|
||||
async def task_handler(task: Task) -> None:
|
||||
print(f"task: {task.input}")
|
||||
autogpt.utils.run(task.input)
|
||||
await Agent.db.create_step(task.task_id, task.input)
|
||||
|
||||
|
||||
async def step_handler(step: Step) -> Step:
|
||||
print(f"step: {step.input}")
|
||||
await Agent.db.create_step(
|
||||
step.task_id, f"Nothing to see here.. {step.name}", is_last=True
|
||||
)
|
||||
step.output = step.input
|
||||
return step
|
||||
async def step_handler(self, step: Step) -> Step:
|
||||
print(f"step: {step.input}")
|
||||
await Agent.db.create_step(
|
||||
step.task_id, f"Nothing to see here.. {step.name}", is_last=True
|
||||
)
|
||||
step.output = step.input
|
||||
return step
|
||||
|
||||
208
autogpt/db.py
Normal file
208
autogpt/db.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
This is an example implementation of the Agent Protocol DB for development Purposes
|
||||
It uses SQlite as the database and file store backend.
|
||||
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
|
||||
"""
|
||||
|
||||
|
||||
import sqlite3
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from agent_protocol import Artifact, Step, Task, TaskDB
|
||||
from agent_protocol.models import TaskInput
|
||||
|
||||
|
||||
class DataNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AgentDB(TaskDB):
|
||||
def __init__(self, database_name) -> None:
|
||||
super().__init__()
|
||||
self.conn = sqlite3.connect(database_name)
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# Create tasks table
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
input TEXT,
|
||||
additional_input TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Create steps table
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS steps (
|
||||
step_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id INTEGER,
|
||||
name TEXT,
|
||||
status TEXT,
|
||||
is_last INTEGER DEFAULT 0,
|
||||
additional_properties TEXT,
|
||||
FOREIGN KEY (task_id) REFERENCES tasks(task_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Create artifacts table
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS artifacts (
|
||||
artifact_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id INTEGER,
|
||||
step_id INTEGER,
|
||||
file_name TEXT,
|
||||
relative_path TEXT,
|
||||
file_data BLOB,
|
||||
FOREIGN KEY (task_id) REFERENCES tasks(task_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit the changes
|
||||
self.conn.commit()
|
||||
print("Databases Created")
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
input: Optional[str],
|
||||
additional_input: Optional[TaskInput] = None,
|
||||
artifacts: List[Artifact] = None,
|
||||
steps: List[Step] = None,
|
||||
) -> Task:
|
||||
"""Create a task"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO tasks (input, additional_input) VALUES (?, ?)",
|
||||
(input, additional_input.json() if additional_input else None),
|
||||
)
|
||||
task_id = cursor.lastrowid
|
||||
self.conn.commit()
|
||||
if task_id:
|
||||
return await self.get_task(task_id)
|
||||
else:
|
||||
raise DataNotFoundError("Task not found")
|
||||
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
name: Optional[str] = None,
|
||||
is_last: bool = False,
|
||||
additional_properties: Optional[Dict[str, str]] = None,
|
||||
) -> Step:
|
||||
"""Create a step for a given task"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO steps (task_id, name, status, is_last, additional_properties) VALUES (?, ?, ?, ?, ?)",
|
||||
(task_id, name, "created", is_last, additional_properties),
|
||||
)
|
||||
step_id = cursor.lastrowid
|
||||
self.conn.commit()
|
||||
if step_id and task_id:
|
||||
return await self.get_step(task_id, step_id)
|
||||
else:
|
||||
raise DataNotFoundError("Step not found")
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
file_name: str,
|
||||
relative_path: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
file_data: bytes | None = None,
|
||||
) -> Artifact:
|
||||
"""Create an artifact for a given task"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO artifacts (task_id, step_id, file_name, relative_path, file_data) VALUES (?, ?, ?, ?, ?)",
|
||||
(task_id, step_id, file_name, relative_path, file_data),
|
||||
)
|
||||
artifact_id = cursor.lastrowid
|
||||
self.conn.commit()
|
||||
return await self.get_artifact(task_id, artifact_id)
|
||||
|
||||
async def get_task(self, task_id: int) -> Task:
|
||||
"""Get a task by its id"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute("SELECT * FROM tasks WHERE task_id=?", (task_id,))
|
||||
if task := cursor.fetchone():
|
||||
return Task(task_id=task[0], input=task[1], additional_input=task[2])
|
||||
else:
|
||||
raise DataNotFoundError("Task not found")
|
||||
|
||||
async def get_step(self, task_id: int, step_id: int) -> Step:
|
||||
"""Get a step by its id"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM steps WHERE task_id=? AND step_id=?", (task_id, step_id)
|
||||
)
|
||||
if step := cursor.fetchone():
|
||||
return Step(task_id=task_id, step_id=step_id, name=step[2], status=step[3])
|
||||
else:
|
||||
raise DataNotFoundError("Step not found")
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
status: str,
|
||||
additional_properties: Optional[Dict[str, str]] = None,
|
||||
) -> Step:
|
||||
"""Update a step by its id"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE steps SET status=?, additional_properties=? WHERE task_id=? AND step_id=?",
|
||||
(status, additional_properties, task_id, step_id),
|
||||
)
|
||||
self.conn.commit()
|
||||
return await self.get_step(task_id, step_id)
|
||||
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||
"""Get an artifact by its id"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT artifact_id, file_name, relative_path FROM artifacts WHERE task_id=? AND artifact_id=?",
|
||||
(task_id, artifact_id),
|
||||
)
|
||||
if artifact := cursor.fetchone():
|
||||
return Artifact(
|
||||
artifact_id=artifact[0],
|
||||
file_name=artifact[1],
|
||||
relative_path=artifact[2],
|
||||
)
|
||||
else:
|
||||
raise DataNotFoundError("Artifact not found")
|
||||
|
||||
async def get_artifact_file(self, task_id: str, artifact_id: str) -> bytes:
|
||||
"""Get an artifact file by its id"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT file_data, file_name FROM artifacts WHERE task_id=? AND artifact_id=?",
|
||||
(task_id, artifact_id),
|
||||
)
|
||||
if artifact := cursor.fetchone():
|
||||
return artifact[0]
|
||||
|
||||
async def list_tasks(self) -> List[Task]:
|
||||
"""List all tasks"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute("SELECT * FROM tasks")
|
||||
tasks = cursor.fetchall()
|
||||
return [
|
||||
Task(task_id=task[0], input=task[1], additional_input=task[2])
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
async def list_steps(self, task_id: str) -> List[Step]:
|
||||
"""List all steps for a given task"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute("SELECT * FROM steps WHERE task_id=?", (task_id,))
|
||||
steps = cursor.fetchall()
|
||||
return [
|
||||
Step(task_id=task_id, step_id=step[0], name=step[2], status=step[3])
|
||||
for step in steps
|
||||
]
|
||||
1437
poetry.lock
generated
1437
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -8,16 +8,11 @@ readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
langchain = "^0.0.215"
|
||||
python-dotenv = "^0.21.0"
|
||||
python-dotenv = "^1.0.0"
|
||||
openai = "^0.27.8"
|
||||
autopack-tools = "^0.2.0"
|
||||
psutil = "^5.9.5"
|
||||
agent-protocol = "^0.2.2"
|
||||
helicone = "^1.0.6"
|
||||
|
||||
tenacity = "^8.2.2"
|
||||
agent-protocol-client = "^0.2.2"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
@@ -26,8 +21,10 @@ black = "^23.3.0"
|
||||
pre-commit = "^3.3.3"
|
||||
mypy = "^1.4.1"
|
||||
flake8 = "^6.0.0"
|
||||
agbenchmark = "^0.0.2"
|
||||
agbenchmark = "^0.0.7"
|
||||
types-requests = "^2.31.0.2"
|
||||
pytest = "^7.4.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
2
pytest.ini
Normal file
2
pytest.ini
Normal file
@@ -0,0 +1,2 @@
|
||||
[pytest]
|
||||
python_paths = ./autogpt
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
190
tests/test_db.py
Normal file
190
tests/test_db.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import os
|
||||
import sqlite3
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt.db import AgentDB, DataNotFoundError
|
||||
|
||||
|
||||
def test_table_creation():
|
||||
db_name = "test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
|
||||
conn = sqlite3.connect("test_db.sqlite3")
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Test for tasks table existence
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
# Test for steps table existence
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='steps'")
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
# Test for artifacts table existence
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='artifacts'"
|
||||
)
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
agent_db.conn.close()
|
||||
os.remove(db_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task():
|
||||
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
|
||||
# TODO: Fix this!
|
||||
db_name = "test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
|
||||
task = await agent_db.create_task("task_input")
|
||||
assert task.input == "task_input"
|
||||
agent_db.conn.close()
|
||||
os.remove(db_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_task():
|
||||
db_name = "test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
await agent_db.create_task("task_input")
|
||||
task = await agent_db.get_task(1)
|
||||
assert task.input == "task_input"
|
||||
agent_db.conn.close()
|
||||
os.remove(db_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_not_found():
|
||||
db_name = "test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
with pytest.raises(DataNotFoundError):
|
||||
await agent_db.get_task(9999)
|
||||
agent_db.conn.close()
|
||||
os.remove(db_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_step():
|
||||
db_name = "test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
await agent_db.create_task("task_input")
|
||||
await agent_db.create_step(1, "step_name")
|
||||
step = await agent_db.get_step(1, 1)
|
||||
assert step.name == "step_name"
|
||||
agent_db.conn.close()
|
||||
os.remove(db_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updating_step():
|
||||
db_name = "test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
await agent_db.create_task("task_input")
|
||||
await agent_db.create_step(1, "step_name")
|
||||
await agent_db.update_step(1, 1, "completed")
|
||||
|
||||
step = await agent_db.get_step(1, 1)
|
||||
assert step.status.value == "completed"
|
||||
agent_db.conn.close()
|
||||
os.remove(db_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_step_not_found():
|
||||
db_name = "test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
with pytest.raises(DataNotFoundError):
|
||||
await agent_db.get_step(9999, 9999)
|
||||
agent_db.conn.close()
|
||||
os.remove(db_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_artifact():
|
||||
db_name = "test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
|
||||
# Given: A task and its corresponding artifact
|
||||
task = await db.create_task("test_input")
|
||||
step = await db.create_step(task.task_id, "step_name")
|
||||
artifact = await db.create_artifact(
|
||||
task.task_id, "sample_file.txt", "/path/to/sample_file.txt", step.step_id
|
||||
)
|
||||
|
||||
# When: The artifact is fetched by its ID
|
||||
fetched_artifact = await db.get_artifact(task.task_id, artifact.artifact_id)
|
||||
|
||||
# Then: The fetched artifact matches the original
|
||||
assert fetched_artifact.artifact_id == artifact.artifact_id
|
||||
assert fetched_artifact.file_name == "sample_file.txt"
|
||||
assert fetched_artifact.relative_path == "/path/to/sample_file.txt"
|
||||
db.conn.close()
|
||||
os.remove(db_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_artifact_file():
|
||||
db_name = "test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
sample_data = b"sample data"
|
||||
# Given: A task and its corresponding artifact
|
||||
task = await db.create_task("test_input")
|
||||
step = await db.create_step(task.task_id, "step_name")
|
||||
artifact = await db.create_artifact(
|
||||
task.task_id,
|
||||
"sample_file.txt",
|
||||
"/path/to/sample_file.txt",
|
||||
step.step_id,
|
||||
sample_data,
|
||||
)
|
||||
|
||||
# When: The artifact is fetched by its ID
|
||||
fetched_artifact = await db.get_artifact_file(task.task_id, artifact.artifact_id)
|
||||
|
||||
# Then: The fetched artifact matches the original
|
||||
assert fetched_artifact == sample_data
|
||||
db.conn.close()
|
||||
os.remove(db_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks():
|
||||
db_name = "test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
|
||||
# Given: Multiple tasks in the database
|
||||
task1 = await db.create_task("test_input_1")
|
||||
task2 = await db.create_task("test_input_2")
|
||||
|
||||
# When: All tasks are fetched
|
||||
fetched_tasks = await db.list_tasks()
|
||||
|
||||
# Then: The fetched tasks list includes the created tasks
|
||||
task_ids = [task.task_id for task in fetched_tasks]
|
||||
assert task1.task_id in task_ids
|
||||
assert task2.task_id in task_ids
|
||||
db.conn.close()
|
||||
os.remove(db_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_steps():
|
||||
db_name = "test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
|
||||
# Given: A task and multiple steps for that task
|
||||
task = await db.create_task("test_input")
|
||||
step1 = await db.create_step(task.task_id, "step_1")
|
||||
step2 = await db.create_step(task.task_id, "step_2")
|
||||
|
||||
# When: All steps for the task are fetched
|
||||
fetched_steps = await db.list_steps(task.task_id)
|
||||
|
||||
# Then: The fetched steps list includes the created steps
|
||||
step_ids = [step.step_id for step in fetched_steps]
|
||||
assert step1.step_id in step_ids
|
||||
assert step2.step_id in step_ids
|
||||
db.conn.close()
|
||||
os.remove(db_name)
|
||||
Reference in New Issue
Block a user