mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-01 21:34:27 +01:00
Integrate plugin.handle_text_embedding hook (#2804)
* add feature custom text embedding in plugin * black code format * _get_embedding_with_plugin() * Fix docstring & type hint --------- Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from contextlib import suppress
|
||||
from typing import Any, overload
|
||||
|
||||
import numpy as np
|
||||
@@ -12,12 +13,12 @@ Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]]
|
||||
|
||||
|
||||
@overload
|
||||
def get_embedding(input: str | TText) -> Embedding:
|
||||
def get_embedding(input: str | TText, config: Config) -> Embedding:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
|
||||
def get_embedding(input: list[str] | list[TText], config: Config) -> list[Embedding]:
|
||||
...
|
||||
|
||||
|
||||
@@ -37,9 +38,16 @@ def get_embedding(
|
||||
|
||||
if isinstance(input, str):
|
||||
input = input.replace("\n", " ")
|
||||
|
||||
with suppress(NotImplementedError):
|
||||
return _get_embedding_with_plugin(input, config)
|
||||
|
||||
elif multiple and isinstance(input[0], str):
|
||||
input = [text.replace("\n", " ") for text in input]
|
||||
|
||||
with suppress(NotImplementedError):
|
||||
return [_get_embedding_with_plugin(i, config) for i in input]
|
||||
|
||||
model = config.embedding_model
|
||||
kwargs = {"model": model}
|
||||
kwargs.update(config.get_openai_credentials(model))
|
||||
@@ -62,3 +70,13 @@ def get_embedding(
|
||||
|
||||
embeddings = sorted(embeddings, key=lambda x: x["index"])
|
||||
return [d["embedding"] for d in embeddings]
|
||||
|
||||
|
||||
def _get_embedding_with_plugin(text: str, config: Config) -> Embedding:
|
||||
for plugin in config.plugins:
|
||||
if plugin.can_handle_text_embedding(text):
|
||||
embedding = plugin.handle_text_embedding(text)
|
||||
if embedding is not None:
|
||||
return embedding
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -198,18 +198,20 @@ class BaseOpenAIPlugin(AutoGPTPluginTemplate):
|
||||
def can_handle_text_embedding(self, text: str) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the text_embedding method.
|
||||
Args:
|
||||
text (str): The text to be convert to embedding.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the text_embedding method."""
|
||||
return False
|
||||
|
||||
def handle_text_embedding(self, text: str) -> list:
|
||||
"""This method is called when the chat completion is done.
|
||||
Args:
|
||||
text (str): The text to be convert to embedding.
|
||||
Returns:
|
||||
list: The text embedding.
|
||||
bool: True if the plugin can handle the text_embedding method."""
|
||||
return False
|
||||
|
||||
def handle_text_embedding(self, text: str) -> list[float]:
|
||||
"""This method is called to create a text embedding.
|
||||
|
||||
Args:
|
||||
text (str): The text to be convert to embedding.
|
||||
Returns:
|
||||
list[float]: The created embedding vector.
|
||||
"""
|
||||
|
||||
def can_handle_user_input(self, user_input: str) -> bool:
|
||||
|
||||
@@ -54,6 +54,7 @@ def test_dummy_plugin_default_methods(dummy_plugin):
|
||||
assert not dummy_plugin.can_handle_pre_command()
|
||||
assert not dummy_plugin.can_handle_post_command()
|
||||
assert not dummy_plugin.can_handle_chat_completion(None, None, None, None)
|
||||
assert not dummy_plugin.can_handle_text_embedding(None)
|
||||
|
||||
assert dummy_plugin.on_response("hello") == "hello"
|
||||
assert dummy_plugin.post_prompt(None) is None
|
||||
@@ -77,3 +78,4 @@ def test_dummy_plugin_default_methods(dummy_plugin):
|
||||
assert isinstance(post_command, str)
|
||||
assert post_command == "upgraded successfully!"
|
||||
assert dummy_plugin.handle_chat_completion(None, None, None, None) is None
|
||||
assert dummy_plugin.handle_text_embedding(None) is None
|
||||
|
||||
Reference in New Issue
Block a user