Files
Auto-GPT/autogpt/db.py
merwanehamadi d3e86622b7 Fix Forge (#22)
Signed-off-by: Merwane Hamadi <merwanehamadi@gmail.com>
2023-08-26 08:20:12 -07:00

460 lines
16 KiB
Python

"""
This is an example implementation of the Agent Protocol DB for development Purposes
It uses SQLite as the database and file store backend.
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
"""
import datetime
import math
import uuid
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import (
JSON,
Boolean,
Column,
DateTime,
ForeignKey,
String,
create_engine,
)
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker
from .errors import NotFoundError
from .forge_log import CustomLogger
from .schema import Artifact, Pagination, Status, Step, Task, TaskInput
LOG = CustomLogger(__name__)
class Base(DeclarativeBase):
pass
class TaskModel(Base):
__tablename__ = "tasks"
task_id = Column(String, primary_key=True, index=True)
input = Column(String)
additional_input = Column(JSON)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
)
artifacts = relationship("ArtifactModel", back_populates="task")
class StepModel(Base):
__tablename__ = "steps"
step_id = Column(String, primary_key=True, index=True)
task_id = Column(String, ForeignKey("tasks.task_id"))
name = Column(String)
input = Column(String)
status = Column(String)
is_last = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
)
additional_input = Column(JSON)
artifacts = relationship("ArtifactModel", back_populates="step")
class ArtifactModel(Base):
__tablename__ = "artifacts"
artifact_id = Column(String, primary_key=True, index=True)
task_id = Column(String, ForeignKey("tasks.task_id"))
step_id = Column(String, ForeignKey("steps.step_id"))
agent_created = Column(Boolean, default=False)
file_name = Column(String)
uri = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
)
step = relationship("StepModel", back_populates="artifacts")
task = relationship("TaskModel", back_populates="artifacts")
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
if debug_enabled:
LOG.debug(f"Converting TaskModel to Task for task_id: {task_obj.task_id}")
task_artifacts = [convert_to_artifact(artifact) for artifact in task_obj.artifacts]
return Task(
task_id=task_obj.task_id,
created_at=task_obj.created_at,
modified_at=task_obj.modified_at,
input=task_obj.input,
additional_input=task_obj.additional_input,
artifacts=task_artifacts,
)
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
if debug_enabled:
LOG.debug(f"Converting StepModel to Step for step_id: {step_model.step_id}")
step_artifacts = [
convert_to_artifact(artifact) for artifact in step_model.artifacts
]
status = Status.completed if step_model.status == "completed" else Status.created
return Step(
task_id=step_model.task_id,
step_id=step_model.step_id,
created_at=step_model.created_at,
modified_at=step_model.modified_at,
name=step_model.name,
input=step_model.input,
status=status,
artifacts=step_artifacts,
is_last=step_model.is_last == 1,
additional_input=step_model.additional_input,
)
def convert_to_artifact(artifact_model: ArtifactModel) -> Artifact:
return Artifact(
artifact_id=artifact_model.artifact_id,
created_at=artifact_model.created_at,
modified_at=artifact_model.modified_at,
agent_created=artifact_model.agent_created,
uri=artifact_model.uri,
)
# sqlite:///{database_name}
class AgentDB:
def __init__(self, database_string, debug_enabled: bool = False) -> None:
super().__init__()
self.debug_enabled = debug_enabled
if self.debug_enabled:
LOG.debug(f"Initializing AgentDB with database_string: {database_string}")
self.engine = create_engine(database_string)
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)
async def create_task(
self, input: Optional[str], additional_input: Optional[TaskInput] = {}
) -> Task:
if self.debug_enabled:
LOG.debug("Creating new task")
try:
with self.Session() as session:
new_task = TaskModel(
task_id=str(uuid.uuid4()),
input=input,
additional_input=additional_input,
)
session.add(new_task)
session.commit()
session.refresh(new_task)
if self.debug_enabled:
LOG.debug(f"Created new task with task_id: {new_task.task_id}")
return convert_to_task(new_task, self.debug_enabled)
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while creating task: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while creating task: {e}")
raise
async def create_step(
self,
task_id: str,
input: str,
is_last: bool = False,
additional_input: Optional[Dict[str, Any]] = {},
) -> Step:
if self.debug_enabled:
LOG.debug(f"Creating new step for task_id: {task_id}")
try:
with self.Session() as session:
new_step = StepModel(
task_id=task_id,
step_id=str(uuid.uuid4()),
name=input.name,
input=input.input,
status="created",
is_last=is_last,
additional_input=additional_input,
)
session.add(new_step)
session.commit()
session.refresh(new_step)
if self.debug_enabled:
LOG.debug(f"Created new step with step_id: {new_step.step_id}")
return convert_to_step(new_step, self.debug_enabled)
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while creating step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while creating step: {e}")
raise
async def create_artifact(
self,
task_id: str,
file_name: str,
uri: str,
agent_created: bool = False,
step_id: str | None = None,
) -> Artifact:
if self.debug_enabled:
LOG.debug(f"Creating new artifact for task_id: {task_id}")
try:
with self.Session() as session:
if (
existing_artifact := session.query(ArtifactModel)
.filter_by(uri=uri)
.first()
):
session.close()
if self.debug_enabled:
LOG.debug(f"Artifact already exists with uri: {uri}")
return convert_to_artifact(existing_artifact)
new_artifact = ArtifactModel(
artifact_id=str(uuid.uuid4()),
task_id=task_id,
step_id=step_id,
agent_created=agent_created,
file_name=file_name,
uri=uri,
)
session.add(new_artifact)
session.commit()
session.refresh(new_artifact)
if self.debug_enabled:
LOG.debug(
f"Created new artifact with artifact_id: {new_artifact.artifact_id}"
)
return convert_to_artifact(new_artifact)
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while creating step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while creating step: {e}")
raise
async def get_task(self, task_id: int) -> Task:
"""Get a task by its id"""
if self.debug_enabled:
LOG.debug(f"Getting task with task_id: {task_id}")
try:
with self.Session() as session:
if task_obj := (
session.query(TaskModel)
.options(joinedload(TaskModel.artifacts))
.filter_by(task_id=task_id)
.first()
):
return convert_to_task(task_obj, self.debug_enabled)
else:
LOG.error(f"Task not found with task_id: {task_id}")
raise NotFoundError("Task not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting task: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting task: {e}")
raise
async def get_step(self, task_id: int, step_id: int) -> Step:
if self.debug_enabled:
LOG.debug(f"Getting step with task_id: {task_id} and step_id: {step_id}")
try:
with self.Session() as session:
if step := (
session.query(StepModel)
.options(joinedload(StepModel.artifacts))
.filter(StepModel.step_id == step_id)
.first()
):
return convert_to_step(step, self.debug_enabled)
else:
LOG.error(
f"Step not found with task_id: {task_id} and step_id: {step_id}"
)
raise NotFoundError("Step not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting step: {e}")
raise
async def update_step(
self,
task_id: str,
step_id: str,
status: str,
additional_input: Optional[Dict[str, Any]] = {},
) -> Step:
if self.debug_enabled:
LOG.debug(f"Updating step with task_id: {task_id} and step_id: {step_id}")
try:
with self.Session() as session:
if (
step := session.query(StepModel)
.filter_by(task_id=task_id, step_id=step_id)
.first()
):
step.status = status
step.additional_input = additional_input
session.commit()
return await self.get_step(task_id, step_id)
else:
LOG.error(
f"Step not found for update with task_id: {task_id} and step_id: {step_id}"
)
raise NotFoundError("Step not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting step: {e}")
raise
async def get_artifact(self, artifact_id: str) -> Artifact:
if self.debug_enabled:
LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}")
try:
with self.Session() as session:
if (
artifact_model := session.query(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.first()
):
return convert_to_artifact(artifact_model)
else:
LOG.error(f"Artifact not found with and artifact_id: {artifact_id}")
raise NotFoundError("Artifact not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting artifact: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting artifact: {e}")
raise
async def list_tasks(
self, page: int = 1, per_page: int = 10
) -> Tuple[List[Task], Pagination]:
if self.debug_enabled:
LOG.debug("Listing tasks")
try:
with self.Session() as session:
tasks = (
session.query(TaskModel)
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
total = session.query(TaskModel).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total,
total_pages=pages,
current_page=page,
page_size=per_page,
)
return [
convert_to_task(task, self.debug_enabled) for task in tasks
], pagination
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while listing tasks: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while listing tasks: {e}")
raise
async def list_steps(
self, task_id: str, page: int = 1, per_page: int = 10
) -> Tuple[List[Step], Pagination]:
if self.debug_enabled:
LOG.debug(f"Listing steps for task_id: {task_id}")
try:
with self.Session() as session:
steps = (
session.query(StepModel)
.filter_by(task_id=task_id)
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
total = session.query(StepModel).filter_by(task_id=task_id).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total,
total_pages=pages,
current_page=page,
page_size=per_page,
)
return [
convert_to_step(step, self.debug_enabled) for step in steps
], pagination
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while listing steps: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while listing steps: {e}")
raise
async def list_artifacts(
self, task_id: str, page: int = 1, per_page: int = 10
) -> Tuple[List[Artifact], Pagination]:
if self.debug_enabled:
LOG.debug(f"Listing artifacts for task_id: {task_id}")
try:
with self.Session() as session:
artifacts = (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
total = session.query(ArtifactModel).filter_by(task_id=task_id).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total,
total_pages=pages,
current_page=page,
page_size=per_page,
)
return [
convert_to_artifact(artifact) for artifact in artifacts
], pagination
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while listing artifacts: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while listing artifacts: {e}")
raise