diff --git a/failure_store.go b/failure_store.go index 8e51743..30e881d 100644 --- a/failure_store.go +++ b/failure_store.go @@ -38,8 +38,21 @@ func (s *FailureStore) ShouldSkip(model string) (bool, error) { if err != nil { return false, err } - if time.Since(time.Unix(ts, 0)) < 15*time.Minute { + // Reduced cooldown from 15 minutes to 5 minutes for faster recovery + if time.Since(time.Unix(ts, 0)) < 5*time.Minute { return true, nil } return false, nil } + +// ClearFailure removes a model from the failure store (for successful requests) +func (s *FailureStore) ClearFailure(model string) error { + _, err := s.db.Exec(`DELETE FROM failures WHERE model=?`, model) + return err +} + +// ResetAllFailures clears all failure records (useful for testing or manual reset) +func (s *FailureStore) ResetAllFailures() error { + _, err := s.db.Exec(`DELETE FROM failures`) + return err +} diff --git a/free_models.go b/free_models.go index 3804fd4..b7c14b4 100644 --- a/free_models.go +++ b/free_models.go @@ -7,6 +7,7 @@ import ( "os" "sort" "strings" + "time" ) type orModels struct { @@ -64,24 +65,48 @@ func fetchFreeModels(apiKey string) ([]string, error) { } func ensureFreeModelFile(apiKey, path string) ([]string, error) { - if _, err := os.Stat(path); err == nil { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - var models []string - for _, line := range strings.Split(string(data), "\n") { - line = strings.TrimSpace(line) - if line != "" { - models = append(models, line) + const cacheMaxAge = 24 * time.Hour // Refresh cache daily + + if stat, err := os.Stat(path); err == nil { + // Check if cache is still fresh + if time.Since(stat.ModTime()) < cacheMaxAge { + data, err := os.ReadFile(path) + if err != nil { + return nil, err } + var models []string + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if line != "" { + models = append(models, line) + } + } + return models, nil } - return models, nil + // Cache is stale, will fetch fresh models below } + + // Fetch fresh models from API models, err := fetchFreeModels(apiKey) if err != nil { + // If fetch fails but we have a cached file (even if stale), use it + if _, statErr := os.Stat(path); statErr == nil { + data, readErr := os.ReadFile(path) + if readErr == nil { + var cachedModels []string + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if line != "" { + cachedModels = append(cachedModels, line) + } + } + return cachedModels, nil + } + } return nil, err } + + // Save fresh models to cache _ = os.WriteFile(path, []byte(strings.Join(models, "\n")), 0644) return models, nil } diff --git a/main.go b/main.go index f3b66c2..72086a8 100644 --- a/main.go +++ b/main.go @@ -103,30 +103,68 @@ func main() { }) r.GET("/api/tags", func(c *gin.Context) { - models, err := provider.GetModels() - if err != nil { - slog.Error("Error getting models", "Error", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - filter := modelFilter - // Construct a new array of model objects with extra fields - newModels := make([]map[string]interface{}, 0, len(models)) - for _, m := range models { - // Если фильтр пустой, значит пропускаем проверку и берём все модели - if len(filter) > 0 { - if _, ok := filter[m.Model]; !ok { + var newModels []map[string]interface{} + + if freeMode { + // In free mode, show only available free models + currentTime := time.Now().Format(time.RFC3339) + for _, freeModel := range freeModels { + // Check if model should be skipped due to recent failures + skip, err := failureStore.ShouldSkip(freeModel) + if err != nil { + slog.Error("db error checking model", "model", freeModel, "error", err) continue } + if skip { + continue // Skip recently failed models + } + + // Extract display name from full model name + parts := strings.Split(freeModel, "/") + displayName := parts[len(parts)-1] + + newModels = append(newModels, map[string]interface{}{ + "name": displayName, + "model": displayName, + "modified_at": currentTime, + "size": 270898672, + "digest": "9077fe9d2ae1a4a41a868836b56b8163731a8fe16621397028c2c76f838c6907", + "details": map[string]interface{}{ + "parent_model": "", + "format": "gguf", + "family": "free", + "families": []string{"free"}, + "parameter_size": "varies", + "quantization_level": "Q4_K_M", + }, + }) + } + } else { + // Non-free mode: use original logic + models, err := provider.GetModels() + if err != nil { + slog.Error("Error getting models", "Error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + filter := modelFilter + newModels = make([]map[string]interface{}, 0, len(models)) + for _, m := range models { + // Если фильтр пустой, значит пропускаем проверку и берём все модели + if len(filter) > 0 { + if _, ok := filter[m.Model]; !ok { + continue + } + } + newModels = append(newModels, map[string]interface{}{ + "name": m.Name, + "model": m.Model, + "modified_at": m.ModifiedAt, + "size": 270898672, + "digest": "9077fe9d2ae1a4a41a868836b56b8163731a8fe16621397028c2c76f838c6907", + "details": m.Details, + }) } - newModels = append(newModels, map[string]interface{}{ - "name": m.Name, - "model": m.Model, - "modified_at": m.ModifiedAt, - "size": 270898672, - "digest": "9077fe9d2ae1a4a41a868836b56b8163731a8fe16621397028c2c76f838c6907", - "details": m.Details, - }) } c.JSON(http.StatusOK, gin.H{"models": newModels}) @@ -183,7 +221,7 @@ func main() { var fullModelName string var err error if freeMode { - response, fullModelName, err = getFreeChat(provider, request.Messages) + response, fullModelName, err = getFreeChatForModel(provider, request.Messages, request.Model) if err != nil { slog.Error("free mode failed", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -251,7 +289,7 @@ func main() { var fullModelName string var err error if freeMode { - stream, fullModelName, err = getFreeStream(provider, request.Messages) + stream, fullModelName, err = getFreeStreamForModel(provider, request.Messages, request.Model) if err != nil { slog.Error("free mode failed", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -388,6 +426,196 @@ func main() { // --- Конец исправлений --- }) + // Add OpenAI-compatible endpoint for tools like Goose + r.POST("/v1/chat/completions", func(c *gin.Context) { + var request openai.ChatCompletionRequest + if err := c.ShouldBindJSON(&request); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON payload"}) + return + } + + slog.Info("OpenAI API request", "model", request.Model, "stream", request.Stream) + + if request.Stream { + // Handle streaming request + var stream *openai.ChatCompletionStream + var fullModelName string + var err error + + if freeMode { + stream, fullModelName, err = getFreeStreamForModel(provider, request.Messages, request.Model) + if err != nil { + slog.Error("free mode streaming failed", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}}) + return + } + } else { + fullModelName, err = provider.GetFullModelName(request.Model) + if err != nil { + slog.Error("Error getting full model name", "Error", err, "model", request.Model) + c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": err.Error()}}) + return + } + stream, err = provider.ChatStream(request.Messages, fullModelName) + if err != nil { + slog.Error("Failed to create stream", "Error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}}) + return + } + } + defer stream.Close() + + // Set headers for Server-Sent Events (OpenAI format) + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + slog.Error("Expected http.ResponseWriter to be an http.Flusher") + return + } + + // Stream responses in OpenAI format + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + // Send final [DONE] message + fmt.Fprintf(w, "data: [DONE]\n\n") + flusher.Flush() + break + } + if err != nil { + slog.Error("Stream error", "Error", err) + break + } + + // Convert to OpenAI response format + openaiResponse := openai.ChatCompletionStreamResponse{ + ID: "chatcmpl-" + fmt.Sprintf("%d", time.Now().Unix()), + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Model: fullModelName, + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: response.Choices[0].Delta.Content, + }, + }, + }, + } + + // Add finish reason if present + if len(response.Choices) > 0 && response.Choices[0].FinishReason != "" { + openaiResponse.Choices[0].FinishReason = response.Choices[0].FinishReason + } + + jsonData, err := json.Marshal(openaiResponse) + if err != nil { + slog.Error("Error marshaling response", "Error", err) + break + } + + fmt.Fprintf(w, "data: %s\n\n", string(jsonData)) + flusher.Flush() + } + } else { + // Handle non-streaming request + var response openai.ChatCompletionResponse + var fullModelName string + var err error + + if freeMode { + response, fullModelName, err = getFreeChatForModel(provider, request.Messages, request.Model) + if err != nil { + slog.Error("free mode failed", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}}) + return + } + } else { + fullModelName, err = provider.GetFullModelName(request.Model) + if err != nil { + slog.Error("Error getting full model name", "Error", err) + c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": err.Error()}}) + return + } + response, err = provider.Chat(request.Messages, fullModelName) + if err != nil { + slog.Error("Failed to get chat response", "Error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}}) + return + } + } + + // Return OpenAI-compatible response + response.ID = "chatcmpl-" + fmt.Sprintf("%d", time.Now().Unix()) + response.Object = "chat.completion" + response.Created = time.Now().Unix() + response.Model = fullModelName + + slog.Info("Used model", "model", fullModelName) + c.JSON(http.StatusOK, response) + } + }) + + // Add OpenAI-compatible models endpoint + r.GET("/v1/models", func(c *gin.Context) { + var models []gin.H + + if freeMode { + // In free mode, show only available free models + for _, freeModel := range freeModels { + skip, err := failureStore.ShouldSkip(freeModel) + if err != nil { + slog.Error("db error checking model", "model", freeModel, "error", err) + continue + } + if skip { + continue + } + + parts := strings.Split(freeModel, "/") + displayName := parts[len(parts)-1] + + models = append(models, gin.H{ + "id": displayName, + "object": "model", + "created": time.Now().Unix(), + "owned_by": "openrouter", + }) + } + } else { + // Non-free mode: get all models from provider + providerModels, err := provider.GetModels() + if err != nil { + slog.Error("Error getting models", "Error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}}) + return + } + + for _, m := range providerModels { + if len(modelFilter) > 0 { + if _, ok := modelFilter[m.Model]; !ok { + continue + } + } + models = append(models, gin.H{ + "id": m.Model, + "object": "model", + "created": time.Now().Unix(), + "owned_by": "openrouter", + }) + } + } + + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": models, + }) + }) + r.Run(":11434") } @@ -408,6 +636,8 @@ func getFreeChat(provider *OpenrouterProvider, msgs []openai.ChatCompletionMessa _ = failureStore.MarkFailure(m) continue } + // Clear failure record on successful request + _ = failureStore.ClearFailure(m) return resp, m, nil } return resp, "", fmt.Errorf("no free models available") @@ -429,7 +659,75 @@ func getFreeStream(provider *OpenrouterProvider, msgs []openai.ChatCompletionMes _ = failureStore.MarkFailure(m) continue } + // Clear failure record on successful request + _ = failureStore.ClearFailure(m) return stream, m, nil } return nil, "", fmt.Errorf("no free models available") } + +// resolveDisplayNameToFullModel resolves a display name back to the full model name +func resolveDisplayNameToFullModel(displayName string) string { + for _, fullModel := range freeModels { + parts := strings.Split(fullModel, "/") + modelDisplayName := parts[len(parts)-1] + if modelDisplayName == displayName { + return fullModel + } + } + return displayName // fallback to original name if not found +} + +// getFreeChatForModel tries to use a specific model first, then falls back to any available free model +func getFreeChatForModel(provider *OpenrouterProvider, msgs []openai.ChatCompletionMessage, requestedModel string) (openai.ChatCompletionResponse, string, error) { + var resp openai.ChatCompletionResponse + + // First try the requested model if it's in our free models list + fullModelName := resolveDisplayNameToFullModel(requestedModel) + if fullModelName != requestedModel || contains(freeModels, fullModelName) { + skip, err := failureStore.ShouldSkip(fullModelName) + if err == nil && !skip { + resp, err = provider.Chat(msgs, fullModelName) + if err == nil { + _ = failureStore.ClearFailure(fullModelName) + return resp, fullModelName, nil + } + slog.Warn("requested model failed, trying fallback", "model", fullModelName, "error", err) + _ = failureStore.MarkFailure(fullModelName) + } + } + + // Fallback to any available free model + return getFreeChat(provider, msgs) +} + +// getFreeStreamForModel tries to use a specific model first, then falls back to any available free model +func getFreeStreamForModel(provider *OpenrouterProvider, msgs []openai.ChatCompletionMessage, requestedModel string) (*openai.ChatCompletionStream, string, error) { + // First try the requested model if it's in our free models list + fullModelName := resolveDisplayNameToFullModel(requestedModel) + if fullModelName != requestedModel || contains(freeModels, fullModelName) { + skip, err := failureStore.ShouldSkip(fullModelName) + if err == nil && !skip { + stream, err := provider.ChatStream(msgs, fullModelName) + if err == nil { + _ = failureStore.ClearFailure(fullModelName) + return stream, fullModelName, nil + } + slog.Warn("requested model failed, trying fallback", "model", fullModelName, "error", err) + _ = failureStore.MarkFailure(fullModelName) + } + } + + // Fallback to any available free model + return getFreeStream(provider, msgs) +} + +// contains checks if a slice contains a string +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +}