Files
Auto-GPT/tests/test_db.py
2023-08-14 13:02:23 +02:00

191 lines
5.4 KiB
Python

import os
import sqlite3
import pytest
from autogpt.db import AgentDB, DataNotFoundError
def test_table_creation():
db_name = "test_db.sqlite3"
agent_db = AgentDB(db_name)
conn = sqlite3.connect("test_db.sqlite3")
cursor = conn.cursor()
# Test for tasks table existence
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
assert cursor.fetchone() is not None
# Test for steps table existence
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='steps'")
assert cursor.fetchone() is not None
# Test for artifacts table existence
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='artifacts'"
)
assert cursor.fetchone() is not None
agent_db.conn.close()
os.remove(db_name)
@pytest.mark.asyncio
async def test_create_task():
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
# TODO: Fix this!
db_name = "test_db.sqlite3"
agent_db = AgentDB(db_name)
task = await agent_db.create_task("task_input")
assert task.input == "task_input"
agent_db.conn.close()
os.remove(db_name)
@pytest.mark.asyncio
async def test_create_and_get_task():
db_name = "test_db.sqlite3"
agent_db = AgentDB(db_name)
await agent_db.create_task("task_input")
task = await agent_db.get_task(1)
assert task.input == "task_input"
agent_db.conn.close()
os.remove(db_name)
@pytest.mark.asyncio
async def test_get_task_not_found():
db_name = "test_db.sqlite3"
agent_db = AgentDB(db_name)
with pytest.raises(DataNotFoundError):
await agent_db.get_task(9999)
agent_db.conn.close()
os.remove(db_name)
@pytest.mark.asyncio
async def test_create_and_get_step():
db_name = "test_db.sqlite3"
agent_db = AgentDB(db_name)
await agent_db.create_task("task_input")
await agent_db.create_step(1, "step_name")
step = await agent_db.get_step(1, 1)
assert step.name == "step_name"
agent_db.conn.close()
os.remove(db_name)
@pytest.mark.asyncio
async def test_updating_step():
db_name = "test_db.sqlite3"
agent_db = AgentDB(db_name)
await agent_db.create_task("task_input")
await agent_db.create_step(1, "step_name")
await agent_db.update_step(1, 1, "completed")
step = await agent_db.get_step(1, 1)
assert step.status.value == "completed"
agent_db.conn.close()
os.remove(db_name)
@pytest.mark.asyncio
async def test_get_step_not_found():
db_name = "test_db.sqlite3"
agent_db = AgentDB(db_name)
with pytest.raises(DataNotFoundError):
await agent_db.get_step(9999, 9999)
agent_db.conn.close()
os.remove(db_name)
@pytest.mark.asyncio
async def test_get_artifact():
db_name = "test_db.sqlite3"
db = AgentDB(db_name)
# Given: A task and its corresponding artifact
task = await db.create_task("test_input")
step = await db.create_step(task.task_id, "step_name")
artifact = await db.create_artifact(
task.task_id, "sample_file.txt", "/path/to/sample_file.txt", step.step_id
)
# When: The artifact is fetched by its ID
fetched_artifact = await db.get_artifact(task.task_id, artifact.artifact_id)
# Then: The fetched artifact matches the original
assert fetched_artifact.artifact_id == artifact.artifact_id
assert fetched_artifact.file_name == "sample_file.txt"
assert fetched_artifact.relative_path == "/path/to/sample_file.txt"
db.conn.close()
os.remove(db_name)
@pytest.mark.asyncio
async def test_get_artifact_file():
db_name = "test_db.sqlite3"
db = AgentDB(db_name)
sample_data = b"sample data"
# Given: A task and its corresponding artifact
task = await db.create_task("test_input")
step = await db.create_step(task.task_id, "step_name")
artifact = await db.create_artifact(
task.task_id,
"sample_file.txt",
"/path/to/sample_file.txt",
step.step_id,
sample_data,
)
# When: The artifact is fetched by its ID
fetched_artifact = await db.get_artifact_file(task.task_id, artifact.artifact_id)
# Then: The fetched artifact matches the original
assert fetched_artifact == sample_data
db.conn.close()
os.remove(db_name)
@pytest.mark.asyncio
async def test_list_tasks():
db_name = "test_db.sqlite3"
db = AgentDB(db_name)
# Given: Multiple tasks in the database
task1 = await db.create_task("test_input_1")
task2 = await db.create_task("test_input_2")
# When: All tasks are fetched
fetched_tasks = await db.list_tasks()
# Then: The fetched tasks list includes the created tasks
task_ids = [task.task_id for task in fetched_tasks]
assert task1.task_id in task_ids
assert task2.task_id in task_ids
db.conn.close()
os.remove(db_name)
@pytest.mark.asyncio
async def test_list_steps():
db_name = "test_db.sqlite3"
db = AgentDB(db_name)
# Given: A task and multiple steps for that task
task = await db.create_task("test_input")
step1 = await db.create_step(task.task_id, "step_1")
step2 = await db.create_step(task.task_id, "step_2")
# When: All steps for the task are fetched
fetched_steps = await db.list_steps(task.task_id)
# Then: The fetched steps list includes the created steps
step_ids = [step.step_id for step in fetched_steps]
assert step1.step_id in step_ids
assert step2.step_id in step_ids
db.conn.close()
os.remove(db_name)