mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-28 03:14:32 +01:00
191 lines
5.4 KiB
Python
191 lines
5.4 KiB
Python
import os
|
|
import sqlite3
|
|
|
|
import pytest
|
|
|
|
from autogpt.db import AgentDB, DataNotFoundError
|
|
|
|
|
|
def test_table_creation():
|
|
db_name = "test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
|
|
conn = sqlite3.connect("test_db.sqlite3")
|
|
cursor = conn.cursor()
|
|
|
|
# Test for tasks table existence
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
|
|
assert cursor.fetchone() is not None
|
|
|
|
# Test for steps table existence
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='steps'")
|
|
assert cursor.fetchone() is not None
|
|
|
|
# Test for artifacts table existence
|
|
cursor.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='artifacts'"
|
|
)
|
|
assert cursor.fetchone() is not None
|
|
|
|
agent_db.conn.close()
|
|
os.remove(db_name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_task():
|
|
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
|
|
# TODO: Fix this!
|
|
db_name = "test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
|
|
task = await agent_db.create_task("task_input")
|
|
assert task.input == "task_input"
|
|
agent_db.conn.close()
|
|
os.remove(db_name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_and_get_task():
|
|
db_name = "test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
await agent_db.create_task("task_input")
|
|
task = await agent_db.get_task(1)
|
|
assert task.input == "task_input"
|
|
agent_db.conn.close()
|
|
os.remove(db_name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_task_not_found():
|
|
db_name = "test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
with pytest.raises(DataNotFoundError):
|
|
await agent_db.get_task(9999)
|
|
agent_db.conn.close()
|
|
os.remove(db_name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_and_get_step():
|
|
db_name = "test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
await agent_db.create_task("task_input")
|
|
await agent_db.create_step(1, "step_name")
|
|
step = await agent_db.get_step(1, 1)
|
|
assert step.name == "step_name"
|
|
agent_db.conn.close()
|
|
os.remove(db_name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_updating_step():
|
|
db_name = "test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
await agent_db.create_task("task_input")
|
|
await agent_db.create_step(1, "step_name")
|
|
await agent_db.update_step(1, 1, "completed")
|
|
|
|
step = await agent_db.get_step(1, 1)
|
|
assert step.status.value == "completed"
|
|
agent_db.conn.close()
|
|
os.remove(db_name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_step_not_found():
|
|
db_name = "test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
with pytest.raises(DataNotFoundError):
|
|
await agent_db.get_step(9999, 9999)
|
|
agent_db.conn.close()
|
|
os.remove(db_name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_artifact():
|
|
db_name = "test_db.sqlite3"
|
|
db = AgentDB(db_name)
|
|
|
|
# Given: A task and its corresponding artifact
|
|
task = await db.create_task("test_input")
|
|
step = await db.create_step(task.task_id, "step_name")
|
|
artifact = await db.create_artifact(
|
|
task.task_id, "sample_file.txt", "/path/to/sample_file.txt", step.step_id
|
|
)
|
|
|
|
# When: The artifact is fetched by its ID
|
|
fetched_artifact = await db.get_artifact(task.task_id, artifact.artifact_id)
|
|
|
|
# Then: The fetched artifact matches the original
|
|
assert fetched_artifact.artifact_id == artifact.artifact_id
|
|
assert fetched_artifact.file_name == "sample_file.txt"
|
|
assert fetched_artifact.relative_path == "/path/to/sample_file.txt"
|
|
db.conn.close()
|
|
os.remove(db_name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_artifact_file():
|
|
db_name = "test_db.sqlite3"
|
|
db = AgentDB(db_name)
|
|
sample_data = b"sample data"
|
|
# Given: A task and its corresponding artifact
|
|
task = await db.create_task("test_input")
|
|
step = await db.create_step(task.task_id, "step_name")
|
|
artifact = await db.create_artifact(
|
|
task.task_id,
|
|
"sample_file.txt",
|
|
"/path/to/sample_file.txt",
|
|
step.step_id,
|
|
sample_data,
|
|
)
|
|
|
|
# When: The artifact is fetched by its ID
|
|
fetched_artifact = await db.get_artifact_file(task.task_id, artifact.artifact_id)
|
|
|
|
# Then: The fetched artifact matches the original
|
|
assert fetched_artifact == sample_data
|
|
db.conn.close()
|
|
os.remove(db_name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_tasks():
|
|
db_name = "test_db.sqlite3"
|
|
db = AgentDB(db_name)
|
|
|
|
# Given: Multiple tasks in the database
|
|
task1 = await db.create_task("test_input_1")
|
|
task2 = await db.create_task("test_input_2")
|
|
|
|
# When: All tasks are fetched
|
|
fetched_tasks = await db.list_tasks()
|
|
|
|
# Then: The fetched tasks list includes the created tasks
|
|
task_ids = [task.task_id for task in fetched_tasks]
|
|
assert task1.task_id in task_ids
|
|
assert task2.task_id in task_ids
|
|
db.conn.close()
|
|
os.remove(db_name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_steps():
|
|
db_name = "test_db.sqlite3"
|
|
db = AgentDB(db_name)
|
|
|
|
# Given: A task and multiple steps for that task
|
|
task = await db.create_task("test_input")
|
|
step1 = await db.create_step(task.task_id, "step_1")
|
|
step2 = await db.create_step(task.task_id, "step_2")
|
|
|
|
# When: All steps for the task are fetched
|
|
fetched_steps = await db.list_steps(task.task_id)
|
|
|
|
# Then: The fetched steps list includes the created steps
|
|
step_ids = [step.step_id for step in fetched_steps]
|
|
assert step1.step_id in step_ids
|
|
assert step2.step_id in step_ids
|
|
db.conn.close()
|
|
os.remove(db_name)
|