From 005b8ac16776512b2d4b1f22bd989da162ca1bad Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 24 Mar 2025 11:47:39 +0100 Subject: [PATCH] initial working agent --- cmd/root.go | 7 +-- internal/llm/agent/title.go | 31 ++++++++++ internal/llm/llm.go | 31 +++++++++- internal/llm/models/models.go | 48 ++++++++++++--- internal/tui/components/repl/messages.go | 75 ++++++++++++++++++++++-- internal/tui/components/repl/sessions.go | 31 +++++++++- 6 files changed, 201 insertions(+), 22 deletions(-) create mode 100644 internal/llm/agent/title.go diff --git a/cmd/root.go b/cmd/root.go index 3879a3cf..9e5ddbd7 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -109,8 +109,6 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { } } -// Execute adds all child commands to the root command and sets flags appropriately. -// This is called by main.main(). It only needs to happen once to the rootCmd. func Execute() { err := rootCmd.Execute() if err != nil { @@ -131,13 +129,14 @@ func loadConfig() { // LLM viper.SetDefault("models.big", string(models.DefaultBigModel)) - viper.SetDefault("models.little", string(models.DefaultLittleModel)) + viper.SetDefault("models.small", string(models.DefaultLittleModel)) viper.SetDefault("providers.openai.key", os.Getenv("OPENAI_API_KEY")) viper.SetDefault("providers.anthropic.key", os.Getenv("ANTHROPIC_API_KEY")) + viper.SetDefault("providers.groq.key", os.Getenv("GROQ_API_KEY")) viper.SetDefault("providers.common.max_tokens", 4000) viper.SetDefault("agents.default", "coder") - // + viper.ReadInConfig() workdir, err := os.Getwd() diff --git a/internal/llm/agent/title.go b/internal/llm/agent/title.go new file mode 100644 index 00000000..1b9840cc --- /dev/null +++ b/internal/llm/agent/title.go @@ -0,0 +1,31 @@ +package agent + +import ( + "context" + + "github.com/cloudwego/eino/schema" + "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/spf13/viper" +) + +func GenerateTitle(ctx context.Context, content string) (string, error) { + model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.small"))) + if err != nil { + return "", err + } + out, err := model.Generate( + ctx, + []*schema.Message{ + schema.SystemMessage(`- you will generate a short title based on the first message a user begins a conversation with + - ensure it is not more than 80 characters long + - the title should be a summary of the user's message + - do not use quotes or colons + - the entire text you return will be used as the title`), + schema.UserMessage(content), + }, + ) + if err != nil { + return "", err + } + return out.Content, nil +} diff --git a/internal/llm/llm.go b/internal/llm/llm.go index bbf9961e..2f87b225 100644 --- a/internal/llm/llm.go +++ b/internal/llm/llm.go @@ -11,6 +11,7 @@ import ( "github.com/cloudwego/eino/schema" "github.com/google/uuid" "github.com/kujtimiihoxha/termai/internal/llm/agent" + "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "github.com/kujtimiihoxha/termai/internal/pubsub" @@ -88,7 +89,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) { } log.Printf("Request: %s", content) - agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default")) + currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default")) if err != nil { s.Publish(AgentErrorEvent, AgentEvent{ ID: id, @@ -110,6 +111,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) { for _, m := range history { messages = append(messages, &m.MessageData) } + builder := callbacks.NewHandlerBuilder() builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { i, ok := input.(*eModel.CallbackInput) @@ -140,7 +142,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) { return ctx }) - out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build()))) + out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build()))) if err != nil { s.Publish(AgentErrorEvent, AgentEvent{ ID: id, @@ -153,6 +155,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) { return } usage := out.ResponseMeta.Usage + s.messages.Create(sessionID, *out) if usage != nil { log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens) session, err := s.sessions.Get(sessionID) @@ -170,6 +173,29 @@ func (s *service) handleRequest(id string, sessionID string, content string) { session.PromptTokens += int64(usage.PromptTokens) session.CompletionTokens += int64(usage.CompletionTokens) // TODO: calculate cost + model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))] + session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) + + float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000) + var newTitle string + if len(history) == 1 { + // first message generate the title + newTitle, err = agent.GenerateTitle(s.ctx, content) + if err != nil { + s.Publish(AgentErrorEvent, AgentEvent{ + ID: id, + Type: AgentMessageTypeError, + AgentID: RootAgent, + MessageID: "", + SessionID: sessionID, + Content: err.Error(), + }) + return + } + } + if newTitle != "" { + session.Title = newTitle + } + _, err = s.sessions.Save(session) if err != nil { s.Publish(AgentErrorEvent, AgentEvent{ @@ -183,7 +209,6 @@ func (s *service) handleRequest(id string, sessionID string, content string) { return } } - s.messages.Create(sessionID, *out) } func (s *service) SendRequest(sessionID string, content string) { diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 1895e256..e59da194 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -3,6 +3,7 @@ package models import ( "context" "errors" + "log" "github.com/cloudwego/eino-ext/components/model/claude" "github.com/cloudwego/eino-ext/components/model/openai" @@ -16,10 +17,12 @@ type ( ) type Model struct { - ID ModelID `json:"id"` - Name string `json:"name"` - Provider ModelProvider `json:"provider"` - APIModel string `json:"api_model"` // Actual value used when calling the API + ID ModelID `json:"id"` + Name string `json:"name"` + Provider ModelProvider `json:"provider"` + APIModel string `json:"api_model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` } const ( @@ -52,6 +55,9 @@ const ( // Meta Llama3 ModelID = "llama-3" Llama270B ModelID = "llama-2-70b" + // GROQ + GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec" + GroqQwen32BCoder ModelID = "qwen-2.5-coder-32b" ) const ( @@ -61,6 +67,7 @@ const ( ProviderXAI ModelProvider = "xai" ProviderDeepSeek ModelProvider = "deepseek" ProviderMeta ModelProvider = "meta" + ProviderGroq ModelProvider = "groq" ) var SupportedModels = map[ModelID]Model{ @@ -72,10 +79,12 @@ var SupportedModels = map[ModelID]Model{ APIModel: "gpt-4o", }, GPT4oMini: { - ID: GPT4oMini, - Name: "GPT-4o Mini", - Provider: ProviderOpenAI, - APIModel: "gpt-4o-mini", + ID: GPT4oMini, + Name: "GPT-4o Mini", + Provider: ProviderOpenAI, + APIModel: "gpt-4o-mini", + CostPer1MIn: 0.150, + CostPer1MOut: 0.600, }, GPT45: { ID: GPT45, @@ -172,10 +181,25 @@ var SupportedModels = map[ModelID]Model{ Provider: ProviderMeta, APIModel: "llama-2-70b", }, + + // GROQ + GroqLlama3SpecDec: { + ID: GroqLlama3SpecDec, + Name: "GROQ LLaMA 3 SpecDec", + Provider: ProviderGroq, + APIModel: "llama-3.3-70b-specdec", + }, + GroqQwen32BCoder: { + ID: GroqQwen32BCoder, + Name: "GROQ Qwen 2.5 Coder 32B", + Provider: ProviderGroq, + APIModel: "qwen-2.5-coder-32b", + }, } func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) { provider := SupportedModels[model].Provider + log.Printf("Provider: %s", provider) maxTokens := viper.GetInt("providers.common.max_tokens") switch provider { case ProviderOpenAI: @@ -191,6 +215,14 @@ func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) { MaxTokens: maxTokens, }) + case ProviderGroq: + return openai.NewChatModel(ctx, &openai.ChatModelConfig{ + BaseURL: "https://api.groq.com/openai/v1", + APIKey: viper.GetString("providers.groq.key"), + Model: string(SupportedModels[model].APIModel), + MaxTokens: &maxTokens, + }) + } return nil, errors.New("unsupported provider") } diff --git a/internal/tui/components/repl/messages.go b/internal/tui/components/repl/messages.go index 9d95d98f..feddf7bf 100644 --- a/internal/tui/components/repl/messages.go +++ b/internal/tui/components/repl/messages.go @@ -1,22 +1,33 @@ package repl import ( + "github.com/charmbracelet/bubbles/key" + "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/message" "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/termai/internal/tui/layout" ) +type MessagesCmp interface { + tea.Model + layout.Focusable + layout.Bordered + layout.Sizeable + layout.Bindings +} + type messagesCmp struct { app *app.App messages []message.Message session session.Session -} - -func (m *messagesCmp) Init() tea.Cmd { - return nil + viewport viewport.Model + width int + height int + focused bool } func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -25,6 +36,12 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if msg.Type == pubsub.CreatedEvent { m.messages = append(m.messages, msg.Payload) } + case pubsub.Event[session.Session]: + if msg.Type == pubsub.UpdatedEvent { + if m.session.ID == msg.Payload.ID { + m.session = msg.Payload + } + } case SelectedSessionMsg: m.session, _ = m.app.Sessions.Get(msg.SessionID) m.messages, _ = m.app.Messages.List(m.session.ID) @@ -40,7 +57,55 @@ func (i *messagesCmp) View() string { return lipgloss.JoinVertical(lipgloss.Top, stringMessages...) } -func NewMessagesCmp(app *app.App) tea.Model { +// BindingKeys implements MessagesCmp. +func (m *messagesCmp) BindingKeys() []key.Binding { + return []key.Binding{} +} + +// Blur implements MessagesCmp. +func (m *messagesCmp) Blur() tea.Cmd { + m.focused = false + return nil +} + +// BorderText implements MessagesCmp. +func (m *messagesCmp) BorderText() map[layout.BorderPosition]string { + title := m.session.Title + if len(title) > 20 { + title = title[:20] + "..." + } + return map[layout.BorderPosition]string{ + layout.TopLeftBorder: title, + } +} + +// Focus implements MessagesCmp. +func (m *messagesCmp) Focus() tea.Cmd { + m.focused = true + return nil +} + +// GetSize implements MessagesCmp. +func (m *messagesCmp) GetSize() (int, int) { + return m.width, m.height +} + +// IsFocused implements MessagesCmp. +func (m *messagesCmp) IsFocused() bool { + return m.focused +} + +// SetSize implements MessagesCmp. +func (m *messagesCmp) SetSize(width int, height int) { + m.width = width + m.height = height +} + +func (m *messagesCmp) Init() tea.Cmd { + return nil +} + +func NewMessagesCmp(app *app.App) MessagesCmp { return &messagesCmp{ app: app, messages: []message.Message{}, diff --git a/internal/tui/components/repl/sessions.go b/internal/tui/components/repl/sessions.go index 44c87002..5d2411fb 100644 --- a/internal/tui/components/repl/sessions.go +++ b/internal/tui/components/repl/sessions.go @@ -2,6 +2,7 @@ package repl import ( "fmt" + "strings" "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/list" @@ -82,7 +83,7 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { items[i] = listItem{ id: s.ID, title: s.Title, - desc: fmt.Sprintf("Tokens: %d, Cost: %.2f", s.PromptTokens+s.CompletionTokens, s.Cost), + desc: formatTokensAndCost(s.PromptTokens+s.CompletionTokens, s.Cost), } } return i, i.list.SetItems(items) @@ -94,7 +95,7 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { s := item.(listItem) if s.id == msg.Payload.ID { s.title = msg.Payload.Title - s.desc = fmt.Sprintf("Tokens: %d, Cost: %.2f", msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost) + s.desc = formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost) items[idx] = s break } @@ -169,6 +170,32 @@ func (i *sessionsCmp) BindingKeys() []key.Binding { return append(layout.KeyMapToSlice(i.list.KeyMap), sessionKeyMapValue.Select) } +func formatTokensAndCost(tokens int64, cost float64) string { + // Format tokens in human-readable format (e.g., 110K, 1.2M) + var formattedTokens string + switch { + case tokens >= 1_000_000: + formattedTokens = fmt.Sprintf("%.1fM", float64(tokens)/1_000_000) + case tokens >= 1_000: + formattedTokens = fmt.Sprintf("%.1fK", float64(tokens)/1_000) + default: + formattedTokens = fmt.Sprintf("%d", tokens) + } + + // Remove .0 suffix if present + if strings.HasSuffix(formattedTokens, ".0K") { + formattedTokens = strings.Replace(formattedTokens, ".0K", "K", 1) + } + if strings.HasSuffix(formattedTokens, ".0M") { + formattedTokens = strings.Replace(formattedTokens, ".0M", "M", 1) + } + + // Format cost with $ symbol and 2 decimal places + formattedCost := fmt.Sprintf("$%.2f", cost) + + return fmt.Sprintf("Tokens: %s, Cost: %s", formattedTokens, formattedCost) +} + func NewSessionsCmp(app *app.App) SessionsCmp { listDelegate := list.NewDefaultDelegate() defaultItemStyle := list.NewDefaultItemStyles()