mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-26 02:14:27 +01:00
* [Fix] Recover the azure config load function * [Style] Apply black, isort, mypy, autoflake * [Fix] Rename the return parameter from 'azure_model_map' to 'azure_model_to_deployment_id_map' * [Feat] Change the azure config file path to be dynamically configurable * [Test] Add azure_config and azure deployment_id_for_model * [Style] Apply black, isort, mypy, autoflake * [Style] Apply black, isort, mypy, autoflake * Refactor Azure configuration - Refactor the `azure_config_file` attribute in the `Config` class to be optional. - Refactor the `azure_model_to_deployment_id_map` attribute in the `Config` class to be optional and provide default values. - Update the `get_azure_deployment_id_for_model` function to accept additional parameters. - Update references to `get_azure_deployment_id_for_model` in `create_text_completion`, `create_chat_completion`, and `get_embedding` functions to pass the required parameters. * Clean up process for azure * Docstring * revert some unneccessary fiddling * Avoid altering args to models * Retry on 404s * Don't permanently change the environment * Formatting --------- Co-authored-by: Luke <2609441+lc0rp@users.noreply.github.com> Co-authored-by: lc0rp <2609411+lc0rp@users.noreply.github.com> Co-authored-by: collijk <collijk@uw.edu>
68 lines
1.8 KiB
Python
68 lines
1.8 KiB
Python
from typing import Any, overload
|
|
|
|
import numpy as np
|
|
|
|
from autogpt.config import Config
|
|
from autogpt.llm.base import TText
|
|
from autogpt.llm.providers import openai as iopenai
|
|
from autogpt.logs import logger
|
|
|
|
Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]]
|
|
"""Embedding vector"""
|
|
|
|
|
|
@overload
|
|
def get_embedding(input: str | TText) -> Embedding:
|
|
...
|
|
|
|
|
|
@overload
|
|
def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
|
|
...
|
|
|
|
|
|
def get_embedding(
|
|
input: str | TText | list[str] | list[TText], config: Config
|
|
) -> Embedding | list[Embedding]:
|
|
"""Get an embedding from the ada model.
|
|
|
|
Args:
|
|
input: Input text to get embeddings for, encoded as a string or array of tokens.
|
|
Multiple inputs may be given as a list of strings or token arrays.
|
|
|
|
Returns:
|
|
List[float]: The embedding.
|
|
"""
|
|
multiple = isinstance(input, list) and all(not isinstance(i, int) for i in input)
|
|
|
|
if isinstance(input, str):
|
|
input = input.replace("\n", " ")
|
|
elif multiple and isinstance(input[0], str):
|
|
input = [text.replace("\n", " ") for text in input]
|
|
|
|
model = config.embedding_model
|
|
if config.use_azure:
|
|
kwargs = config.get_azure_kwargs(model)
|
|
else:
|
|
kwargs = {"model": model}
|
|
|
|
logger.debug(
|
|
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
|
|
f" with model '{model}'"
|
|
+ (f" via Azure deployment '{kwargs['engine']}'" if config.use_azure else "")
|
|
)
|
|
if config.use_azure:
|
|
breakpoint()
|
|
|
|
embeddings = iopenai.create_embedding(
|
|
input,
|
|
**kwargs,
|
|
api_key=config.openai_api_key,
|
|
).data
|
|
|
|
if not multiple:
|
|
return embeddings[0]["embedding"]
|
|
|
|
embeddings = sorted(embeddings, key=lambda x: x["index"])
|
|
return [d["embedding"] for d in embeddings]
|