This commit is contained in:
BillSchumacher
2023-04-16 14:15:38 -05:00
parent c544cebbe6
commit 3fadf2c90b
10 changed files with 92 additions and 73 deletions

View File

@@ -48,7 +48,7 @@ class AgentManager(metaclass=Singleton):
for i, plugin in enumerate(self.cfg.plugins): for i, plugin in enumerate(self.cfg.plugins):
plugin_result = plugin.on_instruction(messages) plugin_result = plugin.on_instruction(messages)
if plugin_result: if plugin_result:
sep = '' if not i else '\n' sep = "" if not i else "\n"
plugins_reply = f"{plugins_reply}{sep}{plugin_result}" plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
if plugins_reply and plugins_reply != "": if plugins_reply and plugins_reply != "":
@@ -98,7 +98,7 @@ class AgentManager(metaclass=Singleton):
for i, plugin in enumerate(self.cfg.plugins): for i, plugin in enumerate(self.cfg.plugins):
plugin_result = plugin.on_instruction(messages) plugin_result = plugin.on_instruction(messages)
if plugin_result: if plugin_result:
sep = '' if not i else '\n' sep = "" if not i else "\n"
plugins_reply = f"{plugins_reply}{sep}{plugin_result}" plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
# Update full message history # Update full message history
if plugins_reply and plugins_reply != "": if plugins_reply and plugins_reply != "":

View File

@@ -132,9 +132,12 @@ def execute_command(command_name: str, arguments, prompt: PromptGenerator):
# google_result can be a list or a string depending on the search results # google_result can be a list or a string depending on the search results
if isinstance(google_result, list): if isinstance(google_result, list):
safe_message = [google_result_single.encode('utf-8', 'ignore') for google_result_single in google_result] safe_message = [
google_result_single.encode("utf-8", "ignore")
for google_result_single in google_result
]
else: else:
safe_message = google_result.encode('utf-8', 'ignore') safe_message = google_result.encode("utf-8", "ignore")
return str(safe_message) return str(safe_message)
elif command_name == "memory_add": elif command_name == "memory_add":

View File

@@ -40,7 +40,7 @@ def log_operation(operation: str, filename: str) -> None:
with open(LOG_FILE_PATH, "w", encoding="utf-8") as f: with open(LOG_FILE_PATH, "w", encoding="utf-8") as f:
f.write("File Operation Logger ") f.write("File Operation Logger ")
append_to_file(LOG_FILE, log_entry, shouldLog = False) append_to_file(LOG_FILE, log_entry, shouldLog=False)
def split_file( def split_file(

View File

@@ -74,7 +74,9 @@ class Config(metaclass=Singleton):
self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None) self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None)
self.weaviate_embedded_path = os.getenv("WEAVIATE_EMBEDDED_PATH") self.weaviate_embedded_path = os.getenv("WEAVIATE_EMBEDDED_PATH")
self.weaviate_api_key = os.getenv("WEAVIATE_API_KEY", None) self.weaviate_api_key = os.getenv("WEAVIATE_API_KEY", None)
self.use_weaviate_embedded = os.getenv("USE_WEAVIATE_EMBEDDED", "False") == "True" self.use_weaviate_embedded = (
os.getenv("USE_WEAVIATE_EMBEDDED", "False") == "True"
)
# milvus configuration, e.g., localhost:19530. # milvus configuration, e.g., localhost:19530.
self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530") self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530")

View File

@@ -56,8 +56,10 @@ def get_memory(cfg, init=False):
memory = RedisMemory(cfg) memory = RedisMemory(cfg)
elif cfg.memory_backend == "weaviate": elif cfg.memory_backend == "weaviate":
if not WeaviateMemory: if not WeaviateMemory:
print("Error: Weaviate is not installed. Please install weaviate-client to" print(
" use Weaviate as a memory backend.") "Error: Weaviate is not installed. Please install weaviate-client to"
" use Weaviate as a memory backend."
)
else: else:
memory = WeaviateMemory(cfg) memory = WeaviateMemory(cfg)
elif cfg.memory_backend == "milvus": elif cfg.memory_backend == "milvus":
@@ -89,5 +91,5 @@ __all__ = [
"PineconeMemory", "PineconeMemory",
"NoMemory", "NoMemory",
"MilvusMemory", "MilvusMemory",
"WeaviateMemory" "WeaviateMemory",
] ]

