🧱 update router

This commit is contained in:
warjiang
2023-06-15 12:19:13 +08:00
parent 932eefebfd
commit 9ba3a1e89e
2 changed files with 54 additions and 31 deletions

View File

@@ -8,7 +8,6 @@ import (
"log" "log"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"path"
"strings" "strings"
"github.com/bytedance/sonic" "github.com/bytedance/sonic"
@@ -16,8 +15,14 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc {
return func(c *gin.Context) {
Proxy(c, requestConverter)
}
}
// Proxy Azure OpenAI // Proxy Azure OpenAI
func Proxy(c *gin.Context) { func Proxy(c *gin.Context, requestConverter RequestConverter) {
if c.Request.Method == http.MethodOptions { if c.Request.Method == http.MethodOptions {
c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, OPTIONS, POST") c.Header("Access-Control-Allow-Methods", "GET, OPTIONS, POST")
@@ -34,38 +39,48 @@ func Proxy(c *gin.Context) {
body, _ := io.ReadAll(req.Body) body, _ := io.ReadAll(req.Body)
req.Body = io.NopCloser(bytes.NewBuffer(body)) req.Body = io.NopCloser(bytes.NewBuffer(body))
// get model from body // get model from url params or body
model, err := sonic.Get(body, "model") model := c.Param("model")
if err != nil { if model == "" {
util.SendError(c, errors.Wrap(err, "get model error")) _model, err := sonic.Get(body, "model")
return if err != nil {
util.SendError(c, errors.Wrap(err, "get model error"))
return
}
_modelStr, err := _model.String()
if err != nil {
util.SendError(c, errors.Wrap(err, "get model name error"))
return
}
model = _modelStr
} }
// get deployment from request // get deployment from request
deployment, err := model.String() deployment, err := GetDeploymentByModel(model)
if err != nil { if err != nil {
util.SendError(c, errors.Wrap(err, "get deployment error")) util.SendError(c, err)
return return
} }
deployment = GetDeploymentByModel(deployment)
// get auth token from header // get auth token from header or deployemnt config
rawToken := req.Header.Get("Authorization") token := deployment.ApiKey
token := strings.TrimPrefix(rawToken, "Bearer ") if token == "" {
rawToken := req.Header.Get("Authorization")
token = strings.TrimPrefix(rawToken, "Bearer ")
}
if token == "" {
util.SendError(c, errors.New("token is empty"))
return
}
req.Header.Set(AuthHeaderKey, token) req.Header.Set(AuthHeaderKey, token)
req.Header.Del("Authorization") req.Header.Del("Authorization")
originURL := req.URL.String() originURL := req.URL.String()
req.Host = AzureOpenAIEndpointParse.Host req, err = requestConverter.Convert(req, deployment)
req.URL.Scheme = AzureOpenAIEndpointParse.Scheme if err != nil {
req.URL.Host = AzureOpenAIEndpointParse.Host util.SendError(c, errors.Wrap(err, "convert request error"))
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.Replace(req.URL.Path, "/v1/", "/", 1)) return
req.URL.RawPath = req.URL.EscapedPath() }
query := req.URL.Query()
query.Add("api-version", AzureOpenAIAPIVer)
req.URL.RawQuery = query.Encode()
log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String()) log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
} }
@@ -80,10 +95,10 @@ func Proxy(c *gin.Context) {
} }
} }
func GetDeploymentByModel(model string) string { func GetDeploymentByModel(model string) (*DeploymentConfig, error) {
if v, ok := AzureOpenAIModelMapper[model]; ok { deploymentConfig, exist := ModelDeploymentConfig[model]
return v if !exist {
return nil, errors.New(fmt.Sprintf("deployment config for %s not found", model))
} }
return &deploymentConfig, nil
return fallbackModelMapper.ReplaceAllString(model, "")
} }

View File

@@ -2,6 +2,7 @@ package main
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/spf13/viper"
"github.com/stulzq/azure-openai-proxy/azure" "github.com/stulzq/azure-openai-proxy/azure"
) )
@@ -14,7 +15,14 @@ func registerRoute(r *gin.Engine) {
r.Any("/health", func(c *gin.Context) { r.Any("/health", func(c *gin.Context) {
c.Status(200) c.Status(200)
}) })
apiBase := viper.GetString("api_base")
r.Any("/v1/*path", azure.Proxy) stripPrefixConverter := azure.NewStripPrefixConverter(apiBase)
templateConverter := azure.NewTemplateConverter("/openai/deployments/{{.DeploymentName}}/embeddings")
apiBasedRouter := r.Group(apiBase)
{
apiBasedRouter.Any("/engines/:model/embeddings", azure.ProxyWithConverter(templateConverter))
apiBasedRouter.Any("/completions", azure.ProxyWithConverter(stripPrefixConverter))
apiBasedRouter.Any("/chat/completions", azure.ProxyWithConverter(stripPrefixConverter))
apiBasedRouter.Any("/embeddings", azure.ProxyWithConverter(stripPrefixConverter))
}
} }