mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-17 14:04:27 +01:00
Added setup command refactored memstore (#5555)
* forge - restructured memstore * Stopped setup from being ran as defualt when running an agent
This commit is contained in:
@@ -1 +1,2 @@
|
|||||||
|
from .memstore import MemStore
|
||||||
|
from .chroma_memstore import ChromaMemStore
|
||||||
|
|||||||
160
autogpts/forge/forge/sdk/memory/chroma_memstore.py
Normal file
160
autogpts/forge/forge/sdk/memory/chroma_memstore.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
from .memstore import MemStore
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
from chromadb.config import Settings
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaMemStore:
|
||||||
|
"""
|
||||||
|
A class used to represent a Memory Store
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, store_path: str):
|
||||||
|
"""
|
||||||
|
Initialize the MemStore with a given store path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store_path (str): The path to the store.
|
||||||
|
"""
|
||||||
|
self.client = chromadb.PersistentClient(
|
||||||
|
path=store_path, settings=Settings(anonymized_telemetry=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def add(self, task_id: str, document: str, metadatas: dict) -> None:
|
||||||
|
"""
|
||||||
|
Add a document to the MemStore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (str): The ID of the task.
|
||||||
|
document (str): The document to be added.
|
||||||
|
metadatas (dict): The metadata of the document.
|
||||||
|
"""
|
||||||
|
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
|
||||||
|
collection = self.client.get_or_create_collection(task_id)
|
||||||
|
collection.add(documents=[document], metadatas=[metadatas], ids=[doc_id])
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
query: str,
|
||||||
|
filters: dict = None,
|
||||||
|
document_search: dict = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Query the MemStore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (str): The ID of the task.
|
||||||
|
query (str): The query string.
|
||||||
|
filters (dict, optional): The filters to be applied. Defaults to None.
|
||||||
|
search_string (str, optional): The search string. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The query results.
|
||||||
|
"""
|
||||||
|
collection = self.client.get_or_create_collection(task_id)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"query_texts": [query],
|
||||||
|
"n_results": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
if filters:
|
||||||
|
kwargs["where"] = filters
|
||||||
|
|
||||||
|
if document_search:
|
||||||
|
kwargs["where_document"] = document_search
|
||||||
|
|
||||||
|
return collection.query(**kwargs)
|
||||||
|
|
||||||
|
def get(self, task_id: str, doc_ids: list = None, filters: dict = None) -> dict:
|
||||||
|
"""
|
||||||
|
Get documents from the MemStore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (str): The ID of the task.
|
||||||
|
doc_ids (list, optional): The IDs of the documents to be retrieved. Defaults to None.
|
||||||
|
filters (dict, optional): The filters to be applied. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The retrieved documents.
|
||||||
|
"""
|
||||||
|
collection = self.client.get_or_create_collection(task_id)
|
||||||
|
kwargs = {}
|
||||||
|
if doc_ids:
|
||||||
|
kwargs["ids"] = doc_ids
|
||||||
|
if filters:
|
||||||
|
kwargs["where"] = filters
|
||||||
|
return collection.get(**kwargs)
|
||||||
|
|
||||||
|
def update(self, task_id: str, doc_ids: list, documents: list, metadatas: list):
|
||||||
|
"""
|
||||||
|
Update documents in the MemStore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (str): The ID of the task.
|
||||||
|
doc_ids (list): The IDs of the documents to be updated.
|
||||||
|
documents (list): The updated documents.
|
||||||
|
metadatas (list): The updated metadata.
|
||||||
|
"""
|
||||||
|
collection = self.client.get_or_create_collection(task_id)
|
||||||
|
collection.update(ids=doc_ids, documents=documents, metadatas=metadatas)
|
||||||
|
|
||||||
|
def delete(self, task_id: str, doc_id: str):
|
||||||
|
"""
|
||||||
|
Delete a document from the MemStore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (str): The ID of the task.
|
||||||
|
doc_id (str): The ID of the document to be deleted.
|
||||||
|
"""
|
||||||
|
collection = self.client.get_or_create_collection(task_id)
|
||||||
|
collection.delete(ids=[doc_id])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("#############################################")
|
||||||
|
# Initialize MemStore
|
||||||
|
mem = ChromaMemStore(".agent_mem_store")
|
||||||
|
|
||||||
|
# Test add function
|
||||||
|
task_id = "test_task"
|
||||||
|
document = "This is a another new test document."
|
||||||
|
metadatas = {"metadata": "test_metadata"}
|
||||||
|
mem.add(task_id, document, metadatas)
|
||||||
|
|
||||||
|
task_id = "test_task"
|
||||||
|
document = "The quick brown fox jumps over the lazy dog."
|
||||||
|
metadatas = {"metadata": "test_metadata"}
|
||||||
|
mem.add(task_id, document, metadatas)
|
||||||
|
|
||||||
|
task_id = "test_task"
|
||||||
|
document = "AI is a new technology that will change the world."
|
||||||
|
metadatas = {"timestamp": 1623936000}
|
||||||
|
mem.add(task_id, document, metadatas)
|
||||||
|
|
||||||
|
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
|
||||||
|
# Test query function
|
||||||
|
query = "test"
|
||||||
|
filters = {"metadata": {"$eq": "test"}}
|
||||||
|
search_string = {"$contains": "test"}
|
||||||
|
doc_ids = [doc_id]
|
||||||
|
documents = ["This is an updated test document."]
|
||||||
|
updated_metadatas = {"metadata": "updated_test_metadata"}
|
||||||
|
|
||||||
|
print("Query:")
|
||||||
|
print(mem.query(task_id, query))
|
||||||
|
|
||||||
|
# Test get function
|
||||||
|
print("Get:")
|
||||||
|
|
||||||
|
print(mem.get(task_id))
|
||||||
|
|
||||||
|
# Test update function
|
||||||
|
print("Update:")
|
||||||
|
print(mem.update(task_id, doc_ids, documents, updated_metadatas))
|
||||||
|
|
||||||
|
print("Delete:")
|
||||||
|
# Test delete function
|
||||||
|
print(mem.delete(task_id, doc_ids[0]))
|
||||||
@@ -149,158 +149,3 @@ class MemStore(abc.ABC):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def delete(self, collection_name: str, doc_id: str):
|
def delete(self, collection_name: str, doc_id: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ChromaMemStore(MemStore):
|
|
||||||
"""
|
|
||||||
A class used to represent a Memory Store
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, store_path: str):
|
|
||||||
"""
|
|
||||||
Initialize the MemStore with a given store path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
store_path (str): The path to the store.
|
|
||||||
"""
|
|
||||||
self.client = chromadb.PersistentClient(
|
|
||||||
path=store_path, settings=Settings(anonymized_telemetry=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
def add(self, task_id: str, document: str, metadatas: dict) -> None:
|
|
||||||
"""
|
|
||||||
Add a document to the MemStore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id (str): The ID of the task.
|
|
||||||
document (str): The document to be added.
|
|
||||||
metadatas (dict): The metadata of the document.
|
|
||||||
"""
|
|
||||||
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
|
|
||||||
collection = self.client.get_or_create_collection(task_id)
|
|
||||||
collection.add(documents=[document], metadatas=[metadatas], ids=[doc_id])
|
|
||||||
|
|
||||||
def query(
|
|
||||||
self,
|
|
||||||
task_id: str,
|
|
||||||
query: str,
|
|
||||||
filters: dict = None,
|
|
||||||
document_search: dict = None,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Query the MemStore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id (str): The ID of the task.
|
|
||||||
query (str): The query string.
|
|
||||||
filters (dict, optional): The filters to be applied. Defaults to None.
|
|
||||||
search_string (str, optional): The search string. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The query results.
|
|
||||||
"""
|
|
||||||
collection = self.client.get_or_create_collection(task_id)
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"query_texts": [query],
|
|
||||||
"n_results": 10,
|
|
||||||
}
|
|
||||||
|
|
||||||
if filters:
|
|
||||||
kwargs["where"] = filters
|
|
||||||
|
|
||||||
if document_search:
|
|
||||||
kwargs["where_document"] = document_search
|
|
||||||
|
|
||||||
return collection.query(**kwargs)
|
|
||||||
|
|
||||||
def get(self, task_id: str, doc_ids: list = None, filters: dict = None) -> dict:
|
|
||||||
"""
|
|
||||||
Get documents from the MemStore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id (str): The ID of the task.
|
|
||||||
doc_ids (list, optional): The IDs of the documents to be retrieved. Defaults to None.
|
|
||||||
filters (dict, optional): The filters to be applied. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The retrieved documents.
|
|
||||||
"""
|
|
||||||
collection = self.client.get_or_create_collection(task_id)
|
|
||||||
kwargs = {}
|
|
||||||
if doc_ids:
|
|
||||||
kwargs["ids"] = doc_ids
|
|
||||||
if filters:
|
|
||||||
kwargs["where"] = filters
|
|
||||||
return collection.get(**kwargs)
|
|
||||||
|
|
||||||
def update(self, task_id: str, doc_ids: list, documents: list, metadatas: list):
|
|
||||||
"""
|
|
||||||
Update documents in the MemStore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id (str): The ID of the task.
|
|
||||||
doc_ids (list): The IDs of the documents to be updated.
|
|
||||||
documents (list): The updated documents.
|
|
||||||
metadatas (list): The updated metadata.
|
|
||||||
"""
|
|
||||||
collection = self.client.get_or_create_collection(task_id)
|
|
||||||
collection.update(ids=doc_ids, documents=documents, metadatas=metadatas)
|
|
||||||
|
|
||||||
def delete(self, task_id: str, doc_id: str):
|
|
||||||
"""
|
|
||||||
Delete a document from the MemStore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id (str): The ID of the task.
|
|
||||||
doc_id (str): The ID of the document to be deleted.
|
|
||||||
"""
|
|
||||||
collection = self.client.get_or_create_collection(task_id)
|
|
||||||
collection.delete(ids=[doc_id])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("#############################################")
|
|
||||||
# Initialize MemStore
|
|
||||||
mem = MemStore(".agent_mem_store")
|
|
||||||
|
|
||||||
# Test add function
|
|
||||||
task_id = "test_task"
|
|
||||||
document = "This is a another new test document."
|
|
||||||
metadatas = {"metadata": "test_metadata"}
|
|
||||||
mem.add(task_id, document, metadatas)
|
|
||||||
|
|
||||||
task_id = "test_task"
|
|
||||||
document = "The quick brown fox jumps over the lazy dog."
|
|
||||||
metadatas = {"metadata": "test_metadata"}
|
|
||||||
mem.add(task_id, document, metadatas)
|
|
||||||
|
|
||||||
task_id = "test_task"
|
|
||||||
document = "AI is a new technology that will change the world."
|
|
||||||
metadatas = {"timestamp": 1623936000}
|
|
||||||
mem.add(task_id, document, metadatas)
|
|
||||||
|
|
||||||
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
|
|
||||||
# Test query function
|
|
||||||
query = "test"
|
|
||||||
filters = {"metadata": {"$eq": "test"}}
|
|
||||||
search_string = {"$contains": "test"}
|
|
||||||
doc_ids = [doc_id]
|
|
||||||
documents = ["This is an updated test document."]
|
|
||||||
updated_metadatas = {"metadata": "updated_test_metadata"}
|
|
||||||
|
|
||||||
print("Query:")
|
|
||||||
print(mem.query(task_id, query))
|
|
||||||
|
|
||||||
# Test get function
|
|
||||||
print("Get:")
|
|
||||||
|
|
||||||
print(mem.get(task_id))
|
|
||||||
|
|
||||||
# Test update function
|
|
||||||
print("Update:")
|
|
||||||
print(mem.update(task_id, doc_ids, documents, updated_metadatas))
|
|
||||||
|
|
||||||
print("Delete:")
|
|
||||||
# Test delete function
|
|
||||||
print(mem.delete(task_id, doc_ids[0]))
|
|
||||||
|
|||||||
4
cli.py
4
cli.py
@@ -258,7 +258,8 @@ def create(agent_name):
|
|||||||
|
|
||||||
@agent.command()
|
@agent.command()
|
||||||
@click.argument("agent_name")
|
@click.argument("agent_name")
|
||||||
def start(agent_name):
|
@click.option("--setup", is_flag=True, help="Rebuilds your poetry env")
|
||||||
|
def start(agent_name, setup):
|
||||||
"""Start agent command"""
|
"""Start agent command"""
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -269,6 +270,7 @@ def start(agent_name):
|
|||||||
run_bench_command = os.path.join(agent_dir, "run_benchmark")
|
run_bench_command = os.path.join(agent_dir, "run_benchmark")
|
||||||
if os.path.exists(agent_dir) and os.path.isfile(run_command) and os.path.isfile(run_bench_command):
|
if os.path.exists(agent_dir) and os.path.isfile(run_command) and os.path.isfile(run_bench_command):
|
||||||
os.chdir(agent_dir)
|
os.chdir(agent_dir)
|
||||||
|
if setup:
|
||||||
setup_process = subprocess.Popen(["./setup"], cwd=agent_dir)
|
setup_process = subprocess.Popen(["./setup"], cwd=agent_dir)
|
||||||
setup_process.wait()
|
setup_process.wait()
|
||||||
subprocess.Popen(["./run_benchmark", "serve"], cwd=agent_dir)
|
subprocess.Popen(["./run_benchmark", "serve"], cwd=agent_dir)
|
||||||
|
|||||||
Reference in New Issue
Block a user