mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-18 14:34:23 +01:00
Blacked
This commit is contained in:
@@ -48,7 +48,7 @@ class AgentManager(metaclass=Singleton):
|
|||||||
for i, plugin in enumerate(self.cfg.plugins):
|
for i, plugin in enumerate(self.cfg.plugins):
|
||||||
plugin_result = plugin.on_instruction(messages)
|
plugin_result = plugin.on_instruction(messages)
|
||||||
if plugin_result:
|
if plugin_result:
|
||||||
sep = '' if not i else '\n'
|
sep = "" if not i else "\n"
|
||||||
plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
|
plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
|
||||||
|
|
||||||
if plugins_reply and plugins_reply != "":
|
if plugins_reply and plugins_reply != "":
|
||||||
@@ -98,7 +98,7 @@ class AgentManager(metaclass=Singleton):
|
|||||||
for i, plugin in enumerate(self.cfg.plugins):
|
for i, plugin in enumerate(self.cfg.plugins):
|
||||||
plugin_result = plugin.on_instruction(messages)
|
plugin_result = plugin.on_instruction(messages)
|
||||||
if plugin_result:
|
if plugin_result:
|
||||||
sep = '' if not i else '\n'
|
sep = "" if not i else "\n"
|
||||||
plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
|
plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
|
||||||
# Update full message history
|
# Update full message history
|
||||||
if plugins_reply and plugins_reply != "":
|
if plugins_reply and plugins_reply != "":
|
||||||
|
|||||||
@@ -132,9 +132,12 @@ def execute_command(command_name: str, arguments, prompt: PromptGenerator):
|
|||||||
|
|
||||||
# google_result can be a list or a string depending on the search results
|
# google_result can be a list or a string depending on the search results
|
||||||
if isinstance(google_result, list):
|
if isinstance(google_result, list):
|
||||||
safe_message = [google_result_single.encode('utf-8', 'ignore') for google_result_single in google_result]
|
safe_message = [
|
||||||
|
google_result_single.encode("utf-8", "ignore")
|
||||||
|
for google_result_single in google_result
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
safe_message = google_result.encode('utf-8', 'ignore')
|
safe_message = google_result.encode("utf-8", "ignore")
|
||||||
|
|
||||||
return str(safe_message)
|
return str(safe_message)
|
||||||
elif command_name == "memory_add":
|
elif command_name == "memory_add":
|
||||||
|
|||||||
@@ -74,7 +74,9 @@ class Config(metaclass=Singleton):
|
|||||||
self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None)
|
self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None)
|
||||||
self.weaviate_embedded_path = os.getenv("WEAVIATE_EMBEDDED_PATH")
|
self.weaviate_embedded_path = os.getenv("WEAVIATE_EMBEDDED_PATH")
|
||||||
self.weaviate_api_key = os.getenv("WEAVIATE_API_KEY", None)
|
self.weaviate_api_key = os.getenv("WEAVIATE_API_KEY", None)
|
||||||
self.use_weaviate_embedded = os.getenv("USE_WEAVIATE_EMBEDDED", "False") == "True"
|
self.use_weaviate_embedded = (
|
||||||
|
os.getenv("USE_WEAVIATE_EMBEDDED", "False") == "True"
|
||||||
|
)
|
||||||
|
|
||||||
# milvus configuration, e.g., localhost:19530.
|
# milvus configuration, e.g., localhost:19530.
|
||||||
self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530")
|
self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530")
|
||||||
|
|||||||
@@ -56,8 +56,10 @@ def get_memory(cfg, init=False):
|
|||||||
memory = RedisMemory(cfg)
|
memory = RedisMemory(cfg)
|
||||||
elif cfg.memory_backend == "weaviate":
|
elif cfg.memory_backend == "weaviate":
|
||||||
if not WeaviateMemory:
|
if not WeaviateMemory:
|
||||||
print("Error: Weaviate is not installed. Please install weaviate-client to"
|
print(
|
||||||
" use Weaviate as a memory backend.")
|
"Error: Weaviate is not installed. Please install weaviate-client to"
|
||||||
|
" use Weaviate as a memory backend."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
memory = WeaviateMemory(cfg)
|
memory = WeaviateMemory(cfg)
|
||||||
elif cfg.memory_backend == "milvus":
|
elif cfg.memory_backend == "milvus":
|
||||||
@@ -89,5 +91,5 @@ __all__ = [
|
|||||||
"PineconeMemory",
|
"PineconeMemory",
|
||||||
"NoMemory",
|
"NoMemory",
|
||||||
"MilvusMemory",
|
"MilvusMemory",
|
||||||
"WeaviateMemory"
|
"WeaviateMemory",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ def default_schema(weaviate_index):
|
|||||||
{
|
{
|
||||||
"name": "raw_text",
|
"name": "raw_text",
|
||||||
"dataType": ["text"],
|
"dataType": ["text"],
|
||||||
"description": "original text for the embedding"
|
"description": "original text for the embedding",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -24,16 +24,20 @@ class WeaviateMemory(MemoryProviderSingleton):
|
|||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
auth_credentials = self._build_auth_credentials(cfg)
|
auth_credentials = self._build_auth_credentials(cfg)
|
||||||
|
|
||||||
url = f'{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}'
|
url = f"{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}"
|
||||||
|
|
||||||
if cfg.use_weaviate_embedded:
|
if cfg.use_weaviate_embedded:
|
||||||
self.client = Client(embedded_options=EmbeddedOptions(
|
self.client = Client(
|
||||||
|
embedded_options=EmbeddedOptions(
|
||||||
hostname=cfg.weaviate_host,
|
hostname=cfg.weaviate_host,
|
||||||
port=int(cfg.weaviate_port),
|
port=int(cfg.weaviate_port),
|
||||||
persistence_data_path=cfg.weaviate_embedded_path
|
persistence_data_path=cfg.weaviate_embedded_path,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}")
|
print(
|
||||||
|
f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.client = Client(url, auth_client_secret=auth_credentials)
|
self.client = Client(url, auth_client_secret=auth_credentials)
|
||||||
|
|
||||||
@@ -47,7 +51,9 @@ class WeaviateMemory(MemoryProviderSingleton):
|
|||||||
|
|
||||||
def _build_auth_credentials(self, cfg):
|
def _build_auth_credentials(self, cfg):
|
||||||
if cfg.weaviate_username and cfg.weaviate_password:
|
if cfg.weaviate_username and cfg.weaviate_password:
|
||||||
return weaviate.AuthClientPassword(cfg.weaviate_username, cfg.weaviate_password)
|
return weaviate.AuthClientPassword(
|
||||||
|
cfg.weaviate_username, cfg.weaviate_password
|
||||||
|
)
|
||||||
if cfg.weaviate_api_key:
|
if cfg.weaviate_api_key:
|
||||||
return weaviate.AuthApiKey(api_key=cfg.weaviate_api_key)
|
return weaviate.AuthApiKey(api_key=cfg.weaviate_api_key)
|
||||||
else:
|
else:
|
||||||
@@ -57,16 +63,14 @@ class WeaviateMemory(MemoryProviderSingleton):
|
|||||||
vector = get_ada_embedding(data)
|
vector = get_ada_embedding(data)
|
||||||
|
|
||||||
doc_uuid = generate_uuid5(data, self.index)
|
doc_uuid = generate_uuid5(data, self.index)
|
||||||
data_object = {
|
data_object = {"raw_text": data}
|
||||||
'raw_text': data
|
|
||||||
}
|
|
||||||
|
|
||||||
with self.client.batch as batch:
|
with self.client.batch as batch:
|
||||||
batch.add_data_object(
|
batch.add_data_object(
|
||||||
uuid=doc_uuid,
|
uuid=doc_uuid,
|
||||||
data_object=data_object,
|
data_object=data_object,
|
||||||
class_name=self.index,
|
class_name=self.index,
|
||||||
vector=vector
|
vector=vector,
|
||||||
)
|
)
|
||||||
|
|
||||||
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}"
|
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}"
|
||||||
@@ -82,29 +86,31 @@ class WeaviateMemory(MemoryProviderSingleton):
|
|||||||
# after a call to delete_all
|
# after a call to delete_all
|
||||||
self._create_schema()
|
self._create_schema()
|
||||||
|
|
||||||
return 'Obliterated'
|
return "Obliterated"
|
||||||
|
|
||||||
def get_relevant(self, data, num_relevant=5):
|
def get_relevant(self, data, num_relevant=5):
|
||||||
query_embedding = get_ada_embedding(data)
|
query_embedding = get_ada_embedding(data)
|
||||||
try:
|
try:
|
||||||
results = self.client.query.get(self.index, ['raw_text']) \
|
results = (
|
||||||
.with_near_vector({'vector': query_embedding, 'certainty': 0.7}) \
|
self.client.query.get(self.index, ["raw_text"])
|
||||||
.with_limit(num_relevant) \
|
.with_near_vector({"vector": query_embedding, "certainty": 0.7})
|
||||||
|
.with_limit(num_relevant)
|
||||||
.do()
|
.do()
|
||||||
|
)
|
||||||
|
|
||||||
if len(results['data']['Get'][self.index]) > 0:
|
if len(results["data"]["Get"][self.index]) > 0:
|
||||||
return [str(item['raw_text']) for item in results['data']['Get'][self.index]]
|
return [
|
||||||
|
str(item["raw_text"]) for item in results["data"]["Get"][self.index]
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
print(f'Unexpected error {err=}, {type(err)=}')
|
print(f"Unexpected error {err=}, {type(err)=}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_stats(self):
|
def get_stats(self):
|
||||||
result = self.client.query.aggregate(self.index) \
|
result = self.client.query.aggregate(self.index).with_meta_count().do()
|
||||||
.with_meta_count() \
|
class_data = result["data"]["Aggregate"][self.index]
|
||||||
.do()
|
|
||||||
class_data = result['data']['Aggregate'][self.index]
|
|
||||||
|
|
||||||
return class_data[0]['meta'] if class_data else {}
|
return class_data[0]["meta"] if class_data else {}
|
||||||
|
|||||||
@@ -84,11 +84,7 @@ def build_default_prompt_generator() -> PromptGenerator:
|
|||||||
# Only add the audio to text command if the model is specified
|
# Only add the audio to text command if the model is specified
|
||||||
if cfg.huggingface_audio_to_text_model:
|
if cfg.huggingface_audio_to_text_model:
|
||||||
commands.append(
|
commands.append(
|
||||||
(
|
("Convert Audio to text", "read_audio_from_file", {"file": "<file>"}),
|
||||||
"Convert Audio to text",
|
|
||||||
"read_audio_from_file",
|
|
||||||
{"file": "<file>"}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only add shell command to the prompt if the AI is allowed to execute it
|
# Only add shell command to the prompt if the AI is allowed to execute it
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ def safe_path_join(base: Path, *paths: str | Path) -> Path:
|
|||||||
joined_path = base.joinpath(*paths).resolve()
|
joined_path = base.joinpath(*paths).resolve()
|
||||||
|
|
||||||
if not joined_path.is_relative_to(base):
|
if not joined_path.is_relative_to(base):
|
||||||
raise ValueError(f"Attempted to access path '{joined_path}' outside of working directory '{base}'.")
|
raise ValueError(
|
||||||
|
f"Attempted to access path '{joined_path}' outside of working directory '{base}'."
|
||||||
|
)
|
||||||
|
|
||||||
return joined_path
|
return joined_path
|
||||||
|
|||||||
@@ -12,14 +12,17 @@ from autogpt.memory.weaviate import WeaviateMemory
|
|||||||
from autogpt.memory.base import get_ada_embedding
|
from autogpt.memory.base import get_ada_embedding
|
||||||
|
|
||||||
|
|
||||||
@mock.patch.dict(os.environ, {
|
@mock.patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
"WEAVIATE_HOST": "127.0.0.1",
|
"WEAVIATE_HOST": "127.0.0.1",
|
||||||
"WEAVIATE_PROTOCOL": "http",
|
"WEAVIATE_PROTOCOL": "http",
|
||||||
"WEAVIATE_PORT": "8080",
|
"WEAVIATE_PORT": "8080",
|
||||||
"WEAVIATE_USERNAME": "",
|
"WEAVIATE_USERNAME": "",
|
||||||
"WEAVIATE_PASSWORD": "",
|
"WEAVIATE_PASSWORD": "",
|
||||||
"MEMORY_INDEX": "AutogptTests"
|
"MEMORY_INDEX": "AutogptTests",
|
||||||
})
|
},
|
||||||
|
)
|
||||||
class TestWeaviateMemory(unittest.TestCase):
|
class TestWeaviateMemory(unittest.TestCase):
|
||||||
cfg = None
|
cfg = None
|
||||||
client = None
|
client = None
|
||||||
@@ -32,13 +35,17 @@ class TestWeaviateMemory(unittest.TestCase):
|
|||||||
if cls.cfg.use_weaviate_embedded:
|
if cls.cfg.use_weaviate_embedded:
|
||||||
from weaviate.embedded import EmbeddedOptions
|
from weaviate.embedded import EmbeddedOptions
|
||||||
|
|
||||||
cls.client = Client(embedded_options=EmbeddedOptions(
|
cls.client = Client(
|
||||||
|
embedded_options=EmbeddedOptions(
|
||||||
hostname=cls.cfg.weaviate_host,
|
hostname=cls.cfg.weaviate_host,
|
||||||
port=int(cls.cfg.weaviate_port),
|
port=int(cls.cfg.weaviate_port),
|
||||||
persistence_data_path=cls.cfg.weaviate_embedded_path
|
persistence_data_path=cls.cfg.weaviate_embedded_path,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}")
|
cls.client = Client(
|
||||||
|
f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}"
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
In order to run these tests you will need a local instance of
|
In order to run these tests you will need a local instance of
|
||||||
@@ -49,6 +56,7 @@ class TestWeaviateMemory(unittest.TestCase):
|
|||||||
USE_WEAVIATE_EMBEDDED=True
|
USE_WEAVIATE_EMBEDDED=True
|
||||||
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
|
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
try:
|
try:
|
||||||
self.client.schema.delete_class(self.cfg.memory_index)
|
self.client.schema.delete_class(self.cfg.memory_index)
|
||||||
@@ -58,23 +66,23 @@ class TestWeaviateMemory(unittest.TestCase):
|
|||||||
self.memory = WeaviateMemory(self.cfg)
|
self.memory = WeaviateMemory(self.cfg)
|
||||||
|
|
||||||
def test_add(self):
|
def test_add(self):
|
||||||
doc = 'You are a Titan name Thanos and you are looking for the Infinity Stones'
|
doc = "You are a Titan name Thanos and you are looking for the Infinity Stones"
|
||||||
self.memory.add(doc)
|
self.memory.add(doc)
|
||||||
result = self.client.query.get(self.cfg.memory_index, ['raw_text']).do()
|
result = self.client.query.get(self.cfg.memory_index, ["raw_text"]).do()
|
||||||
actual = result['data']['Get'][self.cfg.memory_index]
|
actual = result["data"]["Get"][self.cfg.memory_index]
|
||||||
|
|
||||||
self.assertEqual(len(actual), 1)
|
self.assertEqual(len(actual), 1)
|
||||||
self.assertEqual(actual[0]['raw_text'], doc)
|
self.assertEqual(actual[0]["raw_text"], doc)
|
||||||
|
|
||||||
def test_get(self):
|
def test_get(self):
|
||||||
doc = 'You are an Avenger and swore to defend the Galaxy from a menace called Thanos'
|
doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos"
|
||||||
|
|
||||||
with self.client.batch as batch:
|
with self.client.batch as batch:
|
||||||
batch.add_data_object(
|
batch.add_data_object(
|
||||||
uuid=get_valid_uuid(uuid4()),
|
uuid=get_valid_uuid(uuid4()),
|
||||||
data_object={'raw_text': doc},
|
data_object={"raw_text": doc},
|
||||||
class_name=self.cfg.memory_index,
|
class_name=self.cfg.memory_index,
|
||||||
vector=get_ada_embedding(doc)
|
vector=get_ada_embedding(doc),
|
||||||
)
|
)
|
||||||
|
|
||||||
batch.flush()
|
batch.flush()
|
||||||
@@ -86,8 +94,8 @@ class TestWeaviateMemory(unittest.TestCase):
|
|||||||
|
|
||||||
def test_get_stats(self):
|
def test_get_stats(self):
|
||||||
docs = [
|
docs = [
|
||||||
'You are now about to count the number of docs in this index',
|
"You are now about to count the number of docs in this index",
|
||||||
'And then you about to find out if you can count correctly'
|
"And then you about to find out if you can count correctly",
|
||||||
]
|
]
|
||||||
|
|
||||||
[self.memory.add(doc) for doc in docs]
|
[self.memory.add(doc) for doc in docs]
|
||||||
@@ -95,23 +103,23 @@ class TestWeaviateMemory(unittest.TestCase):
|
|||||||
stats = self.memory.get_stats()
|
stats = self.memory.get_stats()
|
||||||
|
|
||||||
self.assertTrue(stats)
|
self.assertTrue(stats)
|
||||||
self.assertTrue('count' in stats)
|
self.assertTrue("count" in stats)
|
||||||
self.assertEqual(stats['count'], 2)
|
self.assertEqual(stats["count"], 2)
|
||||||
|
|
||||||
def test_clear(self):
|
def test_clear(self):
|
||||||
docs = [
|
docs = [
|
||||||
'Shame this is the last test for this class',
|
"Shame this is the last test for this class",
|
||||||
'Testing is fun when someone else is doing it'
|
"Testing is fun when someone else is doing it",
|
||||||
]
|
]
|
||||||
|
|
||||||
[self.memory.add(doc) for doc in docs]
|
[self.memory.add(doc) for doc in docs]
|
||||||
|
|
||||||
self.assertEqual(self.memory.get_stats()['count'], 2)
|
self.assertEqual(self.memory.get_stats()["count"], 2)
|
||||||
|
|
||||||
self.memory.clear()
|
self.memory.clear()
|
||||||
|
|
||||||
self.assertEqual(self.memory.get_stats()['count'], 0)
|
self.assertEqual(self.memory.get_stats()["count"], 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user