From 9ba3a1e89ea5141263ecb58320fc67ef77ee5d60 Mon Sep 17 00:00:00 2001 From: warjiang <1096409085@qq.com> Date: Thu, 15 Jun 2023 12:19:13 +0800 Subject: [PATCH] :bricks: update router --- azure/proxy.go | 71 ++++++++++++++++++++++++++++++-------------------- cmd/router.go | 14 +++++++--- 2 files changed, 54 insertions(+), 31 deletions(-) diff --git a/azure/proxy.go b/azure/proxy.go index 63b85c9..5d73f0f 100644 --- a/azure/proxy.go +++ b/azure/proxy.go @@ -8,7 +8,6 @@ import ( "log" "net/http" "net/http/httputil" - "path" "strings" "github.com/bytedance/sonic" @@ -16,8 +15,14 @@ import ( "github.com/pkg/errors" ) +func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc { + return func(c *gin.Context) { + Proxy(c, requestConverter) + } +} + // Proxy Azure OpenAI -func Proxy(c *gin.Context) { +func Proxy(c *gin.Context, requestConverter RequestConverter) { if c.Request.Method == http.MethodOptions { c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "GET, OPTIONS, POST") @@ -34,38 +39,48 @@ func Proxy(c *gin.Context) { body, _ := io.ReadAll(req.Body) req.Body = io.NopCloser(bytes.NewBuffer(body)) - // get model from body - model, err := sonic.Get(body, "model") - if err != nil { - util.SendError(c, errors.Wrap(err, "get model error")) - return + // get model from url params or body + model := c.Param("model") + if model == "" { + _model, err := sonic.Get(body, "model") + if err != nil { + util.SendError(c, errors.Wrap(err, "get model error")) + return + } + _modelStr, err := _model.String() + if err != nil { + util.SendError(c, errors.Wrap(err, "get model name error")) + return + } + model = _modelStr } // get deployment from request - deployment, err := model.String() + deployment, err := GetDeploymentByModel(model) if err != nil { - util.SendError(c, errors.Wrap(err, "get deployment error")) + util.SendError(c, err) return } - deployment = GetDeploymentByModel(deployment) - // get auth token from header - rawToken := req.Header.Get("Authorization") - token := strings.TrimPrefix(rawToken, "Bearer ") + // get auth token from header or deployemnt config + token := deployment.ApiKey + if token == "" { + rawToken := req.Header.Get("Authorization") + token = strings.TrimPrefix(rawToken, "Bearer ") + } + if token == "" { + util.SendError(c, errors.New("token is empty")) + return + } req.Header.Set(AuthHeaderKey, token) req.Header.Del("Authorization") originURL := req.URL.String() - req.Host = AzureOpenAIEndpointParse.Host - req.URL.Scheme = AzureOpenAIEndpointParse.Scheme - req.URL.Host = AzureOpenAIEndpointParse.Host - req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.Replace(req.URL.Path, "/v1/", "/", 1)) - req.URL.RawPath = req.URL.EscapedPath() - - query := req.URL.Query() - query.Add("api-version", AzureOpenAIAPIVer) - req.URL.RawQuery = query.Encode() - + req, err = requestConverter.Convert(req, deployment) + if err != nil { + util.SendError(c, errors.Wrap(err, "convert request error")) + return + } log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String()) } @@ -80,10 +95,10 @@ func Proxy(c *gin.Context) { } } -func GetDeploymentByModel(model string) string { - if v, ok := AzureOpenAIModelMapper[model]; ok { - return v +func GetDeploymentByModel(model string) (*DeploymentConfig, error) { + deploymentConfig, exist := ModelDeploymentConfig[model] + if !exist { + return nil, errors.New(fmt.Sprintf("deployment config for %s not found", model)) } - - return fallbackModelMapper.ReplaceAllString(model, "") + return &deploymentConfig, nil } diff --git a/cmd/router.go b/cmd/router.go index 727d99a..1b02c5a 100644 --- a/cmd/router.go +++ b/cmd/router.go @@ -2,6 +2,7 @@ package main import ( "github.com/gin-gonic/gin" + "github.com/spf13/viper" "github.com/stulzq/azure-openai-proxy/azure" ) @@ -14,7 +15,14 @@ func registerRoute(r *gin.Engine) { r.Any("/health", func(c *gin.Context) { c.Status(200) }) - - r.Any("/v1/*path", azure.Proxy) - + apiBase := viper.GetString("api_base") + stripPrefixConverter := azure.NewStripPrefixConverter(apiBase) + templateConverter := azure.NewTemplateConverter("/openai/deployments/{{.DeploymentName}}/embeddings") + apiBasedRouter := r.Group(apiBase) + { + apiBasedRouter.Any("/engines/:model/embeddings", azure.ProxyWithConverter(templateConverter)) + apiBasedRouter.Any("/completions", azure.ProxyWithConverter(stripPrefixConverter)) + apiBasedRouter.Any("/chat/completions", azure.ProxyWithConverter(stripPrefixConverter)) + apiBasedRouter.Any("/embeddings", azure.ProxyWithConverter(stripPrefixConverter)) + } }