mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-30 04:14:24 +01:00
* Extract open ai api calls and retry at lowest level * Forgot a test * Gotta fix my local docker config so I can let pre-commit hooks run, ugh * fix: merge artiface * Fix linting * Update memory.vector.utils * feat: make sure resp exists * fix: raise error message if created * feat: rename file * fix: partial test fix * fix: update comments * fix: linting * fix: remove broken test * fix: require a model to exist * fix: BaseError issue * fix: runtime error * Fix mock response in test_make_agent * add 429 as errors to retry --------- Co-authored-by: k-boikov <64261260+k-boikov@users.noreply.github.com> Co-authored-by: Nicholas Tindle <nick@ntindle.com> Co-authored-by: Reinier van der Leer <github@pwuts.nl> Co-authored-by: Nicholas Tindle <nicktindle@outlook.com> Co-authored-by: Luke K (pr-0f3t) <2609441+lc0rp@users.noreply.github.com> Co-authored-by: Merwane Hamadi <merwanehamadi@gmail.com>
67 lines
1.8 KiB
Python
67 lines
1.8 KiB
Python
from typing import Any, overload
|
|
|
|
import numpy as np
|
|
|
|
from autogpt.config import Config
|
|
from autogpt.llm.base import TText
|
|
from autogpt.llm.providers import openai as iopenai
|
|
from autogpt.logs import logger
|
|
|
|
Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]]
|
|
"""Embedding vector"""
|
|
|
|
|
|
@overload
|
|
def get_embedding(input: str | TText) -> Embedding:
|
|
...
|
|
|
|
|
|
@overload
|
|
def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
|
|
...
|
|
|
|
|
|
def get_embedding(
|
|
input: str | TText | list[str] | list[TText],
|
|
) -> Embedding | list[Embedding]:
|
|
"""Get an embedding from the ada model.
|
|
|
|
Args:
|
|
input: Input text to get embeddings for, encoded as a string or array of tokens.
|
|
Multiple inputs may be given as a list of strings or token arrays.
|
|
|
|
Returns:
|
|
List[float]: The embedding.
|
|
"""
|
|
cfg = Config()
|
|
multiple = isinstance(input, list) and all(not isinstance(i, int) for i in input)
|
|
|
|
if isinstance(input, str):
|
|
input = input.replace("\n", " ")
|
|
elif multiple and isinstance(input[0], str):
|
|
input = [text.replace("\n", " ") for text in input]
|
|
|
|
model = cfg.embedding_model
|
|
if cfg.use_azure:
|
|
kwargs = {"engine": cfg.get_azure_deployment_id_for_model(model)}
|
|
else:
|
|
kwargs = {"model": model}
|
|
|
|
logger.debug(
|
|
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
|
|
f" with model '{model}'"
|
|
+ (f" via Azure deployment '{kwargs['engine']}'" if cfg.use_azure else "")
|
|
)
|
|
|
|
embeddings = iopenai.create_embedding(
|
|
input,
|
|
**kwargs,
|
|
api_key=cfg.openai_api_key,
|
|
).data
|
|
|
|
if not multiple:
|
|
return embeddings[0]["embedding"]
|
|
|
|
embeddings = sorted(embeddings, key=lambda x: x["index"])
|
|
return [d["embedding"] for d in embeddings]
|