Fix regression: restore api_base and organization configurability (#4933)

This commit is contained in:
James Collins
2023-07-09 19:32:04 -07:00
committed by GitHub
parent 43a62fdc7c
commit 9adcad8b8a
4 changed files with 38 additions and 26 deletions

View File

@@ -86,7 +86,18 @@ class Config(SystemSettings):
plugins: list[str] plugins: list[str]
authorise_key: str authorise_key: str
def get_azure_kwargs(self, model: str) -> dict[str, str]: def get_openai_credentials(self, model: str) -> dict[str, str]:
credentials = {
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
"organization": self.openai_organization,
}
if self.use_azure:
azure_credentials = self.get_azure_credentials(model)
credentials.update(azure_credentials)
return credentials
def get_azure_credentials(self, model: str) -> dict[str, str]:
"""Get the kwargs for the Azure API.""" """Get the kwargs for the Azure API."""
# Fix --gpt3only and --gpt4only in combination with Azure # Fix --gpt3only and --gpt4only in combination with Azure

View File

@@ -71,17 +71,14 @@ def create_text_completion(
if temperature is None: if temperature is None:
temperature = config.temperature temperature = config.temperature
if config.use_azure: kwargs = {"model": model}
kwargs = config.get_azure_kwargs(model) kwargs.update(config.get_openai_credentials(model))
else:
kwargs = {"model": model}
response = iopenai.create_text_completion( response = iopenai.create_text_completion(
prompt=prompt, prompt=prompt,
**kwargs, **kwargs,
temperature=temperature, temperature=temperature,
max_tokens=max_output_tokens, max_tokens=max_output_tokens,
api_key=config.openai_api_key,
) )
logger.debug(f"Response: {response}") logger.debug(f"Response: {response}")
@@ -137,9 +134,7 @@ def create_chat_completion(
if message is not None: if message is not None:
return message return message
chat_completion_kwargs["api_key"] = config.openai_api_key chat_completion_kwargs.update(config.get_openai_credentials(model))
if config.use_azure:
chat_completion_kwargs.update(config.get_azure_kwargs(model))
if functions: if functions:
chat_completion_kwargs["functions"] = [ chat_completion_kwargs["functions"] = [
@@ -179,12 +174,7 @@ def check_model(
config: Config, config: Config,
) -> str: ) -> str:
"""Check if model is available for use. If not, return gpt-3.5-turbo.""" """Check if model is available for use. If not, return gpt-3.5-turbo."""
openai_credentials = { openai_credentials = config.get_openai_credentials(model_name)
"api_key": config.openai_api_key,
}
if config.use_azure:
openai_credentials.update(config.get_azure_kwargs(model_name))
api_manager = ApiManager() api_manager = ApiManager()
models = api_manager.get_models(**openai_credentials) models = api_manager.get_models(**openai_credentials)

View File

@@ -41,10 +41,8 @@ def get_embedding(
input = [text.replace("\n", " ") for text in input] input = [text.replace("\n", " ") for text in input]
model = config.embedding_model model = config.embedding_model
if config.use_azure: kwargs = {"model": model}
kwargs = config.get_azure_kwargs(model) kwargs.update(config.get_openai_credentials(model))
else:
kwargs = {"model": model}
logger.debug( logger.debug(
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}" f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
@@ -57,7 +55,6 @@ def get_embedding(
embeddings = iopenai.create_embedding( embeddings = iopenai.create_embedding(
input, input,
**kwargs, **kwargs,
api_key=config.openai_api_key,
).data ).data
if not multiple: if not multiple:

View File

@@ -174,18 +174,32 @@ azure_model_map:
fast_llm = config.fast_llm fast_llm = config.fast_llm
smart_llm = config.smart_llm smart_llm = config.smart_llm
assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" assert (
assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID" config.get_azure_credentials(config.fast_llm)["deployment_id"] == "FAST-LLM_ID"
)
assert (
config.get_azure_credentials(config.smart_llm)["deployment_id"]
== "SMART-LLM_ID"
)
# Emulate --gpt4only # Emulate --gpt4only
config.fast_llm = smart_llm config.fast_llm = smart_llm
assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "SMART-LLM_ID" assert (
assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID" config.get_azure_credentials(config.fast_llm)["deployment_id"] == "SMART-LLM_ID"
)
assert (
config.get_azure_credentials(config.smart_llm)["deployment_id"]
== "SMART-LLM_ID"
)
# Emulate --gpt3only # Emulate --gpt3only
config.fast_llm = config.smart_llm = fast_llm config.fast_llm = config.smart_llm = fast_llm
assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" assert (
assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "FAST-LLM_ID" config.get_azure_credentials(config.fast_llm)["deployment_id"] == "FAST-LLM_ID"
)
assert (
config.get_azure_credentials(config.smart_llm)["deployment_id"] == "FAST-LLM_ID"
)
del os.environ["USE_AZURE"] del os.environ["USE_AZURE"]
del os.environ["AZURE_CONFIG_FILE"] del os.environ["AZURE_CONFIG_FILE"]