feat(forge/db): Add AgentDB.update_artifact method

This commit is contained in:
Reinier van der Leer
2024-01-19 11:41:40 +01:00
parent 9012ff4db2
commit b238abac52

View File

@@ -7,7 +7,7 @@ IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
import datetime
import math
import uuid
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple
from sqlalchemy import (
JSON,
@@ -311,6 +311,29 @@ class AgentDB:
LOG.error(f"Unexpected error while getting step: {e}")
raise
async def get_artifact(self, artifact_id: str) -> Artifact:
if self.debug_enabled:
LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}")
try:
with self.Session() as session:
if (
artifact_model := session.query(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.first()
):
return convert_to_artifact(artifact_model)
else:
LOG.error(f"Artifact not found with and artifact_id: {artifact_id}")
raise NotFoundError("Artifact not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting artifact: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting artifact: {e}")
raise
async def update_step(
self,
task_id: str,
@@ -353,28 +376,32 @@ class AgentDB:
LOG.error(f"Unexpected error while getting step: {e}")
raise
async def get_artifact(self, artifact_id: str) -> Artifact:
if self.debug_enabled:
LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}")
try:
with self.Session() as session:
if (
artifact_model := session.query(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.first()
):
return convert_to_artifact(artifact_model)
else:
LOG.error(f"Artifact not found with and artifact_id: {artifact_id}")
raise NotFoundError("Artifact not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting artifact: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting artifact: {e}")
raise
async def update_artifact(
self,
artifact_id: str,
*,
file_name: str = "",
relative_path: str = "",
agent_created: Optional[Literal[True]] = None,
) -> Artifact:
LOG.debug(f"Updating artifact with artifact_id: {artifact_id}")
with self.Session() as session:
if (
artifact := session.query(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.first()
):
if file_name:
artifact.file_name = file_name
if relative_path:
artifact.relative_path = relative_path
if agent_created:
artifact.agent_created = agent_created
session.commit()
return await self.get_artifact(artifact_id)
else:
LOG.error(f"Artifact not found with artifact_id: {artifact_id}")
raise NotFoundError("Artifact not found")
async def list_tasks(
self, page: int = 1, per_page: int = 10