mirror of
https://github.com/stulzq/azure-openai-proxy.git
synced 2025-12-18 14:54:19 +01:00
feat: support multi model
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
FROM alpine:3
|
||||
|
||||
EXPOSE 8080
|
||||
COPY ./bin/azure-openai-proxy /usr/bin
|
||||
|
||||
ENTRYPOINT ["/usr/bin/azure-openai-proxy"]
|
||||
43
apis/chat.go
43
apis/chat.go
@@ -1,43 +0,0 @@
|
||||
package apis
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stulzq/azure-openai-proxy/openai"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChatCompletions xxx
|
||||
// Path: /v1/chat/completions
|
||||
func ChatCompletions(c *gin.Context) {
|
||||
// get auth token from header
|
||||
rawToken := c.GetHeader("Authorization")
|
||||
token := strings.TrimPrefix(rawToken, "Bearer ")
|
||||
|
||||
reqContent, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
SendError(c, errors.Wrap(err, "failed to read request body"))
|
||||
return
|
||||
}
|
||||
|
||||
oaiResp, err := openai.ChatCompletions(token, reqContent)
|
||||
if err != nil {
|
||||
SendError(c, errors.Wrap(err, "failed to call Azure OpenAI"))
|
||||
return
|
||||
}
|
||||
|
||||
// pass-through header
|
||||
extraHeaders := map[string]string{}
|
||||
for k, v := range oaiResp.Header {
|
||||
if _, ok := ignoreHeaders[k]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
extraHeaders[k] = strings.Join(v, ",")
|
||||
}
|
||||
|
||||
c.DataFromReader(oaiResp.StatusCode, oaiResp.ContentLength, oaiResp.Header.Get("Content-Type"), oaiResp.Response.Body, extraHeaders)
|
||||
|
||||
_, _ = c.Writer.Write([]byte{'\n'}) // add a newline to the end of the response https://github.com/Chanzhaoyu/chatgpt-web/issues/831
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package apis
|
||||
|
||||
var (
|
||||
ignoreHeaders = map[string]int{
|
||||
"Content-Type": 1,
|
||||
"Transfer-Encoding": 1,
|
||||
"Date": 1,
|
||||
}
|
||||
)
|
||||
58
azure/init.go
Normal file
58
azure/init.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/stulzq/azure-openai-proxy/constant"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthHeaderKey = "api-key"
|
||||
)
|
||||
|
||||
var (
|
||||
AzureOpenAIEndpoint = ""
|
||||
AzureOpenAIEndpointParse *url.URL
|
||||
|
||||
AzureOpenAIAPIVer = ""
|
||||
|
||||
AzureOpenAIModelMapper = map[string]string{
|
||||
"gpt-3.5-turbo": "gpt-35-turbo",
|
||||
"gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
|
||||
}
|
||||
fallbackModelMapper = regexp.MustCompile(`[.:]`)
|
||||
)
|
||||
|
||||
func init() {
|
||||
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
|
||||
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
|
||||
|
||||
if AzureOpenAIAPIVer == "" {
|
||||
AzureOpenAIAPIVer = "2023-03-15-preview"
|
||||
}
|
||||
|
||||
var err error
|
||||
AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint)
|
||||
if err != nil {
|
||||
log.Fatal("parse AzureOpenAIEndpoint error: ", err)
|
||||
}
|
||||
|
||||
if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" {
|
||||
for _, pair := range strings.Split(v, ",") {
|
||||
info := strings.Split(pair, "=")
|
||||
if len(info) != 2 {
|
||||
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
|
||||
}
|
||||
|
||||
AzureOpenAIModelMapper[info[0]] = info[1]
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
|
||||
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
|
||||
log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper)
|
||||
}
|
||||
81
azure/proxy.go
Normal file
81
azure/proxy.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stulzq/azure-openai-proxy/util"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Proxy Azure OpenAI
|
||||
func Proxy(c *gin.Context) {
|
||||
// improve performance some code from https://github.com/diemus/azure-openai-proxy/blob/main/pkg/azure/proxy.go
|
||||
director := func(req *http.Request) {
|
||||
if req.Body == nil {
|
||||
util.SendError(c, errors.New("request body is empty"))
|
||||
return
|
||||
}
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
req.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
// get model from body
|
||||
model, err := sonic.Get(body, "model")
|
||||
if err != nil {
|
||||
util.SendError(c, errors.Wrap(err, "get model error"))
|
||||
return
|
||||
}
|
||||
|
||||
deployment, err := model.String()
|
||||
if err != nil {
|
||||
util.SendError(c, errors.Wrap(err, "get deployment error"))
|
||||
return
|
||||
}
|
||||
deployment = GetDeploymentByModel(deployment)
|
||||
|
||||
// get auth token from header
|
||||
rawToken := req.Header.Get("Authorization")
|
||||
token := strings.TrimPrefix(rawToken, "Bearer ")
|
||||
req.Header.Set(AuthHeaderKey, token)
|
||||
req.Header.Del("Authorization")
|
||||
|
||||
originURL := req.URL.String()
|
||||
req.Host = AzureOpenAIEndpointParse.Host
|
||||
req.URL.Scheme = AzureOpenAIEndpointParse.Scheme
|
||||
req.URL.Host = AzureOpenAIEndpointParse.Host
|
||||
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.Replace(req.URL.Path, "/v1/", "/", 1))
|
||||
req.URL.RawPath = req.URL.EscapedPath()
|
||||
|
||||
query := req.URL.Query()
|
||||
query.Add("api-version", AzureOpenAIAPIVer)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
|
||||
}
|
||||
|
||||
proxy := &httputil.ReverseProxy{Director: director}
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
|
||||
// https://github.com/Chanzhaoyu/chatgpt-web/issues/831
|
||||
if c.Writer.Header().Get("Content-Type") == "text/event-stream" {
|
||||
if _, err := c.Writer.Write([]byte{'\n'}); err != nil {
|
||||
log.Printf("rewrite response error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetDeploymentByModel(model string) string {
|
||||
if v, ok := AzureOpenAIModelMapper[model]; ok {
|
||||
return v
|
||||
}
|
||||
|
||||
return fallbackModelMapper.ReplaceAllString(model, "")
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stulzq/azure-openai-proxy/openai"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -13,8 +12,6 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
openai.Init()
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.Default()
|
||||
registerRoute(r)
|
||||
|
||||
@@ -2,9 +2,11 @@ package main
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stulzq/azure-openai-proxy/apis"
|
||||
"github.com/stulzq/azure-openai-proxy/azure"
|
||||
)
|
||||
|
||||
// registerRoute registers all routes
|
||||
func registerRoute(r *gin.Engine) {
|
||||
r.POST("/v1/chat/completions", apis.ChatCompletions)
|
||||
// https://platform.openai.com/docs/api-reference
|
||||
r.Any("*path", azure.Proxy)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT"
|
||||
ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER"
|
||||
ENV_AZURE_OPENAI_DEPLOY = "AZURE_OPENAI_DEPLOY"
|
||||
ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT"
|
||||
ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER"
|
||||
ENV_AZURE_OPENAI_MODEL_MAPPER = "AZURE_OPENAI_MODEL_MAPPER"
|
||||
)
|
||||
|
||||
3
go.mod
3
go.mod
@@ -35,6 +35,9 @@ require (
|
||||
github.com/quic-go/qtls-go1-19 v0.2.0 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.1.0 // indirect
|
||||
github.com/quic-go/quic-go v0.32.0 // indirect
|
||||
github.com/tidwall/gjson v1.14.4 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -95,6 +95,12 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
|
||||
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "github.com/imroc/req/v3"
|
||||
|
||||
func ChatCompletions(token string, body []byte) (*req.Response, error) {
|
||||
return client.R().
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader(AuthHeaderKey, token).
|
||||
SetBodyBytes(body).
|
||||
Post(ChatCompletionsUrl)
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stulzq/azure-openai-proxy/constant"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
func Init() {
|
||||
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
|
||||
AzureOpenAIDeploy = os.Getenv(constant.ENV_AZURE_OPENAI_DEPLOY)
|
||||
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
|
||||
|
||||
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
|
||||
log.Println("AzureOpenAIDeploy: ", AzureOpenAIDeploy)
|
||||
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
|
||||
|
||||
ChatCompletionsUrl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", AzureOpenAIEndpoint, AzureOpenAIDeploy, AzureOpenAIAPIVer)
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "github.com/imroc/req/v3"
|
||||
|
||||
const (
|
||||
AuthHeaderKey = "api-key"
|
||||
)
|
||||
|
||||
var (
|
||||
AzureOpenAIEndpoint = ""
|
||||
AzureOpenAIAPIVer = ""
|
||||
AzureOpenAIDeploy = ""
|
||||
|
||||
ChatCompletionsUrl = ""
|
||||
|
||||
client = req.C()
|
||||
)
|
||||
@@ -1,6 +1,8 @@
|
||||
package apis
|
||||
package util
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SendError(c *gin.Context, err error) {
|
||||
c.JSON(500, ApiResponse{
|
||||
@@ -1,4 +1,4 @@
|
||||
package apis
|
||||
package util
|
||||
|
||||
type ApiResponse struct {
|
||||
Error ErrorDescription `json:"error"`
|
||||
Reference in New Issue
Block a user