mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-02-09 08:14:27 +01:00
feat(forge/db): Add AgentDB.update_artifact method
This commit is contained in:
@@ -7,7 +7,7 @@ IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
|
||||
import datetime
|
||||
import math
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
@@ -311,6 +311,29 @@ class AgentDB:
|
||||
LOG.error(f"Unexpected error while getting step: {e}")
|
||||
raise
|
||||
|
||||
async def get_artifact(self, artifact_id: str) -> Artifact:
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
artifact_model := session.query(ArtifactModel)
|
||||
.filter_by(artifact_id=artifact_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_artifact(artifact_model)
|
||||
else:
|
||||
LOG.error(f"Artifact not found with and artifact_id: {artifact_id}")
|
||||
raise NotFoundError("Artifact not found")
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while getting artifact: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Unexpected error while getting artifact: {e}")
|
||||
raise
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
task_id: str,
|
||||
@@ -353,28 +376,32 @@ class AgentDB:
|
||||
LOG.error(f"Unexpected error while getting step: {e}")
|
||||
raise
|
||||
|
||||
async def get_artifact(self, artifact_id: str) -> Artifact:
|
||||
if self.debug_enabled:
|
||||
LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
artifact_model := session.query(ArtifactModel)
|
||||
.filter_by(artifact_id=artifact_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_artifact(artifact_model)
|
||||
else:
|
||||
LOG.error(f"Artifact not found with and artifact_id: {artifact_id}")
|
||||
raise NotFoundError("Artifact not found")
|
||||
except SQLAlchemyError as e:
|
||||
LOG.error(f"SQLAlchemy error while getting artifact: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"Unexpected error while getting artifact: {e}")
|
||||
raise
|
||||
async def update_artifact(
|
||||
self,
|
||||
artifact_id: str,
|
||||
*,
|
||||
file_name: str = "",
|
||||
relative_path: str = "",
|
||||
agent_created: Optional[Literal[True]] = None,
|
||||
) -> Artifact:
|
||||
LOG.debug(f"Updating artifact with artifact_id: {artifact_id}")
|
||||
with self.Session() as session:
|
||||
if (
|
||||
artifact := session.query(ArtifactModel)
|
||||
.filter_by(artifact_id=artifact_id)
|
||||
.first()
|
||||
):
|
||||
if file_name:
|
||||
artifact.file_name = file_name
|
||||
if relative_path:
|
||||
artifact.relative_path = relative_path
|
||||
if agent_created:
|
||||
artifact.agent_created = agent_created
|
||||
session.commit()
|
||||
return await self.get_artifact(artifact_id)
|
||||
else:
|
||||
LOG.error(f"Artifact not found with artifact_id: {artifact_id}")
|
||||
raise NotFoundError("Artifact not found")
|
||||
|
||||
async def list_tasks(
|
||||
self, page: int = 1, per_page: int = 10
|
||||
|
||||
Reference in New Issue
Block a user