diff --git a/autogpt/config/config.py b/autogpt/config/config.py index 05590eb6..b1ff0a0a 100644 --- a/autogpt/config/config.py +++ b/autogpt/config/config.py @@ -86,7 +86,18 @@ class Config(SystemSettings): plugins: list[str] authorise_key: str - def get_azure_kwargs(self, model: str) -> dict[str, str]: + def get_openai_credentials(self, model: str) -> dict[str, str]: + credentials = { + "api_key": self.openai_api_key, + "api_base": self.openai_api_base, + "organization": self.openai_organization, + } + if self.use_azure: + azure_credentials = self.get_azure_credentials(model) + credentials.update(azure_credentials) + return credentials + + def get_azure_credentials(self, model: str) -> dict[str, str]: """Get the kwargs for the Azure API.""" # Fix --gpt3only and --gpt4only in combination with Azure diff --git a/autogpt/llm/utils/__init__.py b/autogpt/llm/utils/__init__.py index 3c2835b7..e0ff1473 100644 --- a/autogpt/llm/utils/__init__.py +++ b/autogpt/llm/utils/__init__.py @@ -71,17 +71,14 @@ def create_text_completion( if temperature is None: temperature = config.temperature - if config.use_azure: - kwargs = config.get_azure_kwargs(model) - else: - kwargs = {"model": model} + kwargs = {"model": model} + kwargs.update(config.get_openai_credentials(model)) response = iopenai.create_text_completion( prompt=prompt, **kwargs, temperature=temperature, max_tokens=max_output_tokens, - api_key=config.openai_api_key, ) logger.debug(f"Response: {response}") @@ -137,9 +134,7 @@ def create_chat_completion( if message is not None: return message - chat_completion_kwargs["api_key"] = config.openai_api_key - if config.use_azure: - chat_completion_kwargs.update(config.get_azure_kwargs(model)) + chat_completion_kwargs.update(config.get_openai_credentials(model)) if functions: chat_completion_kwargs["functions"] = [ @@ -179,12 +174,7 @@ def check_model( config: Config, ) -> str: """Check if model is available for use. If not, return gpt-3.5-turbo.""" - openai_credentials = { - "api_key": config.openai_api_key, - } - if config.use_azure: - openai_credentials.update(config.get_azure_kwargs(model_name)) - + openai_credentials = config.get_openai_credentials(model_name) api_manager = ApiManager() models = api_manager.get_models(**openai_credentials) diff --git a/autogpt/memory/vector/utils.py b/autogpt/memory/vector/utils.py index 74438f28..eb691256 100644 --- a/autogpt/memory/vector/utils.py +++ b/autogpt/memory/vector/utils.py @@ -41,10 +41,8 @@ def get_embedding( input = [text.replace("\n", " ") for text in input] model = config.embedding_model - if config.use_azure: - kwargs = config.get_azure_kwargs(model) - else: - kwargs = {"model": model} + kwargs = {"model": model} + kwargs.update(config.get_openai_credentials(model)) logger.debug( f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}" @@ -57,7 +55,6 @@ def get_embedding( embeddings = iopenai.create_embedding( input, **kwargs, - api_key=config.openai_api_key, ).data if not multiple: diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index b441aa94..7abbfcd5 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -174,18 +174,32 @@ azure_model_map: fast_llm = config.fast_llm smart_llm = config.smart_llm - assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" - assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID" + assert ( + config.get_azure_credentials(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" + ) + assert ( + config.get_azure_credentials(config.smart_llm)["deployment_id"] + == "SMART-LLM_ID" + ) # Emulate --gpt4only config.fast_llm = smart_llm - assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "SMART-LLM_ID" - assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID" + assert ( + config.get_azure_credentials(config.fast_llm)["deployment_id"] == "SMART-LLM_ID" + ) + assert ( + config.get_azure_credentials(config.smart_llm)["deployment_id"] + == "SMART-LLM_ID" + ) # Emulate --gpt3only config.fast_llm = config.smart_llm = fast_llm - assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" - assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "FAST-LLM_ID" + assert ( + config.get_azure_credentials(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" + ) + assert ( + config.get_azure_credentials(config.smart_llm)["deployment_id"] == "FAST-LLM_ID" + ) del os.environ["USE_AZURE"] del os.environ["AZURE_CONFIG_FILE"]