From 50e5ea4e54569ac4a0259d2cbfe07f453dfef5ad Mon Sep 17 00:00:00 2001 From: edwardsp Date: Thu, 29 Feb 2024 17:35:06 +0000 Subject: [PATCH] fix(agent/llm): Fix support for AzureOpenAI (#6927) * Fix unmasking of `azure_endpoint` in `OpenAICredentials.get_api_access_kwargs()` * Amend `ApiManager.get_models` to use `AzureOpenAI` client when `api_type` is set to `azure` --------- Co-authored-by: Reinier van der Leer --- .../core/resource/model_providers/openai.py | 3 ++- autogpts/autogpt/autogpt/llm/api_manager.py | 17 +++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py index 37771ea4..1b564b6d 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py @@ -243,7 +243,8 @@ class OpenAICredentials(ModelProviderCredentials): } if self.api_type == "azure": kwargs["api_version"] = self.api_version - kwargs["azure_endpoint"] = self.azure_endpoint + assert self.azure_endpoint, "Azure endpoint not configured" + kwargs["azure_endpoint"] = self.azure_endpoint.get_secret_value() return kwargs def get_model_access_kwargs(self, model: str) -> dict[str, str]: diff --git a/autogpts/autogpt/autogpt/llm/api_manager.py b/autogpts/autogpt/autogpt/llm/api_manager.py index 35d28d63..4e5e0b95 100644 --- a/autogpts/autogpt/autogpt/llm/api_manager.py +++ b/autogpts/autogpt/autogpt/llm/api_manager.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging from typing import List, Optional -from openai import OpenAI +from openai import OpenAI, AzureOpenAI from openai.types import Model from autogpt.core.resource.model_providers.openai import ( @@ -107,9 +107,18 @@ class ApiManager(metaclass=Singleton): list[Model]: List of available GPT models. """ if self.models is None: - all_models = ( - OpenAI(**openai_credentials.get_api_access_kwargs()).models.list().data - ) + if openai_credentials.api_type == "azure": + all_models = ( + AzureOpenAI(**openai_credentials.get_api_access_kwargs()) + .models.list() + .data + ) + else: + all_models = ( + OpenAI(**openai_credentials.get_api_access_kwargs()) + .models.list() + .data + ) self.models = [model for model in all_models if "gpt" in model.id] return self.models