🧱 unify azure config from env or yaml file

This commit is contained in:
warjiang
2023-06-15 12:17:53 +08:00
parent 8f6b8efc11
commit 932eefebfd

View File

@@ -1,11 +1,12 @@
package azure
import (
"fmt"
"github.com/spf13/viper"
"github.com/stulzq/azure-openai-proxy/constant"
"github.com/stulzq/azure-openai-proxy/util"
"log"
"net/url"
"os"
"regexp"
"strings"
)
@@ -14,43 +15,92 @@ const (
)
var (
AzureOpenAIEndpoint = ""
AzureOpenAIEndpointParse *url.URL
AzureOpenAIAPIVer = ""
AzureOpenAIModelMapper = map[string]string{
"gpt-3.5-turbo": "gpt-35-turbo",
}
fallbackModelMapper = regexp.MustCompile(`[.:]`)
C Config
ModelDeploymentConfig = map[string]DeploymentConfig{}
)
func Init() {
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
func Init() error {
var (
apiVersion string
endpoint string
openaiModelMapper string
err error
)
if AzureOpenAIAPIVer == "" {
AzureOpenAIAPIVer = "2023-03-15-preview"
apiVersion = viper.GetString(constant.ENV_AZURE_OPENAI_API_VER)
endpoint = viper.GetString(constant.ENV_AZURE_OPENAI_ENDPOINT)
openaiModelMapper = viper.GetString(constant.ENV_AZURE_OPENAI_MODEL_MAPPER)
if endpoint != "" && openaiModelMapper != "" {
if apiVersion == "" {
apiVersion = "2023-03-15-preview"
}
InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper)
} else {
if err = InitFromConfigFile(); err != nil {
return err
}
}
var err error
AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint)
// ensure apiBase likes /v1
apiBase := viper.GetString("api_base")
if !strings.HasPrefix(apiBase, "/") {
apiBase = "/" + apiBase
}
if strings.HasSuffix(apiBase, "/") {
apiBase = apiBase[:len(apiBase)-1]
}
viper.Set("api_base", apiBase)
log.Printf("apiBase is: %s", apiBase)
for _, itemConfig := range C.DeploymentConfig {
u, err := url.Parse(itemConfig.Endpoint)
if err != nil {
log.Fatal("parse AzureOpenAIEndpoint error: ", err)
return fmt.Errorf("parse endpoint error: %w", err)
}
itemConfig.EndpointUrl = u
ModelDeploymentConfig[itemConfig.ModelName] = itemConfig
}
return err
}
if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" {
for _, pair := range strings.Split(v, ",") {
func InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper string) {
log.Println("Init from environment variables")
if openaiModelMapper != "" {
// openaiModelMapper example:
// gpt-3.5-turbo=deployment_name_for_gpt_model,text-davinci-003=deployment_name_for_davinci_model
for _, pair := range strings.Split(openaiModelMapper, ",") {
info := strings.Split(pair, "=")
if len(info) != 2 {
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
}
AzureOpenAIModelMapper[info[0]] = info[1]
modelName, deploymentName := info[0], info[1]
ModelDeploymentConfig[modelName] = DeploymentConfig{
DeploymentName: deploymentName,
ModelName: modelName,
Endpoint: endpoint,
ApiKey: "",
ApiVersion: apiVersion,
}
}
}
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper)
}
func InitFromConfigFile() error {
log.Println("Init from config file")
workDir := util.GetWorkdir()
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(fmt.Sprintf("%s/config", workDir))
if err := viper.ReadInConfig(); err != nil {
log.Printf("read config file error: %+v\n", err)
return err
}
if err := viper.Unmarshal(&C); err != nil {
log.Printf("unmarshal config file error: %+v\n", err)
return err
}
for _, configItem := range C.DeploymentConfig {
ModelDeploymentConfig[configItem.ModelName] = configItem
}
return nil
}