mirror of
https://github.com/stulzq/azure-openai-proxy.git
synced 2025-12-19 15:24:24 +01:00
101 lines
3.2 KiB
Go
101 lines
3.2 KiB
Go
package azure
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"github.com/pkg/errors"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"strings"
|
|
"text/template"
|
|
)
|
|
|
|
type DeploymentConfig struct {
|
|
DeploymentName string `yaml:"deployment_name" json:"deployment_name" mapstructure:"deployment_name"` // azure openai deployment name
|
|
ModelName string `yaml:"model_name" json:"model_name" mapstructure:"model_name"` // corresponding model name in openai
|
|
Endpoint string `yaml:"endpoint" json:"endpoint" mapstructure:"endpoint"` // deployment endpoint
|
|
ApiKey string `yaml:"api_key" json:"api_key" mapstructure:"api_key"` // secrect key1 or 2
|
|
ApiVersion string `yaml:"api_version" json:"api_version" mapstructure:"api_version"` // deployment version, not required
|
|
EndpointUrl *url.URL // url.URL form deployment endpoint
|
|
}
|
|
|
|
type Config struct {
|
|
ApiBase string `yaml:"api_base" mapstructure:"api_base"` // if you use openai、langchain as sdk, it will be useful
|
|
DeploymentConfig []DeploymentConfig `yaml:"deployment_config" mapstructure:"deployment_config"` // deployment config
|
|
}
|
|
|
|
type RequestConverter interface {
|
|
Name() string
|
|
Convert(req *http.Request, config *DeploymentConfig) (*http.Request, error)
|
|
}
|
|
|
|
type StripPrefixConverter struct {
|
|
Prefix string
|
|
}
|
|
|
|
func (c *StripPrefixConverter) Name() string {
|
|
return "StripPrefix"
|
|
}
|
|
func (c *StripPrefixConverter) Convert(req *http.Request, config *DeploymentConfig) (*http.Request, error) {
|
|
req.Host = config.EndpointUrl.Host
|
|
req.URL.Scheme = config.EndpointUrl.Scheme
|
|
req.URL.Host = config.EndpointUrl.Host
|
|
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", config.DeploymentName), strings.Replace(req.URL.Path, c.Prefix+"/", "/", 1))
|
|
req.URL.RawPath = req.URL.EscapedPath()
|
|
|
|
query := req.URL.Query()
|
|
query.Add("api-version", config.ApiVersion)
|
|
req.URL.RawQuery = query.Encode()
|
|
return req, nil
|
|
}
|
|
func NewStripPrefixConverter(prefix string) *StripPrefixConverter {
|
|
return &StripPrefixConverter{
|
|
Prefix: prefix,
|
|
}
|
|
}
|
|
|
|
type TemplateConverter struct {
|
|
Tpl string
|
|
Tempalte *template.Template
|
|
}
|
|
|
|
func (c *TemplateConverter) Name() string {
|
|
return "Template"
|
|
}
|
|
func (c *TemplateConverter) Convert(req *http.Request, config *DeploymentConfig) (*http.Request, error) {
|
|
data := map[string]interface{}{
|
|
"DeploymentName": config.DeploymentName,
|
|
"ModelName": config.ModelName,
|
|
"Endpoint": config.Endpoint,
|
|
"ApiKey": config.ApiKey,
|
|
"ApiVersion": config.ApiVersion,
|
|
}
|
|
buff := new(bytes.Buffer)
|
|
if err := c.Tempalte.Execute(buff, data); err != nil {
|
|
return req, errors.Wrap(err, "template execute error")
|
|
}
|
|
|
|
req.Host = config.EndpointUrl.Host
|
|
req.URL.Scheme = config.EndpointUrl.Scheme
|
|
req.URL.Host = config.EndpointUrl.Host
|
|
req.URL.Path = buff.String()
|
|
req.URL.RawPath = req.URL.EscapedPath()
|
|
|
|
query := req.URL.Query()
|
|
query.Add("api-version", config.ApiVersion)
|
|
req.URL.RawQuery = query.Encode()
|
|
return req, nil
|
|
}
|
|
func NewTemplateConverter(tpl string) *TemplateConverter {
|
|
_template, err := template.New("template").Parse(tpl)
|
|
if err != nil {
|
|
log.Fatalf("template parse error: %s", err.Error())
|
|
}
|
|
return &TemplateConverter{
|
|
Tpl: tpl,
|
|
Tempalte: _template,
|
|
}
|
|
}
|