From 5ae044f53db4af1b8a54ef8c7e2afb17e67568b9 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 15 Jul 2023 09:10:32 +0800 Subject: [PATCH] Integrate `plugin.handle_text_embedding` hook (#2804) * add feature custom text embedding in plugin * black code format * _get_embedding_with_plugin() * Fix docstring & type hint --------- Co-authored-by: Reinier van der Leer --- autogpt/memory/vector/utils.py | 22 +++++++++++++++++-- autogpt/models/base_open_ai_plugin.py | 18 ++++++++------- .../unit/models/test_base_open_api_plugin.py | 2 ++ 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/autogpt/memory/vector/utils.py b/autogpt/memory/vector/utils.py index eb691256..1b050d56 100644 --- a/autogpt/memory/vector/utils.py +++ b/autogpt/memory/vector/utils.py @@ -1,3 +1,4 @@ +from contextlib import suppress from typing import Any, overload import numpy as np @@ -12,12 +13,12 @@ Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]] @overload -def get_embedding(input: str | TText) -> Embedding: +def get_embedding(input: str | TText, config: Config) -> Embedding: ... @overload -def get_embedding(input: list[str] | list[TText]) -> list[Embedding]: +def get_embedding(input: list[str] | list[TText], config: Config) -> list[Embedding]: ... @@ -37,9 +38,16 @@ def get_embedding( if isinstance(input, str): input = input.replace("\n", " ") + + with suppress(NotImplementedError): + return _get_embedding_with_plugin(input, config) + elif multiple and isinstance(input[0], str): input = [text.replace("\n", " ") for text in input] + with suppress(NotImplementedError): + return [_get_embedding_with_plugin(i, config) for i in input] + model = config.embedding_model kwargs = {"model": model} kwargs.update(config.get_openai_credentials(model)) @@ -62,3 +70,13 @@ def get_embedding( embeddings = sorted(embeddings, key=lambda x: x["index"]) return [d["embedding"] for d in embeddings] + + +def _get_embedding_with_plugin(text: str, config: Config) -> Embedding: + for plugin in config.plugins: + if plugin.can_handle_text_embedding(text): + embedding = plugin.handle_text_embedding(text) + if embedding is not None: + return embedding + + raise NotImplementedError diff --git a/autogpt/models/base_open_ai_plugin.py b/autogpt/models/base_open_ai_plugin.py index c0aac8ed..60f6f91b 100644 --- a/autogpt/models/base_open_ai_plugin.py +++ b/autogpt/models/base_open_ai_plugin.py @@ -198,18 +198,20 @@ class BaseOpenAIPlugin(AutoGPTPluginTemplate): def can_handle_text_embedding(self, text: str) -> bool: """This method is called to check that the plugin can handle the text_embedding method. - Args: - text (str): The text to be convert to embedding. - Returns: - bool: True if the plugin can handle the text_embedding method.""" - return False - def handle_text_embedding(self, text: str) -> list: - """This method is called when the chat completion is done. Args: text (str): The text to be convert to embedding. Returns: - list: The text embedding. + bool: True if the plugin can handle the text_embedding method.""" + return False + + def handle_text_embedding(self, text: str) -> list[float]: + """This method is called to create a text embedding. + + Args: + text (str): The text to be convert to embedding. + Returns: + list[float]: The created embedding vector. """ def can_handle_user_input(self, user_input: str) -> bool: diff --git a/tests/unit/models/test_base_open_api_plugin.py b/tests/unit/models/test_base_open_api_plugin.py index 4d41eddd..e656f464 100644 --- a/tests/unit/models/test_base_open_api_plugin.py +++ b/tests/unit/models/test_base_open_api_plugin.py @@ -54,6 +54,7 @@ def test_dummy_plugin_default_methods(dummy_plugin): assert not dummy_plugin.can_handle_pre_command() assert not dummy_plugin.can_handle_post_command() assert not dummy_plugin.can_handle_chat_completion(None, None, None, None) + assert not dummy_plugin.can_handle_text_embedding(None) assert dummy_plugin.on_response("hello") == "hello" assert dummy_plugin.post_prompt(None) is None @@ -77,3 +78,4 @@ def test_dummy_plugin_default_methods(dummy_plugin): assert isinstance(post_command, str) assert post_command == "upgraded successfully!" assert dummy_plugin.handle_chat_completion(None, None, None, None) is None + assert dummy_plugin.handle_text_embedding(None) is None