diff --git a/scripts/llm_utils.py b/scripts/llm_utils.py index 35cc5ce0..731acae2 100644 --- a/scripts/llm_utils.py +++ b/scripts/llm_utils.py @@ -1,26 +1,52 @@ +import time import openai +from colorama import Fore from config import Config + cfg = Config() openai.api_key = cfg.openai_api_key + # Overly simple abstraction until we create something better +# simple retry mechanism when getting a rate error or a bad gateway def create_chat_completion(messages, model=None, temperature=cfg.temperature, max_tokens=None)->str: """Create a chat completion using the OpenAI API""" - if cfg.use_azure: - response = openai.ChatCompletion.create( - deployment_id=cfg.get_azure_deployment_id_for_model(model), - model=model, - messages=messages, - temperature=temperature, - max_tokens=max_tokens - ) - else: - response = openai.ChatCompletion.create( - model=model, - messages=messages, - temperature=temperature, - max_tokens=max_tokens - ) + response = None + num_retries = 5 + for attempt in range(num_retries): + try: + if cfg.use_azure: + response = openai.ChatCompletion.create( + deployment_id=cfg.get_azure_deployment_id_for_model(model), + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens + ) + else: + response = openai.ChatCompletion.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens + ) + break + except openai.error.RateLimitError: + if cfg.debug_mode: + print(Fore.RED + "Error: ", "API Rate Limit Reached. Waiting 20 seconds..." + Fore.RESET) + time.sleep(20) + except openai.error.APIError as e: + if e.http_status == 502: + if cfg.debug_mode: + print(Fore.RED + "Error: ", "API Bad gateway. Waiting 20 seconds..." + Fore.RESET) + time.sleep(20) + else: + raise + if attempt == num_retries - 1: + raise + + if response is None: + raise RuntimeError("Failed to get response after 5 retries") return response.choices[0].message["content"]