mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-18 22:44:21 +01:00
Fix regression: restore api_base and organization configurability (#4933)
This commit is contained in:
@@ -86,7 +86,18 @@ class Config(SystemSettings):
|
|||||||
plugins: list[str]
|
plugins: list[str]
|
||||||
authorise_key: str
|
authorise_key: str
|
||||||
|
|
||||||
def get_azure_kwargs(self, model: str) -> dict[str, str]:
|
def get_openai_credentials(self, model: str) -> dict[str, str]:
|
||||||
|
credentials = {
|
||||||
|
"api_key": self.openai_api_key,
|
||||||
|
"api_base": self.openai_api_base,
|
||||||
|
"organization": self.openai_organization,
|
||||||
|
}
|
||||||
|
if self.use_azure:
|
||||||
|
azure_credentials = self.get_azure_credentials(model)
|
||||||
|
credentials.update(azure_credentials)
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def get_azure_credentials(self, model: str) -> dict[str, str]:
|
||||||
"""Get the kwargs for the Azure API."""
|
"""Get the kwargs for the Azure API."""
|
||||||
|
|
||||||
# Fix --gpt3only and --gpt4only in combination with Azure
|
# Fix --gpt3only and --gpt4only in combination with Azure
|
||||||
|
|||||||
@@ -71,17 +71,14 @@ def create_text_completion(
|
|||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = config.temperature
|
temperature = config.temperature
|
||||||
|
|
||||||
if config.use_azure:
|
kwargs = {"model": model}
|
||||||
kwargs = config.get_azure_kwargs(model)
|
kwargs.update(config.get_openai_credentials(model))
|
||||||
else:
|
|
||||||
kwargs = {"model": model}
|
|
||||||
|
|
||||||
response = iopenai.create_text_completion(
|
response = iopenai.create_text_completion(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_output_tokens,
|
max_tokens=max_output_tokens,
|
||||||
api_key=config.openai_api_key,
|
|
||||||
)
|
)
|
||||||
logger.debug(f"Response: {response}")
|
logger.debug(f"Response: {response}")
|
||||||
|
|
||||||
@@ -137,9 +134,7 @@ def create_chat_completion(
|
|||||||
if message is not None:
|
if message is not None:
|
||||||
return message
|
return message
|
||||||
|
|
||||||
chat_completion_kwargs["api_key"] = config.openai_api_key
|
chat_completion_kwargs.update(config.get_openai_credentials(model))
|
||||||
if config.use_azure:
|
|
||||||
chat_completion_kwargs.update(config.get_azure_kwargs(model))
|
|
||||||
|
|
||||||
if functions:
|
if functions:
|
||||||
chat_completion_kwargs["functions"] = [
|
chat_completion_kwargs["functions"] = [
|
||||||
@@ -179,12 +174,7 @@ def check_model(
|
|||||||
config: Config,
|
config: Config,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
|
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
|
||||||
openai_credentials = {
|
openai_credentials = config.get_openai_credentials(model_name)
|
||||||
"api_key": config.openai_api_key,
|
|
||||||
}
|
|
||||||
if config.use_azure:
|
|
||||||
openai_credentials.update(config.get_azure_kwargs(model_name))
|
|
||||||
|
|
||||||
api_manager = ApiManager()
|
api_manager = ApiManager()
|
||||||
models = api_manager.get_models(**openai_credentials)
|
models = api_manager.get_models(**openai_credentials)
|
||||||
|
|
||||||
|
|||||||
@@ -41,10 +41,8 @@ def get_embedding(
|
|||||||
input = [text.replace("\n", " ") for text in input]
|
input = [text.replace("\n", " ") for text in input]
|
||||||
|
|
||||||
model = config.embedding_model
|
model = config.embedding_model
|
||||||
if config.use_azure:
|
kwargs = {"model": model}
|
||||||
kwargs = config.get_azure_kwargs(model)
|
kwargs.update(config.get_openai_credentials(model))
|
||||||
else:
|
|
||||||
kwargs = {"model": model}
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
|
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
|
||||||
@@ -57,7 +55,6 @@ def get_embedding(
|
|||||||
embeddings = iopenai.create_embedding(
|
embeddings = iopenai.create_embedding(
|
||||||
input,
|
input,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
api_key=config.openai_api_key,
|
|
||||||
).data
|
).data
|
||||||
|
|
||||||
if not multiple:
|
if not multiple:
|
||||||
|
|||||||
@@ -174,18 +174,32 @@ azure_model_map:
|
|||||||
|
|
||||||
fast_llm = config.fast_llm
|
fast_llm = config.fast_llm
|
||||||
smart_llm = config.smart_llm
|
smart_llm = config.smart_llm
|
||||||
assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID"
|
assert (
|
||||||
assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID"
|
config.get_azure_credentials(config.fast_llm)["deployment_id"] == "FAST-LLM_ID"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
config.get_azure_credentials(config.smart_llm)["deployment_id"]
|
||||||
|
== "SMART-LLM_ID"
|
||||||
|
)
|
||||||
|
|
||||||
# Emulate --gpt4only
|
# Emulate --gpt4only
|
||||||
config.fast_llm = smart_llm
|
config.fast_llm = smart_llm
|
||||||
assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "SMART-LLM_ID"
|
assert (
|
||||||
assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID"
|
config.get_azure_credentials(config.fast_llm)["deployment_id"] == "SMART-LLM_ID"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
config.get_azure_credentials(config.smart_llm)["deployment_id"]
|
||||||
|
== "SMART-LLM_ID"
|
||||||
|
)
|
||||||
|
|
||||||
# Emulate --gpt3only
|
# Emulate --gpt3only
|
||||||
config.fast_llm = config.smart_llm = fast_llm
|
config.fast_llm = config.smart_llm = fast_llm
|
||||||
assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID"
|
assert (
|
||||||
assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "FAST-LLM_ID"
|
config.get_azure_credentials(config.fast_llm)["deployment_id"] == "FAST-LLM_ID"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
config.get_azure_credentials(config.smart_llm)["deployment_id"] == "FAST-LLM_ID"
|
||||||
|
)
|
||||||
|
|
||||||
del os.environ["USE_AZURE"]
|
del os.environ["USE_AZURE"]
|
||||||
del os.environ["AZURE_CONFIG_FILE"]
|
del os.environ["AZURE_CONFIG_FILE"]
|
||||||
|
|||||||
Reference in New Issue
Block a user