feat: support multi model

This commit is contained in:
Zhiqiang Li
2023-03-28 14:34:21 +08:00
parent 7b9dff64c7
commit 1968022907
16 changed files with 162 additions and 112 deletions

View File

@@ -1,5 +1,6 @@
FROM alpine:3
EXPOSE 8080
COPY ./bin/azure-openai-proxy /usr/bin
ENTRYPOINT ["/usr/bin/azure-openai-proxy"]

View File

@@ -1,43 +0,0 @@
package apis
import (
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/stulzq/azure-openai-proxy/openai"
"io"
"strings"
)
// ChatCompletions xxx
// Path: /v1/chat/completions
func ChatCompletions(c *gin.Context) {
// get auth token from header
rawToken := c.GetHeader("Authorization")
token := strings.TrimPrefix(rawToken, "Bearer ")
reqContent, err := io.ReadAll(c.Request.Body)
if err != nil {
SendError(c, errors.Wrap(err, "failed to read request body"))
return
}
oaiResp, err := openai.ChatCompletions(token, reqContent)
if err != nil {
SendError(c, errors.Wrap(err, "failed to call Azure OpenAI"))
return
}
// pass-through header
extraHeaders := map[string]string{}
for k, v := range oaiResp.Header {
if _, ok := ignoreHeaders[k]; ok {
continue
}
extraHeaders[k] = strings.Join(v, ",")
}
c.DataFromReader(oaiResp.StatusCode, oaiResp.ContentLength, oaiResp.Header.Get("Content-Type"), oaiResp.Response.Body, extraHeaders)
_, _ = c.Writer.Write([]byte{'\n'}) // add a newline to the end of the response https://github.com/Chanzhaoyu/chatgpt-web/issues/831
}

View File

@@ -1,9 +0,0 @@
package apis
var (
ignoreHeaders = map[string]int{
"Content-Type": 1,
"Transfer-Encoding": 1,
"Date": 1,
}
)

58
azure/init.go Normal file
View File

@@ -0,0 +1,58 @@
package azure
import (
"log"
"net/url"
"os"
"regexp"
"strings"
"github.com/stulzq/azure-openai-proxy/constant"
)
const (
AuthHeaderKey = "api-key"
)
var (
AzureOpenAIEndpoint = ""
AzureOpenAIEndpointParse *url.URL
AzureOpenAIAPIVer = ""
AzureOpenAIModelMapper = map[string]string{
"gpt-3.5-turbo": "gpt-35-turbo",
"gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
}
fallbackModelMapper = regexp.MustCompile(`[.:]`)
)
func init() {
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
if AzureOpenAIAPIVer == "" {
AzureOpenAIAPIVer = "2023-03-15-preview"
}
var err error
AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint)
if err != nil {
log.Fatal("parse AzureOpenAIEndpoint error: ", err)
}
if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" {
for _, pair := range strings.Split(v, ",") {
info := strings.Split(pair, "=")
if len(info) != 2 {
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
}
AzureOpenAIModelMapper[info[0]] = info[1]
}
}
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper)
}

81
azure/proxy.go Normal file
View File

@@ -0,0 +1,81 @@
package azure
import (
"bytes"
"fmt"
"github.com/bytedance/sonic"
"github.com/pkg/errors"
"github.com/stulzq/azure-openai-proxy/util"
"io"
"log"
"net/http"
"net/http/httputil"
"path"
"strings"
"github.com/gin-gonic/gin"
)
// Proxy Azure OpenAI
func Proxy(c *gin.Context) {
// improve performance some code from https://github.com/diemus/azure-openai-proxy/blob/main/pkg/azure/proxy.go
director := func(req *http.Request) {
if req.Body == nil {
util.SendError(c, errors.New("request body is empty"))
return
}
body, _ := io.ReadAll(req.Body)
req.Body = io.NopCloser(bytes.NewBuffer(body))
// get model from body
model, err := sonic.Get(body, "model")
if err != nil {
util.SendError(c, errors.Wrap(err, "get model error"))
return
}
deployment, err := model.String()
if err != nil {
util.SendError(c, errors.Wrap(err, "get deployment error"))
return
}
deployment = GetDeploymentByModel(deployment)
// get auth token from header
rawToken := req.Header.Get("Authorization")
token := strings.TrimPrefix(rawToken, "Bearer ")
req.Header.Set(AuthHeaderKey, token)
req.Header.Del("Authorization")
originURL := req.URL.String()
req.Host = AzureOpenAIEndpointParse.Host
req.URL.Scheme = AzureOpenAIEndpointParse.Scheme
req.URL.Host = AzureOpenAIEndpointParse.Host
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.Replace(req.URL.Path, "/v1/", "/", 1))
req.URL.RawPath = req.URL.EscapedPath()
query := req.URL.Query()
query.Add("api-version", AzureOpenAIAPIVer)
req.URL.RawQuery = query.Encode()
log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
}
proxy := &httputil.ReverseProxy{Director: director}
proxy.ServeHTTP(c.Writer, c.Request)
// https://github.com/Chanzhaoyu/chatgpt-web/issues/831
if c.Writer.Header().Get("Content-Type") == "text/event-stream" {
if _, err := c.Writer.Write([]byte{'\n'}); err != nil {
log.Printf("rewrite response error: %v", err)
}
}
}
func GetDeploymentByModel(model string) string {
if v, ok := AzureOpenAIModelMapper[model]; ok {
return v
}
return fallbackModelMapper.ReplaceAllString(model, "")
}

