mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-08 00:34:20 +01:00
Bugfix/broken azure config (#4912)
This commit is contained in:
@@ -88,15 +88,18 @@ def create_config(
|
||||
# --gpt3only should always use gpt-3.5-turbo, despite user's FAST_LLM config
|
||||
config.fast_llm = GPT_3_MODEL
|
||||
config.smart_llm = GPT_3_MODEL
|
||||
|
||||
elif gpt4only and check_model(GPT_4_MODEL, model_type="smart_llm") == GPT_4_MODEL:
|
||||
elif (
|
||||
gpt4only
|
||||
and check_model(GPT_4_MODEL, model_type="smart_llm", config=config)
|
||||
== GPT_4_MODEL
|
||||
):
|
||||
logger.typewriter_log("GPT4 Only Mode: ", Fore.GREEN, "ENABLED")
|
||||
# --gpt4only should always use gpt-4, despite user's SMART_LLM config
|
||||
config.fast_llm = GPT_4_MODEL
|
||||
config.smart_llm = GPT_4_MODEL
|
||||
else:
|
||||
config.fast_llm = check_model(config.fast_llm, "fast_llm")
|
||||
config.smart_llm = check_model(config.smart_llm, "smart_llm")
|
||||
config.fast_llm = check_model(config.fast_llm, "fast_llm", config=config)
|
||||
config.smart_llm = check_model(config.smart_llm, "smart_llm", config=config)
|
||||
|
||||
if memory_type:
|
||||
supported_memory = get_supported_memory_backends()
|
||||
|
||||
@@ -95,7 +95,7 @@ class ApiManager(metaclass=Singleton):
|
||||
"""
|
||||
return self.total_budget
|
||||
|
||||
def get_models(self) -> List[Model]:
|
||||
def get_models(self, **openai_credentials) -> List[Model]:
|
||||
"""
|
||||
Get list of available GPT models.
|
||||
|
||||
@@ -104,7 +104,7 @@ class ApiManager(metaclass=Singleton):
|
||||
|
||||
"""
|
||||
if self.models is None:
|
||||
all_models = openai.Model.list()["data"]
|
||||
all_models = openai.Model.list(**openai_credentials)["data"]
|
||||
self.models = [model for model in all_models if "gpt" in model["id"]]
|
||||
|
||||
return self.models
|
||||
|
||||
@@ -173,10 +173,20 @@ def create_chat_completion(
|
||||
)
|
||||
|
||||
|
||||
def check_model(model_name: str, model_type: Literal["smart_llm", "fast_llm"]) -> str:
|
||||
def check_model(
|
||||
model_name: str,
|
||||
model_type: Literal["smart_llm", "fast_llm"],
|
||||
config: Config,
|
||||
) -> str:
|
||||
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
|
||||
openai_credentials = {
|
||||
"api_key": config.openai_api_key,
|
||||
}
|
||||
if config.use_azure:
|
||||
openai_credentials.update(config.get_azure_kwargs(model_name))
|
||||
|
||||
api_manager = ApiManager()
|
||||
models = api_manager.get_models()
|
||||
models = api_manager.get_models(**openai_credentials)
|
||||
|
||||
if any(model_name in m["id"] for m in models):
|
||||
return model_name
|
||||
|
||||
Reference in New Issue
Block a user