View File

@@ -53,7 +53,7 @@ class NoMemory(MemoryProviderSingleton):
""" """
return "" return ""
def get_relevant(self, data: str, num_relevant: int = 5) ->list[Any] | None: def get_relevant(self, data: str, num_relevant: int = 5) -> list[Any] | None:
""" """
Returns all the data in the memory that is relevant to the given data. Returns all the data in the memory that is relevant to the given data.
NoMemory always returns None. NoMemory always returns None.

View File

@@ -14,7 +14,7 @@ def default_schema(weaviate_index):
{ {
"name": "raw_text", "name": "raw_text",
"dataType": ["text"], "dataType": ["text"],
"description": "original text for the embedding" "description": "original text for the embedding",
} }
], ],
} }
@@ -24,16 +24,20 @@ class WeaviateMemory(MemoryProviderSingleton):
def __init__(self, cfg): def __init__(self, cfg):
auth_credentials = self._build_auth_credentials(cfg) auth_credentials = self._build_auth_credentials(cfg)
url = f'{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}' url = f"{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}"
if cfg.use_weaviate_embedded: if cfg.use_weaviate_embedded:
self.client = Client(embedded_options=EmbeddedOptions( self.client = Client(
embedded_options=EmbeddedOptions(
hostname=cfg.weaviate_host, hostname=cfg.weaviate_host,
port=int(cfg.weaviate_port), port=int(cfg.weaviate_port),
persistence_data_path=cfg.weaviate_embedded_path persistence_data_path=cfg.weaviate_embedded_path,
)) )
)
print(f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}") print(
f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}"
)
else: else:
self.client = Client(url, auth_client_secret=auth_credentials) self.client = Client(url, auth_client_secret=auth_credentials)
@@ -47,7 +51,9 @@ class WeaviateMemory(MemoryProviderSingleton):
def _build_auth_credentials(self, cfg): def _build_auth_credentials(self, cfg):
if cfg.weaviate_username and cfg.weaviate_password: if cfg.weaviate_username and cfg.weaviate_password:
return weaviate.AuthClientPassword(cfg.weaviate_username, cfg.weaviate_password) return weaviate.AuthClientPassword(
cfg.weaviate_username, cfg.weaviate_password
)
if cfg.weaviate_api_key: if cfg.weaviate_api_key:
return weaviate.AuthApiKey(api_key=cfg.weaviate_api_key) return weaviate.AuthApiKey(api_key=cfg.weaviate_api_key)
else: else:
@@ -57,16 +63,14 @@ class WeaviateMemory(MemoryProviderSingleton):
vector = get_ada_embedding(data) vector = get_ada_embedding(data)
doc_uuid = generate_uuid5(data, self.index) doc_uuid = generate_uuid5(data, self.index)
data_object = { data_object = {"raw_text": data}
'raw_text': data
}
with self.client.batch as batch: with self.client.batch as batch:
batch.add_data_object( batch.add_data_object(
uuid=doc_uuid, uuid=doc_uuid,
data_object=data_object, data_object=data_object,
class_name=self.index, class_name=self.index,
vector=vector vector=vector,
) )
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}" return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}"
@@ -82,29 +86,31 @@ class WeaviateMemory(MemoryProviderSingleton):
# after a call to delete_all # after a call to delete_all
self._create_schema() self._create_schema()
return 'Obliterated' return "Obliterated"
def get_relevant(self, data, num_relevant=5): def get_relevant(self, data, num_relevant=5):
query_embedding = get_ada_embedding(data) query_embedding = get_ada_embedding(data)
try: try:
results = self.client.query.get(self.index, ['raw_text']) \ results = (
.with_near_vector({'vector': query_embedding, 'certainty': 0.7}) \ self.client.query.get(self.index, ["raw_text"])
.with_limit(num_relevant) \ .with_near_vector({"vector": query_embedding, "certainty": 0.7})
.with_limit(num_relevant)
.do() .do()
)
if len(results['data']['Get'][self.index]) > 0: if len(results["data"]["Get"][self.index]) > 0:
return [str(item['raw_text']) for item in results['data']['Get'][self.index]] return [
str(item["raw_text"]) for item in results["data"]["Get"][self.index]
]
else: else:
return [] return []
except Exception as err: except Exception as err:
print(f'Unexpected error {err=}, {type(err)=}') print(f"Unexpected error {err=}, {type(err)=}")
return [] return []
def get_stats(self): def get_stats(self):
result = self.client.query.aggregate(self.index) \ result = self.client.query.aggregate(self.index).with_meta_count().do()
.with_meta_count() \ class_data = result["data"]["Aggregate"][self.index]
.do()
class_data = result['data']['Aggregate'][self.index]
return class_data[0]['meta'] if class_data else {} return class_data[0]["meta"] if class_data else {}

