mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-27 19:04:25 +01:00
460 lines
16 KiB
Python
460 lines
16 KiB
Python
"""
|
|
This is an example implementation of the Agent Protocol DB for development Purposes
|
|
It uses SQLite as the database and file store backend.
|
|
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
|
|
"""
|
|
|
|
import datetime
|
|
import math
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from sqlalchemy import (
|
|
JSON,
|
|
Boolean,
|
|
Column,
|
|
DateTime,
|
|
ForeignKey,
|
|
String,
|
|
create_engine,
|
|
)
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker
|
|
|
|
from .errors import NotFoundError
|
|
from .forge_log import CustomLogger
|
|
from .schema import Artifact, Pagination, Status, Step, Task, TaskInput
|
|
|
|
LOG = CustomLogger(__name__)
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
|
|
class TaskModel(Base):
|
|
__tablename__ = "tasks"
|
|
|
|
task_id = Column(String, primary_key=True, index=True)
|
|
input = Column(String)
|
|
additional_input = Column(JSON)
|
|
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
|
modified_at = Column(
|
|
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
|
)
|
|
|
|
artifacts = relationship("ArtifactModel", back_populates="task")
|
|
|
|
|
|
class StepModel(Base):
|
|
__tablename__ = "steps"
|
|
|
|
step_id = Column(String, primary_key=True, index=True)
|
|
task_id = Column(String, ForeignKey("tasks.task_id"))
|
|
name = Column(String)
|
|
input = Column(String)
|
|
status = Column(String)
|
|
is_last = Column(Boolean, default=False)
|
|
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
|
modified_at = Column(
|
|
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
|
)
|
|
|
|
additional_input = Column(JSON)
|
|
artifacts = relationship("ArtifactModel", back_populates="step")
|
|
|
|
|
|
class ArtifactModel(Base):
|
|
__tablename__ = "artifacts"
|
|
|
|
artifact_id = Column(String, primary_key=True, index=True)
|
|
task_id = Column(String, ForeignKey("tasks.task_id"))
|
|
step_id = Column(String, ForeignKey("steps.step_id"))
|
|
agent_created = Column(Boolean, default=False)
|
|
file_name = Column(String)
|
|
uri = Column(String)
|
|
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
|
modified_at = Column(
|
|
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
|
)
|
|
|
|
step = relationship("StepModel", back_populates="artifacts")
|
|
task = relationship("TaskModel", back_populates="artifacts")
|
|
|
|
|
|
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
|
|
if debug_enabled:
|
|
LOG.debug(f"Converting TaskModel to Task for task_id: {task_obj.task_id}")
|
|
task_artifacts = [convert_to_artifact(artifact) for artifact in task_obj.artifacts]
|
|
return Task(
|
|
task_id=task_obj.task_id,
|
|
created_at=task_obj.created_at,
|
|
modified_at=task_obj.modified_at,
|
|
input=task_obj.input,
|
|
additional_input=task_obj.additional_input,
|
|
artifacts=task_artifacts,
|
|
)
|
|
|
|
|
|
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
|
|
if debug_enabled:
|
|
LOG.debug(f"Converting StepModel to Step for step_id: {step_model.step_id}")
|
|
step_artifacts = [
|
|
convert_to_artifact(artifact) for artifact in step_model.artifacts
|
|
]
|
|
status = Status.completed if step_model.status == "completed" else Status.created
|
|
return Step(
|
|
task_id=step_model.task_id,
|
|
step_id=step_model.step_id,
|
|
created_at=step_model.created_at,
|
|
modified_at=step_model.modified_at,
|
|
name=step_model.name,
|
|
input=step_model.input,
|
|
status=status,
|
|
artifacts=step_artifacts,
|
|
is_last=step_model.is_last == 1,
|
|
additional_input=step_model.additional_input,
|
|
)
|
|
|
|
|
|
def convert_to_artifact(artifact_model: ArtifactModel) -> Artifact:
|
|
return Artifact(
|
|
artifact_id=artifact_model.artifact_id,
|
|
created_at=artifact_model.created_at,
|
|
modified_at=artifact_model.modified_at,
|
|
agent_created=artifact_model.agent_created,
|
|
uri=artifact_model.uri,
|
|
)
|
|
|
|
|
|
# sqlite:///{database_name}
|
|
class AgentDB:
|
|
def __init__(self, database_string, debug_enabled: bool = False) -> None:
|
|
super().__init__()
|
|
self.debug_enabled = debug_enabled
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Initializing AgentDB with database_string: {database_string}")
|
|
self.engine = create_engine(database_string)
|
|
Base.metadata.create_all(self.engine)
|
|
self.Session = sessionmaker(bind=self.engine)
|
|
|
|
async def create_task(
|
|
self, input: Optional[str], additional_input: Optional[TaskInput] = {}
|
|
) -> Task:
|
|
if self.debug_enabled:
|
|
LOG.debug("Creating new task")
|
|
|
|
try:
|
|
with self.Session() as session:
|
|
new_task = TaskModel(
|
|
task_id=str(uuid.uuid4()),
|
|
input=input,
|
|
additional_input=additional_input,
|
|
)
|
|
session.add(new_task)
|
|
session.commit()
|
|
session.refresh(new_task)
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Created new task with task_id: {new_task.task_id}")
|
|
return convert_to_task(new_task, self.debug_enabled)
|
|
except SQLAlchemyError as e:
|
|
LOG.error(f"SQLAlchemy error while creating task: {e}")
|
|
raise
|
|
except NotFoundError as e:
|
|
raise
|
|
except Exception as e:
|
|
LOG.error(f"Unexpected error while creating task: {e}")
|
|
raise
|
|
|
|
async def create_step(
|
|
self,
|
|
task_id: str,
|
|
input: str,
|
|
is_last: bool = False,
|
|
additional_input: Optional[Dict[str, Any]] = {},
|
|
) -> Step:
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Creating new step for task_id: {task_id}")
|
|
try:
|
|
with self.Session() as session:
|
|
new_step = StepModel(
|
|
task_id=task_id,
|
|
step_id=str(uuid.uuid4()),
|
|
name=input.name,
|
|
input=input.input,
|
|
status="created",
|
|
is_last=is_last,
|
|
additional_input=additional_input,
|
|
)
|
|
session.add(new_step)
|
|
session.commit()
|
|
session.refresh(new_step)
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Created new step with step_id: {new_step.step_id}")
|
|
return convert_to_step(new_step, self.debug_enabled)
|
|
except SQLAlchemyError as e:
|
|
LOG.error(f"SQLAlchemy error while creating step: {e}")
|
|
raise
|
|
except NotFoundError as e:
|
|
raise
|
|
except Exception as e:
|
|
LOG.error(f"Unexpected error while creating step: {e}")
|
|
raise
|
|
|
|
async def create_artifact(
|
|
self,
|
|
task_id: str,
|
|
file_name: str,
|
|
uri: str,
|
|
agent_created: bool = False,
|
|
step_id: str | None = None,
|
|
) -> Artifact:
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Creating new artifact for task_id: {task_id}")
|
|
try:
|
|
with self.Session() as session:
|
|
if (
|
|
existing_artifact := session.query(ArtifactModel)
|
|
.filter_by(uri=uri)
|
|
.first()
|
|
):
|
|
session.close()
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Artifact already exists with uri: {uri}")
|
|
return convert_to_artifact(existing_artifact)
|
|
|
|
new_artifact = ArtifactModel(
|
|
artifact_id=str(uuid.uuid4()),
|
|
task_id=task_id,
|
|
step_id=step_id,
|
|
agent_created=agent_created,
|
|
file_name=file_name,
|
|
uri=uri,
|
|
)
|
|
session.add(new_artifact)
|
|
session.commit()
|
|
session.refresh(new_artifact)
|
|
if self.debug_enabled:
|
|
LOG.debug(
|
|
f"Created new artifact with artifact_id: {new_artifact.artifact_id}"
|
|
)
|
|
return convert_to_artifact(new_artifact)
|
|
except SQLAlchemyError as e:
|
|
LOG.error(f"SQLAlchemy error while creating step: {e}")
|
|
raise
|
|
except NotFoundError as e:
|
|
raise
|
|
except Exception as e:
|
|
LOG.error(f"Unexpected error while creating step: {e}")
|
|
raise
|
|
|
|
async def get_task(self, task_id: int) -> Task:
|
|
"""Get a task by its id"""
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Getting task with task_id: {task_id}")
|
|
try:
|
|
with self.Session() as session:
|
|
if task_obj := (
|
|
session.query(TaskModel)
|
|
.options(joinedload(TaskModel.artifacts))
|
|
.filter_by(task_id=task_id)
|
|
.first()
|
|
):
|
|
return convert_to_task(task_obj, self.debug_enabled)
|
|
else:
|
|
LOG.error(f"Task not found with task_id: {task_id}")
|
|
raise NotFoundError("Task not found")
|
|
except SQLAlchemyError as e:
|
|
LOG.error(f"SQLAlchemy error while getting task: {e}")
|
|
raise
|
|
except NotFoundError as e:
|
|
raise
|
|
except Exception as e:
|
|
LOG.error(f"Unexpected error while getting task: {e}")
|
|
raise
|
|
|
|
async def get_step(self, task_id: int, step_id: int) -> Step:
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Getting step with task_id: {task_id} and step_id: {step_id}")
|
|
try:
|
|
with self.Session() as session:
|
|
if step := (
|
|
session.query(StepModel)
|
|
.options(joinedload(StepModel.artifacts))
|
|
.filter(StepModel.step_id == step_id)
|
|
.first()
|
|
):
|
|
return convert_to_step(step, self.debug_enabled)
|
|
|
|
else:
|
|
LOG.error(
|
|
f"Step not found with task_id: {task_id} and step_id: {step_id}"
|
|
)
|
|
raise NotFoundError("Step not found")
|
|
except SQLAlchemyError as e:
|
|
LOG.error(f"SQLAlchemy error while getting step: {e}")
|
|
raise
|
|
except NotFoundError as e:
|
|
raise
|
|
except Exception as e:
|
|
LOG.error(f"Unexpected error while getting step: {e}")
|
|
raise
|
|
|
|
async def update_step(
|
|
self,
|
|
task_id: str,
|
|
step_id: str,
|
|
status: str,
|
|
additional_input: Optional[Dict[str, Any]] = {},
|
|
) -> Step:
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Updating step with task_id: {task_id} and step_id: {step_id}")
|
|
try:
|
|
with self.Session() as session:
|
|
if (
|
|
step := session.query(StepModel)
|
|
.filter_by(task_id=task_id, step_id=step_id)
|
|
.first()
|
|
):
|
|
step.status = status
|
|
step.additional_input = additional_input
|
|
session.commit()
|
|
return await self.get_step(task_id, step_id)
|
|
else:
|
|
LOG.error(
|
|
f"Step not found for update with task_id: {task_id} and step_id: {step_id}"
|
|
)
|
|
raise NotFoundError("Step not found")
|
|
except SQLAlchemyError as e:
|
|
LOG.error(f"SQLAlchemy error while getting step: {e}")
|
|
raise
|
|
except NotFoundError as e:
|
|
raise
|
|
except Exception as e:
|
|
LOG.error(f"Unexpected error while getting step: {e}")
|
|
raise
|
|
|
|
async def get_artifact(self, artifact_id: str) -> Artifact:
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}")
|
|
try:
|
|
with self.Session() as session:
|
|
if (
|
|
artifact_model := session.query(ArtifactModel)
|
|
.filter_by(artifact_id=artifact_id)
|
|
.first()
|
|
):
|
|
return convert_to_artifact(artifact_model)
|
|
else:
|
|
LOG.error(f"Artifact not found with and artifact_id: {artifact_id}")
|
|
raise NotFoundError("Artifact not found")
|
|
except SQLAlchemyError as e:
|
|
LOG.error(f"SQLAlchemy error while getting artifact: {e}")
|
|
raise
|
|
except NotFoundError as e:
|
|
raise
|
|
except Exception as e:
|
|
LOG.error(f"Unexpected error while getting artifact: {e}")
|
|
raise
|
|
|
|
async def list_tasks(
|
|
self, page: int = 1, per_page: int = 10
|
|
) -> Tuple[List[Task], Pagination]:
|
|
if self.debug_enabled:
|
|
LOG.debug("Listing tasks")
|
|
try:
|
|
with self.Session() as session:
|
|
tasks = (
|
|
session.query(TaskModel)
|
|
.offset((page - 1) * per_page)
|
|
.limit(per_page)
|
|
.all()
|
|
)
|
|
total = session.query(TaskModel).count()
|
|
pages = math.ceil(total / per_page)
|
|
pagination = Pagination(
|
|
total_items=total,
|
|
total_pages=pages,
|
|
current_page=page,
|
|
page_size=per_page,
|
|
)
|
|
return [
|
|
convert_to_task(task, self.debug_enabled) for task in tasks
|
|
], pagination
|
|
except SQLAlchemyError as e:
|
|
LOG.error(f"SQLAlchemy error while listing tasks: {e}")
|
|
raise
|
|
except NotFoundError as e:
|
|
raise
|
|
except Exception as e:
|
|
LOG.error(f"Unexpected error while listing tasks: {e}")
|
|
raise
|
|
|
|
async def list_steps(
|
|
self, task_id: str, page: int = 1, per_page: int = 10
|
|
) -> Tuple[List[Step], Pagination]:
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Listing steps for task_id: {task_id}")
|
|
try:
|
|
with self.Session() as session:
|
|
steps = (
|
|
session.query(StepModel)
|
|
.filter_by(task_id=task_id)
|
|
.offset((page - 1) * per_page)
|
|
.limit(per_page)
|
|
.all()
|
|
)
|
|
total = session.query(StepModel).filter_by(task_id=task_id).count()
|
|
pages = math.ceil(total / per_page)
|
|
pagination = Pagination(
|
|
total_items=total,
|
|
total_pages=pages,
|
|
current_page=page,
|
|
page_size=per_page,
|
|
)
|
|
return [
|
|
convert_to_step(step, self.debug_enabled) for step in steps
|
|
], pagination
|
|
except SQLAlchemyError as e:
|
|
LOG.error(f"SQLAlchemy error while listing steps: {e}")
|
|
raise
|
|
except NotFoundError as e:
|
|
raise
|
|
except Exception as e:
|
|
LOG.error(f"Unexpected error while listing steps: {e}")
|
|
raise
|
|
|
|
async def list_artifacts(
|
|
self, task_id: str, page: int = 1, per_page: int = 10
|
|
) -> Tuple[List[Artifact], Pagination]:
|
|
if self.debug_enabled:
|
|
LOG.debug(f"Listing artifacts for task_id: {task_id}")
|
|
try:
|
|
with self.Session() as session:
|
|
artifacts = (
|
|
session.query(ArtifactModel)
|
|
.filter_by(task_id=task_id)
|
|
.offset((page - 1) * per_page)
|
|
.limit(per_page)
|
|
.all()
|
|
)
|
|
total = session.query(ArtifactModel).filter_by(task_id=task_id).count()
|
|
pages = math.ceil(total / per_page)
|
|
pagination = Pagination(
|
|
total_items=total,
|
|
total_pages=pages,
|
|
current_page=page,
|
|
page_size=per_page,
|
|
)
|
|
return [
|
|
convert_to_artifact(artifact) for artifact in artifacts
|
|
], pagination
|
|
except SQLAlchemyError as e:
|
|
LOG.error(f"SQLAlchemy error while listing artifacts: {e}")
|
|
raise
|
|
except NotFoundError as e:
|
|
raise
|
|
except Exception as e:
|
|
LOG.error(f"Unexpected error while listing artifacts: {e}")
|
|
raise
|