mirror of
https://github.com/stulzq/azure-openai-proxy.git
synced 2025-12-19 15:24:24 +01:00
🧱 unify azure config from env or yaml file
This commit is contained in:
102
azure/init.go
102
azure/init.go
@@ -1,11 +1,12 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stulzq/azure-openai-proxy/constant"
|
||||
"github.com/stulzq/azure-openai-proxy/util"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -14,43 +15,92 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
AzureOpenAIEndpoint = ""
|
||||
AzureOpenAIEndpointParse *url.URL
|
||||
|
||||
AzureOpenAIAPIVer = ""
|
||||
|
||||
AzureOpenAIModelMapper = map[string]string{
|
||||
"gpt-3.5-turbo": "gpt-35-turbo",
|
||||
}
|
||||
fallbackModelMapper = regexp.MustCompile(`[.:]`)
|
||||
C Config
|
||||
ModelDeploymentConfig = map[string]DeploymentConfig{}
|
||||
)
|
||||
|
||||
func Init() {
|
||||
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
|
||||
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
|
||||
func Init() error {
|
||||
var (
|
||||
apiVersion string
|
||||
endpoint string
|
||||
openaiModelMapper string
|
||||
err error
|
||||
)
|
||||
|
||||
if AzureOpenAIAPIVer == "" {
|
||||
AzureOpenAIAPIVer = "2023-03-15-preview"
|
||||
apiVersion = viper.GetString(constant.ENV_AZURE_OPENAI_API_VER)
|
||||
endpoint = viper.GetString(constant.ENV_AZURE_OPENAI_ENDPOINT)
|
||||
openaiModelMapper = viper.GetString(constant.ENV_AZURE_OPENAI_MODEL_MAPPER)
|
||||
if endpoint != "" && openaiModelMapper != "" {
|
||||
if apiVersion == "" {
|
||||
apiVersion = "2023-03-15-preview"
|
||||
}
|
||||
InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper)
|
||||
} else {
|
||||
if err = InitFromConfigFile(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint)
|
||||
// ensure apiBase likes /v1
|
||||
apiBase := viper.GetString("api_base")
|
||||
if !strings.HasPrefix(apiBase, "/") {
|
||||
apiBase = "/" + apiBase
|
||||
}
|
||||
if strings.HasSuffix(apiBase, "/") {
|
||||
apiBase = apiBase[:len(apiBase)-1]
|
||||
}
|
||||
viper.Set("api_base", apiBase)
|
||||
log.Printf("apiBase is: %s", apiBase)
|
||||
for _, itemConfig := range C.DeploymentConfig {
|
||||
u, err := url.Parse(itemConfig.Endpoint)
|
||||
if err != nil {
|
||||
log.Fatal("parse AzureOpenAIEndpoint error: ", err)
|
||||
return fmt.Errorf("parse endpoint error: %w", err)
|
||||
}
|
||||
itemConfig.EndpointUrl = u
|
||||
ModelDeploymentConfig[itemConfig.ModelName] = itemConfig
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" {
|
||||
for _, pair := range strings.Split(v, ",") {
|
||||
func InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper string) {
|
||||
log.Println("Init from environment variables")
|
||||
if openaiModelMapper != "" {
|
||||
// openaiModelMapper example:
|
||||
// gpt-3.5-turbo=deployment_name_for_gpt_model,text-davinci-003=deployment_name_for_davinci_model
|
||||
for _, pair := range strings.Split(openaiModelMapper, ",") {
|
||||
info := strings.Split(pair, "=")
|
||||
if len(info) != 2 {
|
||||
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
|
||||
}
|
||||
|
||||
AzureOpenAIModelMapper[info[0]] = info[1]
|
||||
modelName, deploymentName := info[0], info[1]
|
||||
ModelDeploymentConfig[modelName] = DeploymentConfig{
|
||||
DeploymentName: deploymentName,
|
||||
ModelName: modelName,
|
||||
Endpoint: endpoint,
|
||||
ApiKey: "",
|
||||
ApiVersion: apiVersion,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
|
||||
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
|
||||
log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper)
|
||||
func InitFromConfigFile() error {
|
||||
log.Println("Init from config file")
|
||||
workDir := util.GetWorkdir()
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath(fmt.Sprintf("%s/config", workDir))
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
log.Printf("read config file error: %+v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := viper.Unmarshal(&C); err != nil {
|
||||
log.Printf("unmarshal config file error: %+v\n", err)
|
||||
return err
|
||||
}
|
||||
for _, configItem := range C.DeploymentConfig {
|
||||
ModelDeploymentConfig[configItem.ModelName] = configItem
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user