mirror of
https://github.com/stulzq/azure-openai-proxy.git
synced 2025-12-19 15:24:24 +01:00
🧱 update router
This commit is contained in:
@@ -8,7 +8,6 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"path"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bytedance/sonic"
|
"github.com/bytedance/sonic"
|
||||||
@@ -16,8 +15,14 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
Proxy(c, requestConverter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Proxy Azure OpenAI
|
// Proxy Azure OpenAI
|
||||||
func Proxy(c *gin.Context) {
|
func Proxy(c *gin.Context, requestConverter RequestConverter) {
|
||||||
if c.Request.Method == http.MethodOptions {
|
if c.Request.Method == http.MethodOptions {
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
c.Header("Access-Control-Allow-Methods", "GET, OPTIONS, POST")
|
c.Header("Access-Control-Allow-Methods", "GET, OPTIONS, POST")
|
||||||
@@ -34,38 +39,48 @@ func Proxy(c *gin.Context) {
|
|||||||
body, _ := io.ReadAll(req.Body)
|
body, _ := io.ReadAll(req.Body)
|
||||||
req.Body = io.NopCloser(bytes.NewBuffer(body))
|
req.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||||
|
|
||||||
// get model from body
|
// get model from url params or body
|
||||||
model, err := sonic.Get(body, "model")
|
model := c.Param("model")
|
||||||
if err != nil {
|
if model == "" {
|
||||||
util.SendError(c, errors.Wrap(err, "get model error"))
|
_model, err := sonic.Get(body, "model")
|
||||||
return
|
if err != nil {
|
||||||
|
util.SendError(c, errors.Wrap(err, "get model error"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_modelStr, err := _model.String()
|
||||||
|
if err != nil {
|
||||||
|
util.SendError(c, errors.Wrap(err, "get model name error"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model = _modelStr
|
||||||
}
|
}
|
||||||
|
|
||||||
// get deployment from request
|
// get deployment from request
|
||||||
deployment, err := model.String()
|
deployment, err := GetDeploymentByModel(model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.SendError(c, errors.Wrap(err, "get deployment error"))
|
util.SendError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
deployment = GetDeploymentByModel(deployment)
|
|
||||||
|
|
||||||
// get auth token from header
|
// get auth token from header or deployemnt config
|
||||||
rawToken := req.Header.Get("Authorization")
|
token := deployment.ApiKey
|
||||||
token := strings.TrimPrefix(rawToken, "Bearer ")
|
if token == "" {
|
||||||
|
rawToken := req.Header.Get("Authorization")
|
||||||
|
token = strings.TrimPrefix(rawToken, "Bearer ")
|
||||||
|
}
|
||||||
|
if token == "" {
|
||||||
|
util.SendError(c, errors.New("token is empty"))
|
||||||
|
return
|
||||||
|
}
|
||||||
req.Header.Set(AuthHeaderKey, token)
|
req.Header.Set(AuthHeaderKey, token)
|
||||||
req.Header.Del("Authorization")
|
req.Header.Del("Authorization")
|
||||||
|
|
||||||
originURL := req.URL.String()
|
originURL := req.URL.String()
|
||||||
req.Host = AzureOpenAIEndpointParse.Host
|
req, err = requestConverter.Convert(req, deployment)
|
||||||
req.URL.Scheme = AzureOpenAIEndpointParse.Scheme
|
if err != nil {
|
||||||
req.URL.Host = AzureOpenAIEndpointParse.Host
|
util.SendError(c, errors.Wrap(err, "convert request error"))
|
||||||
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.Replace(req.URL.Path, "/v1/", "/", 1))
|
return
|
||||||
req.URL.RawPath = req.URL.EscapedPath()
|
}
|
||||||
|
|
||||||
query := req.URL.Query()
|
|
||||||
query.Add("api-version", AzureOpenAIAPIVer)
|
|
||||||
req.URL.RawQuery = query.Encode()
|
|
||||||
|
|
||||||
log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
|
log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,10 +95,10 @@ func Proxy(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetDeploymentByModel(model string) string {
|
func GetDeploymentByModel(model string) (*DeploymentConfig, error) {
|
||||||
if v, ok := AzureOpenAIModelMapper[model]; ok {
|
deploymentConfig, exist := ModelDeploymentConfig[model]
|
||||||
return v
|
if !exist {
|
||||||
|
return nil, errors.New(fmt.Sprintf("deployment config for %s not found", model))
|
||||||
}
|
}
|
||||||
|
return &deploymentConfig, nil
|
||||||
return fallbackModelMapper.ReplaceAllString(model, "")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/spf13/viper"
|
||||||
"github.com/stulzq/azure-openai-proxy/azure"
|
"github.com/stulzq/azure-openai-proxy/azure"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,7 +15,14 @@ func registerRoute(r *gin.Engine) {
|
|||||||
r.Any("/health", func(c *gin.Context) {
|
r.Any("/health", func(c *gin.Context) {
|
||||||
c.Status(200)
|
c.Status(200)
|
||||||
})
|
})
|
||||||
|
apiBase := viper.GetString("api_base")
|
||||||
r.Any("/v1/*path", azure.Proxy)
|
stripPrefixConverter := azure.NewStripPrefixConverter(apiBase)
|
||||||
|
templateConverter := azure.NewTemplateConverter("/openai/deployments/{{.DeploymentName}}/embeddings")
|
||||||
|
apiBasedRouter := r.Group(apiBase)
|
||||||
|
{
|
||||||
|
apiBasedRouter.Any("/engines/:model/embeddings", azure.ProxyWithConverter(templateConverter))
|
||||||
|
apiBasedRouter.Any("/completions", azure.ProxyWithConverter(stripPrefixConverter))
|
||||||
|
apiBasedRouter.Any("/chat/completions", azure.ProxyWithConverter(stripPrefixConverter))
|
||||||
|
apiBasedRouter.Any("/embeddings", azure.ProxyWithConverter(stripPrefixConverter))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user