diff --git a/packages/tui/internal/app/app.go b/packages/tui/internal/app/app.go index 7ef31fd5..a0e68b53 100644 --- a/packages/tui/internal/app/app.go +++ b/packages/tui/internal/app/app.go @@ -270,6 +270,50 @@ func (a *App) SwitchModeReverse() (*App, tea.Cmd) { return a.cycleMode(false) } +// findModelByFullID finds a model by its full ID in the format "provider/model" +func findModelByFullID(providers []opencode.Provider, fullModelID string) (*opencode.Provider, *opencode.Model) { + modelParts := strings.SplitN(fullModelID, "/", 2) + if len(modelParts) < 2 { + return nil, nil + } + + providerID := modelParts[0] + modelID := modelParts[1] + + return findModelByProviderAndModelID(providers, providerID, modelID) +} + +// findModelByProviderAndModelID finds a model by provider ID and model ID +func findModelByProviderAndModelID(providers []opencode.Provider, providerID, modelID string) (*opencode.Provider, *opencode.Model) { + for _, provider := range providers { + if provider.ID != providerID { + continue + } + + for _, model := range provider.Models { + if model.ID == modelID { + return &provider, &model + } + } + + // Provider found but model not found + return nil, nil + } + + // Provider not found + return nil, nil +} + +// findProviderByID finds a provider by its ID +func findProviderByID(providers []opencode.Provider, providerID string) *opencode.Provider { + for _, provider := range providers { + if provider.ID == providerID { + return &provider + } + } + return nil +} + func (a *App) InitializeProvider() tea.Cmd { providersResponse, err := a.Client.App.Providers(context.Background()) if err != nil { @@ -278,29 +322,6 @@ func (a *App) InitializeProvider() tea.Cmd { return nil } providers := providersResponse.Providers - var defaultProvider *opencode.Provider - var defaultModel *opencode.Model - - var anthropic *opencode.Provider - for _, provider := range providers { - if provider.ID == "anthropic" { - anthropic = &provider - } - } - - // default to anthropic if available - if anthropic != nil { - defaultProvider = anthropic - defaultModel = getDefaultModel(providersResponse, *anthropic) - } - - for _, provider := range providers { - if defaultProvider == nil || defaultModel == nil { - defaultProvider = &provider - defaultModel = getDefaultModel(providersResponse, provider) - } - providers = append(providers, provider) - } if len(providers) == 0 { slog.Error("No providers configured") return nil @@ -314,50 +335,86 @@ func (a *App) InitializeProvider() tea.Cmd { a.State.Model = model.ModelID } - var currentProvider *opencode.Provider - var currentModel *opencode.Model - for _, provider := range providers { - if provider.ID == a.State.Provider { - currentProvider = &provider + var selectedProvider *opencode.Provider + var selectedModel *opencode.Model - for _, model := range provider.Models { - if model.ID == a.State.Model { - currentModel = &model - } - } - } - } - if currentProvider == nil || currentModel == nil { - currentProvider = defaultProvider - currentModel = defaultModel - } - - var initialProvider *opencode.Provider - var initialModel *opencode.Model + // Priority 1: Command line --model flag (InitialModel) if a.InitialModel != nil && *a.InitialModel != "" { - splits := strings.Split(*a.InitialModel, "/") - for _, provider := range providers { - if provider.ID == splits[0] { - initialProvider = &provider - for _, model := range provider.Models { - modelID := strings.Join(splits[1:], "/") - if model.ID == modelID { - initialModel = &model - } - } + if provider, model := findModelByFullID(providers, *a.InitialModel); provider != nil && model != nil { + selectedProvider = provider + selectedModel = model + slog.Debug("Selected model from command line", "provider", provider.ID, "model", model.ID) + } else { + slog.Debug("Command line model not found", "model", *a.InitialModel) + } + } + + // Priority 2: Config file model setting + if selectedProvider == nil && a.Config.Model != "" { + if provider, model := findModelByFullID(providers, a.Config.Model); provider != nil && model != nil { + selectedProvider = provider + selectedModel = model + slog.Debug("Selected model from config", "provider", provider.ID, "model", model.ID) + } else { + slog.Debug("Config model not found", "model", a.Config.Model) + } + } + + // Priority 3: Recent model usage (most recently used model) + if selectedProvider == nil && len(a.State.RecentlyUsedModels) > 0 { + recentUsage := a.State.RecentlyUsedModels[0] // Most recent is first + if provider, model := findModelByProviderAndModelID(providers, recentUsage.ProviderID, recentUsage.ModelID); provider != nil && model != nil { + selectedProvider = provider + selectedModel = model + slog.Debug("Selected model from recent usage", "provider", provider.ID, "model", model.ID) + } else { + slog.Debug("Recent model not found", "provider", recentUsage.ProviderID, "model", recentUsage.ModelID) + } + } + + // Priority 4: State-based model (backwards compatibility) + if selectedProvider == nil && a.State.Provider != "" && a.State.Model != "" { + if provider, model := findModelByProviderAndModelID(providers, a.State.Provider, a.State.Model); provider != nil && model != nil { + selectedProvider = provider + selectedModel = model + slog.Debug("Selected model from state", "provider", provider.ID, "model", model.ID) + } else { + slog.Debug("State model not found", "provider", a.State.Provider, "model", a.State.Model) + } + } + + // Priority 5: Internal priority fallback (Anthropic preferred, then first available) + if selectedProvider == nil { + // Try Anthropic first as internal priority + if provider := findProviderByID(providers, "anthropic"); provider != nil { + if model := getDefaultModel(providersResponse, *provider); model != nil { + selectedProvider = provider + selectedModel = model + slog.Debug("Selected model from internal priority (Anthropic)", "provider", provider.ID, "model", model.ID) + } + } + + // If Anthropic not available, use first available provider + if selectedProvider == nil && len(providers) > 0 { + provider := &providers[0] + if model := getDefaultModel(providersResponse, *provider); model != nil { + selectedProvider = provider + selectedModel = model + slog.Debug("Selected model from fallback (first available)", "provider", provider.ID, "model", model.ID) } } } - if initialProvider != nil && initialModel != nil { - currentProvider = initialProvider - currentModel = initialModel + // Final safety check + if selectedProvider == nil || selectedModel == nil { + slog.Error("Failed to select any model") + return nil } var cmds []tea.Cmd cmds = append(cmds, util.CmdHandler(ModelSelectedMsg{ - Provider: *currentProvider, - Model: *currentModel, + Provider: *selectedProvider, + Model: *selectedModel, })) if a.InitialPrompt != nil && *a.InitialPrompt != "" { cmds = append(cmds, util.CmdHandler(SendPrompt{Text: *a.InitialPrompt})) diff --git a/packages/tui/internal/app/app_test.go b/packages/tui/internal/app/app_test.go new file mode 100644 index 00000000..9260a991 --- /dev/null +++ b/packages/tui/internal/app/app_test.go @@ -0,0 +1,228 @@ +package app + +import ( + "testing" + + "github.com/sst/opencode-sdk-go" +) + +// TestFindModelByFullID tests the findModelByFullID function +func TestFindModelByFullID(t *testing.T) { + // Create test providers with models + providers := []opencode.Provider{ + { + ID: "anthropic", + Models: map[string]opencode.Model{ + "claude-3-opus-20240229": {ID: "claude-3-opus-20240229"}, + "claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"}, + }, + }, + { + ID: "openai", + Models: map[string]opencode.Model{ + "gpt-4": {ID: "gpt-4"}, + "gpt-3.5-turbo": {ID: "gpt-3.5-turbo"}, + }, + }, + } + + tests := []struct { + name string + fullModelID string + expectedFound bool + expectedProviderID string + expectedModelID string + }{ + { + name: "valid full model ID", + fullModelID: "anthropic/claude-3-opus-20240229", + expectedFound: true, + expectedProviderID: "anthropic", + expectedModelID: "claude-3-opus-20240229", + }, + { + name: "valid full model ID with slash in model name", + fullModelID: "openai/gpt-3.5-turbo", + expectedFound: true, + expectedProviderID: "openai", + expectedModelID: "gpt-3.5-turbo", + }, + { + name: "invalid format - missing slash", + fullModelID: "anthropic", + expectedFound: false, + }, + { + name: "invalid format - empty string", + fullModelID: "", + expectedFound: false, + }, + { + name: "provider not found", + fullModelID: "nonexistent/model", + expectedFound: false, + }, + { + name: "model not found", + fullModelID: "anthropic/nonexistent-model", + expectedFound: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, model := findModelByFullID(providers, tt.fullModelID) + + if tt.expectedFound { + if provider == nil || model == nil { + t.Errorf("Expected to find provider/model, but got nil") + return + } + + if provider.ID != tt.expectedProviderID { + t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID) + } + + if model.ID != tt.expectedModelID { + t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID) + } + } else { + if provider != nil || model != nil { + t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model) + } + } + }) + } +} + +// TestFindModelByProviderAndModelID tests the findModelByProviderAndModelID function +func TestFindModelByProviderAndModelID(t *testing.T) { + // Create test providers with models + providers := []opencode.Provider{ + { + ID: "anthropic", + Models: map[string]opencode.Model{ + "claude-3-opus-20240229": {ID: "claude-3-opus-20240229"}, + "claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"}, + }, + }, + { + ID: "openai", + Models: map[string]opencode.Model{ + "gpt-4": {ID: "gpt-4"}, + "gpt-3.5-turbo": {ID: "gpt-3.5-turbo"}, + }, + }, + } + + tests := []struct { + name string + providerID string + modelID string + expectedFound bool + expectedProviderID string + expectedModelID string + }{ + { + name: "valid provider and model", + providerID: "anthropic", + modelID: "claude-3-opus-20240229", + expectedFound: true, + expectedProviderID: "anthropic", + expectedModelID: "claude-3-opus-20240229", + }, + { + name: "provider not found", + providerID: "nonexistent", + modelID: "claude-3-opus-20240229", + expectedFound: false, + }, + { + name: "model not found", + providerID: "anthropic", + modelID: "nonexistent-model", + expectedFound: false, + }, + { + name: "both provider and model not found", + providerID: "nonexistent", + modelID: "nonexistent-model", + expectedFound: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, model := findModelByProviderAndModelID(providers, tt.providerID, tt.modelID) + + if tt.expectedFound { + if provider == nil || model == nil { + t.Errorf("Expected to find provider/model, but got nil") + return + } + + if provider.ID != tt.expectedProviderID { + t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID) + } + + if model.ID != tt.expectedModelID { + t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID) + } + } else { + if provider != nil || model != nil { + t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model) + } + } + }) + } +} + +// TestFindProviderByID tests the findProviderByID function +func TestFindProviderByID(t *testing.T) { + // Create test providers + providers := []opencode.Provider{ + {ID: "anthropic"}, + {ID: "openai"}, + {ID: "google"}, + } + + tests := []struct { + name string + providerID string + expectedFound bool + expectedProviderID string + }{ + { + name: "provider found", + providerID: "anthropic", + expectedFound: true, + expectedProviderID: "anthropic", + }, + { + name: "provider not found", + providerID: "nonexistent", + expectedFound: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider := findProviderByID(providers, tt.providerID) + + if tt.expectedFound { + if provider == nil { + t.Errorf("Expected to find provider, but got nil") + return + } + + if provider.ID != tt.expectedProviderID { + t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID) + } + } else { + if provider != nil { + t.Errorf("Expected not to find provider, but got %v", provider) + } + } + }) + } +} diff --git a/packages/web/src/content/docs/docs/models.mdx b/packages/web/src/content/docs/docs/models.mdx index 591625f8..5308921a 100644 --- a/packages/web/src/content/docs/docs/models.mdx +++ b/packages/web/src/content/docs/docs/models.mdx @@ -66,9 +66,11 @@ If you've configured a [custom provider](/docs/providers#custom), the `provider_ ## Loading models -When opencode starts up, it checks for the following: +When opencode starts up, it checks for models in the following priority order: -1. The model list in the opencode config. +1. The `--model` or `-m` command line flag. The format is the same as in the config file: `provider_id/model_id`. + +2. The model list in the opencode config. ```json title="opencode.json" { @@ -79,6 +81,6 @@ When opencode starts up, it checks for the following: The format here is `provider/model`. -2. The last used model. +3. The last used model. -3. The first model using an internal priority. +4. The first model using an internal priority.