Integrate plugin.handle_text_embedding hook (#2804)

* add feature custom text embedding in plugin

* black code format

* _get_embedding_with_plugin()

* Fix docstring & type hint

---------

Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
Lei Zhang
2023-07-15 09:10:32 +08:00
committed by GitHub
parent c821b294c6
commit 5ae044f53d
3 changed files with 32 additions and 10 deletions

View File

@@ -1,3 +1,4 @@
from contextlib import suppress
from typing import Any, overload
import numpy as np
@@ -12,12 +13,12 @@ Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]]
@overload
def get_embedding(input: str | TText) -> Embedding:
def get_embedding(input: str | TText, config: Config) -> Embedding:
...
@overload
def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
def get_embedding(input: list[str] | list[TText], config: Config) -> list[Embedding]:
...
@@ -37,9 +38,16 @@ def get_embedding(
if isinstance(input, str):
input = input.replace("\n", " ")
with suppress(NotImplementedError):
return _get_embedding_with_plugin(input, config)
elif multiple and isinstance(input[0], str):
input = [text.replace("\n", " ") for text in input]
with suppress(NotImplementedError):
return [_get_embedding_with_plugin(i, config) for i in input]
model = config.embedding_model
kwargs = {"model": model}
kwargs.update(config.get_openai_credentials(model))
@@ -62,3 +70,13 @@ def get_embedding(
embeddings = sorted(embeddings, key=lambda x: x["index"])
return [d["embedding"] for d in embeddings]
def _get_embedding_with_plugin(text: str, config: Config) -> Embedding:
for plugin in config.plugins:
if plugin.can_handle_text_embedding(text):
embedding = plugin.handle_text_embedding(text)
if embedding is not None:
return embedding
raise NotImplementedError

View File

@@ -198,18 +198,20 @@ class BaseOpenAIPlugin(AutoGPTPluginTemplate):
def can_handle_text_embedding(self, text: str) -> bool:
"""This method is called to check that the plugin can
handle the text_embedding method.
Args:
text (str): The text to be convert to embedding.
Returns:
bool: True if the plugin can handle the text_embedding method."""
return False
def handle_text_embedding(self, text: str) -> list:
"""This method is called when the chat completion is done.
Args:
text (str): The text to be convert to embedding.
Returns:
list: The text embedding.
bool: True if the plugin can handle the text_embedding method."""
return False
def handle_text_embedding(self, text: str) -> list[float]:
"""This method is called to create a text embedding.
Args:
text (str): The text to be convert to embedding.
Returns:
list[float]: The created embedding vector.
"""
def can_handle_user_input(self, user_input: str) -> bool:

View File

@@ -54,6 +54,7 @@ def test_dummy_plugin_default_methods(dummy_plugin):
assert not dummy_plugin.can_handle_pre_command()
assert not dummy_plugin.can_handle_post_command()
assert not dummy_plugin.can_handle_chat_completion(None, None, None, None)
assert not dummy_plugin.can_handle_text_embedding(None)
assert dummy_plugin.on_response("hello") == "hello"
assert dummy_plugin.post_prompt(None) is None
@@ -77,3 +78,4 @@ def test_dummy_plugin_default_methods(dummy_plugin):
assert isinstance(post_command, str)
assert post_command == "upgraded successfully!"
assert dummy_plugin.handle_chat_completion(None, None, None, None) is None
assert dummy_plugin.handle_text_embedding(None) is None