mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-20 14:34:25 +01:00
Merge pull request #1 from cs0lar/feature/weaviate-embedded
Feature/weaviate embedded
This commit is contained in:
@@ -10,8 +10,11 @@ USE_AZURE=False
|
||||
OPENAI_API_BASE=your-base-url-for-azure
|
||||
OPENAI_API_VERSION=api-version-for-azure
|
||||
OPENAI_DEPLOYMENT_ID=deployment-id-for-azure
|
||||
WEAVIATE_HOST="http://127.0.0.1"
|
||||
WEAVIATE_PORT="8080"
|
||||
WEAVIATE_HOST="127.0.0.1"
|
||||
WEAVIATE_PORT=8080
|
||||
WEAVIATE_PROTOCOL="http"
|
||||
USE_WEAVIATE_EMBEDDED=False
|
||||
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
|
||||
WEAVIATE_USERNAME=
|
||||
WEAVIATE_PASSWORD=
|
||||
IMAGE_PROVIDER=dalle
|
||||
|
||||
@@ -15,4 +15,4 @@ pinecone-client==2.2.1
|
||||
redis
|
||||
orjson
|
||||
Pillow
|
||||
weaviate-client==3.15.4
|
||||
weaviate-client==3.15.5
|
||||
|
||||
@@ -66,9 +66,12 @@ class Config(metaclass=Singleton):
|
||||
|
||||
self.weaviate_host = os.getenv("WEAVIATE_HOST")
|
||||
self.weaviate_port = os.getenv("WEAVIATE_PORT")
|
||||
self.weaviate_protocol = os.getenv("WEAVIATE_PROTOCOL", "http")
|
||||
self.weaviate_username = os.getenv("WEAVIATE_USERNAME", None)
|
||||
self.weaviate_password = os.getenv("WEAVIATE_PASSWORD", None)
|
||||
self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None)
|
||||
self.weaviate_embedded_path = os.getenv('WEAVIATE_EMBEDDED_PATH', '~/.local/share/weaviate')
|
||||
self.use_weaviate_embedded = os.getenv("USE_WEAVIATE_EMBEDDED", "False") == "True"
|
||||
|
||||
self.image_provider = os.getenv("IMAGE_PROVIDER")
|
||||
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
|
||||
|
||||
@@ -3,6 +3,7 @@ from memory.base import MemoryProviderSingleton, get_ada_embedding
|
||||
import uuid
|
||||
import weaviate
|
||||
from weaviate import Client
|
||||
from weaviate.embedded import EmbeddedOptions
|
||||
from weaviate.util import generate_uuid5
|
||||
|
||||
def default_schema(weaviate_index):
|
||||
@@ -21,9 +22,19 @@ class WeaviateMemory(MemoryProviderSingleton):
|
||||
def __init__(self, cfg):
|
||||
auth_credentials = self._build_auth_credentials(cfg)
|
||||
|
||||
url = f'{cfg.weaviate_host}:{cfg.weaviate_port}'
|
||||
url = f'{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}'
|
||||
|
||||
if cfg.use_weaviate_embedded:
|
||||
self.client = Client(embedded_options=EmbeddedOptions(
|
||||
hostname=cfg.weaviate_host,
|
||||
port=int(cfg.weaviate_port),
|
||||
persistence_data_path=cfg.weaviate_embedded_path
|
||||
))
|
||||
|
||||
print(f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}")
|
||||
else:
|
||||
self.client = Client(url, auth_client_secret=auth_credentials)
|
||||
|
||||
self.client = Client(url, auth_client_secret=auth_credentials)
|
||||
self.index = cfg.memory_index
|
||||
self._create_schema()
|
||||
|
||||
@@ -59,7 +70,6 @@ class WeaviateMemory(MemoryProviderSingleton):
|
||||
|
||||
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}"
|
||||
|
||||
|
||||
def get(self, data):
|
||||
return self.get_relevant(data, 1)
|
||||
|
||||
|
||||
@@ -13,10 +13,11 @@ from memory.weaviate import WeaviateMemory
|
||||
from memory.base import get_ada_embedding
|
||||
|
||||
@mock.patch.dict(os.environ, {
|
||||
"WEAVIATE_HOST": "http://127.0.0.1",
|
||||
"WEAVIATE_HOST": "127.0.0.1",
|
||||
"WEAVIATE_PROTOCOL": "http",
|
||||
"WEAVIATE_PORT": "8080",
|
||||
"WEAVIATE_USERNAME": '',
|
||||
"WEAVIATE_PASSWORD": '',
|
||||
"WEAVIATE_USERNAME": "",
|
||||
"WEAVIATE_PASSWORD": "",
|
||||
"MEMORY_INDEX": "AutogptTests"
|
||||
})
|
||||
class TestWeaviateMemory(unittest.TestCase):
|
||||
@@ -24,11 +25,24 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
In order to run these tests you will need a local instance of
|
||||
Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose
|
||||
for creating local instances using docker.
|
||||
Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded):
|
||||
|
||||
USE_WEAVIATE_EMBEDDED=True
|
||||
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
|
||||
"""
|
||||
def setUp(self):
|
||||
self.cfg = Config()
|
||||
|
||||
self.client = Client('http://127.0.0.1:8080')
|
||||
if self.cfg.use_weaviate_embedded:
|
||||
from weaviate.embedded import EmbeddedOptions
|
||||
|
||||
self.client = Client(embedded_options=EmbeddedOptions(
|
||||
hostname=self.cfg.weaviate_host,
|
||||
port=int(self.cfg.weaviate_port),
|
||||
persistence_data_path=self.cfg.weaviate_embedded_path
|
||||
))
|
||||
else:
|
||||
self.client = Client(f"{self.cfg.weaviate_protocol}://{self.cfg.weaviate_host}:{self.cfg.weaviate_port}")
|
||||
|
||||
try:
|
||||
self.client.schema.delete_class(self.cfg.memory_index)
|
||||
|
||||
Reference in New Issue
Block a user