View File

@@ -84,11 +84,7 @@ def build_default_prompt_generator() -> PromptGenerator:
# Only add the audio to text command if the model is specified # Only add the audio to text command if the model is specified
if cfg.huggingface_audio_to_text_model: if cfg.huggingface_audio_to_text_model:
commands.append( commands.append(
( ("Convert Audio to text", "read_audio_from_file", {"file": "<file>"}),
"Convert Audio to text",
"read_audio_from_file",
{"file": "<file>"}
),
) )
# Only add shell command to the prompt if the AI is allowed to execute it # Only add shell command to the prompt if the AI is allowed to execute it

View File

@@ -36,6 +36,8 @@ def safe_path_join(base: Path, *paths: str | Path) -> Path:
joined_path = base.joinpath(*paths).resolve() joined_path = base.joinpath(*paths).resolve()
if not joined_path.is_relative_to(base): if not joined_path.is_relative_to(base):
raise ValueError(f"Attempted to access path '{joined_path}' outside of working directory '{base}'.") raise ValueError(
f"Attempted to access path '{joined_path}' outside of working directory '{base}'."
)
return joined_path return joined_path

View File

@@ -12,14 +12,17 @@ from autogpt.memory.weaviate import WeaviateMemory
from autogpt.memory.base import get_ada_embedding from autogpt.memory.base import get_ada_embedding
@mock.patch.dict(os.environ, { @mock.patch.dict(
os.environ,
{
"WEAVIATE_HOST": "127.0.0.1", "WEAVIATE_HOST": "127.0.0.1",
"WEAVIATE_PROTOCOL": "http", "WEAVIATE_PROTOCOL": "http",
"WEAVIATE_PORT": "8080", "WEAVIATE_PORT": "8080",
"WEAVIATE_USERNAME": "", "WEAVIATE_USERNAME": "",
"WEAVIATE_PASSWORD": "", "WEAVIATE_PASSWORD": "",
"MEMORY_INDEX": "AutogptTests" "MEMORY_INDEX": "AutogptTests",
}) },
)
class TestWeaviateMemory(unittest.TestCase): class TestWeaviateMemory(unittest.TestCase):
cfg = None cfg = None
client = None client = None
@@ -32,13 +35,17 @@ class TestWeaviateMemory(unittest.TestCase):
if cls.cfg.use_weaviate_embedded: if cls.cfg.use_weaviate_embedded:
from weaviate.embedded import EmbeddedOptions from weaviate.embedded import EmbeddedOptions
cls.client = Client(embedded_options=EmbeddedOptions( cls.client = Client(
embedded_options=EmbeddedOptions(
hostname=cls.cfg.weaviate_host, hostname=cls.cfg.weaviate_host,
port=int(cls.cfg.weaviate_port), port=int(cls.cfg.weaviate_port),
persistence_data_path=cls.cfg.weaviate_embedded_path persistence_data_path=cls.cfg.weaviate_embedded_path,
)) )
)
else: else:
cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}") cls.client = Client(
f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}"
)
""" """
In order to run these tests you will need a local instance of In order to run these tests you will need a local instance of
@@ -49,6 +56,7 @@ class TestWeaviateMemory(unittest.TestCase):
USE_WEAVIATE_EMBEDDED=True USE_WEAVIATE_EMBEDDED=True
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate" WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
""" """
def setUp(self): def setUp(self):
try: try:
self.client.schema.delete_class(self.cfg.memory_index) self.client.schema.delete_class(self.cfg.memory_index)
@@ -58,23 +66,23 @@ class TestWeaviateMemory(unittest.TestCase):
self.memory = WeaviateMemory(self.cfg) self.memory = WeaviateMemory(self.cfg)
def test_add(self): def test_add(self):
doc = 'You are a Titan name Thanos and you are looking for the Infinity Stones' doc = "You are a Titan name Thanos and you are looking for the Infinity Stones"
self.memory.add(doc) self.memory.add(doc)
result = self.client.query.get(self.cfg.memory_index, ['raw_text']).do() result = self.client.query.get(self.cfg.memory_index, ["raw_text"]).do()
actual = result['data']['Get'][self.cfg.memory_index] actual = result["data"]["Get"][self.cfg.memory_index]
self.assertEqual(len(actual), 1) self.assertEqual(len(actual), 1)
self.assertEqual(actual[0]['raw_text'], doc) self.assertEqual(actual[0]["raw_text"], doc)
def test_get(self): def test_get(self):
doc = 'You are an Avenger and swore to defend the Galaxy from a menace called Thanos' doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos"
with self.client.batch as batch: with self.client.batch as batch:
batch.add_data_object( batch.add_data_object(
uuid=get_valid_uuid(uuid4()), uuid=get_valid_uuid(uuid4()),
data_object={'raw_text': doc}, data_object={"raw_text": doc},
class_name=self.cfg.memory_index, class_name=self.cfg.memory_index,
vector=get_ada_embedding(doc) vector=get_ada_embedding(doc),
) )
batch.flush() batch.flush()
@@ -86,8 +94,8 @@ class TestWeaviateMemory(unittest.TestCase):
def test_get_stats(self): def test_get_stats(self):
docs = [ docs = [
'You are now about to count the number of docs in this index', "You are now about to count the number of docs in this index",
'And then you about to find out if you can count correctly' "And then you about to find out if you can count correctly",
] ]
[self.memory.add(doc) for doc in docs] [self.memory.add(doc) for doc in docs]
@@ -95,23 +103,23 @@ class TestWeaviateMemory(unittest.TestCase):
stats = self.memory.get_stats() stats = self.memory.get_stats()
self.assertTrue(stats) self.assertTrue(stats)
self.assertTrue('count' in stats) self.assertTrue("count" in stats)
self.assertEqual(stats['count'], 2) self.assertEqual(stats["count"], 2)
def test_clear(self): def test_clear(self):
docs = [ docs = [
'Shame this is the last test for this class', "Shame this is the last test for this class",
'Testing is fun when someone else is doing it' "Testing is fun when someone else is doing it",
] ]
[self.memory.add(doc) for doc in docs] [self.memory.add(doc) for doc in docs]
self.assertEqual(self.memory.get_stats()['count'], 2) self.assertEqual(self.memory.get_stats()["count"], 2)
self.memory.clear() self.memory.clear()
self.assertEqual(self.memory.get_stats()['count'], 0) self.assertEqual(self.memory.get_stats()["count"], 0)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()