Files
Auto-GPT/scripts/providers/weaviate.py

100 lines
2.8 KiB
Python

from config import Config
from providers.memory import Memory, get_ada_embedding
from weaviate import Client
import weaviate
import uuid
from weaviate.util import generate_uuid5
cfg = Config()
SCHEMA = {
"class": cfg.weaviate_index,
"properties": [
{
"name": "raw_text",
"dataType": ["text"],
"description": "original text for the embedding"
}
],
}
class WeaviateMemory(Memory):
def __init__(self):
auth_credentials = self._build_auth_credentials()
url = f'{cfg.weaviate_host}:{cfg.weaviate_port}'
self.client = Client(url, auth_client_secret=auth_credentials)
self._create_schema()
def _create_schema(self):
if not self.client.schema.contains(SCHEMA):
self.client.schema.create_class(SCHEMA)
@staticmethod
def _build_auth_credentials():
if cfg.weaviate_username and cfg.weaviate_password:
return weaviate_auth.AuthClientPassword(cfg.weaviate_username, cfg.weaviate_password)
else:
return None
def add(self, data):
vector = get_ada_embedding(data)
doc_uuid = generate_uuid5(data, cfg.weaviate_index)
data_object = {
'class': cfg.weaviate_index,
'raw_text': data
}
with self.client.batch as batch:
batch.add_data_object(
uuid=doc_uuid,
data_object=data_object,
class_name=cfg.weaviate_index,
vector=vector
)
batch.flush()
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}"
def get(self, data):
return self.get_relevant(data, 1)
def clear(self):
self.client.schema.delete_all()
# weaviate does not yet have a neat way to just remove the items in an index
# without removing the entire schema, therefore we need to re-create it
# after a call to delete_all
self._create_schema()
return 'Obliterated'
def get_relevant(self, data, num_relevant=5):
query_embedding = get_ada_embedding(data)
try:
results = self.client.query.get(cfg.weaviate_index, ['raw_text']) \
.with_near_vector({'vector': query_embedding, 'certainty': 0.7}) \
.with_limit(num_relevant) \
.do()
print(results)
if len(results['data']['Get'][cfg.weaviate_index]) > 0:
return [str(item['raw_text']) for item in results['data']['Get'][cfg.weaviate_index]]
else:
return []
except Exception as err:
print(f'Unexpected error {err=}, {type(err)=}')
return []
def get_stats(self):
return self.client.index_stats.get(cfg.weaviate_index)