diff --git a/autogpt/config/config.py b/autogpt/config/config.py index fc76d084..05590eb6 100644 --- a/autogpt/config/config.py +++ b/autogpt/config/config.py @@ -13,7 +13,8 @@ from autogpt.core.configuration.schema import Configurable, SystemSettings from autogpt.plugins.plugins_config import PluginsConfig AZURE_CONFIG_FILE = os.path.join(os.path.dirname(__file__), "../..", "azure.yaml") -from typing import Optional +GPT_4_MODEL = "gpt-4" +GPT_3_MODEL = "gpt-3.5-turbo" class Config(SystemSettings): @@ -87,20 +88,39 @@ class Config(SystemSettings): def get_azure_kwargs(self, model: str) -> dict[str, str]: """Get the kwargs for the Azure API.""" + + # Fix --gpt3only and --gpt4only in combination with Azure + fast_llm = ( + self.fast_llm + if not ( + self.fast_llm == self.smart_llm + and self.fast_llm.startswith(GPT_4_MODEL) + ) + else f"not_{self.fast_llm}" + ) + smart_llm = ( + self.smart_llm + if not ( + self.smart_llm == self.fast_llm + and self.smart_llm.startswith(GPT_3_MODEL) + ) + else f"not_{self.smart_llm}" + ) + deployment_id = { - self.fast_llm: self.azure_model_to_deployment_id_map.get( + fast_llm: self.azure_model_to_deployment_id_map.get( "fast_llm_deployment_id", self.azure_model_to_deployment_id_map.get( "fast_llm_model_deployment_id" # backwards compatibility ), ), - self.smart_llm: self.azure_model_to_deployment_id_map.get( + smart_llm: self.azure_model_to_deployment_id_map.get( "smart_llm_deployment_id", self.azure_model_to_deployment_id_map.get( "smart_llm_model_deployment_id" # backwards compatibility ), ), - "text-embedding-ada-002": self.azure_model_to_deployment_id_map.get( + self.embedding_model: self.azure_model_to_deployment_id_map.get( "embedding_model_deployment_id" ), }.get(model, None) @@ -110,7 +130,7 @@ class Config(SystemSettings): "api_base": self.openai_api_base, "api_version": self.openai_api_version, } - if model == "text-embedding-ada-002": + if model == self.embedding_model: kwargs["engine"] = deployment_id else: kwargs["deployment_id"] = deployment_id @@ -272,12 +292,7 @@ class ConfigBuilder(Configurable[Config]): if config_dict["use_azure"]: azure_config = cls.load_azure_config(config_dict["azure_config_file"]) - config_dict["openai_api_type"] = azure_config["openai_api_type"] - config_dict["openai_api_base"] = azure_config["openai_api_base"] - config_dict["openai_api_version"] = azure_config["openai_api_version"] - config_dict["azure_model_to_deployment_id_map"] = azure_config[ - "azure_model_to_deployment_id_map" - ] + config_dict.update(azure_config) elif os.getenv("OPENAI_API_BASE_URL"): config_dict["openai_api_base"] = os.getenv("OPENAI_API_BASE_URL") diff --git a/autogpt/configurator.py b/autogpt/configurator.py index 9d22f092..2da5c58b 100644 --- a/autogpt/configurator.py +++ b/autogpt/configurator.py @@ -7,6 +7,7 @@ import click from colorama import Back, Fore, Style from autogpt import utils +from autogpt.config.config import GPT_3_MODEL, GPT_4_MODEL from autogpt.llm.utils import check_model from autogpt.logs import logger from autogpt.memory.vector import get_supported_memory_backends @@ -14,9 +15,6 @@ from autogpt.memory.vector import get_supported_memory_backends if TYPE_CHECKING: from autogpt.config import Config -GPT_4_MODEL = "gpt-4" -GPT_3_MODEL = "gpt-3.5-turbo" - def create_config( config: Config, diff --git a/azure.yaml.template b/azure.yaml.template index 6fe2af7a..685b7087 100644 --- a/azure.yaml.template +++ b/azure.yaml.template @@ -3,5 +3,5 @@ azure_api_base: your-base-url-for-azure azure_api_version: api-version-for-azure azure_model_map: fast_llm_deployment_id: gpt35-deployment-id-for-azure - smart_llm_deployment_id: gpt4-deployment-id-for-azure + smart_llm_deployment_id: gpt4-deployment-id-for-azure embedding_model_deployment_id: embedding-deployment-id-for-azure diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index d5c9d97d..b441aa94 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -146,18 +146,19 @@ def test_missing_azure_config(workspace: Workspace): assert azure_config["azure_model_to_deployment_id_map"] == {} -def test_azure_config(workspace: Workspace) -> None: - yaml_content = """ +def test_azure_config(config: Config, workspace: Workspace) -> None: + config_file = workspace.get_path("azure_config.yaml") + yaml_content = f""" azure_api_type: azure azure_api_base: https://dummy.openai.azure.com azure_api_version: 2023-06-01-preview azure_model_map: - fast_llm_deployment_id: gpt-3.5-turbo - smart_llm_deployment_id: gpt-4 + fast_llm_deployment_id: FAST-LLM_ID + smart_llm_deployment_id: SMART-LLM_ID embedding_model_deployment_id: embedding-deployment-id-for-azure """ - config_file = workspace.get_path("azure.yaml") config_file.write_text(yaml_content) + os.environ["USE_AZURE"] = "True" os.environ["AZURE_CONFIG_FILE"] = str(config_file) config = ConfigBuilder.build_config_from_env() @@ -166,53 +167,31 @@ azure_model_map: assert config.openai_api_base == "https://dummy.openai.azure.com" assert config.openai_api_version == "2023-06-01-preview" assert config.azure_model_to_deployment_id_map == { - "fast_llm_deployment_id": "gpt-3.5-turbo", - "smart_llm_deployment_id": "gpt-4", + "fast_llm_deployment_id": "FAST-LLM_ID", + "smart_llm_deployment_id": "SMART-LLM_ID", "embedding_model_deployment_id": "embedding-deployment-id-for-azure", } - del os.environ["USE_AZURE"] - del os.environ["AZURE_CONFIG_FILE"] + fast_llm = config.fast_llm + smart_llm = config.smart_llm + assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" + assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID" + # Emulate --gpt4only + config.fast_llm = smart_llm + assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "SMART-LLM_ID" + assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID" -def test_azure_deployment_id_for_model(workspace: Workspace) -> None: - yaml_content = """ -azure_api_type: azure -azure_api_base: https://dummy.openai.azure.com -azure_api_version: 2023-06-01-preview -azure_model_map: - fast_llm_deployment_id: gpt-3.5-turbo - smart_llm_deployment_id: gpt-4 - embedding_model_deployment_id: embedding-deployment-id-for-azure -""" - config_file = workspace.get_path("azure.yaml") - config_file.write_text(yaml_content) - os.environ["USE_AZURE"] = "True" - os.environ["AZURE_CONFIG_FILE"] = str(config_file) - config = ConfigBuilder.build_config_from_env() - - config.fast_llm = "fast_llm" - config.smart_llm = "smart_llm" - - def _get_deployment_id(model): - kwargs = config.get_azure_kwargs(model) - return kwargs.get("deployment_id", kwargs.get("engine")) - - assert _get_deployment_id(config.fast_llm) == "gpt-3.5-turbo" - assert _get_deployment_id(config.smart_llm) == "gpt-4" - assert ( - _get_deployment_id("text-embedding-ada-002") - == "embedding-deployment-id-for-azure" - ) - assert _get_deployment_id("dummy") is None + # Emulate --gpt3only + config.fast_llm = config.smart_llm = fast_llm + assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" + assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "FAST-LLM_ID" del os.environ["USE_AZURE"] del os.environ["AZURE_CONFIG_FILE"] def test_create_config_gpt4only(config: Config) -> None: - fast_llm = config.fast_llm - smart_llm = config.smart_llm with mock.patch("autogpt.llm.api_manager.ApiManager.get_models") as mock_get_models: mock_get_models.return_value = [{"id": GPT_4_MODEL}] create_config( @@ -234,14 +213,8 @@ def test_create_config_gpt4only(config: Config) -> None: assert config.fast_llm == GPT_4_MODEL assert config.smart_llm == GPT_4_MODEL - # Reset config - config.fast_llm = fast_llm - config.smart_llm = smart_llm - def test_create_config_gpt3only(config: Config) -> None: - fast_llm = config.fast_llm - smart_llm = config.smart_llm with mock.patch("autogpt.llm.api_manager.ApiManager.get_models") as mock_get_models: mock_get_models.return_value = [{"id": GPT_3_MODEL}] create_config( @@ -262,7 +235,3 @@ def test_create_config_gpt3only(config: Config) -> None: ) assert config.fast_llm == GPT_3_MODEL assert config.smart_llm == GPT_3_MODEL - - # Reset config - config.fast_llm = fast_llm - config.smart_llm = smart_llm