Merge pull request #1 from cs0lar/feature/weaviate-embedded

Feature/weaviate embedded
This commit is contained in:
cs0lar
2023-04-12 08:24:40 +01:00
committed by GitHub
5 changed files with 40 additions and 10 deletions

View File

@@ -10,8 +10,11 @@ USE_AZURE=False
OPENAI_API_BASE=your-base-url-for-azure
OPENAI_API_VERSION=api-version-for-azure
OPENAI_DEPLOYMENT_ID=deployment-id-for-azure
WEAVIATE_HOST="http://127.0.0.1"
WEAVIATE_PORT="8080"
WEAVIATE_HOST="127.0.0.1"
WEAVIATE_PORT=8080
WEAVIATE_PROTOCOL="http"
USE_WEAVIATE_EMBEDDED=False
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
WEAVIATE_USERNAME=
WEAVIATE_PASSWORD=
IMAGE_PROVIDER=dalle

View File

@@ -15,4 +15,4 @@ pinecone-client==2.2.1
redis
orjson
Pillow
weaviate-client==3.15.4
weaviate-client==3.15.5

View File

@@ -66,9 +66,12 @@ class Config(metaclass=Singleton):
self.weaviate_host = os.getenv("WEAVIATE_HOST")
self.weaviate_port = os.getenv("WEAVIATE_PORT")
self.weaviate_protocol = os.getenv("WEAVIATE_PROTOCOL", "http")
self.weaviate_username = os.getenv("WEAVIATE_USERNAME", None)
self.weaviate_password = os.getenv("WEAVIATE_PASSWORD", None)
self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None)
self.weaviate_embedded_path = os.getenv('WEAVIATE_EMBEDDED_PATH', '~/.local/share/weaviate')
self.use_weaviate_embedded = os.getenv("USE_WEAVIATE_EMBEDDED", "False") == "True"
self.image_provider = os.getenv("IMAGE_PROVIDER")
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")

View File

@@ -3,6 +3,7 @@ from memory.base import MemoryProviderSingleton, get_ada_embedding
import uuid
import weaviate
from weaviate import Client
from weaviate.embedded import EmbeddedOptions
from weaviate.util import generate_uuid5
def default_schema(weaviate_index):
@@ -21,9 +22,19 @@ class WeaviateMemory(MemoryProviderSingleton):
def __init__(self, cfg):
auth_credentials = self._build_auth_credentials(cfg)
url = f'{cfg.weaviate_host}:{cfg.weaviate_port}'
url = f'{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}'
if cfg.use_weaviate_embedded:
self.client = Client(embedded_options=EmbeddedOptions(
hostname=cfg.weaviate_host,
port=int(cfg.weaviate_port),
persistence_data_path=cfg.weaviate_embedded_path
))
print(f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}")
else:
self.client = Client(url, auth_client_secret=auth_credentials)
self.client = Client(url, auth_client_secret=auth_credentials)
self.index = cfg.memory_index
self._create_schema()
@@ -59,7 +70,6 @@ class WeaviateMemory(MemoryProviderSingleton):
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}"
def get(self, data):
return self.get_relevant(data, 1)

View File

@@ -13,10 +13,11 @@ from memory.weaviate import WeaviateMemory
from memory.base import get_ada_embedding
@mock.patch.dict(os.environ, {
"WEAVIATE_HOST": "http://127.0.0.1",
"WEAVIATE_HOST": "127.0.0.1",
"WEAVIATE_PROTOCOL": "http",
"WEAVIATE_PORT": "8080",
"WEAVIATE_USERNAME": '',
"WEAVIATE_PASSWORD": '',
"WEAVIATE_USERNAME": "",
"WEAVIATE_PASSWORD": "",
"MEMORY_INDEX": "AutogptTests"
})
class TestWeaviateMemory(unittest.TestCase):
@@ -24,11 +25,24 @@ class TestWeaviateMemory(unittest.TestCase):
In order to run these tests you will need a local instance of
Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose
for creating local instances using docker.
Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded):
USE_WEAVIATE_EMBEDDED=True
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
"""
def setUp(self):
self.cfg = Config()
self.client = Client('http://127.0.0.1:8080')
if self.cfg.use_weaviate_embedded:
from weaviate.embedded import EmbeddedOptions
self.client = Client(embedded_options=EmbeddedOptions(
hostname=self.cfg.weaviate_host,
port=int(self.cfg.weaviate_port),
persistence_data_path=self.cfg.weaviate_embedded_path
))
else:
self.client = Client(f"{self.cfg.weaviate_protocol}://{self.cfg.weaviate_host}:{self.cfg.weaviate_port}")
try:
self.client.schema.delete_class(self.cfg.memory_index)