mirror of
https://github.com/aljazceru/opencode.git
synced 2025-12-21 01:34:22 +01:00
feat(tui): handle --model and --prompt flags
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea/v2"
|
tea "github.com/charmbracelet/bubbletea/v2"
|
||||||
|
flag "github.com/spf13/pflag"
|
||||||
"github.com/sst/opencode-sdk-go"
|
"github.com/sst/opencode-sdk-go"
|
||||||
"github.com/sst/opencode-sdk-go/option"
|
"github.com/sst/opencode-sdk-go/option"
|
||||||
"github.com/sst/opencode/internal/app"
|
"github.com/sst/opencode/internal/app"
|
||||||
@@ -23,6 +24,10 @@ func main() {
|
|||||||
version = "v" + Version
|
version = "v" + Version
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var model *string = flag.String("model", "", "model to begin with")
|
||||||
|
var prompt *string = flag.String("prompt", "", "prompt to begin with")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
url := os.Getenv("OPENCODE_SERVER")
|
url := os.Getenv("OPENCODE_SERVER")
|
||||||
|
|
||||||
appInfoStr := os.Getenv("OPENCODE_APP_INFO")
|
appInfoStr := os.Getenv("OPENCODE_APP_INFO")
|
||||||
@@ -65,7 +70,7 @@ func main() {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
app_, err := app.New(ctx, version, appInfo, httpClient)
|
app_, err := app.New(ctx, version, appInfo, httpClient, model, prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ require (
|
|||||||
github.com/muesli/termenv v0.16.0
|
github.com/muesli/termenv v0.16.0
|
||||||
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3
|
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3
|
||||||
github.com/sst/opencode-sdk-go v0.1.0-alpha.8
|
github.com/sst/opencode-sdk-go v0.1.0-alpha.8
|
||||||
github.com/tidwall/gjson v1.14.4
|
|
||||||
rsc.io/qr v0.2.0
|
rsc.io/qr v0.2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,6 +49,7 @@ require (
|
|||||||
github.com/sosodev/duration v1.3.1 // indirect
|
github.com/sosodev/duration v1.3.1 // indirect
|
||||||
github.com/speakeasy-api/openapi-overlay v0.9.0 // indirect
|
github.com/speakeasy-api/openapi-overlay v0.9.0 // indirect
|
||||||
github.com/spf13/cobra v1.9.1 // indirect
|
github.com/spf13/cobra v1.9.1 // indirect
|
||||||
|
github.com/tidwall/gjson v1.14.4 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
github.com/tidwall/sjson v1.2.5 // indirect
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
|
|||||||
@@ -21,17 +21,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type App struct {
|
type App struct {
|
||||||
Info opencode.App
|
Info opencode.App
|
||||||
Version string
|
Version string
|
||||||
StatePath string
|
StatePath string
|
||||||
Config *opencode.Config
|
Config *opencode.Config
|
||||||
Client *opencode.Client
|
Client *opencode.Client
|
||||||
State *config.State
|
State *config.State
|
||||||
Provider *opencode.Provider
|
Provider *opencode.Provider
|
||||||
Model *opencode.Model
|
Model *opencode.Model
|
||||||
Session *opencode.Session
|
Session *opencode.Session
|
||||||
Messages []opencode.MessageUnion
|
Messages []opencode.MessageUnion
|
||||||
Commands commands.CommandRegistry
|
Commands commands.CommandRegistry
|
||||||
|
InitialModel *string
|
||||||
|
InitialPrompt *string
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionSelectedMsg = *opencode.Session
|
type SessionSelectedMsg = *opencode.Session
|
||||||
@@ -58,6 +60,8 @@ func New(
|
|||||||
version string,
|
version string,
|
||||||
appInfo opencode.App,
|
appInfo opencode.App,
|
||||||
httpClient *opencode.Client,
|
httpClient *opencode.Client,
|
||||||
|
model *string,
|
||||||
|
prompt *string,
|
||||||
) (*App, error) {
|
) (*App, error) {
|
||||||
util.RootPath = appInfo.Path.Root
|
util.RootPath = appInfo.Path.Root
|
||||||
util.CwdPath = appInfo.Path.Cwd
|
util.CwdPath = appInfo.Path.Cwd
|
||||||
@@ -109,15 +113,17 @@ func New(
|
|||||||
slog.Debug("Loaded config", "config", configInfo)
|
slog.Debug("Loaded config", "config", configInfo)
|
||||||
|
|
||||||
app := &App{
|
app := &App{
|
||||||
Info: appInfo,
|
Info: appInfo,
|
||||||
Version: version,
|
Version: version,
|
||||||
StatePath: appStatePath,
|
StatePath: appStatePath,
|
||||||
Config: configInfo,
|
Config: configInfo,
|
||||||
State: appState,
|
State: appState,
|
||||||
Client: httpClient,
|
Client: httpClient,
|
||||||
Session: &opencode.Session{},
|
Session: &opencode.Session{},
|
||||||
Messages: []opencode.MessageUnion{},
|
Messages: []opencode.MessageUnion{},
|
||||||
Commands: commands.LoadFromConfig(configInfo),
|
Commands: commands.LoadFromConfig(configInfo),
|
||||||
|
InitialModel: model,
|
||||||
|
InitialPrompt: prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
return app, nil
|
return app, nil
|
||||||
@@ -141,65 +147,89 @@ func (a *App) Key(commandName commands.CommandName) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) InitializeProvider() tea.Cmd {
|
func (a *App) InitializeProvider() tea.Cmd {
|
||||||
return func() tea.Msg {
|
providersResponse, err := a.Client.Config.Providers(context.Background())
|
||||||
providersResponse, err := a.Client.Config.Providers(context.Background())
|
if err != nil {
|
||||||
if err != nil {
|
slog.Error("Failed to list providers", "error", err)
|
||||||
slog.Error("Failed to list providers", "error", err)
|
// TODO: notify user
|
||||||
// TODO: notify user
|
return nil
|
||||||
return nil
|
}
|
||||||
}
|
providers := providersResponse.Providers
|
||||||
providers := providersResponse.Providers
|
var defaultProvider *opencode.Provider
|
||||||
var defaultProvider *opencode.Provider
|
var defaultModel *opencode.Model
|
||||||
var defaultModel *opencode.Model
|
|
||||||
|
|
||||||
var anthropic *opencode.Provider
|
var anthropic *opencode.Provider
|
||||||
for _, provider := range providers {
|
for _, provider := range providers {
|
||||||
if provider.ID == "anthropic" {
|
if provider.ID == "anthropic" {
|
||||||
anthropic = &provider
|
anthropic = &provider
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// default to anthropic if available
|
||||||
|
if anthropic != nil {
|
||||||
|
defaultProvider = anthropic
|
||||||
|
defaultModel = getDefaultModel(providersResponse, *anthropic)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, provider := range providers {
|
||||||
|
if defaultProvider == nil || defaultModel == nil {
|
||||||
|
defaultProvider = &provider
|
||||||
|
defaultModel = getDefaultModel(providersResponse, provider)
|
||||||
|
}
|
||||||
|
providers = append(providers, provider)
|
||||||
|
}
|
||||||
|
if len(providers) == 0 {
|
||||||
|
slog.Error("No providers configured")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentProvider *opencode.Provider
|
||||||
|
var currentModel *opencode.Model
|
||||||
|
for _, provider := range providers {
|
||||||
|
if provider.ID == a.State.Provider {
|
||||||
|
currentProvider = &provider
|
||||||
|
|
||||||
|
for _, model := range provider.Models {
|
||||||
|
if model.ID == a.State.Model {
|
||||||
|
currentModel = &model
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if currentProvider == nil || currentModel == nil {
|
||||||
|
currentProvider = defaultProvider
|
||||||
|
currentModel = defaultModel
|
||||||
|
}
|
||||||
|
|
||||||
// default to anthropic if available
|
var initialProvider *opencode.Provider
|
||||||
if anthropic != nil {
|
var initialModel *opencode.Model
|
||||||
defaultProvider = anthropic
|
if a.InitialModel != nil && *a.InitialModel != "" {
|
||||||
defaultModel = getDefaultModel(providersResponse, *anthropic)
|
splits := strings.Split(*a.InitialModel, "/")
|
||||||
}
|
|
||||||
|
|
||||||
for _, provider := range providers {
|
for _, provider := range providers {
|
||||||
if defaultProvider == nil || defaultModel == nil {
|
if provider.ID == splits[0] {
|
||||||
defaultProvider = &provider
|
initialProvider = &provider
|
||||||
defaultModel = getDefaultModel(providersResponse, provider)
|
|
||||||
}
|
|
||||||
providers = append(providers, provider)
|
|
||||||
}
|
|
||||||
if len(providers) == 0 {
|
|
||||||
slog.Error("No providers configured")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var currentProvider *opencode.Provider
|
|
||||||
var currentModel *opencode.Model
|
|
||||||
for _, provider := range providers {
|
|
||||||
if provider.ID == a.State.Provider {
|
|
||||||
currentProvider = &provider
|
|
||||||
|
|
||||||
for _, model := range provider.Models {
|
for _, model := range provider.Models {
|
||||||
if model.ID == a.State.Model {
|
if model.ID == splits[1] {
|
||||||
currentModel = &model
|
initialModel = &model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if currentProvider == nil || currentModel == nil {
|
|
||||||
currentProvider = defaultProvider
|
|
||||||
currentModel = defaultModel
|
|
||||||
}
|
|
||||||
|
|
||||||
return ModelSelectedMsg{
|
|
||||||
Provider: *currentProvider,
|
|
||||||
Model: *currentModel,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if initialProvider != nil && initialModel != nil {
|
||||||
|
currentProvider = initialProvider
|
||||||
|
currentModel = initialModel
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmds []tea.Cmd
|
||||||
|
cmds = append(cmds, util.CmdHandler(ModelSelectedMsg{
|
||||||
|
Provider: *currentProvider,
|
||||||
|
Model: *currentModel,
|
||||||
|
}))
|
||||||
|
if a.InitialPrompt != nil && *a.InitialPrompt != "" {
|
||||||
|
cmds = append(cmds, util.CmdHandler(SendMsg{Text: *a.InitialPrompt}))
|
||||||
|
}
|
||||||
|
return tea.Sequence(cmds...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDefaultModel(
|
func getDefaultModel(
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func (m *editorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
return m, tea.Batch(cmds...)
|
return m, tea.Batch(cmds...)
|
||||||
}
|
}
|
||||||
case dialog.ThemeSelectedMsg:
|
case dialog.ThemeSelectedMsg:
|
||||||
m.textarea = createTextArea(&m.textarea)
|
m.textarea = m.resetTextareaStyles()
|
||||||
m.spinner = createSpinner()
|
m.spinner = createSpinner()
|
||||||
return m, tea.Batch(m.spinner.Tick, m.textarea.Focus())
|
return m, tea.Batch(m.spinner.Tick, m.textarea.Focus())
|
||||||
case dialog.CompletionSelectedMsg:
|
case dialog.CompletionSelectedMsg:
|
||||||
@@ -306,13 +306,13 @@ func (m *editorComponent) getSubmitKeyText() string {
|
|||||||
return m.app.Commands[commands.InputSubmitCommand].Keys()[0]
|
return m.app.Commands[commands.InputSubmitCommand].Keys()[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTextArea(existing *textarea.Model) textarea.Model {
|
func (m *editorComponent) resetTextareaStyles() textarea.Model {
|
||||||
t := theme.CurrentTheme()
|
t := theme.CurrentTheme()
|
||||||
bgColor := t.BackgroundElement()
|
bgColor := t.BackgroundElement()
|
||||||
textColor := t.Text()
|
textColor := t.Text()
|
||||||
textMutedColor := t.TextMuted()
|
textMutedColor := t.TextMuted()
|
||||||
|
|
||||||
ta := textarea.New()
|
ta := m.textarea
|
||||||
|
|
||||||
ta.Styles.Blurred.Base = styles.NewStyle().Foreground(textColor).Background(bgColor).Lipgloss()
|
ta.Styles.Blurred.Base = styles.NewStyle().Foreground(textColor).Background(bgColor).Lipgloss()
|
||||||
ta.Styles.Blurred.CursorLine = styles.NewStyle().Background(bgColor).Lipgloss()
|
ta.Styles.Blurred.CursorLine = styles.NewStyle().Background(bgColor).Lipgloss()
|
||||||
@@ -337,17 +337,6 @@ func createTextArea(existing *textarea.Model) textarea.Model {
|
|||||||
Background(t.Secondary()).
|
Background(t.Secondary()).
|
||||||
Lipgloss()
|
Lipgloss()
|
||||||
ta.Styles.Cursor.Color = t.Primary()
|
ta.Styles.Cursor.Color = t.Primary()
|
||||||
|
|
||||||
ta.Prompt = " "
|
|
||||||
ta.ShowLineNumbers = false
|
|
||||||
ta.CharLimit = -1
|
|
||||||
|
|
||||||
if existing != nil {
|
|
||||||
ta.SetValue(existing.Value())
|
|
||||||
// ta.SetWidth(existing.Width())
|
|
||||||
ta.SetHeight(existing.Height())
|
|
||||||
}
|
|
||||||
|
|
||||||
return ta
|
return ta
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -367,12 +356,19 @@ func createSpinner() spinner.Model {
|
|||||||
|
|
||||||
func NewEditorComponent(app *app.App) EditorComponent {
|
func NewEditorComponent(app *app.App) EditorComponent {
|
||||||
s := createSpinner()
|
s := createSpinner()
|
||||||
ta := createTextArea(nil)
|
|
||||||
|
|
||||||
return &editorComponent{
|
ta := textarea.New()
|
||||||
|
ta.Prompt = " "
|
||||||
|
ta.ShowLineNumbers = false
|
||||||
|
ta.CharLimit = -1
|
||||||
|
|
||||||
|
m := &editorComponent{
|
||||||
app: app,
|
app: app,
|
||||||
textarea: ta,
|
textarea: ta,
|
||||||
spinner: s,
|
spinner: s,
|
||||||
interruptKeyInDebounce: false,
|
interruptKeyInDebounce: false,
|
||||||
}
|
}
|
||||||
|
m.resetTextareaStyles()
|
||||||
|
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user