Added setup command refactored memstore (#5555)

* forge -  restructured memstore

* Stopped setup from being ran as defualt when running an agent
This commit is contained in:
Swifty
2023-10-05 10:13:59 -07:00
committed by GitHub
parent 3b7d83a1a6
commit bef8203da2
4 changed files with 167 additions and 159 deletions

View File

@@ -1 +1,2 @@
from .memstore import MemStore
from .chroma_memstore import ChromaMemStore

View File

@@ -0,0 +1,160 @@
from .memstore import MemStore
import chromadb
from chromadb.config import Settings
import hashlib
class ChromaMemStore:
"""
A class used to represent a Memory Store
"""
def __init__(self, store_path: str):
"""
Initialize the MemStore with a given store path.
Args:
store_path (str): The path to the store.
"""
self.client = chromadb.PersistentClient(
path=store_path, settings=Settings(anonymized_telemetry=False)
)
def add(self, task_id: str, document: str, metadatas: dict) -> None:
"""
Add a document to the MemStore.
Args:
task_id (str): The ID of the task.
document (str): The document to be added.
metadatas (dict): The metadata of the document.
"""
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
collection = self.client.get_or_create_collection(task_id)
collection.add(documents=[document], metadatas=[metadatas], ids=[doc_id])
def query(
self,
task_id: str,
query: str,
filters: dict = None,
document_search: dict = None,
) -> dict:
"""
Query the MemStore.
Args:
task_id (str): The ID of the task.
query (str): The query string.
filters (dict, optional): The filters to be applied. Defaults to None.
search_string (str, optional): The search string. Defaults to None.
Returns:
dict: The query results.
"""
collection = self.client.get_or_create_collection(task_id)
kwargs = {
"query_texts": [query],
"n_results": 10,
}
if filters:
kwargs["where"] = filters
if document_search:
kwargs["where_document"] = document_search
return collection.query(**kwargs)
def get(self, task_id: str, doc_ids: list = None, filters: dict = None) -> dict:
"""
Get documents from the MemStore.
Args:
task_id (str): The ID of the task.
doc_ids (list, optional): The IDs of the documents to be retrieved. Defaults to None.
filters (dict, optional): The filters to be applied. Defaults to None.
Returns:
dict: The retrieved documents.
"""
collection = self.client.get_or_create_collection(task_id)
kwargs = {}
if doc_ids:
kwargs["ids"] = doc_ids
if filters:
kwargs["where"] = filters
return collection.get(**kwargs)
def update(self, task_id: str, doc_ids: list, documents: list, metadatas: list):
"""
Update documents in the MemStore.
Args:
task_id (str): The ID of the task.
doc_ids (list): The IDs of the documents to be updated.
documents (list): The updated documents.
metadatas (list): The updated metadata.
"""
collection = self.client.get_or_create_collection(task_id)
collection.update(ids=doc_ids, documents=documents, metadatas=metadatas)
def delete(self, task_id: str, doc_id: str):
"""
Delete a document from the MemStore.
Args:
task_id (str): The ID of the task.
doc_id (str): The ID of the document to be deleted.
"""
collection = self.client.get_or_create_collection(task_id)
collection.delete(ids=[doc_id])
if __name__ == "__main__":
print("#############################################")
# Initialize MemStore
mem = ChromaMemStore(".agent_mem_store")
# Test add function
task_id = "test_task"
document = "This is a another new test document."
metadatas = {"metadata": "test_metadata"}
mem.add(task_id, document, metadatas)
task_id = "test_task"
document = "The quick brown fox jumps over the lazy dog."
metadatas = {"metadata": "test_metadata"}
mem.add(task_id, document, metadatas)
task_id = "test_task"
document = "AI is a new technology that will change the world."
metadatas = {"timestamp": 1623936000}
mem.add(task_id, document, metadatas)
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
# Test query function
query = "test"
filters = {"metadata": {"$eq": "test"}}
search_string = {"$contains": "test"}
doc_ids = [doc_id]
documents = ["This is an updated test document."]
updated_metadatas = {"metadata": "updated_test_metadata"}
print("Query:")
print(mem.query(task_id, query))
# Test get function
print("Get:")
print(mem.get(task_id))
# Test update function
print("Update:")
print(mem.update(task_id, doc_ids, documents, updated_metadatas))
print("Delete:")
# Test delete function
print(mem.delete(task_id, doc_ids[0]))

View File

@@ -149,158 +149,3 @@ class MemStore(abc.ABC):
@abc.abstractmethod
def delete(self, collection_name: str, doc_id: str):
pass
class ChromaMemStore(MemStore):
"""
A class used to represent a Memory Store
"""
def __init__(self, store_path: str):
"""
Initialize the MemStore with a given store path.
Args:
store_path (str): The path to the store.
"""
self.client = chromadb.PersistentClient(
path=store_path, settings=Settings(anonymized_telemetry=False)
)
def add(self, task_id: str, document: str, metadatas: dict) -> None:
"""
Add a document to the MemStore.
Args:
task_id (str): The ID of the task.
document (str): The document to be added.
metadatas (dict): The metadata of the document.
"""
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
collection = self.client.get_or_create_collection(task_id)
collection.add(documents=[document], metadatas=[metadatas], ids=[doc_id])
def query(
self,
task_id: str,
query: str,
filters: dict = None,
document_search: dict = None,
) -> dict:
"""
Query the MemStore.
Args:
task_id (str): The ID of the task.
query (str): The query string.
filters (dict, optional): The filters to be applied. Defaults to None.
search_string (str, optional): The search string. Defaults to None.
Returns:
dict: The query results.
"""
collection = self.client.get_or_create_collection(task_id)
kwargs = {
"query_texts": [query],
"n_results": 10,
}
if filters:
kwargs["where"] = filters
if document_search:
kwargs["where_document"] = document_search
return collection.query(**kwargs)
def get(self, task_id: str, doc_ids: list = None, filters: dict = None) -> dict:
"""
Get documents from the MemStore.
Args:
task_id (str): The ID of the task.
doc_ids (list, optional): The IDs of the documents to be retrieved. Defaults to None.
filters (dict, optional): The filters to be applied. Defaults to None.
Returns:
dict: The retrieved documents.
"""
collection = self.client.get_or_create_collection(task_id)
kwargs = {}
if doc_ids:
kwargs["ids"] = doc_ids
if filters:
kwargs["where"] = filters
return collection.get(**kwargs)
def update(self, task_id: str, doc_ids: list, documents: list, metadatas: list):
"""
Update documents in the MemStore.
Args:
task_id (str): The ID of the task.
doc_ids (list): The IDs of the documents to be updated.
documents (list): The updated documents.
metadatas (list): The updated metadata.
"""
collection = self.client.get_or_create_collection(task_id)
collection.update(ids=doc_ids, documents=documents, metadatas=metadatas)
def delete(self, task_id: str, doc_id: str):
"""
Delete a document from the MemStore.
Args:
task_id (str): The ID of the task.
doc_id (str): The ID of the document to be deleted.
"""
collection = self.client.get_or_create_collection(task_id)
collection.delete(ids=[doc_id])
if __name__ == "__main__":
print("#############################################")
# Initialize MemStore
mem = MemStore(".agent_mem_store")
# Test add function
task_id = "test_task"
document = "This is a another new test document."
metadatas = {"metadata": "test_metadata"}
mem.add(task_id, document, metadatas)
task_id = "test_task"
document = "The quick brown fox jumps over the lazy dog."
metadatas = {"metadata": "test_metadata"}
mem.add(task_id, document, metadatas)
task_id = "test_task"
document = "AI is a new technology that will change the world."
metadatas = {"timestamp": 1623936000}
mem.add(task_id, document, metadatas)
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
# Test query function
query = "test"
filters = {"metadata": {"$eq": "test"}}
search_string = {"$contains": "test"}
doc_ids = [doc_id]
documents = ["This is an updated test document."]
updated_metadatas = {"metadata": "updated_test_metadata"}
print("Query:")
print(mem.query(task_id, query))
# Test get function
print("Get:")
print(mem.get(task_id))
# Test update function
print("Update:")
print(mem.update(task_id, doc_ids, documents, updated_metadatas))
print("Delete:")
# Test delete function
print(mem.delete(task_id, doc_ids[0]))

8
cli.py
View File

@@ -258,7 +258,8 @@ def create(agent_name):
@agent.command()
@click.argument("agent_name")
def start(agent_name):
@click.option("--setup", is_flag=True, help="Rebuilds your poetry env")
def start(agent_name, setup):
"""Start agent command"""
import os
import subprocess
@@ -269,8 +270,9 @@ def start(agent_name):
run_bench_command = os.path.join(agent_dir, "run_benchmark")
if os.path.exists(agent_dir) and os.path.isfile(run_command) and os.path.isfile(run_bench_command):
os.chdir(agent_dir)
setup_process = subprocess.Popen(["./setup"], cwd=agent_dir)
setup_process.wait()
if setup:
setup_process = subprocess.Popen(["./setup"], cwd=agent_dir)
setup_process.wait()
subprocess.Popen(["./run_benchmark", "serve"], cwd=agent_dir)
click.echo(f"Benchmark Server starting please wait...")
subprocess.Popen(["./run"], cwd=agent_dir)