mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-18 06:24:20 +01:00
Blacked
This commit is contained in:
@@ -48,7 +48,7 @@ class AgentManager(metaclass=Singleton):
|
||||
for i, plugin in enumerate(self.cfg.plugins):
|
||||
plugin_result = plugin.on_instruction(messages)
|
||||
if plugin_result:
|
||||
sep = '' if not i else '\n'
|
||||
sep = "" if not i else "\n"
|
||||
plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
|
||||
|
||||
if plugins_reply and plugins_reply != "":
|
||||
@@ -98,7 +98,7 @@ class AgentManager(metaclass=Singleton):
|
||||
for i, plugin in enumerate(self.cfg.plugins):
|
||||
plugin_result = plugin.on_instruction(messages)
|
||||
if plugin_result:
|
||||
sep = '' if not i else '\n'
|
||||
sep = "" if not i else "\n"
|
||||
plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
|
||||
# Update full message history
|
||||
if plugins_reply and plugins_reply != "":
|
||||
|
||||
@@ -132,9 +132,12 @@ def execute_command(command_name: str, arguments, prompt: PromptGenerator):
|
||||
|
||||
# google_result can be a list or a string depending on the search results
|
||||
if isinstance(google_result, list):
|
||||
safe_message = [google_result_single.encode('utf-8', 'ignore') for google_result_single in google_result]
|
||||
safe_message = [
|
||||
google_result_single.encode("utf-8", "ignore")
|
||||
for google_result_single in google_result
|
||||
]
|
||||
else:
|
||||
safe_message = google_result.encode('utf-8', 'ignore')
|
||||
safe_message = google_result.encode("utf-8", "ignore")
|
||||
|
||||
return str(safe_message)
|
||||
elif command_name == "memory_add":
|
||||
|
||||
@@ -40,7 +40,7 @@ def log_operation(operation: str, filename: str) -> None:
|
||||
with open(LOG_FILE_PATH, "w", encoding="utf-8") as f:
|
||||
f.write("File Operation Logger ")
|
||||
|
||||
append_to_file(LOG_FILE, log_entry, shouldLog = False)
|
||||
append_to_file(LOG_FILE, log_entry, shouldLog=False)
|
||||
|
||||
|
||||
def split_file(
|
||||
|
||||
@@ -74,7 +74,9 @@ class Config(metaclass=Singleton):
|
||||
self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None)
|
||||
self.weaviate_embedded_path = os.getenv("WEAVIATE_EMBEDDED_PATH")
|
||||
self.weaviate_api_key = os.getenv("WEAVIATE_API_KEY", None)
|
||||
self.use_weaviate_embedded = os.getenv("USE_WEAVIATE_EMBEDDED", "False") == "True"
|
||||
self.use_weaviate_embedded = (
|
||||
os.getenv("USE_WEAVIATE_EMBEDDED", "False") == "True"
|
||||
)
|
||||
|
||||
# milvus configuration, e.g., localhost:19530.
|
||||
self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530")
|
||||
|
||||
@@ -56,8 +56,10 @@ def get_memory(cfg, init=False):
|
||||
memory = RedisMemory(cfg)
|
||||
elif cfg.memory_backend == "weaviate":
|
||||
if not WeaviateMemory:
|
||||
print("Error: Weaviate is not installed. Please install weaviate-client to"
|
||||
" use Weaviate as a memory backend.")
|
||||
print(
|
||||
"Error: Weaviate is not installed. Please install weaviate-client to"
|
||||
" use Weaviate as a memory backend."
|
||||
)
|
||||
else:
|
||||
memory = WeaviateMemory(cfg)
|
||||
elif cfg.memory_backend == "milvus":
|
||||
@@ -89,5 +91,5 @@ __all__ = [
|
||||
"PineconeMemory",
|
||||
"NoMemory",
|
||||
"MilvusMemory",
|
||||
"WeaviateMemory"
|
||||
"WeaviateMemory",
|
||||
]
|
||||
|
||||
@@ -53,7 +53,7 @@ class NoMemory(MemoryProviderSingleton):
|
||||
"""
|
||||
return ""
|
||||
|
||||
def get_relevant(self, data: str, num_relevant: int = 5) ->list[Any] | None:
|
||||
def get_relevant(self, data: str, num_relevant: int = 5) -> list[Any] | None:
|
||||
"""
|
||||
Returns all the data in the memory that is relevant to the given data.
|
||||
NoMemory always returns None.
|
||||
|
||||
@@ -14,7 +14,7 @@ def default_schema(weaviate_index):
|
||||
{
|
||||
"name": "raw_text",
|
||||
"dataType": ["text"],
|
||||
"description": "original text for the embedding"
|
||||
"description": "original text for the embedding",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -24,16 +24,20 @@ class WeaviateMemory(MemoryProviderSingleton):
|
||||
def __init__(self, cfg):
|
||||
auth_credentials = self._build_auth_credentials(cfg)
|
||||
|
||||
url = f'{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}'
|
||||
url = f"{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}"
|
||||
|
||||
if cfg.use_weaviate_embedded:
|
||||
self.client = Client(embedded_options=EmbeddedOptions(
|
||||
self.client = Client(
|
||||
embedded_options=EmbeddedOptions(
|
||||
hostname=cfg.weaviate_host,
|
||||
port=int(cfg.weaviate_port),
|
||||
persistence_data_path=cfg.weaviate_embedded_path
|
||||
))
|
||||
persistence_data_path=cfg.weaviate_embedded_path,
|
||||
)
|
||||
)
|
||||
|
||||
print(f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}")
|
||||
print(
|
||||
f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}"
|
||||
)
|
||||
else:
|
||||
self.client = Client(url, auth_client_secret=auth_credentials)
|
||||
|
||||
@@ -47,7 +51,9 @@ class WeaviateMemory(MemoryProviderSingleton):
|
||||
|
||||
def _build_auth_credentials(self, cfg):
|
||||
if cfg.weaviate_username and cfg.weaviate_password:
|
||||
return weaviate.AuthClientPassword(cfg.weaviate_username, cfg.weaviate_password)
|
||||
return weaviate.AuthClientPassword(
|
||||
cfg.weaviate_username, cfg.weaviate_password
|
||||
)
|
||||
if cfg.weaviate_api_key:
|
||||
return weaviate.AuthApiKey(api_key=cfg.weaviate_api_key)
|
||||
else:
|
||||
@@ -57,16 +63,14 @@ class WeaviateMemory(MemoryProviderSingleton):
|
||||
vector = get_ada_embedding(data)
|
||||
|
||||
doc_uuid = generate_uuid5(data, self.index)
|
||||
data_object = {
|
||||
'raw_text': data
|
||||
}
|
||||
data_object = {"raw_text": data}
|
||||
|
||||
with self.client.batch as batch:
|
||||
batch.add_data_object(
|
||||
uuid=doc_uuid,
|
||||
data_object=data_object,
|
||||
class_name=self.index,
|
||||
vector=vector
|
||||
vector=vector,
|
||||
)
|
||||
|
||||
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}"
|
||||
@@ -82,29 +86,31 @@ class WeaviateMemory(MemoryProviderSingleton):
|
||||
# after a call to delete_all
|
||||
self._create_schema()
|
||||
|
||||
return 'Obliterated'
|
||||
return "Obliterated"
|
||||
|
||||
def get_relevant(self, data, num_relevant=5):
|
||||
query_embedding = get_ada_embedding(data)
|
||||
try:
|
||||
results = self.client.query.get(self.index, ['raw_text']) \
|
||||
.with_near_vector({'vector': query_embedding, 'certainty': 0.7}) \
|
||||
.with_limit(num_relevant) \
|
||||
results = (
|
||||
self.client.query.get(self.index, ["raw_text"])
|
||||
.with_near_vector({"vector": query_embedding, "certainty": 0.7})
|
||||
.with_limit(num_relevant)
|
||||
.do()
|
||||
)
|
||||
|
||||
if len(results['data']['Get'][self.index]) > 0:
|
||||
return [str(item['raw_text']) for item in results['data']['Get'][self.index]]
|
||||
if len(results["data"]["Get"][self.index]) > 0:
|
||||
return [
|
||||
str(item["raw_text"]) for item in results["data"]["Get"][self.index]
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
except Exception as err:
|
||||
print(f'Unexpected error {err=}, {type(err)=}')
|
||||
print(f"Unexpected error {err=}, {type(err)=}")
|
||||
return []
|
||||
|
||||
def get_stats(self):
|
||||
result = self.client.query.aggregate(self.index) \
|
||||
.with_meta_count() \
|
||||
.do()
|
||||
class_data = result['data']['Aggregate'][self.index]
|
||||
result = self.client.query.aggregate(self.index).with_meta_count().do()
|
||||
class_data = result["data"]["Aggregate"][self.index]
|
||||
|
||||
return class_data[0]['meta'] if class_data else {}
|
||||
return class_data[0]["meta"] if class_data else {}
|
||||
|
||||
@@ -84,11 +84,7 @@ def build_default_prompt_generator() -> PromptGenerator:
|
||||
# Only add the audio to text command if the model is specified
|
||||
if cfg.huggingface_audio_to_text_model:
|
||||
commands.append(
|
||||
(
|
||||
"Convert Audio to text",
|
||||
"read_audio_from_file",
|
||||
{"file": "<file>"}
|
||||
),
|
||||
("Convert Audio to text", "read_audio_from_file", {"file": "<file>"}),
|
||||
)
|
||||
|
||||
# Only add shell command to the prompt if the AI is allowed to execute it
|
||||
|
||||
@@ -36,6 +36,8 @@ def safe_path_join(base: Path, *paths: str | Path) -> Path:
|
||||
joined_path = base.joinpath(*paths).resolve()
|
||||
|
||||
if not joined_path.is_relative_to(base):
|
||||
raise ValueError(f"Attempted to access path '{joined_path}' outside of working directory '{base}'.")
|
||||
raise ValueError(
|
||||
f"Attempted to access path '{joined_path}' outside of working directory '{base}'."
|
||||
)
|
||||
|
||||
return joined_path
|
||||
|
||||
@@ -12,14 +12,17 @@ from autogpt.memory.weaviate import WeaviateMemory
|
||||
from autogpt.memory.base import get_ada_embedding
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WEAVIATE_HOST": "127.0.0.1",
|
||||
"WEAVIATE_PROTOCOL": "http",
|
||||
"WEAVIATE_PORT": "8080",
|
||||
"WEAVIATE_USERNAME": "",
|
||||
"WEAVIATE_PASSWORD": "",
|
||||
"MEMORY_INDEX": "AutogptTests"
|
||||
})
|
||||
"MEMORY_INDEX": "AutogptTests",
|
||||
},
|
||||
)
|
||||
class TestWeaviateMemory(unittest.TestCase):
|
||||
cfg = None
|
||||
client = None
|
||||
@@ -32,13 +35,17 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
if cls.cfg.use_weaviate_embedded:
|
||||
from weaviate.embedded import EmbeddedOptions
|
||||
|
||||
cls.client = Client(embedded_options=EmbeddedOptions(
|
||||
cls.client = Client(
|
||||
embedded_options=EmbeddedOptions(
|
||||
hostname=cls.cfg.weaviate_host,
|
||||
port=int(cls.cfg.weaviate_port),
|
||||
persistence_data_path=cls.cfg.weaviate_embedded_path
|
||||
))
|
||||
persistence_data_path=cls.cfg.weaviate_embedded_path,
|
||||
)
|
||||
)
|
||||
else:
|
||||
cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}")
|
||||
cls.client = Client(
|
||||
f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}"
|
||||
)
|
||||
|
||||
"""
|
||||
In order to run these tests you will need a local instance of
|
||||
@@ -49,6 +56,7 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
USE_WEAVIATE_EMBEDDED=True
|
||||
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
try:
|
||||
self.client.schema.delete_class(self.cfg.memory_index)
|
||||
@@ -58,23 +66,23 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
self.memory = WeaviateMemory(self.cfg)
|
||||
|
||||
def test_add(self):
|
||||
doc = 'You are a Titan name Thanos and you are looking for the Infinity Stones'
|
||||
doc = "You are a Titan name Thanos and you are looking for the Infinity Stones"
|
||||
self.memory.add(doc)
|
||||
result = self.client.query.get(self.cfg.memory_index, ['raw_text']).do()
|
||||
actual = result['data']['Get'][self.cfg.memory_index]
|
||||
result = self.client.query.get(self.cfg.memory_index, ["raw_text"]).do()
|
||||
actual = result["data"]["Get"][self.cfg.memory_index]
|
||||
|
||||
self.assertEqual(len(actual), 1)
|
||||
self.assertEqual(actual[0]['raw_text'], doc)
|
||||
self.assertEqual(actual[0]["raw_text"], doc)
|
||||
|
||||
def test_get(self):
|
||||
doc = 'You are an Avenger and swore to defend the Galaxy from a menace called Thanos'
|
||||
doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos"
|
||||
|
||||
with self.client.batch as batch:
|
||||
batch.add_data_object(
|
||||
uuid=get_valid_uuid(uuid4()),
|
||||
data_object={'raw_text': doc},
|
||||
data_object={"raw_text": doc},
|
||||
class_name=self.cfg.memory_index,
|
||||
vector=get_ada_embedding(doc)
|
||||
vector=get_ada_embedding(doc),
|
||||
)
|
||||
|
||||
batch.flush()
|
||||
@@ -86,8 +94,8 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
|
||||
def test_get_stats(self):
|
||||
docs = [
|
||||
'You are now about to count the number of docs in this index',
|
||||
'And then you about to find out if you can count correctly'
|
||||
"You are now about to count the number of docs in this index",
|
||||
"And then you about to find out if you can count correctly",
|
||||
]
|
||||
|
||||
[self.memory.add(doc) for doc in docs]
|
||||
@@ -95,23 +103,23 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
stats = self.memory.get_stats()
|
||||
|
||||
self.assertTrue(stats)
|
||||
self.assertTrue('count' in stats)
|
||||
self.assertEqual(stats['count'], 2)
|
||||
self.assertTrue("count" in stats)
|
||||
self.assertEqual(stats["count"], 2)
|
||||
|
||||
def test_clear(self):
|
||||
docs = [
|
||||
'Shame this is the last test for this class',
|
||||
'Testing is fun when someone else is doing it'
|
||||
"Shame this is the last test for this class",
|
||||
"Testing is fun when someone else is doing it",
|
||||
]
|
||||
|
||||
[self.memory.add(doc) for doc in docs]
|
||||
|
||||
self.assertEqual(self.memory.get_stats()['count'], 2)
|
||||
self.assertEqual(self.memory.get_stats()["count"], 2)
|
||||
|
||||
self.memory.clear()
|
||||
|
||||
self.assertEqual(self.memory.get_stats()['count'], 0)
|
||||
self.assertEqual(self.memory.get_stats()["count"], 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user