Choose model according to the docs (#1536)

This commit is contained in:
Dominik Engelhardt
2025-08-02 16:29:03 +02:00
committed by GitHub
parent 8ad83f71a9
commit 42a5fcead4
3 changed files with 348 additions and 61 deletions

View File

@@ -270,6 +270,50 @@ func (a *App) SwitchModeReverse() (*App, tea.Cmd) {
return a.cycleMode(false) return a.cycleMode(false)
} }
// findModelByFullID finds a model by its full ID in the format "provider/model"
func findModelByFullID(providers []opencode.Provider, fullModelID string) (*opencode.Provider, *opencode.Model) {
modelParts := strings.SplitN(fullModelID, "/", 2)
if len(modelParts) < 2 {
return nil, nil
}
providerID := modelParts[0]
modelID := modelParts[1]
return findModelByProviderAndModelID(providers, providerID, modelID)
}
// findModelByProviderAndModelID finds a model by provider ID and model ID
func findModelByProviderAndModelID(providers []opencode.Provider, providerID, modelID string) (*opencode.Provider, *opencode.Model) {
for _, provider := range providers {
if provider.ID != providerID {
continue
}
for _, model := range provider.Models {
if model.ID == modelID {
return &provider, &model
}
}
// Provider found but model not found
return nil, nil
}
// Provider not found
return nil, nil
}
// findProviderByID finds a provider by its ID
func findProviderByID(providers []opencode.Provider, providerID string) *opencode.Provider {
for _, provider := range providers {
if provider.ID == providerID {
return &provider
}
}
return nil
}
func (a *App) InitializeProvider() tea.Cmd { func (a *App) InitializeProvider() tea.Cmd {
providersResponse, err := a.Client.App.Providers(context.Background()) providersResponse, err := a.Client.App.Providers(context.Background())
if err != nil { if err != nil {
@@ -278,29 +322,6 @@ func (a *App) InitializeProvider() tea.Cmd {
return nil return nil
} }
providers := providersResponse.Providers providers := providersResponse.Providers
var defaultProvider *opencode.Provider
var defaultModel *opencode.Model
var anthropic *opencode.Provider
for _, provider := range providers {
if provider.ID == "anthropic" {
anthropic = &provider
}
}
// default to anthropic if available
if anthropic != nil {
defaultProvider = anthropic
defaultModel = getDefaultModel(providersResponse, *anthropic)
}
for _, provider := range providers {
if defaultProvider == nil || defaultModel == nil {
defaultProvider = &provider
defaultModel = getDefaultModel(providersResponse, provider)
}
providers = append(providers, provider)
}
if len(providers) == 0 { if len(providers) == 0 {
slog.Error("No providers configured") slog.Error("No providers configured")
return nil return nil
@@ -314,50 +335,86 @@ func (a *App) InitializeProvider() tea.Cmd {
a.State.Model = model.ModelID a.State.Model = model.ModelID
} }
var currentProvider *opencode.Provider var selectedProvider *opencode.Provider
var currentModel *opencode.Model var selectedModel *opencode.Model
for _, provider := range providers {
if provider.ID == a.State.Provider {
currentProvider = &provider
for _, model := range provider.Models { // Priority 1: Command line --model flag (InitialModel)
if model.ID == a.State.Model {
currentModel = &model
}
}
}
}
if currentProvider == nil || currentModel == nil {
currentProvider = defaultProvider
currentModel = defaultModel
}
var initialProvider *opencode.Provider
var initialModel *opencode.Model
if a.InitialModel != nil && *a.InitialModel != "" { if a.InitialModel != nil && *a.InitialModel != "" {
splits := strings.Split(*a.InitialModel, "/") if provider, model := findModelByFullID(providers, *a.InitialModel); provider != nil && model != nil {
for _, provider := range providers { selectedProvider = provider
if provider.ID == splits[0] { selectedModel = model
initialProvider = &provider slog.Debug("Selected model from command line", "provider", provider.ID, "model", model.ID)
for _, model := range provider.Models { } else {
modelID := strings.Join(splits[1:], "/") slog.Debug("Command line model not found", "model", *a.InitialModel)
if model.ID == modelID { }
initialModel = &model }
}
} // Priority 2: Config file model setting
if selectedProvider == nil && a.Config.Model != "" {
if provider, model := findModelByFullID(providers, a.Config.Model); provider != nil && model != nil {
selectedProvider = provider
selectedModel = model
slog.Debug("Selected model from config", "provider", provider.ID, "model", model.ID)
} else {
slog.Debug("Config model not found", "model", a.Config.Model)
}
}
// Priority 3: Recent model usage (most recently used model)
if selectedProvider == nil && len(a.State.RecentlyUsedModels) > 0 {
recentUsage := a.State.RecentlyUsedModels[0] // Most recent is first
if provider, model := findModelByProviderAndModelID(providers, recentUsage.ProviderID, recentUsage.ModelID); provider != nil && model != nil {
selectedProvider = provider
selectedModel = model
slog.Debug("Selected model from recent usage", "provider", provider.ID, "model", model.ID)
} else {
slog.Debug("Recent model not found", "provider", recentUsage.ProviderID, "model", recentUsage.ModelID)
}
}
// Priority 4: State-based model (backwards compatibility)
if selectedProvider == nil && a.State.Provider != "" && a.State.Model != "" {
if provider, model := findModelByProviderAndModelID(providers, a.State.Provider, a.State.Model); provider != nil && model != nil {
selectedProvider = provider
selectedModel = model
slog.Debug("Selected model from state", "provider", provider.ID, "model", model.ID)
} else {
slog.Debug("State model not found", "provider", a.State.Provider, "model", a.State.Model)
}
}
// Priority 5: Internal priority fallback (Anthropic preferred, then first available)
if selectedProvider == nil {
// Try Anthropic first as internal priority
if provider := findProviderByID(providers, "anthropic"); provider != nil {
if model := getDefaultModel(providersResponse, *provider); model != nil {
selectedProvider = provider
selectedModel = model
slog.Debug("Selected model from internal priority (Anthropic)", "provider", provider.ID, "model", model.ID)
}
}
// If Anthropic not available, use first available provider
if selectedProvider == nil && len(providers) > 0 {
provider := &providers[0]
if model := getDefaultModel(providersResponse, *provider); model != nil {
selectedProvider = provider
selectedModel = model
slog.Debug("Selected model from fallback (first available)", "provider", provider.ID, "model", model.ID)
} }
} }
} }
if initialProvider != nil && initialModel != nil { // Final safety check
currentProvider = initialProvider if selectedProvider == nil || selectedModel == nil {
currentModel = initialModel slog.Error("Failed to select any model")
return nil
} }
var cmds []tea.Cmd var cmds []tea.Cmd
cmds = append(cmds, util.CmdHandler(ModelSelectedMsg{ cmds = append(cmds, util.CmdHandler(ModelSelectedMsg{
Provider: *currentProvider, Provider: *selectedProvider,
Model: *currentModel, Model: *selectedModel,
})) }))
if a.InitialPrompt != nil && *a.InitialPrompt != "" { if a.InitialPrompt != nil && *a.InitialPrompt != "" {
cmds = append(cmds, util.CmdHandler(SendPrompt{Text: *a.InitialPrompt})) cmds = append(cmds, util.CmdHandler(SendPrompt{Text: *a.InitialPrompt}))

View File

@@ -0,0 +1,228 @@
package app
import (
"testing"
"github.com/sst/opencode-sdk-go"
)
// TestFindModelByFullID tests the findModelByFullID function
func TestFindModelByFullID(t *testing.T) {
// Create test providers with models
providers := []opencode.Provider{
{
ID: "anthropic",
Models: map[string]opencode.Model{
"claude-3-opus-20240229": {ID: "claude-3-opus-20240229"},
"claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"},
},
},
{
ID: "openai",
Models: map[string]opencode.Model{
"gpt-4": {ID: "gpt-4"},
"gpt-3.5-turbo": {ID: "gpt-3.5-turbo"},
},
},
}
tests := []struct {
name string
fullModelID string
expectedFound bool
expectedProviderID string
expectedModelID string
}{
{
name: "valid full model ID",
fullModelID: "anthropic/claude-3-opus-20240229",
expectedFound: true,
expectedProviderID: "anthropic",
expectedModelID: "claude-3-opus-20240229",
},
{
name: "valid full model ID with slash in model name",
fullModelID: "openai/gpt-3.5-turbo",
expectedFound: true,
expectedProviderID: "openai",
expectedModelID: "gpt-3.5-turbo",
},
{
name: "invalid format - missing slash",
fullModelID: "anthropic",
expectedFound: false,
},
{
name: "invalid format - empty string",
fullModelID: "",
expectedFound: false,
},
{
name: "provider not found",
fullModelID: "nonexistent/model",
expectedFound: false,
},
{
name: "model not found",
fullModelID: "anthropic/nonexistent-model",
expectedFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, model := findModelByFullID(providers, tt.fullModelID)
if tt.expectedFound {
if provider == nil || model == nil {
t.Errorf("Expected to find provider/model, but got nil")
return
}
if provider.ID != tt.expectedProviderID {
t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
}
if model.ID != tt.expectedModelID {
t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID)
}
} else {
if provider != nil || model != nil {
t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model)
}
}
})
}
}
// TestFindModelByProviderAndModelID tests the findModelByProviderAndModelID function
func TestFindModelByProviderAndModelID(t *testing.T) {
// Create test providers with models
providers := []opencode.Provider{
{
ID: "anthropic",
Models: map[string]opencode.Model{
"claude-3-opus-20240229": {ID: "claude-3-opus-20240229"},
"claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"},
},
},
{
ID: "openai",
Models: map[string]opencode.Model{
"gpt-4": {ID: "gpt-4"},
"gpt-3.5-turbo": {ID: "gpt-3.5-turbo"},
},
},
}
tests := []struct {
name string
providerID string
modelID string
expectedFound bool
expectedProviderID string
expectedModelID string
}{
{
name: "valid provider and model",
providerID: "anthropic",
modelID: "claude-3-opus-20240229",
expectedFound: true,
expectedProviderID: "anthropic",
expectedModelID: "claude-3-opus-20240229",
},
{
name: "provider not found",
providerID: "nonexistent",
modelID: "claude-3-opus-20240229",
expectedFound: false,
},
{
name: "model not found",
providerID: "anthropic",
modelID: "nonexistent-model",
expectedFound: false,
},
{
name: "both provider and model not found",
providerID: "nonexistent",
modelID: "nonexistent-model",
expectedFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, model := findModelByProviderAndModelID(providers, tt.providerID, tt.modelID)
if tt.expectedFound {
if provider == nil || model == nil {
t.Errorf("Expected to find provider/model, but got nil")
return
}
if provider.ID != tt.expectedProviderID {
t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
}
if model.ID != tt.expectedModelID {
t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID)
}
} else {
if provider != nil || model != nil {
t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model)
}
}
})
}
}
// TestFindProviderByID tests the findProviderByID function
func TestFindProviderByID(t *testing.T) {
// Create test providers
providers := []opencode.Provider{
{ID: "anthropic"},
{ID: "openai"},
{ID: "google"},
}
tests := []struct {
name string
providerID string
expectedFound bool
expectedProviderID string
}{
{
name: "provider found",
providerID: "anthropic",
expectedFound: true,
expectedProviderID: "anthropic",
},
{
name: "provider not found",
providerID: "nonexistent",
expectedFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider := findProviderByID(providers, tt.providerID)
if tt.expectedFound {
if provider == nil {
t.Errorf("Expected to find provider, but got nil")
return
}
if provider.ID != tt.expectedProviderID {
t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
}
} else {
if provider != nil {
t.Errorf("Expected not to find provider, but got %v", provider)
}
}
})
}
}

View File

@@ -66,9 +66,11 @@ If you've configured a [custom provider](/docs/providers#custom), the `provider_
## Loading models ## Loading models
When opencode starts up, it checks for the following: When opencode starts up, it checks for models in the following priority order:
1. The model list in the opencode config. 1. The `--model` or `-m` command line flag. The format is the same as in the config file: `provider_id/model_id`.
2. The model list in the opencode config.
```json title="opencode.json" ```json title="opencode.json"
{ {
@@ -79,6 +81,6 @@ When opencode starts up, it checks for the following:
The format here is `provider/model`. The format here is `provider/model`.
2. The last used model. 3. The last used model.
3. The first model using an internal priority. 4. The first model using an internal priority.