Files
ollama-openrouter-proxy/provider.go
2025-08-25 16:03:56 +02:00

205 lines
5.4 KiB
Go

package main
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/sashabaranov/go-openai"
)
type OpenrouterProvider struct {
client *openai.Client
modelNames []string // Shared storage for model names
}
func NewOpenrouterProvider(apiKey string) *OpenrouterProvider {
config := openai.DefaultConfig(apiKey)
config.BaseURL = "https://openrouter.ai/api/v1/"
// Add timeout configuration
if config.HTTPClient == nil {
config.HTTPClient = &http.Client{
Timeout: 30 * time.Second,
}
}
return &OpenrouterProvider{
client: openai.NewClientWithConfig(config),
modelNames: []string{},
}
}
func (o *OpenrouterProvider) Chat(messages []openai.ChatCompletionMessage, modelName string) (openai.ChatCompletionResponse, error) {
// Validate inputs
if modelName == "" {
return openai.ChatCompletionResponse{}, fmt.Errorf("model name cannot be empty")
}
if len(messages) == 0 {
return openai.ChatCompletionResponse{}, fmt.Errorf("messages cannot be empty")
}
// Create a chat completion request with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req := openai.ChatCompletionRequest{
Model: modelName,
Messages: messages,
Stream: false,
}
// Call the OpenAI API to get a complete response
resp, err := o.client.CreateChatCompletion(ctx, req)
if err != nil {
return openai.ChatCompletionResponse{}, fmt.Errorf("chat completion failed: %w", err)
}
// Return the complete response
return resp, nil
}
func (o *OpenrouterProvider) ChatStream(messages []openai.ChatCompletionMessage, modelName string) (*openai.ChatCompletionStream, error) {
// Validate inputs
if modelName == "" {
return nil, fmt.Errorf("model name cannot be empty")
}
if len(messages) == 0 {
return nil, fmt.Errorf("messages cannot be empty")
}
// Create a chat completion request with timeout
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
// Note: Don't defer cancel here as the stream needs the context to remain valid
req := openai.ChatCompletionRequest{
Model: modelName,
Messages: messages,
Stream: true,
}
// Call the OpenAI API to get a streaming response
stream, err := o.client.CreateChatCompletionStream(ctx, req)
if err != nil {
cancel()
return nil, fmt.Errorf("stream creation failed: %w", err)
}
// Return the stream for further processing
return stream, nil
}
type ModelDetails struct {
ParentModel string `json:"parent_model"`
Format string `json:"format"`
Family string `json:"family"`
Families []string `json:"families"`
ParameterSize string `json:"parameter_size"`
QuantizationLevel string `json:"quantization_level"`
}
type Model struct {
Name string `json:"name"`
Model string `json:"model,omitempty"`
ModifiedAt string `json:"modified_at,omitempty"`
Size int64 `json:"size,omitempty"`
Digest string `json:"digest,omitempty"`
Details ModelDetails `json:"details,omitempty"`
}
func (o *OpenrouterProvider) GetModels() ([]Model, error) {
currentTime := time.Now().Format(time.RFC3339)
// Fetch models from the OpenAI API with timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
modelsResponse, err := o.client.ListModels(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list models: %w", err)
}
// Clear shared model storage
o.modelNames = []string{}
var models []Model
for _, apiModel := range modelsResponse.Models {
// Split model name
parts := strings.Split(apiModel.ID, "/")
name := parts[len(parts)-1]
// Store name in shared storage
o.modelNames = append(o.modelNames, apiModel.ID)
// Create model struct
model := Model{
Name: name,
Model: name,
ModifiedAt: currentTime,
Size: 0, // Stubbed size
Digest: name,
Details: ModelDetails{
ParentModel: "",
Format: "gguf",
Family: "claude",
Families: []string{"claude"},
ParameterSize: "175B",
QuantizationLevel: "Q4_K_M",
},
}
models = append(models, model)
}
return models, nil
}
func (o *OpenrouterProvider) GetModelDetails(modelName string) (map[string]interface{}, error) {
// Stub response; replace with actual model details if available
currentTime := time.Now().Format(time.RFC3339)
return map[string]interface{}{
"license": "STUB License",
"system": "STUB SYSTEM",
"modifiedAt": currentTime,
"details": map[string]interface{}{
"format": "gguf",
"parameter_size": "200B",
"quantization_level": "Q4_K_M",
},
"model_info": map[string]interface{}{
"architecture": "STUB",
"context_length": 200000,
"parameter_count": 200_000_000_000,
},
}, nil
}
func (o *OpenrouterProvider) GetFullModelName(alias string) (string, error) {
// If modelNames is empty or not populated yet, try to get models first
if len(o.modelNames) == 0 {
_, err := o.GetModels()
if err != nil {
return "", fmt.Errorf("failed to get models: %w", err)
}
}
// First try exact match
for _, fullName := range o.modelNames {
if fullName == alias {
return fullName, nil
}
}
// Then try suffix match
for _, fullName := range o.modelNames {
if strings.HasSuffix(fullName, alias) {
return fullName, nil
}
}
// If no match found, just use the alias as is
// This allows direct use of model names that might not be in the list
return alias, nil
}