From 6fbac455d4e83a243d6276a4ce3becc0e2acbcbb Mon Sep 17 00:00:00 2001 From: James Collins Date: Tue, 25 Apr 2023 12:10:12 -0700 Subject: [PATCH] Remove import time loading of config from llm_utils (#3245) --- autogpt/llm_utils.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/autogpt/llm_utils.py b/autogpt/llm_utils.py index 7b565edd..2ef0422f 100644 --- a/autogpt/llm_utils.py +++ b/autogpt/llm_utils.py @@ -13,8 +13,6 @@ from autogpt.config import Config from autogpt.logs import logger from autogpt.types.openai import Message -CFG = Config() - def retry_openai_api( num_retries: int = 10, @@ -86,8 +84,9 @@ def call_ai_function( Returns: str: The response from the function """ + cfg = Config() if model is None: - model = CFG.smart_llm_model + model = cfg.smart_llm_model # For each arg, if any are None, convert to "None": args = [str(arg) if arg is not None else "None" for arg in args] # parse args to comma separated string @@ -109,7 +108,7 @@ def call_ai_function( def create_chat_completion( messages: List[Message], # type: ignore model: Optional[str] = None, - temperature: float = CFG.temperature, + temperature: float = None, max_tokens: Optional[int] = None, ) -> str: """Create a chat completion using the OpenAI API @@ -123,13 +122,17 @@ def create_chat_completion( Returns: str: The response from the chat completion """ + cfg = Config() + if temperature is None: + temperature = cfg.temperature + num_retries = 10 warned_user = False - if CFG.debug_mode: + if cfg.debug_mode: print( f"{Fore.GREEN}Creating chat completion with model {model}, temperature {temperature}, max_tokens {max_tokens}{Fore.RESET}" ) - for plugin in CFG.plugins: + for plugin in cfg.plugins: if plugin.can_handle_chat_completion( messages=messages, model=model, @@ -148,9 +151,9 @@ def create_chat_completion( for attempt in range(num_retries): backoff = 2 ** (attempt + 2) try: - if CFG.use_azure: + if cfg.use_azure: response = api_manager.create_chat_completion( - deployment_id=CFG.get_azure_deployment_id_for_model(model), + deployment_id=cfg.get_azure_deployment_id_for_model(model), model=model, messages=messages, temperature=temperature, @@ -165,7 +168,7 @@ def create_chat_completion( ) break except RateLimitError: - if CFG.debug_mode: + if cfg.debug_mode: print( f"{Fore.RED}Error: ", f"Reached rate limit, passing...{Fore.RESET}" ) @@ -180,7 +183,7 @@ def create_chat_completion( raise if attempt == num_retries - 1: raise - if CFG.debug_mode: + if cfg.debug_mode: print( f"{Fore.RED}Error: ", f"API Bad gateway. Waiting {backoff} seconds...{Fore.RESET}", @@ -194,12 +197,12 @@ def create_chat_completion( + f"Try running Auto-GPT again, and if the problem the persists try running it with `{Fore.CYAN}--debug{Fore.RESET}`.", ) logger.double_check() - if CFG.debug_mode: + if cfg.debug_mode: raise RuntimeError(f"Failed to get response after {num_retries} retries") else: quit(1) resp = response.choices[0].message["content"] - for plugin in CFG.plugins: + for plugin in cfg.plugins: if not plugin.can_handle_on_response(): continue resp = plugin.on_response(resp) @@ -215,11 +218,12 @@ def get_ada_embedding(text: str) -> List[int]: Returns: List[int]: The embedding. """ + cfg = Config() model = "text-embedding-ada-002" text = text.replace("\n", " ") - if CFG.use_azure: - kwargs = {"engine": CFG.get_azure_deployment_id_for_model(model)} + if cfg.use_azure: + kwargs = {"engine": cfg.get_azure_deployment_id_for_model(model)} else: kwargs = {"model": model} @@ -247,8 +251,9 @@ def create_embedding( Returns: openai.Embedding: The embedding object. """ + cfg = Config() return openai.Embedding.create( input=[text], - api_key=CFG.openai_api_key, + api_key=cfg.openai_api_key, **kwargs, )