mirror of
https://github.com/stulzq/azure-openai-proxy.git
synced 2025-12-19 07:14:21 +01:00
feat: ability to mimic /v1/models/ api (#68)
* add: ability to mimic `/v1/models/` api * fix: missing json dependency * fix: accidentally remove dependency * fix: linter error caught by gh workflow
This commit is contained in:
@@ -2,6 +2,7 @@ package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/stulzq/azure-openai-proxy/util"
|
||||
"io"
|
||||
@@ -21,6 +22,85 @@ func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
type DeploymentInfo struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
func ModelProxy(c *gin.Context) {
|
||||
// Create a channel to receive the results of each request
|
||||
results := make(chan []map[string]interface{}, len(ModelDeploymentConfig))
|
||||
|
||||
// Send a request for each deployment in the map
|
||||
for _, deployment := range ModelDeploymentConfig {
|
||||
go func(deployment DeploymentConfig) {
|
||||
// Create the request
|
||||
req, err := http.NewRequest(http.MethodGet, deployment.Endpoint+"/openai/deployments?api-version=2022-12-01", nil)
|
||||
if err != nil {
|
||||
log.Printf("error parsing response body for deployment %s: %v", deployment.DeploymentName, err)
|
||||
results <- nil
|
||||
return
|
||||
}
|
||||
|
||||
// Set the auth header
|
||||
req.Header.Set(AuthHeaderKey, deployment.ApiKey)
|
||||
|
||||
// Send the request
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("error sending request for deployment %s: %v", deployment.DeploymentName, err)
|
||||
results <- nil
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("unexpected status code %d for deployment %s", resp.StatusCode, deployment.DeploymentName)
|
||||
results <- nil
|
||||
return
|
||||
}
|
||||
|
||||
// Read the response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("error reading response body for deployment %s: %v", deployment.DeploymentName, err)
|
||||
results <- nil
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the response body as JSON
|
||||
var deplotmentInfo DeploymentInfo
|
||||
err = json.Unmarshal(body, &deplotmentInfo)
|
||||
if err != nil {
|
||||
log.Printf("error parsing response body for deployment %s: %v", deployment.DeploymentName, err)
|
||||
results <- nil
|
||||
return
|
||||
}
|
||||
results <- deplotmentInfo.Data
|
||||
}(deployment)
|
||||
}
|
||||
|
||||
// Wait for all requests to finish and collect the results
|
||||
var allResults []map[string]interface{}
|
||||
for i := 0; i < len(ModelDeploymentConfig); i++ {
|
||||
result := <-results
|
||||
if result != nil {
|
||||
allResults = append(allResults, result...)
|
||||
}
|
||||
}
|
||||
var info = DeploymentInfo{Data: allResults, Object: "list"}
|
||||
combinedResults, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
log.Printf("error marshalling results: %v", err)
|
||||
util.SendError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the response headers and body
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.String(http.StatusOK, string(combinedResults))
|
||||
}
|
||||
|
||||
// Proxy Azure OpenAI
|
||||
func Proxy(c *gin.Context, requestConverter RequestConverter) {
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
|
||||
@@ -17,6 +17,7 @@ func registerRoute(r *gin.Engine) {
|
||||
})
|
||||
apiBase := viper.GetString("api_base")
|
||||
stripPrefixConverter := azure.NewStripPrefixConverter(apiBase)
|
||||
r.GET(stripPrefixConverter.Prefix+"/models", azure.ModelProxy)
|
||||
templateConverter := azure.NewTemplateConverter("/openai/deployments/{{.DeploymentName}}/embeddings")
|
||||
apiBasedRouter := r.Group(apiBase)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user