mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-02-06 14:54:40 +01:00
Addition of Simple Memory System Based on ChromaDB (#28)
This commit is contained in:
@@ -31,12 +31,13 @@ repos:
|
||||
hooks:
|
||||
- id: autoflake
|
||||
name: autoflake
|
||||
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt
|
||||
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring forge/autogpt
|
||||
language: python
|
||||
types: [ python ]
|
||||
- id: pytest-check
|
||||
name: pytest-check
|
||||
entry: pytest
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
# Mono repo has bronken this TODO: fix
|
||||
# - id: pytest-check
|
||||
# name: pytest-check
|
||||
# entry: pytest
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
# always_run: true
|
||||
|
||||
1
forge/autogpt/sdk/memory/__init__.py
Normal file
1
forge/autogpt/sdk/memory/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
159
forge/autogpt/sdk/memory/memstore.py
Normal file
159
forge/autogpt/sdk/memory/memstore.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import hashlib
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
|
||||
class MemStore:
|
||||
"""
|
||||
A class used to represent a Memory Store
|
||||
"""
|
||||
|
||||
def __init__(self, store_path: str):
|
||||
"""
|
||||
Initialize the MemStore with a given store path.
|
||||
|
||||
Args:
|
||||
store_path (str): The path to the store.
|
||||
"""
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=store_path, settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
|
||||
def add(self, task_id: str, document: str, metadatas: dict) -> None:
|
||||
"""
|
||||
Add a document to the MemStore.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task.
|
||||
document (str): The document to be added.
|
||||
metadatas (dict): The metadata of the document.
|
||||
"""
|
||||
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
|
||||
collection = self.client.get_or_create_collection(task_id)
|
||||
collection.add(documents=[document], metadatas=[metadatas], ids=[doc_id])
|
||||
|
||||
def query(
|
||||
self,
|
||||
task_id: str,
|
||||
query: str,
|
||||
filters: dict = None,
|
||||
document_search: dict = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Query the MemStore.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task.
|
||||
query (str): The query string.
|
||||
filters (dict, optional): The filters to be applied. Defaults to None.
|
||||
search_string (str, optional): The search string. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: The query results.
|
||||
"""
|
||||
collection = self.client.get_or_create_collection(task_id)
|
||||
|
||||
kwargs = {
|
||||
"query_texts": [query],
|
||||
"n_results": 10,
|
||||
}
|
||||
|
||||
if filters:
|
||||
kwargs["where"] = filters
|
||||
|
||||
if document_search:
|
||||
kwargs["where_document"] = document_search
|
||||
|
||||
return collection.query(**kwargs)
|
||||
|
||||
def get(self, task_id: str, doc_ids: list = None, filters: dict = None) -> dict:
|
||||
"""
|
||||
Get documents from the MemStore.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task.
|
||||
doc_ids (list, optional): The IDs of the documents to be retrieved. Defaults to None.
|
||||
filters (dict, optional): The filters to be applied. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: The retrieved documents.
|
||||
"""
|
||||
collection = self.client.get_or_create_collection(task_id)
|
||||
kwargs = {}
|
||||
if doc_ids:
|
||||
kwargs["ids"] = doc_ids
|
||||
if filters:
|
||||
kwargs["where"] = filters
|
||||
return collection.get(**kwargs)
|
||||
|
||||
def update(self, task_id: str, doc_ids: list, documents: list, metadatas: list):
|
||||
"""
|
||||
Update documents in the MemStore.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task.
|
||||
doc_ids (list): The IDs of the documents to be updated.
|
||||
documents (list): The updated documents.
|
||||
metadatas (list): The updated metadata.
|
||||
"""
|
||||
collection = self.client.get_or_create_collection(task_id)
|
||||
collection.update(ids=doc_ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
def delete(self, task_id: str, doc_id: str):
|
||||
"""
|
||||
Delete a document from the MemStore.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task.
|
||||
doc_id (str): The ID of the document to be deleted.
|
||||
"""
|
||||
collection = self.client.get_or_create_collection(task_id)
|
||||
collection.delete(ids=[doc_id])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("#############################################")
|
||||
# Initialize MemStore
|
||||
mem = MemStore(".agent_mem_store")
|
||||
|
||||
# Test add function
|
||||
task_id = "test_task"
|
||||
document = "This is a another new test document."
|
||||
metadatas = {"metadata": "test_metadata"}
|
||||
mem.add(task_id, document, metadatas)
|
||||
|
||||
task_id = "test_task"
|
||||
document = "The quick brown fox jumps over the lazy dog."
|
||||
metadatas = {"metadata": "test_metadata"}
|
||||
mem.add(task_id, document, metadatas)
|
||||
|
||||
task_id = "test_task"
|
||||
document = "AI is a new technology that will change the world."
|
||||
metadatas = {"timestamp": 1623936000}
|
||||
mem.add(task_id, document, metadatas)
|
||||
|
||||
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
|
||||
# Test query function
|
||||
query = "test"
|
||||
filters = {"metadata": {"$eq": "test"}}
|
||||
search_string = {"$contains": "test"}
|
||||
doc_ids = [doc_id]
|
||||
documents = ["This is an updated test document."]
|
||||
updated_metadatas = {"metadata": "updated_test_metadata"}
|
||||
|
||||
print("Query:")
|
||||
print(mem.query(task_id, query))
|
||||
|
||||
# Test get function
|
||||
print("Get:")
|
||||
|
||||
print(mem.get(task_id))
|
||||
|
||||
# Test update function
|
||||
print("Update:")
|
||||
print(mem.update(task_id, doc_ids, documents, updated_metadatas))
|
||||
|
||||
print("Delete:")
|
||||
# Test delete function
|
||||
print(mem.delete(task_id, doc_ids[0]))
|
||||
58
forge/autogpt/sdk/memory/memstore_test.py
Normal file
58
forge/autogpt/sdk/memory/memstore_test.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import hashlib
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt.sdk.memory.memstore import MemStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memstore():
|
||||
mem = MemStore(".test_mem_store")
|
||||
yield mem
|
||||
shutil.rmtree(".test_mem_store")
|
||||
|
||||
|
||||
def test_add(memstore):
|
||||
task_id = "test_task"
|
||||
document = "This is a test document."
|
||||
metadatas = {"metadata": "test_metadata"}
|
||||
memstore.add(task_id, document, metadatas)
|
||||
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
|
||||
assert memstore.client.get_or_create_collection(task_id).count() == 1
|
||||
|
||||
|
||||
def test_query(memstore):
|
||||
task_id = "test_task"
|
||||
document = "This is a test document."
|
||||
metadatas = {"metadata": "test_metadata"}
|
||||
memstore.add(task_id, document, metadatas)
|
||||
query = "test"
|
||||
assert len(memstore.query(task_id, query)["documents"]) == 1
|
||||
|
||||
|
||||
def test_update(memstore):
|
||||
task_id = "test_task"
|
||||
document = "This is a test document."
|
||||
metadatas = {"metadata": "test_metadata"}
|
||||
memstore.add(task_id, document, metadatas)
|
||||
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
|
||||
updated_document = "This is an updated test document."
|
||||
updated_metadatas = {"metadata": "updated_test_metadata"}
|
||||
memstore.update(task_id, [doc_id], [updated_document], [updated_metadatas])
|
||||
assert memstore.get(task_id, [doc_id]) == {
|
||||
"documents": [updated_document],
|
||||
"metadatas": [updated_metadatas],
|
||||
"embeddings": None,
|
||||
"ids": [doc_id],
|
||||
}
|
||||
|
||||
|
||||
def test_delete(memstore):
|
||||
task_id = "test_task"
|
||||
document = "This is a test document."
|
||||
metadatas = {"metadata": "test_metadata"}
|
||||
memstore.add(task_id, document, metadatas)
|
||||
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
|
||||
memstore.delete(task_id, doc_id)
|
||||
assert memstore.client.get_or_create_collection(task_id).count() == 0
|
||||
1465
forge/poetry.lock
generated
1465
forge/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,7 @@ tenacity = "^8.2.2"
|
||||
sqlalchemy = "^2.0.19"
|
||||
aiohttp = "^3.8.5"
|
||||
colorlog = "^6.7.0"
|
||||
chromadb = "^0.4.8"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
@@ -22,7 +23,6 @@ black = "^23.3.0"
|
||||
pre-commit = "^3.3.3"
|
||||
mypy = "^1.4.1"
|
||||
flake8 = "^6.0.0"
|
||||
agbenchmark = "^0.0.9"
|
||||
types-requests = "^2.31.0.2"
|
||||
pytest = "^7.4.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
|
||||
Reference in New Issue
Block a user