From c69c9327da4a43a63928807fcf36b24755cfac18 Mon Sep 17 00:00:00 2001 From: adamdottv <2363879+adamdottv@users.noreply.github.com> Date: Fri, 30 May 2025 15:34:22 -0500 Subject: [PATCH] wip: refactoring tui --- internal/tui/app/app.go | 48 +++-- internal/tui/app/bridge.go | 27 --- internal/tui/app/interfaces.go | 1 - internal/tui/components/core/status.go | 5 +- internal/tui/components/dialog/models.go | 236 ++++++++++------------- internal/tui/state/state.go | 7 + internal/tui/tui.go | 58 ++++-- js/src/config/config.ts | 2 +- js/src/llm/llm.ts | 79 ++++---- js/src/provider/provider.ts | 5 +- js/src/server/server.ts | 8 +- pkg/client/gen/openapi.json | 20 +- pkg/client/generated-client.go | 11 +- 13 files changed, 244 insertions(+), 263 deletions(-) diff --git a/internal/tui/app/app.go b/internal/tui/app/app.go index b00a6d61..8320d815 100644 --- a/internal/tui/app/app.go +++ b/internal/tui/app/app.go @@ -3,7 +3,6 @@ package app import ( "context" "fmt" - "sync" "log/slog" @@ -20,20 +19,14 @@ import ( type App struct { Client *client.ClientWithResponses Events *client.Client + Provider *client.ProviderInfo + Model *client.ProviderModel Session *client.SessionInfo Messages []client.MessageInfo - - LogsOLD any // TODO: Define LogService interface when needed - HistoryOLD any // TODO: Define HistoryService interface when needed - PermissionsOLD any // TODO: Define PermissionService interface when needed - Status status.Service + Status status.Service PrimaryAgentOLD AgentService - watcherCancelFuncs []context.CancelFunc - cancelFuncsMutex sync.Mutex - watcherWG sync.WaitGroup - // UI state filepickerOpen bool completionDialogOpen bool @@ -70,13 +63,9 @@ func New(ctx context.Context) (*App, error) { Client: httpClient, Events: eventClient, Session: &client.SessionInfo{}, + Messages: []client.MessageInfo{}, PrimaryAgentOLD: agentBridge, Status: status.GetService(), - - // TODO: These services need API endpoints: - LogsOLD: nil, // logging.GetService(), - HistoryOLD: nil, // history.GetService(), - PermissionsOLD: nil, // permission.GetService(), } // Initialize theme based on configuration @@ -128,13 +117,12 @@ func (a *App) SendChatMessage(ctx context.Context, text string, attachments []At go a.Client.PostSessionChatWithResponse(ctx, client.PostSessionChatJSONRequestBody{ SessionID: a.Session.Id, Parts: parts, - ProviderID: "anthropic", - ModelID: "claude-sonnet-4-20250514", + ProviderID: a.Provider.Id, + ModelID: a.Model.Id, }) // The actual response will come through SSE // For now, just return success - return tea.Batch(cmds...) } @@ -169,6 +157,22 @@ func (a *App) ListMessages(ctx context.Context, sessionId string) ([]client.Mess return messages, nil } +func (a *App) ListProviders(ctx context.Context) ([]client.ProviderInfo, error) { + resp, err := a.Client.PostProviderListWithResponse(ctx) + if err != nil { + return nil, err + } + if resp.StatusCode() != 200 { + return nil, fmt.Errorf("failed to list sessions: %d", resp.StatusCode()) + } + if resp.JSON200 == nil { + return []client.ProviderInfo{}, nil + } + + providers := *resp.JSON200 + return providers, nil +} + // initTheme sets the application theme based on the configuration func (app *App) initTheme() { cfg := config.Get() @@ -207,11 +211,5 @@ func (app *App) SetCompletionDialogOpen(open bool) { // Shutdown performs a clean shutdown of the application func (app *App) Shutdown() { - // Cancel all watcher goroutines - app.cancelFuncsMutex.Lock() - for _, cancel := range app.watcherCancelFuncs { - cancel() - } - app.cancelFuncsMutex.Unlock() - app.watcherWG.Wait() + // TODO: cleanup? } diff --git a/internal/tui/app/bridge.go b/internal/tui/app/bridge.go index 3e8768ac..cd149f6b 100644 --- a/internal/tui/app/bridge.go +++ b/internal/tui/app/bridge.go @@ -17,33 +17,6 @@ func NewAgentServiceBridge(client *client.ClientWithResponses) *AgentServiceBrid return &AgentServiceBridge{client: client} } -// Run sends a message to the chat API -func (a *AgentServiceBridge) Run(ctx context.Context, sessionID string, text string, attachments ...Attachment) (string, error) { - // TODO: Handle attachments when API supports them - if len(attachments) > 0 { - // For now, ignore attachments - // return "", fmt.Errorf("attachments not supported yet") - } - - part := client.MessagePart{} - part.FromMessagePartText(client.MessagePartText{ - Type: "text", - Text: text, - }) - parts := []client.MessagePart{part} - - go a.client.PostSessionChatWithResponse(ctx, client.PostSessionChatJSONRequestBody{ - SessionID: sessionID, - Parts: parts, - ProviderID: "anthropic", - ModelID: "claude-sonnet-4-20250514", - }) - - // The actual response will come through SSE - // For now, just return success - return "", nil -} - // Cancel cancels the current generation - NOT IMPLEMENTED IN API YET func (a *AgentServiceBridge) Cancel(sessionID string) error { // TODO: Not implemented in TypeScript API yet diff --git a/internal/tui/app/interfaces.go b/internal/tui/app/interfaces.go index 4cc9b802..a396ef58 100644 --- a/internal/tui/app/interfaces.go +++ b/internal/tui/app/interfaces.go @@ -6,7 +6,6 @@ import ( // AgentService defines the interface for agent operations type AgentService interface { - Run(ctx context.Context, sessionID string, text string, attachments ...Attachment) (string, error) Cancel(sessionID string) error IsBusy() bool IsSessionBusy(sessionID string) bool diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index fd782ee4..18a0ad6b 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -335,7 +335,10 @@ func (m *statusCmp) projectDiagnostics() string { func (m statusCmp) model() string { t := theme.CurrentTheme() - model := "Claude Sonnet 4" // models.SupportedModels[coder.Model] + model := "None" + if m.app.Model != nil { + model = *m.app.Model.Name + } return styles.Padded(). Background(t.Secondary()). diff --git a/internal/tui/components/dialog/models.go b/internal/tui/components/dialog/models.go index 1d8c4f49..2dd1e2fe 100644 --- a/internal/tui/components/dialog/models.go +++ b/internal/tui/components/dialog/models.go @@ -1,14 +1,18 @@ package dialog import ( + "context" + "fmt" + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/sst/opencode/internal/config" + "github.com/sst/opencode/internal/tui/app" "github.com/sst/opencode/internal/tui/layout" "github.com/sst/opencode/internal/tui/styles" "github.com/sst/opencode/internal/tui/theme" "github.com/sst/opencode/internal/tui/util" + "github.com/sst/opencode/pkg/client" ) const ( @@ -16,24 +20,25 @@ const ( maxDialogWidth = 40 ) -// ModelSelectedMsg is sent when a model is selected -type ModelSelectedMsg struct { - // Model models.Model -} - // CloseModelDialogMsg is sent when a model is selected -type CloseModelDialogMsg struct{} +type CloseModelDialogMsg struct { + Provider *client.ProviderInfo + Model *client.ProviderModel +} // ModelDialog interface for the model selection dialog type ModelDialog interface { tea.Model layout.Bindings + + SetProviders(providers []client.ProviderInfo) } type modelDialogCmp struct { - // models []models.Model - // provider models.ModelProvider - // availableProviders []models.ModelProvider + app *app.App + availableProviders []client.ProviderInfo + provider client.ProviderInfo + model *client.ProviderModel selectedIdx int width int @@ -100,10 +105,28 @@ var modelKeys = modelKeyMap{ } func (m *modelDialogCmp) Init() tea.Cmd { - m.setupModels() + // cfg := config.Get() + // modelInfo := GetSelectedModel(cfg) + // m.availableProviders = getEnabledProviders(cfg) + // m.hScrollPossible = len(m.availableProviders) > 1 + + // m.provider = modelInfo.Provider + // m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider) + + // m.setupModelsForProvider(m.provider) + + m.availableProviders, _ = m.app.ListProviders(context.Background()) + m.hScrollOffset = 0 + m.hScrollPossible = len(m.availableProviders) > 1 + m.provider = m.availableProviders[m.hScrollOffset] + return nil } +func (m *modelDialogCmp) SetProviders(providers []client.ProviderInfo) { + m.availableProviders = providers +} + func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.KeyMsg: @@ -121,7 +144,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.switchProvider(1) } case key.Matches(msg, modelKeys.Enter): - // return m, util.CmdHandler(ModelSelectedMsg{Model: m.models[m.selectedIdx]}) + return m, util.CmdHandler(CloseModelDialogMsg{Provider: &m.provider, Model: &m.provider.Models[m.selectedIdx]}) case key.Matches(msg, modelKeys.Escape): return m, util.CmdHandler(CloseModelDialogMsg{}) } @@ -138,8 +161,8 @@ func (m *modelDialogCmp) moveSelectionUp() { if m.selectedIdx > 0 { m.selectedIdx-- } else { - // m.selectedIdx = len(m.models) - 1 - // m.scrollOffset = max(0, len(m.models)-numVisibleModels) + m.selectedIdx = len(m.provider.Models) - 1 + m.scrollOffset = max(0, len(m.provider.Models)-numVisibleModels) } // Keep selection visible @@ -150,12 +173,12 @@ func (m *modelDialogCmp) moveSelectionUp() { // moveSelectionDown moves the selection down or wraps to top func (m *modelDialogCmp) moveSelectionDown() { - // if m.selectedIdx < len(m.models)-1 { - // m.selectedIdx++ - // } else { - // m.selectedIdx = 0 - // m.scrollOffset = 0 - // } + if m.selectedIdx < len(m.provider.Models)-1 { + m.selectedIdx++ + } else { + m.selectedIdx = 0 + m.scrollOffset = 0 + } // Keep selection visible if m.selectedIdx >= m.scrollOffset+numVisibleModels { @@ -167,16 +190,16 @@ func (m *modelDialogCmp) switchProvider(offset int) { newOffset := m.hScrollOffset + offset // Ensure we stay within bounds - // if newOffset < 0 { - // newOffset = len(m.availableProviders) - 1 - // } - // if newOffset >= len(m.availableProviders) { - // newOffset = 0 - // } + if newOffset < 0 { + newOffset = len(m.availableProviders) - 1 + } + if newOffset >= len(m.availableProviders) { + newOffset = 0 + } m.hScrollOffset = newOffset - // m.provider = m.availableProviders[m.hScrollOffset] - // m.setupModelsForProvider(m.provider) + m.provider = m.availableProviders[m.hScrollOffset] + m.setupModelsForProvider(m.provider.Id) } func (m *modelDialogCmp) View() string { @@ -184,33 +207,32 @@ func (m *modelDialogCmp) View() string { baseStyle := styles.BaseStyle() // Capitalize first letter of provider name - // providerName := strings.ToUpper(string(m.provider)[:1]) + string(m.provider[1:]) - // title := baseStyle. - // Foreground(t.Primary()). - // Bold(true). - // Width(maxDialogWidth). - // Padding(0, 0, 1). - // Render(fmt.Sprintf("Select %s Model", providerName)) + title := baseStyle. + Foreground(t.Primary()). + Bold(true). + Width(maxDialogWidth). + Padding(0, 0, 1). + Render(fmt.Sprintf("Select %s Model", m.provider.Name)) // Render visible models - // endIdx := min(m.scrollOffset+numVisibleModels, len(m.models)) - // modelItems := make([]string, 0, endIdx-m.scrollOffset) - // - // for i := m.scrollOffset; i < endIdx; i++ { - // itemStyle := baseStyle.Width(maxDialogWidth) - // if i == m.selectedIdx { - // itemStyle = itemStyle.Background(t.Primary()). - // Foreground(t.Background()).Bold(true) - // } - // modelItems = append(modelItems, itemStyle.Render(m.models[i].Name)) - // } + endIdx := min(m.scrollOffset+numVisibleModels, len(m.provider.Models)) + modelItems := make([]string, 0, endIdx-m.scrollOffset) + + for i := m.scrollOffset; i < endIdx; i++ { + itemStyle := baseStyle.Width(maxDialogWidth) + if i == m.selectedIdx { + itemStyle = itemStyle.Background(t.Primary()). + Foreground(t.Background()).Bold(true) + } + modelItems = append(modelItems, itemStyle.Render(*m.provider.Models[i].Name)) + } scrollIndicator := m.getScrollIndicators(maxDialogWidth) content := lipgloss.JoinVertical( lipgloss.Left, - // title, - // baseStyle.Width(maxDialogWidth).Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)), + title, + baseStyle.Width(maxDialogWidth).Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)), scrollIndicator, ) @@ -225,22 +247,22 @@ func (m *modelDialogCmp) View() string { func (m *modelDialogCmp) getScrollIndicators(maxWidth int) string { var indicator string - // if len(m.models) > numVisibleModels { - // if m.scrollOffset > 0 { - // indicator += "↑ " - // } - // if m.scrollOffset+numVisibleModels < len(m.models) { - // indicator += "↓ " - // } - // } + if len(m.provider.Models) > numVisibleModels { + if m.scrollOffset > 0 { + indicator += "↑ " + } + if m.scrollOffset+numVisibleModels < len(m.provider.Models) { + indicator += "↓ " + } + } if m.hScrollPossible { if m.hScrollOffset > 0 { indicator = "← " + indicator } - // if m.hScrollOffset < len(m.availableProviders)-1 { - // indicator += "→" - // } + if m.hScrollOffset < len(m.availableProviders)-1 { + indicator += "→" + } } if indicator == "" { @@ -262,70 +284,26 @@ func (m *modelDialogCmp) BindingKeys() []key.Binding { return layout.KeyMapToSlice(modelKeys) } -func (m *modelDialogCmp) setupModels() { - // cfg := config.Get() - // modelInfo := GetSelectedModel(cfg) - // m.availableProviders = getEnabledProviders(cfg) - // m.hScrollPossible = len(m.availableProviders) > 1 - // - // m.provider = modelInfo.Provider - // m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider) - // - // m.setupModelsForProvider(m.provider) -} - -func GetSelectedModel(cfg *config.Config) string { - return "Claude Sonnet 4" - // agentCfg := cfg.Agents[config.AgentPrimary] - // selectedModelId := agentCfg.Model - // return models.SupportedModels[selectedModelId] -} - -func getEnabledProviders(cfg *config.Config) []string { - return []string{"anthropic", "openai", "google"} - // var providers []models.ModelProvider - // for providerId, provider := range cfg.Providers { - // if !provider.Disabled { - // providers = append(providers, providerId) - // } - // } - // - // // Sort by provider popularity - // slices.SortFunc(providers, func(a, b models.ModelProvider) int { - // rA := models.ProviderPopularity[a] - // rB := models.ProviderPopularity[b] - // - // // models not included in popularity ranking default to last - // if rA == 0 { - // rA = 999 - // } - // if rB == 0 { - // rB = 999 - // } - // return rA - rB - // }) - // return providers -} - // findProviderIndex returns the index of the provider in the list, or -1 if not found -func findProviderIndex(providers []string, provider string) int { - for i, p := range providers { - if p == provider { - return i - } - } - return -1 -} +// func findProviderIndex(providers []string, provider string) int { +// for i, p := range providers { +// if p == provider { +// return i +// } +// } +// return -1 +// } + +func (m *modelDialogCmp) setupModelsForProvider(_ string) { + m.selectedIdx = 0 + m.scrollOffset = 0 -func (m *modelDialogCmp) setupModelsForProvider(provider string) { // cfg := config.Get() // agentCfg := cfg.Agents[config.AgentPrimary] // selectedModelId := agentCfg.Model // m.provider = provider // m.models = getModelsForProvider(provider) - m.selectedIdx = 0 - m.scrollOffset = 0 // Try to select the current model if it belongs to this provider // if provider == models.SupportedModels[selectedModelId].Provider { @@ -342,28 +320,8 @@ func (m *modelDialogCmp) setupModelsForProvider(provider string) { // } } -func getModelsForProvider(provider string) []string { - return []string{"Claude Sonnet 4"} - // var providerModels []models.Model - // for _, model := range models.SupportedModels { - // if model.Provider == provider { - // providerModels = append(providerModels, model) - // } - // } - - // reverse alphabetical order (if llm naming was consistent latest would appear first) - // slices.SortFunc(providerModels, func(a, b models.Model) int { - // if a.Name > b.Name { - // return -1 - // } else if a.Name < b.Name { - // return 1 - // } - // return 0 - // }) - - // return providerModels -} - -func NewModelDialogCmp() ModelDialog { - return &modelDialogCmp{} +func NewModelDialogCmp(app *app.App) ModelDialog { + return &modelDialogCmp{ + app: app, + } } diff --git a/internal/tui/state/state.go b/internal/tui/state/state.go index 33b1cc7f..6b117518 100644 --- a/internal/tui/state/state.go +++ b/internal/tui/state/state.go @@ -5,8 +5,15 @@ import ( ) type SessionSelectedMsg = *client.SessionInfo +type ModelSelectedMsg struct { + Provider client.ProviderInfo + Model client.ProviderModel +} + type SessionClearedMsg struct{} type CompactSessionMsg struct{} + +// TODO: remove type StateUpdatedMsg struct { State map[string]any } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 024abe00..1b8d4c08 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -168,16 +168,27 @@ func (a appModel) Init() tea.Cmd { return dialog.ShowInitDialogMsg{Show: shouldShow} }) + cmds = append(cmds, func() tea.Msg { + providers, _ := a.app.ListProviders(context.Background()) + return state.ModelSelectedMsg{Provider: providers[0], Model: providers[0].Models[0]} + }) + return tea.Batch(cmds...) } func (a appModel) updateAllPages(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd var cmd tea.Cmd + for id := range a.pages { a.pages[id], cmd = a.pages[id].Update(msg) cmds = append(cmds, cmd) } + + s, cmd := a.status.Update(msg) + cmds = append(cmds, cmd) + a.status = s.(core.StatusCmp) + return a, tea.Batch(cmds...) } @@ -201,12 +212,10 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { for i, m := range a.app.Messages { if m.Id == msg.Properties.Info.Id { a.app.Messages[i] = msg.Properties.Info - slog.Debug("Updated message", "message", msg.Properties.Info) return a.updateAllPages(state.StateUpdatedMsg{State: nil}) } } a.app.Messages = append(a.app.Messages, msg.Properties.Info) - slog.Debug("Appended message", "message", msg.Properties.Info) return a.updateAllPages(state.StateUpdatedMsg{State: nil}) } @@ -287,6 +296,19 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.app.Messages, _ = a.app.ListMessages(context.Background(), msg.Id) return a.updateAllPages(msg) + case dialog.CloseModelDialogMsg: + a.showModelDialog = false + slog.Debug("closing model dialog", "msg", msg) + if msg.Provider != nil && msg.Model != nil { + return a, util.CmdHandler(state.ModelSelectedMsg{Provider: *msg.Provider, Model: *msg.Model}) + } + return a, nil + + case state.ModelSelectedMsg: + a.app.Provider = &msg.Provider + a.app.Model = &msg.Model + return a.updateAllPages(msg) + case dialog.CloseCommandDialogMsg: a.showCommandDialog = false return a, nil @@ -309,24 +331,6 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { status.Info("Theme changed to: " + msg.ThemeName) return a, cmd - case dialog.CloseModelDialogMsg: - a.showModelDialog = false - return a, nil - - case dialog.ModelSelectedMsg: - a.showModelDialog = false - - // TODO: Agent model update not implemented in API yet - // model, err := a.app.PrimaryAgent.Update(config.AgentPrimary, msg.Model.ID) - // if err != nil { - // status.Error(err.Error()) - // return a, nil - // } - - // status.Info(fmt.Sprintf("Model changed to %s", model.Name)) - status.Info("Model selection not implemented in API yet") - return a, nil - case dialog.ShowInitDialogMsg: a.showInitDialog = msg.Show return a, nil @@ -476,6 +480,18 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.showThemeDialog = false a.showFilepicker = false + // Load providers and show the dialog + providers, err := a.app.ListProviders(context.Background()) + if err != nil { + status.Error(err.Error()) + return a, nil + } + if len(providers) == 0 { + status.Warn("No providers available") + return a, nil + } + a.modelDialog.SetProviders(providers) + a.showModelDialog = true return a, nil } @@ -907,7 +923,7 @@ func New(app *app.App) tea.Model { quit: dialog.NewQuitCmp(), sessionDialog: dialog.NewSessionDialogCmp(), commandDialog: dialog.NewCommandDialogCmp(), - modelDialog: dialog.NewModelDialogCmp(), + modelDialog: dialog.NewModelDialogCmp(app), permissions: dialog.NewPermissionDialogCmp(), initDialog: dialog.NewInitDialogCmp(), themeDialog: dialog.NewThemeDialogCmp(), diff --git a/js/src/config/config.ts b/js/src/config/config.ts index 76181b79..8c374b3f 100644 --- a/js/src/config/config.ts +++ b/js/src/config/config.ts @@ -14,7 +14,7 @@ export namespace Config { export const Info = z .object({ - providers: z.record(z.string(), Provider.Info).optional(), + providers: Provider.Info.array().optional(), }) .strict(); diff --git a/js/src/llm/llm.ts b/js/src/llm/llm.ts index 990c3006..a7d31fb9 100644 --- a/js/src/llm/llm.ts +++ b/js/src/llm/llm.ts @@ -1,6 +1,6 @@ import { App } from "../app/app"; import { Log } from "../util/log"; -import { mergeDeep } from "remeda"; +import { concat } from "remeda"; import path from "path"; import { Provider } from "../provider/provider"; @@ -19,26 +19,32 @@ export namespace LLM { } } - const NATIVE_PROVIDERS: Record = { - anthropic: { - models: { - "claude-sonnet-4-20250514": { - name: "Claude 4 Sonnet", + const NATIVE_PROVIDERS: Provider.Info[] = [ + { + id: "anthropic", + name: "Anthropic", + models: [ + { + id: "claude-sonnet-4-20250514", + name: "Claude Sonnet 4", cost: { input: 3.0 / 1_000_000, output: 15.0 / 1_000_000, inputCached: 3.75 / 1_000_000, outputCached: 0.3 / 1_000_000, }, - contextWindow: 200000, - maxTokens: 50000, + contextWindow: 200_000, + maxOutputTokens: 50_000, attachment: true, }, - }, + ], }, - openai: { - models: { - "codex-mini-latest": { + { + id: "openai", + name: "OpenAI", + models: [ + { + id: "codex-mini-latest", name: "Codex Mini", cost: { input: 1.5 / 1_000_000, @@ -46,16 +52,19 @@ export namespace LLM { output: 6.0 / 1_000_000, outputCached: 0.0 / 1_000_000, }, - contextWindow: 200000, - maxTokens: 100000, + contextWindow: 200_000, + maxOutputTokens: 100_000, attachment: true, reasoning: true, }, - }, + ], }, - google: { - models: { - "gemini-2.5-pro-preview-03-25": { + { + id: "google", + name: "Google", + models: [ + { + id: "gemini-2.5-pro-preview-03-25", name: "Gemini 2.5 Pro", cost: { input: 1.25 / 1_000_000, @@ -63,18 +72,18 @@ export namespace LLM { output: 10 / 1_000_000, outputCached: 0 / 1_000_000, }, - contextWindow: 1000000, - maxTokens: 50000, + contextWindow: 1_000_000, + maxOutputTokens: 50_000, attachment: true, }, - }, + ], }, - }; + ]; const AUTODETECT: Record = { anthropic: ["ANTHROPIC_API_KEY"], openai: ["OPENAI_API_KEY"], - google: ["GOOGLE_GENERATIVE_AI_API_KEY"], + google: ["GOOGLE_GENERATIVE_AI_API_KEY", "GEMINI_API_KEY"], }; const state = App.state("llm", async () => { @@ -91,33 +100,33 @@ export namespace LLM { { info: Provider.Model; instance: LanguageModel } >(); - const list = mergeDeep(NATIVE_PROVIDERS, config.providers ?? {}); + const list = concat(NATIVE_PROVIDERS, config.providers ?? []); - for (const [providerID, providerInfo] of Object.entries(list)) { + for (const provider of list) { if ( - !config.providers?.[providerID] && - !AUTODETECT[providerID]?.some((env) => process.env[env]) + !config.providers?.find((p) => p.id === provider.id) && + !AUTODETECT[provider.id]?.some((env) => process.env[env]) ) continue; const dir = path.join( Global.cache(), `node_modules`, `@ai-sdk`, - providerID, + provider.id, ); if (!(await Bun.file(path.join(dir, "package.json")).exists())) { - BunProc.run(["add", "--exact", `@ai-sdk/${providerID}@alpha`], { + BunProc.run(["add", "--exact", `@ai-sdk/${provider.id}@alpha`], { cwd: Global.cache(), }); } const mod = await import( - path.join(Global.cache(), `node_modules`, `@ai-sdk`, providerID) + path.join(Global.cache(), `node_modules`, `@ai-sdk`, provider.id) ); const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]; - const loaded = fn(providerInfo.options); - log.info("loaded", { provider: providerID }); - providers[providerID] = { - info: providerInfo, + const loaded = fn(provider.options); + log.info("loaded", { provider: provider.id }); + providers[provider.id] = { + info: provider, instance: loaded, }; } @@ -142,7 +151,7 @@ export namespace LLM { providerID, modelID, }); - const info = provider.info.models[modelID]; + const info = provider.info.models.find((m) => m.id === modelID); if (!info) throw new ModelNotFoundError(modelID); try { const match = provider.instance.languageModel(modelID); diff --git a/js/src/provider/provider.ts b/js/src/provider/provider.ts index a35645e6..d4719ffb 100644 --- a/js/src/provider/provider.ts +++ b/js/src/provider/provider.ts @@ -3,6 +3,7 @@ import z from "zod"; export namespace Provider { export const Model = z .object({ + id: z.string(), name: z.string().optional(), cost: z.object({ input: z.number(), @@ -22,8 +23,10 @@ export namespace Provider { export const Info = z .object({ + id: z.string(), + name: z.string(), options: z.record(z.string(), z.any()).optional(), - models: z.record(z.string(), Model), + models: Model.array(), }) .openapi({ ref: "Provider.Info", diff --git a/js/src/server/server.ts b/js/src/server/server.ts index b93ca5a6..28591cbd 100644 --- a/js/src/server/server.ts +++ b/js/src/server/server.ts @@ -263,7 +263,7 @@ export namespace Server { description: "List of providers", content: { "application/json": { - schema: resolver(z.record(z.string(), Provider.Info)), + schema: resolver(Provider.Info.array()), }, }, }, @@ -271,9 +271,9 @@ export namespace Server { }), async (c) => { const providers = await LLM.providers(); - const result: Record = {}; - for (const [providerID, provider] of Object.entries(providers)) { - result[providerID] = provider.info; + const result = [] as (Provider.Info & { key: string })[]; + for (const [key, provider] of Object.entries(providers)) { + result.push({ ...provider.info, key }); } return c.json(result); }, diff --git a/pkg/client/gen/openapi.json b/pkg/client/gen/openapi.json index 43e33b97..c5f1c64b 100644 --- a/pkg/client/gen/openapi.json +++ b/pkg/client/gen/openapi.json @@ -280,8 +280,8 @@ "content": { "application/json": { "schema": { - "type": "object", - "additionalProperties": { + "type": "array", + "items": { "$ref": "#/components/schemas/Provider.Info" } } @@ -818,24 +818,35 @@ "Provider.Info": { "type": "object", "properties": { + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, "options": { "type": "object", "additionalProperties": {} }, "models": { - "type": "object", - "additionalProperties": { + "type": "array", + "items": { "$ref": "#/components/schemas/Provider.Model" } } }, "required": [ + "id", + "name", "models" ] }, "Provider.Model": { "type": "object", "properties": { + "id": { + "type": "string" + }, "name": { "type": "string" }, @@ -876,6 +887,7 @@ } }, "required": [ + "id", "cost", "contextWindow", "attachment" diff --git a/pkg/client/generated-client.go b/pkg/client/generated-client.go index e8fbc388..925468a6 100644 --- a/pkg/client/generated-client.go +++ b/pkg/client/generated-client.go @@ -173,8 +173,10 @@ type MessageToolInvocationToolResult struct { // ProviderInfo defines model for Provider.Info. type ProviderInfo struct { - Models map[string]ProviderModel `json:"models"` - Options *map[string]interface{} `json:"options,omitempty"` + Id string `json:"id"` + Models []ProviderModel `json:"models"` + Name string `json:"name"` + Options *map[string]interface{} `json:"options,omitempty"` } // ProviderModel defines model for Provider.Model. @@ -187,6 +189,7 @@ type ProviderModel struct { Output float32 `json:"output"` OutputCached float32 `json:"outputCached"` } `json:"cost"` + Id string `json:"id"` MaxOutputTokens *float32 `json:"maxOutputTokens,omitempty"` Name *string `json:"name,omitempty"` Reasoning *bool `json:"reasoning,omitempty"` @@ -1421,7 +1424,7 @@ func (r GetEventResponse) StatusCode() int { type PostProviderListResponse struct { Body []byte HTTPResponse *http.Response - JSON200 *map[string]ProviderInfo + JSON200 *[]ProviderInfo } // Status returns HTTPResponse.Status @@ -1756,7 +1759,7 @@ func ParsePostProviderListResponse(rsp *http.Response) (*PostProviderListRespons switch { case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: - var dest map[string]ProviderInfo + var dest []ProviderInfo if err := json.Unmarshal(bodyBytes, &dest); err != nil { return nil, err }