mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-02-08 07:44:22 +01:00
Fix --gpt3only and --gpt4only for Azure (#4098)
* Fix --gpt3only and --gpt4only * Fix and consolidate test_config.py::test_azure_config (x2) --------- Co-authored-by: Luke K (pr-0f3t) <2609441+lc0rp@users.noreply.github.com> Co-authored-by: Ryan <eimu.gray@gmail.com> Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
@@ -13,7 +13,8 @@ from autogpt.core.configuration.schema import Configurable, SystemSettings
|
||||
from autogpt.plugins.plugins_config import PluginsConfig
|
||||
|
||||
AZURE_CONFIG_FILE = os.path.join(os.path.dirname(__file__), "../..", "azure.yaml")
|
||||
from typing import Optional
|
||||
GPT_4_MODEL = "gpt-4"
|
||||
GPT_3_MODEL = "gpt-3.5-turbo"
|
||||
|
||||
|
||||
class Config(SystemSettings):
|
||||
@@ -87,20 +88,39 @@ class Config(SystemSettings):
|
||||
|
||||
def get_azure_kwargs(self, model: str) -> dict[str, str]:
|
||||
"""Get the kwargs for the Azure API."""
|
||||
|
||||
# Fix --gpt3only and --gpt4only in combination with Azure
|
||||
fast_llm = (
|
||||
self.fast_llm
|
||||
if not (
|
||||
self.fast_llm == self.smart_llm
|
||||
and self.fast_llm.startswith(GPT_4_MODEL)
|
||||
)
|
||||
else f"not_{self.fast_llm}"
|
||||
)
|
||||
smart_llm = (
|
||||
self.smart_llm
|
||||
if not (
|
||||
self.smart_llm == self.fast_llm
|
||||
and self.smart_llm.startswith(GPT_3_MODEL)
|
||||
)
|
||||
else f"not_{self.smart_llm}"
|
||||
)
|
||||
|
||||
deployment_id = {
|
||||
self.fast_llm: self.azure_model_to_deployment_id_map.get(
|
||||
fast_llm: self.azure_model_to_deployment_id_map.get(
|
||||
"fast_llm_deployment_id",
|
||||
self.azure_model_to_deployment_id_map.get(
|
||||
"fast_llm_model_deployment_id" # backwards compatibility
|
||||
),
|
||||
),
|
||||
self.smart_llm: self.azure_model_to_deployment_id_map.get(
|
||||
smart_llm: self.azure_model_to_deployment_id_map.get(
|
||||
"smart_llm_deployment_id",
|
||||
self.azure_model_to_deployment_id_map.get(
|
||||
"smart_llm_model_deployment_id" # backwards compatibility
|
||||
),
|
||||
),
|
||||
"text-embedding-ada-002": self.azure_model_to_deployment_id_map.get(
|
||||
self.embedding_model: self.azure_model_to_deployment_id_map.get(
|
||||
"embedding_model_deployment_id"
|
||||
),
|
||||
}.get(model, None)
|
||||
@@ -110,7 +130,7 @@ class Config(SystemSettings):
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
}
|
||||
if model == "text-embedding-ada-002":
|
||||
if model == self.embedding_model:
|
||||
kwargs["engine"] = deployment_id
|
||||
else:
|
||||
kwargs["deployment_id"] = deployment_id
|
||||
@@ -272,12 +292,7 @@ class ConfigBuilder(Configurable[Config]):
|
||||
|
||||
if config_dict["use_azure"]:
|
||||
azure_config = cls.load_azure_config(config_dict["azure_config_file"])
|
||||
config_dict["openai_api_type"] = azure_config["openai_api_type"]
|
||||
config_dict["openai_api_base"] = azure_config["openai_api_base"]
|
||||
config_dict["openai_api_version"] = azure_config["openai_api_version"]
|
||||
config_dict["azure_model_to_deployment_id_map"] = azure_config[
|
||||
"azure_model_to_deployment_id_map"
|
||||
]
|
||||
config_dict.update(azure_config)
|
||||
|
||||
elif os.getenv("OPENAI_API_BASE_URL"):
|
||||
config_dict["openai_api_base"] = os.getenv("OPENAI_API_BASE_URL")
|
||||
|
||||
@@ -7,6 +7,7 @@ import click
|
||||
from colorama import Back, Fore, Style
|
||||
|
||||
from autogpt import utils
|
||||
from autogpt.config.config import GPT_3_MODEL, GPT_4_MODEL
|
||||
from autogpt.llm.utils import check_model
|
||||
from autogpt.logs import logger
|
||||
from autogpt.memory.vector import get_supported_memory_backends
|
||||
@@ -14,9 +15,6 @@ from autogpt.memory.vector import get_supported_memory_backends
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
GPT_4_MODEL = "gpt-4"
|
||||
GPT_3_MODEL = "gpt-3.5-turbo"
|
||||
|
||||
|
||||
def create_config(
|
||||
config: Config,
|
||||
|
||||
@@ -3,5 +3,5 @@ azure_api_base: your-base-url-for-azure
|
||||
azure_api_version: api-version-for-azure
|
||||
azure_model_map:
|
||||
fast_llm_deployment_id: gpt35-deployment-id-for-azure
|
||||
smart_llm_deployment_id: gpt4-deployment-id-for-azure
|
||||
smart_llm_deployment_id: gpt4-deployment-id-for-azure
|
||||
embedding_model_deployment_id: embedding-deployment-id-for-azure
|
||||
|
||||
@@ -146,18 +146,19 @@ def test_missing_azure_config(workspace: Workspace):
|
||||
assert azure_config["azure_model_to_deployment_id_map"] == {}
|
||||
|
||||
|
||||
def test_azure_config(workspace: Workspace) -> None:
|
||||
yaml_content = """
|
||||
def test_azure_config(config: Config, workspace: Workspace) -> None:
|
||||
config_file = workspace.get_path("azure_config.yaml")
|
||||
yaml_content = f"""
|
||||
azure_api_type: azure
|
||||
azure_api_base: https://dummy.openai.azure.com
|
||||
azure_api_version: 2023-06-01-preview
|
||||
azure_model_map:
|
||||
fast_llm_deployment_id: gpt-3.5-turbo
|
||||
smart_llm_deployment_id: gpt-4
|
||||
fast_llm_deployment_id: FAST-LLM_ID
|
||||
smart_llm_deployment_id: SMART-LLM_ID
|
||||
embedding_model_deployment_id: embedding-deployment-id-for-azure
|
||||
"""
|
||||
config_file = workspace.get_path("azure.yaml")
|
||||
config_file.write_text(yaml_content)
|
||||
|
||||
os.environ["USE_AZURE"] = "True"
|
||||
os.environ["AZURE_CONFIG_FILE"] = str(config_file)
|
||||
config = ConfigBuilder.build_config_from_env()
|
||||
@@ -166,53 +167,31 @@ azure_model_map:
|
||||
assert config.openai_api_base == "https://dummy.openai.azure.com"
|
||||
assert config.openai_api_version == "2023-06-01-preview"
|
||||
assert config.azure_model_to_deployment_id_map == {
|
||||
"fast_llm_deployment_id": "gpt-3.5-turbo",
|
||||
"smart_llm_deployment_id": "gpt-4",
|
||||
"fast_llm_deployment_id": "FAST-LLM_ID",
|
||||
"smart_llm_deployment_id": "SMART-LLM_ID",
|
||||
"embedding_model_deployment_id": "embedding-deployment-id-for-azure",
|
||||
}
|
||||
|
||||
del os.environ["USE_AZURE"]
|
||||
del os.environ["AZURE_CONFIG_FILE"]
|
||||
fast_llm = config.fast_llm
|
||||
smart_llm = config.smart_llm
|
||||
assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID"
|
||||
assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID"
|
||||
|
||||
# Emulate --gpt4only
|
||||
config.fast_llm = smart_llm
|
||||
assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "SMART-LLM_ID"
|
||||
assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID"
|
||||
|
||||
def test_azure_deployment_id_for_model(workspace: Workspace) -> None:
|
||||
yaml_content = """
|
||||
azure_api_type: azure
|
||||
azure_api_base: https://dummy.openai.azure.com
|
||||
azure_api_version: 2023-06-01-preview
|
||||
azure_model_map:
|
||||
fast_llm_deployment_id: gpt-3.5-turbo
|
||||
smart_llm_deployment_id: gpt-4
|
||||
embedding_model_deployment_id: embedding-deployment-id-for-azure
|
||||
"""
|
||||
config_file = workspace.get_path("azure.yaml")
|
||||
config_file.write_text(yaml_content)
|
||||
os.environ["USE_AZURE"] = "True"
|
||||
os.environ["AZURE_CONFIG_FILE"] = str(config_file)
|
||||
config = ConfigBuilder.build_config_from_env()
|
||||
|
||||
config.fast_llm = "fast_llm"
|
||||
config.smart_llm = "smart_llm"
|
||||
|
||||
def _get_deployment_id(model):
|
||||
kwargs = config.get_azure_kwargs(model)
|
||||
return kwargs.get("deployment_id", kwargs.get("engine"))
|
||||
|
||||
assert _get_deployment_id(config.fast_llm) == "gpt-3.5-turbo"
|
||||
assert _get_deployment_id(config.smart_llm) == "gpt-4"
|
||||
assert (
|
||||
_get_deployment_id("text-embedding-ada-002")
|
||||
== "embedding-deployment-id-for-azure"
|
||||
)
|
||||
assert _get_deployment_id("dummy") is None
|
||||
# Emulate --gpt3only
|
||||
config.fast_llm = config.smart_llm = fast_llm
|
||||
assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID"
|
||||
assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "FAST-LLM_ID"
|
||||
|
||||
del os.environ["USE_AZURE"]
|
||||
del os.environ["AZURE_CONFIG_FILE"]
|
||||
|
||||
|
||||
def test_create_config_gpt4only(config: Config) -> None:
|
||||
fast_llm = config.fast_llm
|
||||
smart_llm = config.smart_llm
|
||||
with mock.patch("autogpt.llm.api_manager.ApiManager.get_models") as mock_get_models:
|
||||
mock_get_models.return_value = [{"id": GPT_4_MODEL}]
|
||||
create_config(
|
||||
@@ -234,14 +213,8 @@ def test_create_config_gpt4only(config: Config) -> None:
|
||||
assert config.fast_llm == GPT_4_MODEL
|
||||
assert config.smart_llm == GPT_4_MODEL
|
||||
|
||||
# Reset config
|
||||
config.fast_llm = fast_llm
|
||||
config.smart_llm = smart_llm
|
||||
|
||||
|
||||
def test_create_config_gpt3only(config: Config) -> None:
|
||||
fast_llm = config.fast_llm
|
||||
smart_llm = config.smart_llm
|
||||
with mock.patch("autogpt.llm.api_manager.ApiManager.get_models") as mock_get_models:
|
||||
mock_get_models.return_value = [{"id": GPT_3_MODEL}]
|
||||
create_config(
|
||||
@@ -262,7 +235,3 @@ def test_create_config_gpt3only(config: Config) -> None:
|
||||
)
|
||||
assert config.fast_llm == GPT_3_MODEL
|
||||
assert config.smart_llm == GPT_3_MODEL
|
||||
|
||||
# Reset config
|
||||
config.fast_llm = fast_llm
|
||||
config.smart_llm = smart_llm
|
||||
|
||||
Reference in New Issue
Block a user