Files
Auto-GPT/autogpt/memory/vector/utils.py
Jayden ac17518663 Fix Azure OpenAI setup problems (#4875)
* [Fix] Recover the azure config load function

* [Style] Apply black, isort, mypy, autoflake

* [Fix] Rename the return parameter from 'azure_model_map' to 'azure_model_to_deployment_id_map'

* [Feat] Change the azure config file path to be dynamically configurable

* [Test] Add azure_config and azure deployment_id_for_model

* [Style] Apply black, isort, mypy, autoflake

* [Style] Apply black, isort, mypy, autoflake

* Refactor Azure configuration

- Refactor the `azure_config_file` attribute in the `Config` class to be optional.
- Refactor the `azure_model_to_deployment_id_map` attribute in the `Config` class to be optional and provide default values.
- Update the `get_azure_deployment_id_for_model` function to accept additional parameters.
- Update references to `get_azure_deployment_id_for_model` in `create_text_completion`, `create_chat_completion`, and `get_embedding` functions to pass the required parameters.

* Clean up process for azure

* Docstring

* revert some unneccessary fiddling

* Avoid altering args to models

* Retry on 404s

* Don't permanently change the environment

* Formatting

---------

Co-authored-by: Luke <2609441+lc0rp@users.noreply.github.com>
Co-authored-by: lc0rp <2609411+lc0rp@users.noreply.github.com>
Co-authored-by: collijk <collijk@uw.edu>
2023-07-06 17:51:59 -07:00

68 lines
1.8 KiB
Python

from typing import Any, overload
import numpy as np
from autogpt.config import Config
from autogpt.llm.base import TText
from autogpt.llm.providers import openai as iopenai
from autogpt.logs import logger
Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]]
"""Embedding vector"""
@overload
def get_embedding(input: str | TText) -> Embedding:
...
@overload
def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
...
def get_embedding(
input: str | TText | list[str] | list[TText], config: Config
) -> Embedding | list[Embedding]:
"""Get an embedding from the ada model.
Args:
input: Input text to get embeddings for, encoded as a string or array of tokens.
Multiple inputs may be given as a list of strings or token arrays.
Returns:
List[float]: The embedding.
"""
multiple = isinstance(input, list) and all(not isinstance(i, int) for i in input)
if isinstance(input, str):
input = input.replace("\n", " ")
elif multiple and isinstance(input[0], str):
input = [text.replace("\n", " ") for text in input]
model = config.embedding_model
if config.use_azure:
kwargs = config.get_azure_kwargs(model)
else:
kwargs = {"model": model}
logger.debug(
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
f" with model '{model}'"
+ (f" via Azure deployment '{kwargs['engine']}'" if config.use_azure else "")
)
if config.use_azure:
breakpoint()
embeddings = iopenai.create_embedding(
input,
**kwargs,
api_key=config.openai_api_key,
).data
if not multiple:
return embeddings[0]["embedding"]
embeddings = sorted(embeddings, key=lambda x: x["index"])
return [d["embedding"] for d in embeddings]