mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-18 06:24:20 +01:00
326 lines
11 KiB
Python
326 lines
11 KiB
Python
import os
|
|
import sqlite3
|
|
from datetime import datetime
|
|
|
|
import pytest
|
|
|
|
from forge.sdk.db import (
|
|
AgentDB,
|
|
ArtifactModel,
|
|
StepModel,
|
|
TaskModel,
|
|
convert_to_artifact,
|
|
convert_to_step,
|
|
convert_to_task,
|
|
)
|
|
from forge.sdk.errors import NotFoundError as DataNotFoundError
|
|
from forge.sdk.schema import *
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
def test_table_creation():
|
|
db_name = "sqlite:///test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
|
|
conn = sqlite3.connect("test_db.sqlite3")
|
|
cursor = conn.cursor()
|
|
|
|
# Test for tasks table existence
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
|
|
assert cursor.fetchone() is not None
|
|
|
|
# Test for steps table existence
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='steps'")
|
|
assert cursor.fetchone() is not None
|
|
|
|
# Test for artifacts table existence
|
|
cursor.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='artifacts'"
|
|
)
|
|
assert cursor.fetchone() is not None
|
|
|
|
os.remove(db_name.split("///")[1])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_task_schema():
|
|
now = datetime.now()
|
|
task = Task(
|
|
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
|
input="Write the words you receive to the file 'output.txt'.",
|
|
created_at=now,
|
|
modified_at=now,
|
|
artifacts=[
|
|
Artifact(
|
|
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
|
agent_created=True,
|
|
file_name="main.py",
|
|
relative_path="python/code/",
|
|
created_at=now,
|
|
modified_at=now,
|
|
)
|
|
],
|
|
)
|
|
assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
|
assert task.input == "Write the words you receive to the file 'output.txt'."
|
|
assert len(task.artifacts) == 1
|
|
assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_step_schema():
|
|
now = datetime.now()
|
|
step = Step(
|
|
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
|
step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
|
created_at=now,
|
|
modified_at=now,
|
|
name="Write to file",
|
|
input="Write the words you receive to the file 'output.txt'.",
|
|
status=Status.created,
|
|
output="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>",
|
|
artifacts=[
|
|
Artifact(
|
|
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
|
file_name="main.py",
|
|
relative_path="python/code/",
|
|
created_at=now,
|
|
modified_at=now,
|
|
agent_created=True,
|
|
)
|
|
],
|
|
is_last=False,
|
|
)
|
|
assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
|
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
|
|
assert step.name == "Write to file"
|
|
assert step.status == Status.created
|
|
assert (
|
|
step.output
|
|
== "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
|
)
|
|
assert len(step.artifacts) == 1
|
|
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
|
assert step.is_last == False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convert_to_task():
|
|
now = datetime.now()
|
|
task_model = TaskModel(
|
|
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
|
created_at=now,
|
|
modified_at=now,
|
|
input="Write the words you receive to the file 'output.txt'.",
|
|
artifacts=[
|
|
ArtifactModel(
|
|
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
|
created_at=now,
|
|
modified_at=now,
|
|
relative_path="file:///path/to/main.py",
|
|
agent_created=True,
|
|
file_name="main.py",
|
|
)
|
|
],
|
|
)
|
|
task = convert_to_task(task_model)
|
|
assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
|
assert task.input == "Write the words you receive to the file 'output.txt'."
|
|
assert len(task.artifacts) == 1
|
|
assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convert_to_step():
|
|
now = datetime.now()
|
|
step_model = StepModel(
|
|
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
|
step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
|
created_at=now,
|
|
modified_at=now,
|
|
name="Write to file",
|
|
status="created",
|
|
input="Write the words you receive to the file 'output.txt'.",
|
|
artifacts=[
|
|
ArtifactModel(
|
|
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
|
created_at=now,
|
|
modified_at=now,
|
|
relative_path="file:///path/to/main.py",
|
|
agent_created=True,
|
|
file_name="main.py",
|
|
)
|
|
],
|
|
is_last=False,
|
|
)
|
|
step = convert_to_step(step_model)
|
|
assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
|
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
|
|
assert step.name == "Write to file"
|
|
assert step.status == Status.created
|
|
assert len(step.artifacts) == 1
|
|
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
|
assert step.is_last == False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convert_to_artifact():
|
|
now = datetime.now()
|
|
artifact_model = ArtifactModel(
|
|
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
|
created_at=now,
|
|
modified_at=now,
|
|
relative_path="file:///path/to/main.py",
|
|
agent_created=True,
|
|
file_name="main.py",
|
|
)
|
|
artifact = convert_to_artifact(artifact_model)
|
|
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
|
assert artifact.relative_path == "file:///path/to/main.py"
|
|
assert artifact.agent_created == True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_task():
|
|
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
|
|
# TODO: Fix this!
|
|
db_name = "sqlite:///test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
|
|
task = await agent_db.create_task("task_input")
|
|
assert task.input == "task_input"
|
|
os.remove(db_name.split("///")[1])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_and_get_task():
|
|
db_name = "sqlite:///test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
task = await agent_db.create_task("test_input")
|
|
fetched_task = await agent_db.get_task(task.task_id)
|
|
assert fetched_task.input == "test_input"
|
|
os.remove(db_name.split("///")[1])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_task_not_found():
|
|
db_name = "sqlite:///test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
with pytest.raises(DataNotFoundError):
|
|
await agent_db.get_task(9999)
|
|
os.remove(db_name.split("///")[1])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_and_get_step():
|
|
db_name = "sqlite:///test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
task = await agent_db.create_task("task_input")
|
|
step_input = StepInput(type="python/code")
|
|
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
|
step = await agent_db.create_step(task.task_id, request)
|
|
step = await agent_db.get_step(task.task_id, step.step_id)
|
|
assert step.input == "test_input debug"
|
|
os.remove(db_name.split("///")[1])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_updating_step():
|
|
db_name = "sqlite:///test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
created_task = await agent_db.create_task("task_input")
|
|
step_input = StepInput(type="python/code")
|
|
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
|
created_step = await agent_db.create_step(created_task.task_id, request)
|
|
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
|
|
|
|
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
|
|
assert step.status.value == "completed"
|
|
os.remove(db_name.split("///")[1])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_step_not_found():
|
|
db_name = "sqlite:///test_db.sqlite3"
|
|
agent_db = AgentDB(db_name)
|
|
with pytest.raises(DataNotFoundError):
|
|
await agent_db.get_step(9999, 9999)
|
|
os.remove(db_name.split("///")[1])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_artifact():
|
|
db_name = "sqlite:///test_db.sqlite3"
|
|
db = AgentDB(db_name)
|
|
|
|
# Given: A task and its corresponding artifact
|
|
task = await db.create_task("test_input debug")
|
|
step_input = StepInput(type="python/code")
|
|
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
|
|
|
step = await db.create_step(task.task_id, requst)
|
|
|
|
# Create an artifact
|
|
artifact = await db.create_artifact(
|
|
task_id=task.task_id,
|
|
file_name="test_get_artifact_sample_file.txt",
|
|
relative_path="file:///path/to/test_get_artifact_sample_file.txt",
|
|
agent_created=True,
|
|
step_id=step.step_id,
|
|
)
|
|
|
|
# When: The artifact is fetched by its ID
|
|
fetched_artifact = await db.get_artifact(artifact.artifact_id)
|
|
|
|
# Then: The fetched artifact matches the original
|
|
assert fetched_artifact.artifact_id == artifact.artifact_id
|
|
assert (
|
|
fetched_artifact.relative_path
|
|
== "file:///path/to/test_get_artifact_sample_file.txt"
|
|
)
|
|
|
|
os.remove(db_name.split("///")[1])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_tasks():
|
|
db_name = "sqlite:///test_db.sqlite3"
|
|
db = AgentDB(db_name)
|
|
|
|
# Given: Multiple tasks in the database
|
|
task1 = await db.create_task("test_input_1")
|
|
task2 = await db.create_task("test_input_2")
|
|
|
|
# When: All tasks are fetched
|
|
fetched_tasks, pagination = await db.list_tasks()
|
|
|
|
# Then: The fetched tasks list includes the created tasks
|
|
task_ids = [task.task_id for task in fetched_tasks]
|
|
assert task1.task_id in task_ids
|
|
assert task2.task_id in task_ids
|
|
os.remove(db_name.split("///")[1])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_steps():
|
|
db_name = "sqlite:///test_db.sqlite3"
|
|
db = AgentDB(db_name)
|
|
|
|
step_input = StepInput(type="python/code")
|
|
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
|
|
|
# Given: A task and multiple steps for that task
|
|
task = await db.create_task("test_input")
|
|
step1 = await db.create_step(task.task_id, requst)
|
|
requst = StepRequestBody(input="step two", additional_input=step_input)
|
|
step2 = await db.create_step(task.task_id, requst)
|
|
|
|
# When: All steps for the task are fetched
|
|
fetched_steps, pagination = await db.list_steps(task.task_id)
|
|
|
|
# Then: The fetched steps list includes the created steps
|
|
step_ids = [step.step_id for step in fetched_steps]
|
|
assert step1.step_id in step_ids
|
|
assert step2.step_id in step_ids
|
|
os.remove(db_name.split("///")[1])
|