Files
Auto-GPT/autogpt/memory/vector/utils.py
James Collins 6e6e7fcc9a Extract openai API calls and retry at lowest level (#3696)
* Extract open ai api calls and retry at lowest level

* Forgot a test

* Gotta fix my local docker config so I can let pre-commit hooks run, ugh

* fix: merge artiface

* Fix linting

* Update memory.vector.utils

* feat: make sure resp exists

* fix: raise error message if created

* feat: rename file

* fix: partial test fix

* fix: update comments

* fix: linting

* fix: remove broken test

* fix: require a model to exist

* fix: BaseError issue

* fix: runtime error

* Fix mock response in test_make_agent

* add 429 as errors to retry

---------

Co-authored-by: k-boikov <64261260+k-boikov@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nick@ntindle.com>
Co-authored-by: Reinier van der Leer <github@pwuts.nl>
Co-authored-by: Nicholas Tindle <nicktindle@outlook.com>
Co-authored-by: Luke K (pr-0f3t) <2609441+lc0rp@users.noreply.github.com>
Co-authored-by: Merwane Hamadi <merwanehamadi@gmail.com>
2023-06-14 07:59:26 -07:00

67 lines
1.8 KiB
Python

from typing import Any, overload
import numpy as np
from autogpt.config import Config
from autogpt.llm.base import TText
from autogpt.llm.providers import openai as iopenai
from autogpt.logs import logger
Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]]
"""Embedding vector"""
@overload
def get_embedding(input: str | TText) -> Embedding:
...
@overload
def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
...
def get_embedding(
input: str | TText | list[str] | list[TText],
) -> Embedding | list[Embedding]:
"""Get an embedding from the ada model.
Args:
input: Input text to get embeddings for, encoded as a string or array of tokens.
Multiple inputs may be given as a list of strings or token arrays.
Returns:
List[float]: The embedding.
"""
cfg = Config()
multiple = isinstance(input, list) and all(not isinstance(i, int) for i in input)
if isinstance(input, str):
input = input.replace("\n", " ")
elif multiple and isinstance(input[0], str):
input = [text.replace("\n", " ") for text in input]
model = cfg.embedding_model
if cfg.use_azure:
kwargs = {"engine": cfg.get_azure_deployment_id_for_model(model)}
else:
kwargs = {"model": model}
logger.debug(
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
f" with model '{model}'"
+ (f" via Azure deployment '{kwargs['engine']}'" if cfg.use_azure else "")
)
embeddings = iopenai.create_embedding(
input,
**kwargs,
api_key=cfg.openai_api_key,
).data
if not multiple:
return embeddings[0]["embedding"]
embeddings = sorted(embeddings, key=lambda x: x["index"])
return [d["embedding"] for d in embeddings]