mirror of
https://github.com/stulzq/azure-openai-proxy.git
synced 2025-12-18 23:04:19 +01:00
feat: support multi model
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
FROM alpine:3
|
FROM alpine:3
|
||||||
|
|
||||||
|
EXPOSE 8080
|
||||||
COPY ./bin/azure-openai-proxy /usr/bin
|
COPY ./bin/azure-openai-proxy /usr/bin
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/bin/azure-openai-proxy"]
|
ENTRYPOINT ["/usr/bin/azure-openai-proxy"]
|
||||||
43
apis/chat.go
43
apis/chat.go
@@ -1,43 +0,0 @@
|
|||||||
package apis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/stulzq/azure-openai-proxy/openai"
|
|
||||||
"io"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ChatCompletions xxx
|
|
||||||
// Path: /v1/chat/completions
|
|
||||||
func ChatCompletions(c *gin.Context) {
|
|
||||||
// get auth token from header
|
|
||||||
rawToken := c.GetHeader("Authorization")
|
|
||||||
token := strings.TrimPrefix(rawToken, "Bearer ")
|
|
||||||
|
|
||||||
reqContent, err := io.ReadAll(c.Request.Body)
|
|
||||||
if err != nil {
|
|
||||||
SendError(c, errors.Wrap(err, "failed to read request body"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
oaiResp, err := openai.ChatCompletions(token, reqContent)
|
|
||||||
if err != nil {
|
|
||||||
SendError(c, errors.Wrap(err, "failed to call Azure OpenAI"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// pass-through header
|
|
||||||
extraHeaders := map[string]string{}
|
|
||||||
for k, v := range oaiResp.Header {
|
|
||||||
if _, ok := ignoreHeaders[k]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
extraHeaders[k] = strings.Join(v, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.DataFromReader(oaiResp.StatusCode, oaiResp.ContentLength, oaiResp.Header.Get("Content-Type"), oaiResp.Response.Body, extraHeaders)
|
|
||||||
|
|
||||||
_, _ = c.Writer.Write([]byte{'\n'}) // add a newline to the end of the response https://github.com/Chanzhaoyu/chatgpt-web/issues/831
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
package apis
|
|
||||||
|
|
||||||
var (
|
|
||||||
ignoreHeaders = map[string]int{
|
|
||||||
"Content-Type": 1,
|
|
||||||
"Transfer-Encoding": 1,
|
|
||||||
"Date": 1,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
58
azure/init.go
Normal file
58
azure/init.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package azure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/stulzq/azure-openai-proxy/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuthHeaderKey = "api-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
AzureOpenAIEndpoint = ""
|
||||||
|
AzureOpenAIEndpointParse *url.URL
|
||||||
|
|
||||||
|
AzureOpenAIAPIVer = ""
|
||||||
|
|
||||||
|
AzureOpenAIModelMapper = map[string]string{
|
||||||
|
"gpt-3.5-turbo": "gpt-35-turbo",
|
||||||
|
"gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
|
||||||
|
}
|
||||||
|
fallbackModelMapper = regexp.MustCompile(`[.:]`)
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
|
||||||
|
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
|
||||||
|
|
||||||
|
if AzureOpenAIAPIVer == "" {
|
||||||
|
AzureOpenAIAPIVer = "2023-03-15-preview"
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("parse AzureOpenAIEndpoint error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" {
|
||||||
|
for _, pair := range strings.Split(v, ",") {
|
||||||
|
info := strings.Split(pair, "=")
|
||||||
|
if len(info) != 2 {
|
||||||
|
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
AzureOpenAIModelMapper[info[0]] = info[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
|
||||||
|
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
|
||||||
|
log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper)
|
||||||
|
}
|
||||||
81
azure/proxy.go
Normal file
81
azure/proxy.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package azure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"github.com/bytedance/sonic"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/stulzq/azure-openai-proxy/util"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Proxy Azure OpenAI
|
||||||
|
func Proxy(c *gin.Context) {
|
||||||
|
// improve performance some code from https://github.com/diemus/azure-openai-proxy/blob/main/pkg/azure/proxy.go
|
||||||
|
director := func(req *http.Request) {
|
||||||
|
if req.Body == nil {
|
||||||
|
util.SendError(c, errors.New("request body is empty"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body, _ := io.ReadAll(req.Body)
|
||||||
|
req.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||||
|
|
||||||
|
// get model from body
|
||||||
|
model, err := sonic.Get(body, "model")
|
||||||
|
if err != nil {
|
||||||
|
util.SendError(c, errors.Wrap(err, "get model error"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deployment, err := model.String()
|
||||||
|
if err != nil {
|
||||||
|
util.SendError(c, errors.Wrap(err, "get deployment error"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
deployment = GetDeploymentByModel(deployment)
|
||||||
|
|
||||||
|
// get auth token from header
|
||||||
|
rawToken := req.Header.Get("Authorization")
|
||||||
|
token := strings.TrimPrefix(rawToken, "Bearer ")
|
||||||
|
req.Header.Set(AuthHeaderKey, token)
|
||||||
|
req.Header.Del("Authorization")
|
||||||
|
|
||||||
|
originURL := req.URL.String()
|
||||||
|
req.Host = AzureOpenAIEndpointParse.Host
|
||||||
|
req.URL.Scheme = AzureOpenAIEndpointParse.Scheme
|
||||||
|
req.URL.Host = AzureOpenAIEndpointParse.Host
|
||||||
|
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.Replace(req.URL.Path, "/v1/", "/", 1))
|
||||||
|
req.URL.RawPath = req.URL.EscapedPath()
|
||||||
|
|
||||||
|
query := req.URL.Query()
|
||||||
|
query.Add("api-version", AzureOpenAIAPIVer)
|
||||||
|
req.URL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := &httputil.ReverseProxy{Director: director}
|
||||||
|
proxy.ServeHTTP(c.Writer, c.Request)
|
||||||
|
|
||||||
|
// https://github.com/Chanzhaoyu/chatgpt-web/issues/831
|
||||||
|
if c.Writer.Header().Get("Content-Type") == "text/event-stream" {
|
||||||
|
if _, err := c.Writer.Write([]byte{'\n'}); err != nil {
|
||||||
|
log.Printf("rewrite response error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDeploymentByModel(model string) string {
|
||||||
|
if v, ok := AzureOpenAIModelMapper[model]; ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
return fallbackModelMapper.ReplaceAllString(model, "")
|
||||||
|
}
|
||||||
2
build.sh
2
build.sh
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
VERSION=v1.0.0
|
VERSION=v1.1.0
|
||||||
|
|
||||||
rm -rf bin
|
rm -rf bin
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/stulzq/azure-openai-proxy/openai"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -13,8 +12,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
openai.Init()
|
|
||||||
|
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
registerRoute(r)
|
registerRoute(r)
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stulzq/azure-openai-proxy/apis"
|
"github.com/stulzq/azure-openai-proxy/azure"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// registerRoute registers all routes
|
||||||
func registerRoute(r *gin.Engine) {
|
func registerRoute(r *gin.Engine) {
|
||||||
r.POST("/v1/chat/completions", apis.ChatCompletions)
|
// https://platform.openai.com/docs/api-reference
|
||||||
|
r.Any("*path", azure.Proxy)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT"
|
ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT"
|
||||||
ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER"
|
ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER"
|
||||||
ENV_AZURE_OPENAI_DEPLOY = "AZURE_OPENAI_DEPLOY"
|
ENV_AZURE_OPENAI_MODEL_MAPPER = "AZURE_OPENAI_MODEL_MAPPER"
|
||||||
)
|
)
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -35,6 +35,9 @@ require (
|
|||||||
github.com/quic-go/qtls-go1-19 v0.2.0 // indirect
|
github.com/quic-go/qtls-go1-19 v0.2.0 // indirect
|
||||||
github.com/quic-go/qtls-go1-20 v0.1.0 // indirect
|
github.com/quic-go/qtls-go1-20 v0.1.0 // indirect
|
||||||
github.com/quic-go/quic-go v0.32.0 // indirect
|
github.com/quic-go/quic-go v0.32.0 // indirect
|
||||||
|
github.com/tidwall/gjson v1.14.4 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
|
|||||||
6
go.sum
6
go.sum
@@ -95,6 +95,12 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
|||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
|
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
|
||||||
|
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import "github.com/imroc/req/v3"
|
|
||||||
|
|
||||||
func ChatCompletions(token string, body []byte) (*req.Response, error) {
|
|
||||||
return client.R().
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetHeader(AuthHeaderKey, token).
|
|
||||||
SetBodyBytes(body).
|
|
||||||
Post(ChatCompletionsUrl)
|
|
||||||
}
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"github.com/stulzq/azure-openai-proxy/constant"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Init() {
|
|
||||||
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
|
|
||||||
AzureOpenAIDeploy = os.Getenv(constant.ENV_AZURE_OPENAI_DEPLOY)
|
|
||||||
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
|
|
||||||
|
|
||||||
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
|
|
||||||
log.Println("AzureOpenAIDeploy: ", AzureOpenAIDeploy)
|
|
||||||
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
|
|
||||||
|
|
||||||
ChatCompletionsUrl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", AzureOpenAIEndpoint, AzureOpenAIDeploy, AzureOpenAIAPIVer)
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import "github.com/imroc/req/v3"
|
|
||||||
|
|
||||||
const (
|
|
||||||
AuthHeaderKey = "api-key"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
AzureOpenAIEndpoint = ""
|
|
||||||
AzureOpenAIAPIVer = ""
|
|
||||||
AzureOpenAIDeploy = ""
|
|
||||||
|
|
||||||
ChatCompletionsUrl = ""
|
|
||||||
|
|
||||||
client = req.C()
|
|
||||||
)
|
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
package apis
|
package util
|
||||||
|
|
||||||
import "github.com/gin-gonic/gin"
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
func SendError(c *gin.Context, err error) {
|
func SendError(c *gin.Context, err error) {
|
||||||
c.JSON(500, ApiResponse{
|
c.JSON(500, ApiResponse{
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package apis
|
package util
|
||||||
|
|
||||||
type ApiResponse struct {
|
type ApiResponse struct {
|
||||||
Error ErrorDescription `json:"error"`
|
Error ErrorDescription `json:"error"`
|
||||||
Reference in New Issue
Block a user