Files
ollama-free-model-proxy/main.go

436 lines
14 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
openai "github.com/sashabaranov/go-openai"
)
var modelFilter map[string]struct{}
var freeModels []string
var failureStore *FailureStore
var freeMode bool
func loadModelFilter(path string) (map[string]struct{}, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
filter := make(map[string]struct{})
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line != "" {
filter[line] = struct{}{}
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
return filter, nil
}
func main() {
r := gin.Default()
// Load the API key from environment variables or command-line arguments.
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
if len(os.Args) > 1 {
apiKey = os.Args[1]
} else {
slog.Error("OPENAI_API_KEY environment variable or command-line argument not set.")
return
}
}
freeMode = strings.ToLower(os.Getenv("FREE_MODE")) == "true"
if freeMode {
var err error
freeModels, err = ensureFreeModelFile(apiKey, "free-models")
if err != nil {
slog.Error("failed to load free models", "error", err)
return
}
failureStore, err = NewFailureStore("failures.db")
if err != nil {
slog.Error("failed to init failure store", "error", err)
return
}
defer failureStore.Close()
slog.Info("Free mode enabled", "models", len(freeModels))
}
provider := NewOpenrouterProvider(apiKey)
filter, err := loadModelFilter("models-filter")
if err != nil {
if os.IsNotExist(err) {
slog.Info("models-filter file not found. Skipping model filtering.")
modelFilter = make(map[string]struct{})
} else {
slog.Error("Error loading models filter", "Error", err)
return
}
} else {
modelFilter = filter
slog.Info("Loaded models from filter:")
for model := range modelFilter {
slog.Info(" - " + model)
}
}
r.GET("/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
r.HEAD("/", func(c *gin.Context) {
c.String(http.StatusOK, "")
})
r.GET("/api/tags", func(c *gin.Context) {
models, err := provider.GetModels()
if err != nil {
slog.Error("Error getting models", "Error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
filter := modelFilter
// Construct a new array of model objects with extra fields
newModels := make([]map[string]interface{}, 0, len(models))
for _, m := range models {
// Если фильтр пустой, значит пропускаем проверку и берём все модели
if len(filter) > 0 {
if _, ok := filter[m.Model]; !ok {
continue
}
}
newModels = append(newModels, map[string]interface{}{
"name": m.Name,
"model": m.Model,
"modified_at": m.ModifiedAt,
"size": 270898672,
"digest": "9077fe9d2ae1a4a41a868836b56b8163731a8fe16621397028c2c76f838c6907",
"details": m.Details,
})
}
c.JSON(http.StatusOK, gin.H{"models": newModels})
})
r.POST("/api/show", func(c *gin.Context) {
var request map[string]string
if err := c.BindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON payload"})
return
}
modelName := request["name"]
if modelName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Model name is required"})
return
}
details, err := provider.GetModelDetails(modelName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, details)
})
r.POST("/api/chat", func(c *gin.Context) {
var request struct {
Model string `json:"model"`
Messages []openai.ChatCompletionMessage `json:"messages"`
Stream *bool `json:"stream"` // Добавим поле Stream
}
// Parse the JSON request
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON payload"})
return
}
// Определяем, нужен ли стриминг (по умолчанию true, если не указано для /api/chat)
// ВАЖНО: Open WebUI может НЕ передавать "stream": true для /api/chat, подразумевая это.
// Нужно проверить, какой запрос шлет Open WebUI. Если не шлет, ставим true.
streamRequested := true
if request.Stream != nil {
streamRequested = *request.Stream
}
// Если стриминг не запрошен, нужно будет реализовать отдельную логику
// для сбора полного ответа и отправки его одним JSON.
// Пока реализуем только стриминг.
if !streamRequested {
var response openai.ChatCompletionResponse
var fullModelName string
var err error
if freeMode {
response, fullModelName, err = getFreeChat(provider, request.Messages)
if err != nil {
slog.Error("free mode failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else {
fullModelName, err = provider.GetFullModelName(request.Model)
if err != nil {
slog.Error("Error getting full model name", "Error", err)
// Ollama returns 404 for invalid model names
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
response, err = provider.Chat(request.Messages, fullModelName)
if err != nil {
slog.Error("Failed to get chat response", "Error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
// Format the response according to Ollama's format
if len(response.Choices) == 0 {
c.JSON(http.StatusInternalServerError, gin.H{"error": "No response from model"})
return
}
// Extract the content from the response
content := ""
if len(response.Choices) > 0 && response.Choices[0].Message.Content != "" {
content = response.Choices[0].Message.Content
}
// Get finish reason, default to "stop" if not provided
finishReason := "stop"
if response.Choices[0].FinishReason != "" {
finishReason = string(response.Choices[0].FinishReason)
}
// Create Ollama-compatible response
ollamaResponse := map[string]interface{}{
"model": fullModelName,
"created_at": time.Now().Format(time.RFC3339),
"message": map[string]string{
"role": "assistant",
"content": content,
},
"done": true,
"finish_reason": finishReason,
"total_duration": response.Usage.TotalTokens * 10, // Approximate duration based on token count
"load_duration": 0,
"prompt_eval_count": response.Usage.PromptTokens,
"eval_count": response.Usage.CompletionTokens,
"eval_duration": response.Usage.CompletionTokens * 10, // Approximate duration based on token count
}
slog.Info("Used model", "model", fullModelName)
c.JSON(http.StatusOK, ollamaResponse)
return
}
slog.Info("Requested model", "model", request.Model)
var stream *openai.ChatCompletionStream
var fullModelName string
var err error
if freeMode {
stream, fullModelName, err = getFreeStream(provider, request.Messages)
if err != nil {
slog.Error("free mode failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else {
fullModelName, err = provider.GetFullModelName(request.Model)
if err != nil {
slog.Error("Error getting full model name", "Error", err, "model", request.Model)
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
stream, err = provider.ChatStream(request.Messages, fullModelName)
if err != nil {
slog.Error("Failed to create stream", "Error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
slog.Info("Using model", "fullModelName", fullModelName)
// Call ChatStream to get the stream
if err != nil {
slog.Error("Failed to create stream", "Error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer stream.Close() // Ensure stream closure
// --- ИСПРАВЛЕНИЯ для NDJSON (Ollama-style) ---
// Set headers CORRECTLY for Newline Delimited JSON
c.Writer.Header().Set("Content-Type", "application/x-ndjson") // <--- КЛЮЧЕВОЕ ИЗМЕНЕНИЕ
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
// Transfer-Encoding: chunked устанавливается Gin автоматически
w := c.Writer // Получаем ResponseWriter
flusher, ok := w.(http.Flusher)
if !ok {
slog.Error("Expected http.ResponseWriter to be an http.Flusher")
// Отправить ошибку клиенту уже сложно, т.к. заголовки могли уйти
return
}
var lastFinishReason string
// Stream responses back to the client
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
// End of stream from the backend provider
break
}
if err != nil {
slog.Error("Backend stream error", "Error", err)
// Попытка отправить ошибку в формате NDJSON
// Ollama обычно просто обрывает соединение или шлет 500 перед этим
errorMsg := map[string]string{"error": "Stream error: " + err.Error()}
errorJson, _ := json.Marshal(errorMsg)
fmt.Fprintf(w, "%s\n", string(errorJson)) // Отправляем ошибку + \n
flusher.Flush()
return
}
// Сохраняем причину остановки, если она есть в чанке
if len(response.Choices) > 0 && response.Choices[0].FinishReason != "" {
lastFinishReason = string(response.Choices[0].FinishReason)
}
// Build JSON response structure for intermediate chunks (Ollama chat format)
responseJSON := map[string]interface{}{
"model": fullModelName,
"created_at": time.Now().Format(time.RFC3339),
"message": map[string]string{
"role": "assistant",
"content": response.Choices[0].Delta.Content, // Может быть ""
},
"done": false, // Всегда false для промежуточных чанков
}
// Marshal JSON
jsonData, err := json.Marshal(responseJSON)
if err != nil {
slog.Error("Error marshaling intermediate response JSON", "Error", err)
return // Прерываем, так как не можем отправить данные
}
// Send JSON object followed by a newline
fmt.Fprintf(w, "%s\n", string(jsonData)) // <--- ИЗМЕНЕНО: Формат NDJSON (JSON + \n)
// Flush data to send it immediately
flusher.Flush()
}
// --- Отправка финального сообщения (done: true) в стиле Ollama ---
// Определяем причину остановки (если бэкенд не дал, ставим 'stop')
// Ollama использует 'stop', 'length', 'content_filter', 'tool_calls'
if lastFinishReason == "" {
lastFinishReason = "stop"
}
// ВАЖНО: Замените nil на 0 для числовых полей статистики
finalResponse := map[string]interface{}{
"model": fullModelName,
"created_at": time.Now().Format(time.RFC3339),
"message": map[string]string{
"role": "assistant",
"content": "", // Пустой контент для финального сообщения
},
"done": true,
"finish_reason": lastFinishReason, // Необязательно для /api/chat Ollama, но не вредит
"total_duration": 0,
"load_duration": 0,
"prompt_eval_count": 0, // <--- ИЗМЕНЕНО: nil заменен на 0
"eval_count": 0, // <--- ИЗМЕНЕНО: nil заменен на 0
"eval_duration": 0,
}
finalJsonData, err := json.Marshal(finalResponse)
if err != nil {
slog.Error("Error marshaling final response JSON", "Error", err)
return
}
// Отправляем финальный JSON-объект + newline
fmt.Fprintf(w, "%s\n", string(finalJsonData)) // <--- ИЗМЕНЕНО: Формат NDJSON
flusher.Flush()
// ВАЖНО: Для NDJSON НЕТ 'data: [DONE]' маркера.
// Клиент понимает конец потока по получению объекта с "done": true
// и/или по закрытию соединения сервером (что Gin сделает автоматически после выхода из хендлера).
// --- Конец исправлений ---
})
r.Run(":11434")
}
func getFreeChat(provider *OpenrouterProvider, msgs []openai.ChatCompletionMessage) (openai.ChatCompletionResponse, string, error) {
var resp openai.ChatCompletionResponse
for _, m := range freeModels {
skip, err := failureStore.ShouldSkip(m)
if err != nil {
slog.Error("db error", "error", err)
continue
}
if skip {
continue
}
resp, err = provider.Chat(msgs, m)
if err != nil {
slog.Warn("model failed", "model", m, "error", err)
_ = failureStore.MarkFailure(m)
continue
}
return resp, m, nil
}
return resp, "", fmt.Errorf("no free models available")
}
func getFreeStream(provider *OpenrouterProvider, msgs []openai.ChatCompletionMessage) (*openai.ChatCompletionStream, string, error) {
for _, m := range freeModels {
skip, err := failureStore.ShouldSkip(m)
if err != nil {
slog.Error("db error", "error", err)
continue
}
if skip {
continue
}
stream, err := provider.ChatStream(msgs, m)
if err != nil {
slog.Warn("model failed", "model", m, "error", err)
_ = failureStore.MarkFailure(m)
continue
}
return stream, m, nil
}
return nil, "", fmt.Errorf("no free models available")
}