mirror of
https://github.com/stulzq/azure-openai-proxy.git
synced 2025-12-19 07:14:21 +01:00
🧱 unify azure config from env or yaml file
This commit is contained in:
106
azure/init.go
106
azure/init.go
@@ -1,11 +1,12 @@
|
|||||||
package azure
|
package azure
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/spf13/viper"
|
||||||
"github.com/stulzq/azure-openai-proxy/constant"
|
"github.com/stulzq/azure-openai-proxy/constant"
|
||||||
|
"github.com/stulzq/azure-openai-proxy/util"
|
||||||
"log"
|
"log"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,43 +15,92 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
AzureOpenAIEndpoint = ""
|
C Config
|
||||||
AzureOpenAIEndpointParse *url.URL
|
ModelDeploymentConfig = map[string]DeploymentConfig{}
|
||||||
|
|
||||||
AzureOpenAIAPIVer = ""
|
|
||||||
|
|
||||||
AzureOpenAIModelMapper = map[string]string{
|
|
||||||
"gpt-3.5-turbo": "gpt-35-turbo",
|
|
||||||
}
|
|
||||||
fallbackModelMapper = regexp.MustCompile(`[.:]`)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Init() {
|
func Init() error {
|
||||||
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
|
var (
|
||||||
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
|
apiVersion string
|
||||||
|
endpoint string
|
||||||
|
openaiModelMapper string
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
if AzureOpenAIAPIVer == "" {
|
apiVersion = viper.GetString(constant.ENV_AZURE_OPENAI_API_VER)
|
||||||
AzureOpenAIAPIVer = "2023-03-15-preview"
|
endpoint = viper.GetString(constant.ENV_AZURE_OPENAI_ENDPOINT)
|
||||||
|
openaiModelMapper = viper.GetString(constant.ENV_AZURE_OPENAI_MODEL_MAPPER)
|
||||||
|
if endpoint != "" && openaiModelMapper != "" {
|
||||||
|
if apiVersion == "" {
|
||||||
|
apiVersion = "2023-03-15-preview"
|
||||||
|
}
|
||||||
|
InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper)
|
||||||
|
} else {
|
||||||
|
if err = InitFromConfigFile(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
// ensure apiBase likes /v1
|
||||||
AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint)
|
apiBase := viper.GetString("api_base")
|
||||||
if err != nil {
|
if !strings.HasPrefix(apiBase, "/") {
|
||||||
log.Fatal("parse AzureOpenAIEndpoint error: ", err)
|
apiBase = "/" + apiBase
|
||||||
}
|
}
|
||||||
|
if strings.HasSuffix(apiBase, "/") {
|
||||||
|
apiBase = apiBase[:len(apiBase)-1]
|
||||||
|
}
|
||||||
|
viper.Set("api_base", apiBase)
|
||||||
|
log.Printf("apiBase is: %s", apiBase)
|
||||||
|
for _, itemConfig := range C.DeploymentConfig {
|
||||||
|
u, err := url.Parse(itemConfig.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse endpoint error: %w", err)
|
||||||
|
}
|
||||||
|
itemConfig.EndpointUrl = u
|
||||||
|
ModelDeploymentConfig[itemConfig.ModelName] = itemConfig
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" {
|
func InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper string) {
|
||||||
for _, pair := range strings.Split(v, ",") {
|
log.Println("Init from environment variables")
|
||||||
|
if openaiModelMapper != "" {
|
||||||
|
// openaiModelMapper example:
|
||||||
|
// gpt-3.5-turbo=deployment_name_for_gpt_model,text-davinci-003=deployment_name_for_davinci_model
|
||||||
|
for _, pair := range strings.Split(openaiModelMapper, ",") {
|
||||||
info := strings.Split(pair, "=")
|
info := strings.Split(pair, "=")
|
||||||
if len(info) != 2 {
|
if len(info) != 2 {
|
||||||
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
|
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
|
||||||
}
|
}
|
||||||
|
modelName, deploymentName := info[0], info[1]
|
||||||
AzureOpenAIModelMapper[info[0]] = info[1]
|
ModelDeploymentConfig[modelName] = DeploymentConfig{
|
||||||
|
DeploymentName: deploymentName,
|
||||||
|
ModelName: modelName,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
ApiKey: "",
|
||||||
|
ApiVersion: apiVersion,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
|
|
||||||
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
|
func InitFromConfigFile() error {
|
||||||
log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper)
|
log.Println("Init from config file")
|
||||||
|
workDir := util.GetWorkdir()
|
||||||
|
viper.SetConfigName("config")
|
||||||
|
viper.SetConfigType("yaml")
|
||||||
|
viper.AddConfigPath(fmt.Sprintf("%s/config", workDir))
|
||||||
|
if err := viper.ReadInConfig(); err != nil {
|
||||||
|
log.Printf("read config file error: %+v\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := viper.Unmarshal(&C); err != nil {
|
||||||
|
log.Printf("unmarshal config file error: %+v\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, configItem := range C.DeploymentConfig {
|
||||||
|
ModelDeploymentConfig[configItem.ModelName] = configItem
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user