mirror of
https://github.com/aljazceru/opencode.git
synced 2025-12-20 17:24:22 +01:00
Choose model according to the docs (#1536)
This commit is contained in:
committed by
GitHub
parent
8ad83f71a9
commit
42a5fcead4
@@ -270,6 +270,50 @@ func (a *App) SwitchModeReverse() (*App, tea.Cmd) {
|
||||
return a.cycleMode(false)
|
||||
}
|
||||
|
||||
// findModelByFullID finds a model by its full ID in the format "provider/model"
|
||||
func findModelByFullID(providers []opencode.Provider, fullModelID string) (*opencode.Provider, *opencode.Model) {
|
||||
modelParts := strings.SplitN(fullModelID, "/", 2)
|
||||
if len(modelParts) < 2 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
providerID := modelParts[0]
|
||||
modelID := modelParts[1]
|
||||
|
||||
return findModelByProviderAndModelID(providers, providerID, modelID)
|
||||
}
|
||||
|
||||
// findModelByProviderAndModelID finds a model by provider ID and model ID
|
||||
func findModelByProviderAndModelID(providers []opencode.Provider, providerID, modelID string) (*opencode.Provider, *opencode.Model) {
|
||||
for _, provider := range providers {
|
||||
if provider.ID != providerID {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, model := range provider.Models {
|
||||
if model.ID == modelID {
|
||||
return &provider, &model
|
||||
}
|
||||
}
|
||||
|
||||
// Provider found but model not found
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Provider not found
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// findProviderByID finds a provider by its ID
|
||||
func findProviderByID(providers []opencode.Provider, providerID string) *opencode.Provider {
|
||||
for _, provider := range providers {
|
||||
if provider.ID == providerID {
|
||||
return &provider
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) InitializeProvider() tea.Cmd {
|
||||
providersResponse, err := a.Client.App.Providers(context.Background())
|
||||
if err != nil {
|
||||
@@ -278,29 +322,6 @@ func (a *App) InitializeProvider() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
providers := providersResponse.Providers
|
||||
var defaultProvider *opencode.Provider
|
||||
var defaultModel *opencode.Model
|
||||
|
||||
var anthropic *opencode.Provider
|
||||
for _, provider := range providers {
|
||||
if provider.ID == "anthropic" {
|
||||
anthropic = &provider
|
||||
}
|
||||
}
|
||||
|
||||
// default to anthropic if available
|
||||
if anthropic != nil {
|
||||
defaultProvider = anthropic
|
||||
defaultModel = getDefaultModel(providersResponse, *anthropic)
|
||||
}
|
||||
|
||||
for _, provider := range providers {
|
||||
if defaultProvider == nil || defaultModel == nil {
|
||||
defaultProvider = &provider
|
||||
defaultModel = getDefaultModel(providersResponse, provider)
|
||||
}
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
slog.Error("No providers configured")
|
||||
return nil
|
||||
@@ -314,50 +335,86 @@ func (a *App) InitializeProvider() tea.Cmd {
|
||||
a.State.Model = model.ModelID
|
||||
}
|
||||
|
||||
var currentProvider *opencode.Provider
|
||||
var currentModel *opencode.Model
|
||||
for _, provider := range providers {
|
||||
if provider.ID == a.State.Provider {
|
||||
currentProvider = &provider
|
||||
var selectedProvider *opencode.Provider
|
||||
var selectedModel *opencode.Model
|
||||
|
||||
for _, model := range provider.Models {
|
||||
if model.ID == a.State.Model {
|
||||
currentModel = &model
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if currentProvider == nil || currentModel == nil {
|
||||
currentProvider = defaultProvider
|
||||
currentModel = defaultModel
|
||||
}
|
||||
|
||||
var initialProvider *opencode.Provider
|
||||
var initialModel *opencode.Model
|
||||
// Priority 1: Command line --model flag (InitialModel)
|
||||
if a.InitialModel != nil && *a.InitialModel != "" {
|
||||
splits := strings.Split(*a.InitialModel, "/")
|
||||
for _, provider := range providers {
|
||||
if provider.ID == splits[0] {
|
||||
initialProvider = &provider
|
||||
for _, model := range provider.Models {
|
||||
modelID := strings.Join(splits[1:], "/")
|
||||
if model.ID == modelID {
|
||||
initialModel = &model
|
||||
}
|
||||
}
|
||||
if provider, model := findModelByFullID(providers, *a.InitialModel); provider != nil && model != nil {
|
||||
selectedProvider = provider
|
||||
selectedModel = model
|
||||
slog.Debug("Selected model from command line", "provider", provider.ID, "model", model.ID)
|
||||
} else {
|
||||
slog.Debug("Command line model not found", "model", *a.InitialModel)
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: Config file model setting
|
||||
if selectedProvider == nil && a.Config.Model != "" {
|
||||
if provider, model := findModelByFullID(providers, a.Config.Model); provider != nil && model != nil {
|
||||
selectedProvider = provider
|
||||
selectedModel = model
|
||||
slog.Debug("Selected model from config", "provider", provider.ID, "model", model.ID)
|
||||
} else {
|
||||
slog.Debug("Config model not found", "model", a.Config.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: Recent model usage (most recently used model)
|
||||
if selectedProvider == nil && len(a.State.RecentlyUsedModels) > 0 {
|
||||
recentUsage := a.State.RecentlyUsedModels[0] // Most recent is first
|
||||
if provider, model := findModelByProviderAndModelID(providers, recentUsage.ProviderID, recentUsage.ModelID); provider != nil && model != nil {
|
||||
selectedProvider = provider
|
||||
selectedModel = model
|
||||
slog.Debug("Selected model from recent usage", "provider", provider.ID, "model", model.ID)
|
||||
} else {
|
||||
slog.Debug("Recent model not found", "provider", recentUsage.ProviderID, "model", recentUsage.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 4: State-based model (backwards compatibility)
|
||||
if selectedProvider == nil && a.State.Provider != "" && a.State.Model != "" {
|
||||
if provider, model := findModelByProviderAndModelID(providers, a.State.Provider, a.State.Model); provider != nil && model != nil {
|
||||
selectedProvider = provider
|
||||
selectedModel = model
|
||||
slog.Debug("Selected model from state", "provider", provider.ID, "model", model.ID)
|
||||
} else {
|
||||
slog.Debug("State model not found", "provider", a.State.Provider, "model", a.State.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 5: Internal priority fallback (Anthropic preferred, then first available)
|
||||
if selectedProvider == nil {
|
||||
// Try Anthropic first as internal priority
|
||||
if provider := findProviderByID(providers, "anthropic"); provider != nil {
|
||||
if model := getDefaultModel(providersResponse, *provider); model != nil {
|
||||
selectedProvider = provider
|
||||
selectedModel = model
|
||||
slog.Debug("Selected model from internal priority (Anthropic)", "provider", provider.ID, "model", model.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// If Anthropic not available, use first available provider
|
||||
if selectedProvider == nil && len(providers) > 0 {
|
||||
provider := &providers[0]
|
||||
if model := getDefaultModel(providersResponse, *provider); model != nil {
|
||||
selectedProvider = provider
|
||||
selectedModel = model
|
||||
slog.Debug("Selected model from fallback (first available)", "provider", provider.ID, "model", model.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if initialProvider != nil && initialModel != nil {
|
||||
currentProvider = initialProvider
|
||||
currentModel = initialModel
|
||||
// Final safety check
|
||||
if selectedProvider == nil || selectedModel == nil {
|
||||
slog.Error("Failed to select any model")
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmds []tea.Cmd
|
||||
cmds = append(cmds, util.CmdHandler(ModelSelectedMsg{
|
||||
Provider: *currentProvider,
|
||||
Model: *currentModel,
|
||||
Provider: *selectedProvider,
|
||||
Model: *selectedModel,
|
||||
}))
|
||||
if a.InitialPrompt != nil && *a.InitialPrompt != "" {
|
||||
cmds = append(cmds, util.CmdHandler(SendPrompt{Text: *a.InitialPrompt}))
|
||||
|
||||
228
packages/tui/internal/app/app_test.go
Normal file
228
packages/tui/internal/app/app_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sst/opencode-sdk-go"
|
||||
)
|
||||
|
||||
// TestFindModelByFullID tests the findModelByFullID function
|
||||
func TestFindModelByFullID(t *testing.T) {
|
||||
// Create test providers with models
|
||||
providers := []opencode.Provider{
|
||||
{
|
||||
ID: "anthropic",
|
||||
Models: map[string]opencode.Model{
|
||||
"claude-3-opus-20240229": {ID: "claude-3-opus-20240229"},
|
||||
"claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "openai",
|
||||
Models: map[string]opencode.Model{
|
||||
"gpt-4": {ID: "gpt-4"},
|
||||
"gpt-3.5-turbo": {ID: "gpt-3.5-turbo"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fullModelID string
|
||||
expectedFound bool
|
||||
expectedProviderID string
|
||||
expectedModelID string
|
||||
}{
|
||||
{
|
||||
name: "valid full model ID",
|
||||
fullModelID: "anthropic/claude-3-opus-20240229",
|
||||
expectedFound: true,
|
||||
expectedProviderID: "anthropic",
|
||||
expectedModelID: "claude-3-opus-20240229",
|
||||
},
|
||||
{
|
||||
name: "valid full model ID with slash in model name",
|
||||
fullModelID: "openai/gpt-3.5-turbo",
|
||||
expectedFound: true,
|
||||
expectedProviderID: "openai",
|
||||
expectedModelID: "gpt-3.5-turbo",
|
||||
},
|
||||
{
|
||||
name: "invalid format - missing slash",
|
||||
fullModelID: "anthropic",
|
||||
expectedFound: false,
|
||||
},
|
||||
{
|
||||
name: "invalid format - empty string",
|
||||
fullModelID: "",
|
||||
expectedFound: false,
|
||||
},
|
||||
{
|
||||
name: "provider not found",
|
||||
fullModelID: "nonexistent/model",
|
||||
expectedFound: false,
|
||||
},
|
||||
{
|
||||
name: "model not found",
|
||||
fullModelID: "anthropic/nonexistent-model",
|
||||
expectedFound: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, model := findModelByFullID(providers, tt.fullModelID)
|
||||
|
||||
if tt.expectedFound {
|
||||
if provider == nil || model == nil {
|
||||
t.Errorf("Expected to find provider/model, but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if provider.ID != tt.expectedProviderID {
|
||||
t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
|
||||
}
|
||||
|
||||
if model.ID != tt.expectedModelID {
|
||||
t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID)
|
||||
}
|
||||
} else {
|
||||
if provider != nil || model != nil {
|
||||
t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindModelByProviderAndModelID tests the findModelByProviderAndModelID function
|
||||
func TestFindModelByProviderAndModelID(t *testing.T) {
|
||||
// Create test providers with models
|
||||
providers := []opencode.Provider{
|
||||
{
|
||||
ID: "anthropic",
|
||||
Models: map[string]opencode.Model{
|
||||
"claude-3-opus-20240229": {ID: "claude-3-opus-20240229"},
|
||||
"claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "openai",
|
||||
Models: map[string]opencode.Model{
|
||||
"gpt-4": {ID: "gpt-4"},
|
||||
"gpt-3.5-turbo": {ID: "gpt-3.5-turbo"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providerID string
|
||||
modelID string
|
||||
expectedFound bool
|
||||
expectedProviderID string
|
||||
expectedModelID string
|
||||
}{
|
||||
{
|
||||
name: "valid provider and model",
|
||||
providerID: "anthropic",
|
||||
modelID: "claude-3-opus-20240229",
|
||||
expectedFound: true,
|
||||
expectedProviderID: "anthropic",
|
||||
expectedModelID: "claude-3-opus-20240229",
|
||||
},
|
||||
{
|
||||
name: "provider not found",
|
||||
providerID: "nonexistent",
|
||||
modelID: "claude-3-opus-20240229",
|
||||
expectedFound: false,
|
||||
},
|
||||
{
|
||||
name: "model not found",
|
||||
providerID: "anthropic",
|
||||
modelID: "nonexistent-model",
|
||||
expectedFound: false,
|
||||
},
|
||||
{
|
||||
name: "both provider and model not found",
|
||||
providerID: "nonexistent",
|
||||
modelID: "nonexistent-model",
|
||||
expectedFound: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, model := findModelByProviderAndModelID(providers, tt.providerID, tt.modelID)
|
||||
|
||||
if tt.expectedFound {
|
||||
if provider == nil || model == nil {
|
||||
t.Errorf("Expected to find provider/model, but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if provider.ID != tt.expectedProviderID {
|
||||
t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
|
||||
}
|
||||
|
||||
if model.ID != tt.expectedModelID {
|
||||
t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID)
|
||||
}
|
||||
} else {
|
||||
if provider != nil || model != nil {
|
||||
t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindProviderByID tests the findProviderByID function
|
||||
func TestFindProviderByID(t *testing.T) {
|
||||
// Create test providers
|
||||
providers := []opencode.Provider{
|
||||
{ID: "anthropic"},
|
||||
{ID: "openai"},
|
||||
{ID: "google"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providerID string
|
||||
expectedFound bool
|
||||
expectedProviderID string
|
||||
}{
|
||||
{
|
||||
name: "provider found",
|
||||
providerID: "anthropic",
|
||||
expectedFound: true,
|
||||
expectedProviderID: "anthropic",
|
||||
},
|
||||
{
|
||||
name: "provider not found",
|
||||
providerID: "nonexistent",
|
||||
expectedFound: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider := findProviderByID(providers, tt.providerID)
|
||||
|
||||
if tt.expectedFound {
|
||||
if provider == nil {
|
||||
t.Errorf("Expected to find provider, but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if provider.ID != tt.expectedProviderID {
|
||||
t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
|
||||
}
|
||||
} else {
|
||||
if provider != nil {
|
||||
t.Errorf("Expected not to find provider, but got %v", provider)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -66,9 +66,11 @@ If you've configured a [custom provider](/docs/providers#custom), the `provider_
|
||||
|
||||
## Loading models
|
||||
|
||||
When opencode starts up, it checks for the following:
|
||||
When opencode starts up, it checks for models in the following priority order:
|
||||
|
||||
1. The model list in the opencode config.
|
||||
1. The `--model` or `-m` command line flag. The format is the same as in the config file: `provider_id/model_id`.
|
||||
|
||||
2. The model list in the opencode config.
|
||||
|
||||
```json title="opencode.json"
|
||||
{
|
||||
@@ -79,6 +81,6 @@ When opencode starts up, it checks for the following:
|
||||
|
||||
The format here is `provider/model`.
|
||||
|
||||
2. The last used model.
|
||||
3. The last used model.
|
||||
|
||||
3. The first model using an internal priority.
|
||||
4. The first model using an internal priority.
|
||||
|
||||
Reference in New Issue
Block a user