mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-18 14:34:23 +01:00
110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
from autogpt.config import Config
|
|
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
|
|
import uuid
|
|
import weaviate
|
|
from weaviate import Client
|
|
from weaviate.embedded import EmbeddedOptions
|
|
from weaviate.util import generate_uuid5
|
|
|
|
def default_schema(weaviate_index):
|
|
return {
|
|
"class": weaviate_index,
|
|
"properties": [
|
|
{
|
|
"name": "raw_text",
|
|
"dataType": ["text"],
|
|
"description": "original text for the embedding"
|
|
}
|
|
],
|
|
}
|
|
|
|
class WeaviateMemory(MemoryProviderSingleton):
|
|
def __init__(self, cfg):
|
|
auth_credentials = self._build_auth_credentials(cfg)
|
|
|
|
url = f'{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}'
|
|
|
|
if cfg.use_weaviate_embedded:
|
|
self.client = Client(embedded_options=EmbeddedOptions(
|
|
hostname=cfg.weaviate_host,
|
|
port=int(cfg.weaviate_port),
|
|
persistence_data_path=cfg.weaviate_embedded_path
|
|
))
|
|
|
|
print(f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}")
|
|
else:
|
|
self.client = Client(url, auth_client_secret=auth_credentials)
|
|
|
|
self.index = cfg.memory_index
|
|
self._create_schema()
|
|
|
|
def _create_schema(self):
|
|
schema = default_schema(self.index)
|
|
if not self.client.schema.contains(schema):
|
|
self.client.schema.create_class(schema)
|
|
|
|
def _build_auth_credentials(self, cfg):
|
|
if cfg.weaviate_username and cfg.weaviate_password:
|
|
return weaviate_auth.AuthClientPassword(cfg.weaviate_username, cfg.weaviate_password)
|
|
if cfg.weaviate_api_key:
|
|
return weaviate.auth.AuthApiKey(api_key=cfg.weaviate_api_key)
|
|
else:
|
|
return None
|
|
|
|
def add(self, data):
|
|
vector = get_ada_embedding(data)
|
|
|
|
doc_uuid = generate_uuid5(data, self.index)
|
|
data_object = {
|
|
'raw_text': data
|
|
}
|
|
|
|
with self.client.batch as batch:
|
|
batch.add_data_object(
|
|
uuid=doc_uuid,
|
|
data_object=data_object,
|
|
class_name=self.index,
|
|
vector=vector
|
|
)
|
|
|
|
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}"
|
|
|
|
def get(self, data):
|
|
return self.get_relevant(data, 1)
|
|
|
|
|
|
def clear(self):
|
|
self.client.schema.delete_all()
|
|
|
|
# weaviate does not yet have a neat way to just remove the items in an index
|
|
# without removing the entire schema, therefore we need to re-create it
|
|
# after a call to delete_all
|
|
self._create_schema()
|
|
|
|
return 'Obliterated'
|
|
|
|
def get_relevant(self, data, num_relevant=5):
|
|
query_embedding = get_ada_embedding(data)
|
|
try:
|
|
results = self.client.query.get(self.index, ['raw_text']) \
|
|
.with_near_vector({'vector': query_embedding, 'certainty': 0.7}) \
|
|
.with_limit(num_relevant) \
|
|
.do()
|
|
|
|
if len(results['data']['Get'][self.index]) > 0:
|
|
return [str(item['raw_text']) for item in results['data']['Get'][self.index]]
|
|
else:
|
|
return []
|
|
|
|
except Exception as err:
|
|
print(f'Unexpected error {err=}, {type(err)=}')
|
|
return []
|
|
|
|
def get_stats(self):
|
|
result = self.client.query.aggregate(self.index) \
|
|
.with_meta_count() \
|
|
.do()
|
|
class_data = result['data']['Aggregate'][self.index]
|
|
|
|
return class_data[0]['meta'] if class_data else {}
|