Files
azure-openai-proxy/azure/init.go
2023-06-15 15:57:53 +08:00

107 lines
2.8 KiB
Go

package azure
import (
"fmt"
"github.com/spf13/viper"
"github.com/stulzq/azure-openai-proxy/constant"
"github.com/stulzq/azure-openai-proxy/util"
"log"
"net/url"
"strings"
)
const (
AuthHeaderKey = "api-key"
)
var (
C Config
ModelDeploymentConfig = map[string]DeploymentConfig{}
)
func Init() error {
var (
apiVersion string
endpoint string
openaiModelMapper string
err error
)
apiVersion = viper.GetString(constant.ENV_AZURE_OPENAI_API_VER)
endpoint = viper.GetString(constant.ENV_AZURE_OPENAI_ENDPOINT)
openaiModelMapper = viper.GetString(constant.ENV_AZURE_OPENAI_MODEL_MAPPER)
if endpoint != "" && openaiModelMapper != "" {
if apiVersion == "" {
apiVersion = "2023-03-15-preview"
}
InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper)
} else {
if err = InitFromConfigFile(); err != nil {
return err
}
}
// ensure apiBase likes /v1
apiBase := viper.GetString("api_base")
if !strings.HasPrefix(apiBase, "/") {
apiBase = "/" + apiBase
}
if strings.HasSuffix(apiBase, "/") {
apiBase = apiBase[:len(apiBase)-1]
}
viper.Set("api_base", apiBase)
log.Printf("apiBase is: %s", apiBase)
for _, itemConfig := range C.DeploymentConfig {
u, err := url.Parse(itemConfig.Endpoint)
if err != nil {
return fmt.Errorf("parse endpoint error: %w", err)
}
itemConfig.EndpointUrl = u
ModelDeploymentConfig[itemConfig.ModelName] = itemConfig
}
return err
}
func InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper string) {
log.Println("Init from environment variables")
if openaiModelMapper != "" {
// openaiModelMapper example:
// gpt-3.5-turbo=deployment_name_for_gpt_model,text-davinci-003=deployment_name_for_davinci_model
for _, pair := range strings.Split(openaiModelMapper, ",") {
info := strings.Split(pair, "=")
if len(info) != 2 {
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
}
modelName, deploymentName := info[0], info[1]
ModelDeploymentConfig[modelName] = DeploymentConfig{
DeploymentName: deploymentName,
ModelName: modelName,
Endpoint: endpoint,
ApiKey: "",
ApiVersion: apiVersion,
}
}
}
}
func InitFromConfigFile() error {
log.Println("Init from config file")
workDir := util.GetWorkdir()
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(fmt.Sprintf("%s/config", workDir))
if err := viper.ReadInConfig(); err != nil {
log.Printf("read config file error: %+v\n", err)
return err
}
if err := viper.Unmarshal(&C); err != nil {
log.Printf("unmarshal config file error: %+v\n", err)
return err
}
for _, configItem := range C.DeploymentConfig {
ModelDeploymentConfig[configItem.ModelName] = configItem
}
return nil
}