View File

@@ -2,7 +2,7 @@
set -e
VERSION=v1.0.0
VERSION=v1.1.0
rm -rf bin

View File

@@ -4,7 +4,6 @@ import (
"context"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/stulzq/azure-openai-proxy/openai"
"log"
"net/http"
"os"
@@ -13,8 +12,6 @@ import (
)
func main() {
openai.Init()
gin.SetMode(gin.ReleaseMode)
r := gin.Default()
registerRoute(r)

View File

@@ -2,9 +2,11 @@ package main
import (
"github.com/gin-gonic/gin"
"github.com/stulzq/azure-openai-proxy/apis"
"github.com/stulzq/azure-openai-proxy/azure"
)
// registerRoute registers all routes
func registerRoute(r *gin.Engine) {
r.POST("/v1/chat/completions", apis.ChatCompletions)
// https://platform.openai.com/docs/api-reference
r.Any("*path", azure.Proxy)
}

View File

@@ -1,7 +1,7 @@
package constant
const (
ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT"
ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER"
ENV_AZURE_OPENAI_DEPLOY = "AZURE_OPENAI_DEPLOY"
ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT"
ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER"
ENV_AZURE_OPENAI_MODEL_MAPPER = "AZURE_OPENAI_MODEL_MAPPER"
)

3
go.mod
View File

@@ -35,6 +35,9 @@ require (
github.com/quic-go/qtls-go1-19 v0.2.0 // indirect
github.com/quic-go/qtls-go1-20 v0.1.0 // indirect
github.com/quic-go/quic-go v0.32.0 // indirect
github.com/tidwall/gjson v1.14.4 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect

6
go.sum
View File

@@ -95,6 +95,12 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=

View File

@@ -1,11 +0,0 @@
package openai
import "github.com/imroc/req/v3"
func ChatCompletions(token string, body []byte) (*req.Response, error) {
return client.R().
SetHeader("Content-Type", "application/json").
SetHeader(AuthHeaderKey, token).
SetBodyBytes(body).
Post(ChatCompletionsUrl)
}

View File

@@ -1,20 +0,0 @@
package openai
import (
"fmt"
"github.com/stulzq/azure-openai-proxy/constant"
"log"
"os"
)
func Init() {
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
AzureOpenAIDeploy = os.Getenv(constant.ENV_AZURE_OPENAI_DEPLOY)
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
log.Println("AzureOpenAIDeploy: ", AzureOpenAIDeploy)
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
ChatCompletionsUrl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", AzureOpenAIEndpoint, AzureOpenAIDeploy, AzureOpenAIAPIVer)
}

View File

@@ -1,17 +0,0 @@
package openai
import "github.com/imroc/req/v3"
const (
AuthHeaderKey = "api-key"
)
var (
AzureOpenAIEndpoint = ""
AzureOpenAIAPIVer = ""
AzureOpenAIDeploy = ""
ChatCompletionsUrl = ""
client = req.C()
)

View File

@@ -1,6 +1,8 @@
package apis
package util
import "github.com/gin-gonic/gin"
import (
"github.com/gin-gonic/gin"
)
func SendError(c *gin.Context, err error) {
c.JSON(500, ApiResponse{

View File

@@ -1,4 +1,4 @@
package apis
package util
type ApiResponse struct {
Error ErrorDescription `json:"error"`