Bugfix/broken azure config (#4912)

This commit is contained in:
James Collins
2023-07-07 19:42:26 -07:00
committed by GitHub
parent 9a2a9f7439
commit 57315bddfb
3 changed files with 21 additions and 8 deletions

View File

@@ -88,15 +88,18 @@ def create_config(
# --gpt3only should always use gpt-3.5-turbo, despite user's FAST_LLM config
config.fast_llm = GPT_3_MODEL
config.smart_llm = GPT_3_MODEL
elif gpt4only and check_model(GPT_4_MODEL, model_type="smart_llm") == GPT_4_MODEL:
elif (
gpt4only
and check_model(GPT_4_MODEL, model_type="smart_llm", config=config)
== GPT_4_MODEL
):
logger.typewriter_log("GPT4 Only Mode: ", Fore.GREEN, "ENABLED")
# --gpt4only should always use gpt-4, despite user's SMART_LLM config
config.fast_llm = GPT_4_MODEL
config.smart_llm = GPT_4_MODEL
else:
config.fast_llm = check_model(config.fast_llm, "fast_llm")
config.smart_llm = check_model(config.smart_llm, "smart_llm")
config.fast_llm = check_model(config.fast_llm, "fast_llm", config=config)
config.smart_llm = check_model(config.smart_llm, "smart_llm", config=config)
if memory_type:
supported_memory = get_supported_memory_backends()

View File

@@ -95,7 +95,7 @@ class ApiManager(metaclass=Singleton):
"""
return self.total_budget
def get_models(self) -> List[Model]:
def get_models(self, **openai_credentials) -> List[Model]:
"""
Get list of available GPT models.
@@ -104,7 +104,7 @@ class ApiManager(metaclass=Singleton):
"""
if self.models is None:
all_models = openai.Model.list()["data"]
all_models = openai.Model.list(**openai_credentials)["data"]
self.models = [model for model in all_models if "gpt" in model["id"]]
return self.models

View File

@@ -173,10 +173,20 @@ def create_chat_completion(
)
def check_model(model_name: str, model_type: Literal["smart_llm", "fast_llm"]) -> str:
def check_model(
model_name: str,
model_type: Literal["smart_llm", "fast_llm"],
config: Config,
) -> str:
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
openai_credentials = {
"api_key": config.openai_api_key,
}
if config.use_azure:
openai_credentials.update(config.get_azure_kwargs(model_name))
api_manager = ApiManager()
models = api_manager.get_models()
models = api_manager.get_models(**openai_credentials)
if any(model_name in m["id"] for m in models):
return model_name