fix logs and add cancellation

This commit is contained in:
Kujtim Hoxha
2025-04-10 13:29:44 +02:00
parent 0b007b9c77
commit 36f201d5d3
23 changed files with 343 additions and 283 deletions

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"log/slog"
"os" "os"
"sync" "sync"
@@ -10,6 +11,7 @@ import (
"github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/db"
"github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/llm/agent"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/tui" "github.com/kujtimiihoxha/termai/internal/tui"
zone "github.com/lrstanley/bubblezone" zone "github.com/lrstanley/bubblezone"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@@ -26,6 +28,16 @@ var rootCmd = &cobra.Command{
} }
debug, _ := cmd.Flags().GetBool("debug") debug, _ := cmd.Flags().GetBool("debug")
err := config.Load(debug) err := config.Load(debug)
cfg := config.Get()
defaultLevel := slog.LevelInfo
if cfg.Debug {
defaultLevel = slog.LevelDebug
}
logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
Level: defaultLevel,
}))
slog.SetDefault(logger)
if err != nil { if err != nil {
return err return err
} }
@@ -37,14 +49,14 @@ var rootCmd = &cobra.Command{
app := app.New(ctx, conn) app := app.New(ctx, conn)
defer app.Close() defer app.Close()
app.Logger.Info("Starting termai...") logging.Info("Starting termai...")
zone.NewGlobal() zone.NewGlobal()
tui := tea.NewProgram( tui := tea.NewProgram(
tui.New(app), tui.New(app),
tea.WithAltScreen(), tea.WithAltScreen(),
tea.WithMouseCellMotion(), tea.WithMouseCellMotion(),
) )
app.Logger.Info("Setting up subscriptions...") logging.Info("Setting up subscriptions...")
ch, unsub := setupSubscriptions(app) ch, unsub := setupSubscriptions(app)
defer unsub() defer unsub()
@@ -66,9 +78,8 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
ch := make(chan tea.Msg) ch := make(chan tea.Msg)
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
ctx, cancel := context.WithCancel(app.Context) ctx, cancel := context.WithCancel(app.Context)
{ {
sub := app.Logger.Subscribe(ctx) sub := logging.Subscribe(ctx)
wg.Add(1) wg.Add(1)
go func() { go func() {
for ev := range sub { for ev := range sub {

4
go.mod
View File

@@ -33,7 +33,7 @@ require (
github.com/spf13/cobra v1.9.1 github.com/spf13/cobra v1.9.1
github.com/spf13/viper v1.20.0 github.com/spf13/viper v1.20.0
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 golang.org/x/net v0.34.0
google.golang.org/api v0.215.0 google.golang.org/api v0.215.0
) )
@@ -116,10 +116,10 @@ require (
go.uber.org/multierr v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect
golang.design/x/clipboard v0.7.0 // indirect golang.design/x/clipboard v0.7.0 // indirect
golang.org/x/crypto v0.33.0 // indirect golang.org/x/crypto v0.33.0 // indirect
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect
golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect
golang.org/x/image v0.14.0 // indirect golang.org/x/image v0.14.0 // indirect
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/oauth2 v0.25.0 // indirect
golang.org/x/sync v0.12.0 // indirect golang.org/x/sync v0.12.0 // indirect
golang.org/x/sys v0.31.0 // indirect golang.org/x/sys v0.31.0 // indirect

View File

@@ -3,6 +3,7 @@ package app
import ( import (
"context" "context"
"database/sql" "database/sql"
"log/slog"
"github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/db"
@@ -23,16 +24,14 @@ type App struct {
LSPClients map[string]*lsp.Client LSPClients map[string]*lsp.Client
Logger logging.Interface
ceanups []func() ceanups []func()
} }
func New(ctx context.Context, conn *sql.DB) *App { func New(ctx context.Context, conn *sql.DB) *App {
cfg := config.Get() cfg := config.Get()
logging.Info("Debug mode enabled")
q := db.New(conn) q := db.New(conn)
log := logging.Get()
log.SetLevel(cfg.Log.Level)
sessions := session.NewService(ctx, q) sessions := session.NewService(ctx, q)
messages := message.NewService(ctx, q) messages := message.NewService(ctx, q)
@@ -41,7 +40,6 @@ func New(ctx context.Context, conn *sql.DB) *App {
Sessions: sessions, Sessions: sessions,
Messages: messages, Messages: messages,
Permissions: permission.NewPermissionService(), Permissions: permission.NewPermissionService(),
Logger: log,
LSPClients: make(map[string]*lsp.Client), LSPClients: make(map[string]*lsp.Client),
} }
@@ -52,13 +50,13 @@ func New(ctx context.Context, conn *sql.DB) *App {
}) })
workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient) workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient)
if err != nil { if err != nil {
log.Error("Failed to create LSP client for", name, err) logging.Error("Failed to create LSP client for", name, err)
continue continue
} }
_, err = lspClient.InitializeLSPClient(ctx, config.WorkingDirectory()) _, err = lspClient.InitializeLSPClient(ctx, config.WorkingDirectory())
if err != nil { if err != nil {
log.Error("Initialize failed", "error", err) logging.Error("Initialize failed", "error", err)
continue continue
} }
go workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory()) go workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory())
@@ -74,5 +72,5 @@ func (a *App) Close() {
for _, client := range a.LSPClients { for _, client := range a.LSPClients {
client.Close() client.Close()
} }
a.Logger.Info("App closed") slog.Info("App closed")
} }

View File

@@ -16,8 +16,6 @@ import (
"github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/logging"
) )
var log = logging.Get()
func Connect() (*sql.DB, error) { func Connect() (*sql.DB, error) {
dataDir := config.Get().Data.Directory dataDir := config.Get().Data.Directory
if dataDir == "" { if dataDir == "" {
@@ -50,43 +48,43 @@ func Connect() (*sql.DB, error) {
for _, pragma := range pragmas { for _, pragma := range pragmas {
if _, err = db.Exec(pragma); err != nil { if _, err = db.Exec(pragma); err != nil {
log.Warn("Failed to set pragma", pragma, err) logging.Warn("Failed to set pragma", pragma, err)
} else { } else {
log.Warn("Set pragma", "pragma", pragma) logging.Warn("Set pragma", "pragma", pragma)
} }
} }
// Initialize schema from embedded file // Initialize schema from embedded file
d, err := iofs.New(FS, "migrations") d, err := iofs.New(FS, "migrations")
if err != nil { if err != nil {
log.Error("Failed to open embedded migrations", "error", err) logging.Error("Failed to open embedded migrations", "error", err)
db.Close() db.Close()
return nil, fmt.Errorf("failed to open embedded migrations: %w", err) return nil, fmt.Errorf("failed to open embedded migrations: %w", err)
} }
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil { if err != nil {
log.Error("Failed to create SQLite driver", "error", err) logging.Error("Failed to create SQLite driver", "error", err)
db.Close() db.Close()
return nil, fmt.Errorf("failed to create SQLite driver: %w", err) return nil, fmt.Errorf("failed to create SQLite driver: %w", err)
} }
m, err := migrate.NewWithInstance("iofs", d, "ql", driver) m, err := migrate.NewWithInstance("iofs", d, "ql", driver)
if err != nil { if err != nil {
log.Error("Failed to create migration instance", "error", err) logging.Error("Failed to create migration instance", "error", err)
db.Close() db.Close()
return nil, fmt.Errorf("failed to create migration instance: %w", err) return nil, fmt.Errorf("failed to create migration instance: %w", err)
} }
err = m.Up() err = m.Up()
if err != nil && err != migrate.ErrNoChange { if err != nil && err != migrate.ErrNoChange {
log.Error("Migration failed", "error", err) logging.Error("Migration failed", "error", err)
db.Close() db.Close()
return nil, fmt.Errorf("failed to apply schema: %w", err) return nil, fmt.Errorf("failed to apply schema: %w", err)
} else if err == migrate.ErrNoChange { } else if err == migrate.ErrNoChange {
log.Info("No schema changes to apply") logging.Info("No schema changes to apply")
} else { } else {
log.Info("Schema migration applied successfully") logging.Info("Schema migration applied successfully")
} }
return db, nil return db, nil

View File

@@ -56,7 +56,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil
} }
err = agent.Generate(session.ID, params.Prompt) err = agent.Generate(ctx, session.ID, params.Prompt)
if err != nil { if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil
} }

View File

@@ -13,11 +13,12 @@ import (
"github.com/kujtimiihoxha/termai/internal/llm/prompt" "github.com/kujtimiihoxha/termai/internal/llm/prompt"
"github.com/kujtimiihoxha/termai/internal/llm/provider" "github.com/kujtimiihoxha/termai/internal/llm/provider"
"github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message" "github.com/kujtimiihoxha/termai/internal/message"
) )
type Agent interface { type Agent interface {
Generate(sessionID string, content string) error Generate(ctx context.Context, sessionID string, content string) error
} }
type agent struct { type agent struct {
@@ -28,9 +29,9 @@ type agent struct {
titleGenerator provider.Provider titleGenerator provider.Provider
} }
func (c *agent) handleTitleGeneration(sessionID, content string) { func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
response, err := c.titleGenerator.SendMessages( response, err := c.titleGenerator.SendMessages(
c.Context, ctx,
[]message.Message{ []message.Message{
{ {
Role: message.User, Role: message.User,
@@ -91,13 +92,16 @@ func (c *agent) processEvent(
assistantMsg.AppendContent(event.Content) assistantMsg.AppendContent(event.Content)
return c.Messages.Update(*assistantMsg) return c.Messages.Update(*assistantMsg)
case provider.EventError: case provider.EventError:
c.App.Logger.PersistError(event.Error.Error()) if errors.Is(event.Error, context.Canceled) {
return nil
}
logging.ErrorPersist(event.Error.Error())
return event.Error return event.Error
case provider.EventWarning: case provider.EventWarning:
c.App.Logger.PersistWarn(event.Info) logging.WarnPersist(event.Info)
return nil return nil
case provider.EventInfo: case provider.EventInfo:
c.App.Logger.PersistInfo(event.Info) logging.InfoPersist(event.Info)
case provider.EventComplete: case provider.EventComplete:
assistantMsg.SetToolCalls(event.Response.ToolCalls) assistantMsg.SetToolCalls(event.Response.ToolCalls)
assistantMsg.AddFinish(event.Response.FinishReason) assistantMsg.AddFinish(event.Response.FinishReason)
@@ -115,12 +119,37 @@ func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall,
var wg sync.WaitGroup var wg sync.WaitGroup
toolResults := make([]message.ToolResult, len(toolCalls)) toolResults := make([]message.ToolResult, len(toolCalls))
mutex := &sync.Mutex{} mutex := &sync.Mutex{}
errChan := make(chan error, 1)
// Create a child context that can be canceled
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for i, tc := range toolCalls { for i, tc := range toolCalls {
wg.Add(1) wg.Add(1)
go func(index int, toolCall message.ToolCall) { go func(index int, toolCall message.ToolCall) {
defer wg.Done() defer wg.Done()
// Check if context is already canceled
select {
case <-ctx.Done():
mutex.Lock()
toolResults[index] = message.ToolResult{
ToolCallID: toolCall.ID,
Content: "Tool execution canceled",
IsError: true,
}
mutex.Unlock()
// Send cancellation error to error channel if it's empty
select {
case errChan <- ctx.Err():
default:
}
return
default:
}
response := "" response := ""
isError := false isError := false
found := false found := false
@@ -133,8 +162,19 @@ func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall,
Name: toolCall.Name, Name: toolCall.Name,
Input: toolCall.Input, Input: toolCall.Input,
}) })
if toolErr != nil { if toolErr != nil {
if errors.Is(toolErr, context.Canceled) {
response = "Tool execution canceled"
// Send cancellation error to error channel if it's empty
select {
case errChan <- ctx.Err():
default:
}
} else {
response = fmt.Sprintf("error running tool: %s", toolErr) response = fmt.Sprintf("error running tool: %s", toolErr)
}
isError = true isError = true
} else { } else {
response = toolResult.Content response = toolResult.Content
@@ -160,7 +200,24 @@ func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall,
}(i, tc) }(i, tc)
} }
// Wait for all goroutines to finish or context to be canceled
done := make(chan struct{})
go func() {
wg.Wait() wg.Wait()
close(done)
}()
select {
case <-done:
// All tools completed successfully
case err := <-errChan:
// One of the tools encountered a cancellation
return toolResults, err
case <-ctx.Done():
// Context was canceled externally
return toolResults, ctx.Err()
}
return toolResults, nil return toolResults, nil
} }
@@ -188,14 +245,14 @@ func (c *agent) handleToolExecution(
return &msg, err return &msg, err
} }
func (c *agent) generate(sessionID string, content string) error { func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
messages, err := c.Messages.List(sessionID) messages, err := c.Messages.List(sessionID)
if err != nil { if err != nil {
return err return err
} }
if len(messages) == 0 { if len(messages) == 0 {
go c.handleTitleGeneration(sessionID, content) go c.handleTitleGeneration(ctx, sessionID, content)
} }
userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
@@ -212,11 +269,38 @@ func (c *agent) generate(sessionID string, content string) error {
messages = append(messages, userMsg) messages = append(messages, userMsg)
for { for {
select {
eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools) case <-ctx.Done():
assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{},
})
if err != nil { if err != nil {
return err return err
} }
assistantMsg.AddFinish("canceled")
c.Messages.Update(assistantMsg)
return context.Canceled
default:
// Continue processing
}
eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
if err != nil {
if errors.Is(err, context.Canceled) {
assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{},
})
if err != nil {
return err
}
assistantMsg.AddFinish("canceled")
c.Messages.Update(assistantMsg)
return context.Canceled
}
return err
}
assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
Role: message.Assistant, Role: message.Assistant,
@@ -228,19 +312,47 @@ func (c *agent) generate(sessionID string, content string) error {
for event := range eventChan { for event := range eventChan {
err = c.processEvent(sessionID, &assistantMsg, event) err = c.processEvent(sessionID, &assistantMsg, event)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
assistantMsg.AddFinish("canceled")
c.Messages.Update(assistantMsg)
return context.Canceled
}
assistantMsg.AddFinish("error:" + err.Error()) assistantMsg.AddFinish("error:" + err.Error())
c.Messages.Update(assistantMsg) c.Messages.Update(assistantMsg)
return err return err
} }
select {
case <-ctx.Done():
assistantMsg.AddFinish("canceled")
c.Messages.Update(assistantMsg)
return context.Canceled
default:
}
} }
msg, err := c.handleToolExecution(c.Context, assistantMsg) // Check for context cancellation before tool execution
select {
case <-ctx.Done():
assistantMsg.AddFinish("canceled")
c.Messages.Update(assistantMsg) c.Messages.Update(assistantMsg)
return context.Canceled
default:
// Continue processing
}
msg, err := c.handleToolExecution(ctx, assistantMsg)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
assistantMsg.AddFinish("canceled")
c.Messages.Update(assistantMsg)
return context.Canceled
}
return err return err
} }
c.Messages.Update(assistantMsg)
if len(assistantMsg.ToolCalls()) == 0 { if len(assistantMsg.ToolCalls()) == 0 {
break break
} }
@@ -249,6 +361,16 @@ func (c *agent) generate(sessionID string, content string) error {
if msg != nil { if msg != nil {
messages = append(messages, *msg) messages = append(messages, *msg)
} }
// Check for context cancellation after tool execution
select {
case <-ctx.Done():
assistantMsg.AddFinish("canceled")
c.Messages.Update(assistantMsg)
return context.Canceled
default:
// Continue processing
}
} }
return nil return nil
} }

View File

@@ -1,6 +1,7 @@
package agent package agent
import ( import (
"context"
"errors" "errors"
"github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/app"
@@ -28,9 +29,9 @@ func (c *coderAgent) setAgentTool(sessionID string) {
} }
} }
func (c *coderAgent) Generate(sessionID string, content string) error { func (c *coderAgent) Generate(ctx context.Context, sessionID string, content string) error {
c.setAgentTool(sessionID) c.setAgentTool(sessionID)
return c.generate(sessionID, content) return c.generate(ctx, sessionID, content)
} }
func NewCoderAgent(app *app.App) (Agent, error) { func NewCoderAgent(app *app.App) (Agent, error) {

View File

@@ -22,8 +22,6 @@ type mcpTool struct {
permissions permission.Service permissions permission.Service
} }
var logger = logging.Get()
type MCPClient interface { type MCPClient interface {
Initialize( Initialize(
ctx context.Context, ctx context.Context,
@@ -143,13 +141,13 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions
_, err := c.Initialize(ctx, initRequest) _, err := c.Initialize(ctx, initRequest)
if err != nil { if err != nil {
logger.Error("error initializing mcp client", "error", err) logging.Error("error initializing mcp client", "error", err)
return stdioTools return stdioTools
} }
toolsRequest := mcp.ListToolsRequest{} toolsRequest := mcp.ListToolsRequest{}
tools, err := c.ListTools(ctx, toolsRequest) tools, err := c.ListTools(ctx, toolsRequest)
if err != nil { if err != nil {
logger.Error("error listing tools", "error", err) logging.Error("error listing tools", "error", err)
return stdioTools return stdioTools
} }
for _, t := range tools.Tools { for _, t := range tools.Tools {
@@ -172,7 +170,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
m.Args..., m.Args...,
) )
if err != nil { if err != nil {
logger.Error("error creating mcp client", "error", err) logging.Error("error creating mcp client", "error", err)
continue continue
} }
@@ -183,7 +181,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
client.WithHeaders(m.Headers), client.WithHeaders(m.Headers),
) )
if err != nil { if err != nil {
logger.Error("error creating mcp client", "error", err) logging.Error("error creating mcp client", "error", err)
continue continue
} }
mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...) mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)

View File

@@ -1,6 +1,7 @@
package agent package agent
import ( import (
"context"
"errors" "errors"
"github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/app"
@@ -13,8 +14,8 @@ type taskAgent struct {
*agent *agent
} }
func (c *taskAgent) Generate(sessionID string, content string) error { func (c *taskAgent) Generate(ctx context.Context, sessionID string, content string) error {
return c.generate(sessionID, content) return c.generate(ctx, sessionID, content)
} }
func NewTaskAgent(app *app.App) (Agent, error) { func NewTaskAgent(app *app.App) (Agent, error) {

View File

@@ -1,12 +0,0 @@
package logging
var defaultLogger Interface
func Get() Interface {
if defaultLogger == nil {
defaultLogger = NewLogger(Options{
Level: "info",
})
}
return defaultLogger
}

View File

@@ -1,141 +1,39 @@
package logging package logging
import ( import "log/slog"
"context"
"io"
"log/slog"
"slices"
"github.com/kujtimiihoxha/termai/internal/pubsub" func Info(msg string, args ...any) {
"golang.org/x/exp/maps" slog.Info(msg, args...)
)
const DefaultLevel = "info"
const (
persistKeyArg = "$persist"
PersistTimeArg = "$persist_time"
)
var levels = map[string]slog.Level{
"debug": slog.LevelDebug,
DefaultLevel: slog.LevelInfo,
"warn": slog.LevelWarn,
"error": slog.LevelError,
} }
func ValidLevels() []string { func Debug(msg string, args ...any) {
keys := maps.Keys(levels) slog.Debug(msg, args...)
slices.SortFunc(keys, func(a, b string) int {
if a == DefaultLevel {
return -1
}
if b == DefaultLevel {
return 1
}
if a < b {
return -1
}
return 1
})
return keys
} }
func NewLogger(opts Options) Interface { func Warn(msg string, args ...any) {
logger := &Logger{} slog.Warn(msg, args...)
broker := pubsub.NewBroker[LogMessage]()
writer := &writer{
messages: []LogMessage{},
Broker: broker,
}
handler := slog.NewTextHandler(
io.MultiWriter(writer),
&slog.HandlerOptions{
Level: slog.Level(levels[opts.Level]),
},
)
logger.logger = slog.New(handler)
logger.writer = writer
return logger
} }
type Options struct { func Error(msg string, args ...any) {
Level string slog.Error(msg, args...)
} }
type Logger struct { func InfoPersist(msg string, args ...any) {
logger *slog.Logger
writer *writer
}
func (l *Logger) SetLevel(level string) {
if _, ok := levels[level]; !ok {
level = DefaultLevel
}
handler := slog.NewTextHandler(
io.MultiWriter(l.writer),
&slog.HandlerOptions{
Level: levels[level],
},
)
l.logger = slog.New(handler)
}
// PersistDebug implements Interface.
func (l *Logger) PersistDebug(msg string, args ...any) {
args = append(args, persistKeyArg, true) args = append(args, persistKeyArg, true)
l.Debug(msg, args...) slog.Info(msg, args...)
} }
// PersistError implements Interface. func DebugPersist(msg string, args ...any) {
func (l *Logger) PersistError(msg string, args ...any) {
args = append(args, persistKeyArg, true) args = append(args, persistKeyArg, true)
l.Error(msg, args...) slog.Debug(msg, args...)
} }
// PersistInfo implements Interface. func WarnPersist(msg string, args ...any) {
func (l *Logger) PersistInfo(msg string, args ...any) {
args = append(args, persistKeyArg, true) args = append(args, persistKeyArg, true)
l.Info(msg, args...) slog.Warn(msg, args...)
} }
// PersistWarn implements Interface. func ErrorPersist(msg string, args ...any) {
func (l *Logger) PersistWarn(msg string, args ...any) {
args = append(args, persistKeyArg, true) args = append(args, persistKeyArg, true)
l.Warn(msg, args...) slog.Error(msg, args...)
}
func (l *Logger) Debug(msg string, args ...any) {
l.logger.Debug(msg, args...)
}
func (l *Logger) Info(msg string, args ...any) {
l.logger.Info(msg, args...)
}
func (l *Logger) Warn(msg string, args ...any) {
l.logger.Warn(msg, args...)
}
func (l *Logger) Error(msg string, args ...any) {
l.logger.Error(msg, args...)
}
func (l *Logger) List() []LogMessage {
return l.writer.messages
}
func (l *Logger) Get(id string) (LogMessage, error) {
for _, msg := range l.writer.messages {
if msg.ID == id {
return msg, nil
}
}
return LogMessage{}, io.EOF
}
func (l *Logger) Subscribe(ctx context.Context) <-chan pubsub.Event[LogMessage] {
return l.writer.Subscribe(ctx)
} }

View File

@@ -1,23 +0,0 @@
package logging
import (
"context"
"github.com/kujtimiihoxha/termai/internal/pubsub"
)
type Interface interface {
Debug(msg string, args ...any)
Info(msg string, args ...any)
Warn(msg string, args ...any)
Error(msg string, args ...any)
Subscribe(ctx context.Context) <-chan pubsub.Event[LogMessage]
PersistDebug(msg string, args ...any)
PersistInfo(msg string, args ...any)
PersistWarn(msg string, args ...any)
PersistError(msg string, args ...any)
List() []LogMessage
SetLevel(level string)
}

View File

@@ -2,18 +2,47 @@ package logging
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"strings"
"sync"
"time" "time"
"github.com/go-logfmt/logfmt" "github.com/go-logfmt/logfmt"
"github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/pubsub"
) )
type writer struct { const (
persistKeyArg = "$_persist"
PersistTimeArg = "$_persist_time"
)
type LogData struct {
messages []LogMessage messages []LogMessage
*pubsub.Broker[LogMessage] *pubsub.Broker[LogMessage]
lock sync.Mutex
} }
func (l *LogData) Add(msg LogMessage) {
l.lock.Lock()
defer l.lock.Unlock()
l.messages = append(l.messages, msg)
l.Publish(pubsub.CreatedEvent, msg)
}
func (l *LogData) List() []LogMessage {
l.lock.Lock()
defer l.lock.Unlock()
return l.messages
}
var defaultLogData = &LogData{
messages: make([]LogMessage, 0),
Broker: pubsub.NewBroker[LogMessage](),
}
type writer struct{}
func (w *writer) Write(p []byte) (int, error) { func (w *writer) Write(p []byte) (int, error) {
d := logfmt.NewDecoder(bytes.NewReader(p)) d := logfmt.NewDecoder(bytes.NewReader(p))
for d.ScanRecord() { for d.ScanRecord() {
@@ -30,7 +59,7 @@ func (w *writer) Write(p []byte) (int, error) {
} }
msg.Time = parsed msg.Time = parsed
case "level": case "level":
msg.Level = string(d.Value()) msg.Level = strings.ToLower(string(d.Value()))
case "msg": case "msg":
msg.Message = string(d.Value()) msg.Message = string(d.Value())
default: default:
@@ -50,11 +79,23 @@ func (w *writer) Write(p []byte) (int, error) {
} }
} }
} }
w.messages = append(w.messages, msg) defaultLogData.Add(msg)
w.Publish(pubsub.CreatedEvent, msg)
} }
if d.Err() != nil { if d.Err() != nil {
return 0, d.Err() return 0, d.Err()
} }
return len(p), nil return len(p), nil
} }
func NewWriter() *writer {
w := &writer{}
return w
}
func Subscribe(ctx context.Context) <-chan pubsub.Event[LogMessage] {
return defaultLogData.Subscribe(ctx)
}
func List() []LogMessage {
return defaultLogData.List()
}

View File

@@ -18,8 +18,6 @@ import (
"github.com/kujtimiihoxha/termai/internal/lsp/protocol" "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
) )
var logger = logging.Get()
type Client struct { type Client struct {
Cmd *exec.Cmd Cmd *exec.Cmd
stdin io.WriteCloser stdin io.WriteCloser
@@ -377,7 +375,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error {
} }
if cnf.Debug { if cnf.Debug {
logger.Debug("Closing file", "file", filepath) logging.Debug("Closing file", "file", filepath)
} }
if err := c.Notify(ctx, "textDocument/didClose", params); err != nil { if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
return err return err
@@ -416,12 +414,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) {
for _, filePath := range filesToClose { for _, filePath := range filesToClose {
err := c.CloseFile(ctx, filePath) err := c.CloseFile(ctx, filePath)
if err != nil && cnf.Debug { if err != nil && cnf.Debug {
logger.Warn("Error closing file", "file", filePath, "error", err) logging.Warn("Error closing file", "file", filePath, "error", err)
} }
} }
if cnf.Debug { if cnf.Debug {
logger.Debug("Closed all files", "files", filesToClose) logging.Debug("Closed all files", "files", filesToClose)
} }
} }

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/lsp/protocol" "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
"github.com/kujtimiihoxha/termai/internal/lsp/util" "github.com/kujtimiihoxha/termai/internal/lsp/util"
) )
@@ -17,7 +18,7 @@ func HandleWorkspaceConfiguration(params json.RawMessage) (any, error) {
func HandleRegisterCapability(params json.RawMessage) (any, error) { func HandleRegisterCapability(params json.RawMessage) (any, error) {
var registerParams protocol.RegistrationParams var registerParams protocol.RegistrationParams
if err := json.Unmarshal(params, &registerParams); err != nil { if err := json.Unmarshal(params, &registerParams); err != nil {
logger.Error("Error unmarshaling registration params", "error", err) logging.Error("Error unmarshaling registration params", "error", err)
return nil, err return nil, err
} }
@@ -27,13 +28,13 @@ func HandleRegisterCapability(params json.RawMessage) (any, error) {
// Parse the registration options // Parse the registration options
optionsJSON, err := json.Marshal(reg.RegisterOptions) optionsJSON, err := json.Marshal(reg.RegisterOptions)
if err != nil { if err != nil {
logger.Error("Error marshaling registration options", "error", err) logging.Error("Error marshaling registration options", "error", err)
continue continue
} }
var options protocol.DidChangeWatchedFilesRegistrationOptions var options protocol.DidChangeWatchedFilesRegistrationOptions
if err := json.Unmarshal(optionsJSON, &options); err != nil { if err := json.Unmarshal(optionsJSON, &options); err != nil {
logger.Error("Error unmarshaling registration options", "error", err) logging.Error("Error unmarshaling registration options", "error", err)
continue continue
} }
@@ -53,7 +54,7 @@ func HandleApplyEdit(params json.RawMessage) (any, error) {
err := util.ApplyWorkspaceEdit(edit.Edit) err := util.ApplyWorkspaceEdit(edit.Edit)
if err != nil { if err != nil {
logger.Error("Error applying workspace edit", "error", err) logging.Error("Error applying workspace edit", "error", err)
return protocol.ApplyWorkspaceEditResult{Applied: false, FailureReason: err.Error()}, nil return protocol.ApplyWorkspaceEditResult{Applied: false, FailureReason: err.Error()}, nil
} }
@@ -88,7 +89,7 @@ func HandleServerMessage(params json.RawMessage) {
} }
if err := json.Unmarshal(params, &msg); err == nil { if err := json.Unmarshal(params, &msg); err == nil {
if cnf.Debug { if cnf.Debug {
logger.Debug("Server message", "type", msg.Type, "message", msg.Message) logging.Debug("Server message", "type", msg.Type, "message", msg.Message)
} }
} }
} }
@@ -96,7 +97,7 @@ func HandleServerMessage(params json.RawMessage) {
func HandleDiagnostics(client *Client, params json.RawMessage) { func HandleDiagnostics(client *Client, params json.RawMessage) {
var diagParams protocol.PublishDiagnosticsParams var diagParams protocol.PublishDiagnosticsParams
if err := json.Unmarshal(params, &diagParams); err != nil { if err := json.Unmarshal(params, &diagParams); err != nil {
logger.Error("Error unmarshaling diagnostics params", "error", err) logging.Error("Error unmarshaling diagnostics params", "error", err)
return return
} }

View File

@@ -9,6 +9,7 @@ import (
"strings" "strings"
"github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/logging"
) )
// Write writes an LSP message to the given writer // Write writes an LSP message to the given writer
@@ -20,7 +21,7 @@ func WriteMessage(w io.Writer, msg *Message) error {
cnf := config.Get() cnf := config.Get()
if cnf.Debug { if cnf.Debug {
logger.Debug("Sending message to server", "method", msg.Method, "id", msg.ID) logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
} }
_, err = fmt.Fprintf(w, "Content-Length: %d\r\n\r\n", len(data)) _, err = fmt.Fprintf(w, "Content-Length: %d\r\n\r\n", len(data))
@@ -49,7 +50,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
if cnf.Debug { if cnf.Debug {
logger.Debug("Received header", "line", line) logging.Debug("Received header", "line", line)
} }
if line == "" { if line == "" {
@@ -65,7 +66,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
} }
if cnf.Debug { if cnf.Debug {
logger.Debug("Content-Length", "length", contentLength) logging.Debug("Content-Length", "length", contentLength)
} }
// Read content // Read content
@@ -76,7 +77,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
} }
if cnf.Debug { if cnf.Debug {
logger.Debug("Received content", "content", string(content)) logging.Debug("Received content", "content", string(content))
} }
// Parse message // Parse message
@@ -95,7 +96,7 @@ func (c *Client) handleMessages() {
msg, err := ReadMessage(c.stdout) msg, err := ReadMessage(c.stdout)
if err != nil { if err != nil {
if cnf.Debug { if cnf.Debug {
logger.Error("Error reading message", "error", err) logging.Error("Error reading message", "error", err)
} }
return return
} }
@@ -103,7 +104,7 @@ func (c *Client) handleMessages() {
// Handle server->client request (has both Method and ID) // Handle server->client request (has both Method and ID)
if msg.Method != "" && msg.ID != 0 { if msg.Method != "" && msg.ID != 0 {
if cnf.Debug { if cnf.Debug {
logger.Debug("Received request from server", "method", msg.Method, "id", msg.ID) logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
} }
response := &Message{ response := &Message{
@@ -143,7 +144,7 @@ func (c *Client) handleMessages() {
// Send response back to server // Send response back to server
if err := WriteMessage(c.stdin, response); err != nil { if err := WriteMessage(c.stdin, response); err != nil {
logger.Error("Error sending response to server", "error", err) logging.Error("Error sending response to server", "error", err)
} }
continue continue
@@ -157,11 +158,11 @@ func (c *Client) handleMessages() {
if ok { if ok {
if cnf.Debug { if cnf.Debug {
logger.Debug("Handling notification", "method", msg.Method) logging.Debug("Handling notification", "method", msg.Method)
} }
go handler(msg.Params) go handler(msg.Params)
} else if cnf.Debug { } else if cnf.Debug {
logger.Debug("No handler for notification", "method", msg.Method) logging.Debug("No handler for notification", "method", msg.Method)
} }
continue continue
} }
@@ -174,12 +175,12 @@ func (c *Client) handleMessages() {
if ok { if ok {
if cnf.Debug { if cnf.Debug {
logger.Debug("Received response for request", "id", msg.ID) logging.Debug("Received response for request", "id", msg.ID)
} }
ch <- msg ch <- msg
close(ch) close(ch)
} else if cnf.Debug { } else if cnf.Debug {
logger.Debug("No handler for response", "id", msg.ID) logging.Debug("No handler for response", "id", msg.ID)
} }
} }
} }
@@ -191,7 +192,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
id := c.nextID.Add(1) id := c.nextID.Add(1)
if cnf.Debug { if cnf.Debug {
logger.Debug("Making call", "method", method, "id", id) logging.Debug("Making call", "method", method, "id", id)
} }
msg, err := NewRequest(id, method, params) msg, err := NewRequest(id, method, params)
@@ -217,14 +218,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
} }
if cnf.Debug { if cnf.Debug {
logger.Debug("Request sent", "method", method, "id", id) logging.Debug("Request sent", "method", method, "id", id)
} }
// Wait for response // Wait for response
resp := <-ch resp := <-ch
if cnf.Debug { if cnf.Debug {
logger.Debug("Received response", "id", id) logging.Debug("Received response", "id", id)
} }
if resp.Error != nil { if resp.Error != nil {
@@ -250,7 +251,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
func (c *Client) Notify(ctx context.Context, method string, params any) error { func (c *Client) Notify(ctx context.Context, method string, params any) error {
cnf := config.Get() cnf := config.Get()
if cnf.Debug { if cnf.Debug {
logger.Debug("Sending notification", "method", method) logging.Debug("Sending notification", "method", method)
} }
msg, err := NewNotification(method, params) msg, err := NewNotification(method, params)

View File

@@ -16,8 +16,6 @@ import (
"github.com/kujtimiihoxha/termai/internal/lsp/protocol" "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
) )
var logger = logging.Get()
// WorkspaceWatcher manages LSP file watching // WorkspaceWatcher manages LSP file watching
type WorkspaceWatcher struct { type WorkspaceWatcher struct {
client *lsp.Client client *lsp.Client
@@ -53,7 +51,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// Print detailed registration information for debugging // Print detailed registration information for debugging
if cnf.Debug { if cnf.Debug {
logger.Debug("Adding file watcher registrations", logging.Debug("Adding file watcher registrations",
"id", id, "id", id,
"watchers", len(watchers), "watchers", len(watchers),
"total", len(w.registrations), "total", len(w.registrations),
@@ -61,26 +59,26 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
) )
for i, watcher := range watchers { for i, watcher := range watchers {
logger.Debug("Registration", "index", i+1) logging.Debug("Registration", "index", i+1)
// Log the GlobPattern // Log the GlobPattern
switch v := watcher.GlobPattern.Value.(type) { switch v := watcher.GlobPattern.Value.(type) {
case string: case string:
logger.Debug("GlobPattern", "pattern", v) logging.Debug("GlobPattern", "pattern", v)
case protocol.RelativePattern: case protocol.RelativePattern:
logger.Debug("GlobPattern", "pattern", v.Pattern) logging.Debug("GlobPattern", "pattern", v.Pattern)
// Log BaseURI details // Log BaseURI details
switch u := v.BaseURI.Value.(type) { switch u := v.BaseURI.Value.(type) {
case string: case string:
logger.Debug("BaseURI", "baseURI", u) logging.Debug("BaseURI", "baseURI", u)
case protocol.DocumentUri: case protocol.DocumentUri:
logger.Debug("BaseURI", "baseURI", u) logging.Debug("BaseURI", "baseURI", u)
default: default:
logger.Debug("BaseURI", "baseURI", u) logging.Debug("BaseURI", "baseURI", u)
} }
default: default:
logger.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v)) logging.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v))
} }
// Log WatchKind // Log WatchKind
@@ -89,7 +87,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
watchKind = *watcher.Kind watchKind = *watcher.Kind
} }
logger.Debug("WatchKind", "kind", watchKind) logging.Debug("WatchKind", "kind", watchKind)
// Test match against some example paths // Test match against some example paths
testPaths := []string{ testPaths := []string{
@@ -99,7 +97,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
for _, testPath := range testPaths { for _, testPath := range testPaths {
isMatch := w.matchesPattern(testPath, watcher.GlobPattern) isMatch := w.matchesPattern(testPath, watcher.GlobPattern)
logger.Debug("Test path", "path", testPath, "matches", isMatch) logging.Debug("Test path", "path", testPath, "matches", isMatch)
} }
} }
} }
@@ -119,7 +117,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
if d.IsDir() { if d.IsDir() {
if path != w.workspacePath && shouldExcludeDir(path) { if path != w.workspacePath && shouldExcludeDir(path) {
if cnf.Debug { if cnf.Debug {
logger.Debug("Skipping excluded directory", "path", path) logging.Debug("Skipping excluded directory", "path", path)
} }
return filepath.SkipDir return filepath.SkipDir
} }
@@ -139,7 +137,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
elapsedTime := time.Since(startTime) elapsedTime := time.Since(startTime)
if cnf.Debug { if cnf.Debug {
logger.Debug("Workspace scan complete", logging.Debug("Workspace scan complete",
"filesOpened", filesOpened, "filesOpened", filesOpened,
"elapsedTime", elapsedTime.Seconds(), "elapsedTime", elapsedTime.Seconds(),
"workspacePath", w.workspacePath, "workspacePath", w.workspacePath,
@@ -147,7 +145,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
} }
if err != nil && cnf.Debug { if err != nil && cnf.Debug {
logger.Debug("Error scanning workspace for files to open", "error", err) logging.Debug("Error scanning workspace for files to open", "error", err)
} }
}() }()
} }
@@ -164,7 +162,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
watcher, err := fsnotify.NewWatcher() watcher, err := fsnotify.NewWatcher()
if err != nil { if err != nil {
logger.Error("Error creating watcher", "error", err) logging.Error("Error creating watcher", "error", err)
} }
defer watcher.Close() defer watcher.Close()
@@ -178,7 +176,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
if d.IsDir() && path != workspacePath { if d.IsDir() && path != workspacePath {
if shouldExcludeDir(path) { if shouldExcludeDir(path) {
if cnf.Debug { if cnf.Debug {
logger.Debug("Skipping excluded directory", "path", path) logging.Debug("Skipping excluded directory", "path", path)
} }
return filepath.SkipDir return filepath.SkipDir
} }
@@ -188,14 +186,14 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
if d.IsDir() { if d.IsDir() {
err = watcher.Add(path) err = watcher.Add(path)
if err != nil { if err != nil {
logger.Error("Error watching path", "path", path, "error", err) logging.Error("Error watching path", "path", path, "error", err)
} }
} }
return nil return nil
}) })
if err != nil { if err != nil {
logger.Error("Error walking workspace", "error", err) logging.Error("Error walking workspace", "error", err)
} }
// Event loop // Event loop
@@ -217,7 +215,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
// Skip excluded directories // Skip excluded directories
if !shouldExcludeDir(event.Name) { if !shouldExcludeDir(event.Name) {
if err := watcher.Add(event.Name); err != nil { if err := watcher.Add(event.Name); err != nil {
logger.Error("Error adding directory to watcher", "path", event.Name, "error", err) logging.Error("Error adding directory to watcher", "path", event.Name, "error", err)
} }
} }
} else { } else {
@@ -232,7 +230,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
// Debug logging // Debug logging
if cnf.Debug { if cnf.Debug {
matched, kind := w.isPathWatched(event.Name) matched, kind := w.isPathWatched(event.Name)
logger.Debug("File event", logging.Debug("File event",
"path", event.Name, "path", event.Name,
"operation", event.Op.String(), "operation", event.Op.String(),
"watched", matched, "watched", matched,
@@ -277,7 +275,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
if !ok { if !ok {
return return
} }
logger.Error("Error watching file", "error", err) logging.Error("Error watching file", "error", err)
} }
} }
} }
@@ -402,7 +400,7 @@ func matchesSimpleGlob(pattern, path string) bool {
// Fall back to simple matching for simpler patterns // Fall back to simple matching for simpler patterns
matched, err := filepath.Match(pattern, path) matched, err := filepath.Match(pattern, path)
if err != nil { if err != nil {
logger.Error("Error matching pattern", "pattern", pattern, "path", path, "error", err) logging.Error("Error matching pattern", "pattern", pattern, "path", path, "error", err)
return false return false
} }
@@ -413,7 +411,7 @@ func matchesSimpleGlob(pattern, path string) bool {
func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPattern) bool { func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPattern) bool {
patternInfo, err := pattern.AsPattern() patternInfo, err := pattern.AsPattern()
if err != nil { if err != nil {
logger.Error("Error parsing pattern", "pattern", pattern, "error", err) logging.Error("Error parsing pattern", "pattern", pattern, "error", err)
return false return false
} }
@@ -438,7 +436,7 @@ func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPatt
// Make path relative to basePath for matching // Make path relative to basePath for matching
relPath, err := filepath.Rel(basePath, path) relPath, err := filepath.Rel(basePath, path)
if err != nil { if err != nil {
logger.Error("Error getting relative path", "path", path, "basePath", basePath, "error", err) logging.Error("Error getting relative path", "path", path, "basePath", basePath, "error", err)
return false return false
} }
relPath = filepath.ToSlash(relPath) relPath = filepath.ToSlash(relPath)
@@ -479,14 +477,14 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan
if changeType == protocol.FileChangeType(protocol.Changed) && w.client.IsFileOpen(filePath) { if changeType == protocol.FileChangeType(protocol.Changed) && w.client.IsFileOpen(filePath) {
err := w.client.NotifyChange(ctx, filePath) err := w.client.NotifyChange(ctx, filePath)
if err != nil { if err != nil {
logger.Error("Error notifying change", "error", err) logging.Error("Error notifying change", "error", err)
} }
return return
} }
// Notify LSP server about the file event using didChangeWatchedFiles // Notify LSP server about the file event using didChangeWatchedFiles
if err := w.notifyFileEvent(ctx, uri, changeType); err != nil { if err := w.notifyFileEvent(ctx, uri, changeType); err != nil {
logger.Error("Error notifying LSP server about file event", "error", err) logging.Error("Error notifying LSP server about file event", "error", err)
} }
} }
@@ -494,7 +492,7 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan
func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error { func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error {
cnf := config.Get() cnf := config.Get()
if cnf.Debug { if cnf.Debug {
logger.Debug("Notifying file event", logging.Debug("Notifying file event",
"uri", uri, "uri", uri,
"changeType", changeType, "changeType", changeType,
) )
@@ -618,7 +616,7 @@ func shouldExcludeFile(filePath string) bool {
// Skip large files // Skip large files
if info.Size() > maxFileSize { if info.Size() > maxFileSize {
if cnf.Debug { if cnf.Debug {
logger.Debug("Skipping large file", logging.Debug("Skipping large file",
"path", filePath, "path", filePath,
"size", info.Size(), "size", info.Size(),
"maxSize", maxFileSize, "maxSize", maxFileSize,
@@ -651,7 +649,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
if watched, _ := w.isPathWatched(path); watched { if watched, _ := w.isPathWatched(path); watched {
// Don't need to check if it's already open - the client.OpenFile handles that // Don't need to check if it's already open - the client.OpenFile handles that
if err := w.client.OpenFile(ctx, path); err != nil && cnf.Debug { if err := w.client.OpenFile(ctx, path); err != nil && cnf.Debug {
logger.Error("Error opening file", "path", path, "error", err) logging.Error("Error opening file", "path", path, "error", err)
} }
} }
} }

View File

@@ -13,7 +13,7 @@ import (
) )
type statusCmp struct { type statusCmp struct {
info *util.InfoMsg info util.InfoMsg
width int width int
messageTTL time.Duration messageTTL time.Duration
} }
@@ -35,14 +35,14 @@ func (m statusCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.width = msg.Width m.width = msg.Width
return m, nil return m, nil
case util.InfoMsg: case util.InfoMsg:
m.info = &msg m.info = msg
ttl := msg.TTL ttl := msg.TTL
if ttl == 0 { if ttl == 0 {
ttl = m.messageTTL ttl = m.messageTTL
} }
return m, m.clearMessageCmd(ttl) return m, m.clearMessageCmd(ttl)
case util.ClearStatusMsg: case util.ClearStatusMsg:
m.info = nil m.info = util.InfoMsg{}
} }
return m, nil return m, nil
} }
@@ -54,7 +54,7 @@ var (
func (m statusCmp) View() string { func (m statusCmp) View() string {
status := styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help") status := styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help")
if m.info != nil { if m.info.Msg != "" {
infoStyle := styles.Padded. infoStyle := styles.Padded.
Foreground(styles.Base). Foreground(styles.Base).
Width(m.availableFooterMsgWidth()) Width(m.availableFooterMsgWidth())

View File

@@ -30,7 +30,7 @@ type detailCmp struct {
} }
func (i *detailCmp) Init() tea.Cmd { func (i *detailCmp) Init() tea.Cmd {
messages := logging.Get().List() messages := logging.List()
if len(messages) == 0 { if len(messages) == 0 {
return nil return nil
} }

View File

@@ -22,8 +22,6 @@ type TableComponent interface {
layout.Bordered layout.Bordered
} }
var logger = logging.Get()
type tableCmp struct { type tableCmp struct {
table table.Model table table.Model
} }
@@ -57,7 +55,7 @@ func (i *tableCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if selectedRow != nil { if selectedRow != nil {
if prevSelectedRow == nil || selectedRow[0] == prevSelectedRow[0] { if prevSelectedRow == nil || selectedRow[0] == prevSelectedRow[0] {
var log logging.LogMessage var log logging.LogMessage
for _, row := range logging.Get().List() { for _, row := range logging.List() {
if row.ID == selectedRow[0] { if row.ID == selectedRow[0] {
log = row log = row
break break
@@ -112,7 +110,7 @@ func (i *tableCmp) BindingKeys() []key.Binding {
func (i *tableCmp) setRows() { func (i *tableCmp) setRows() {
rows := []table.Row{} rows := []table.Row{}
logs := logger.List() logs := logging.List()
slices.SortFunc(logs, func(a, b logging.LogMessage) int { slices.SortFunc(logs, func(a, b logging.LogMessage) int {
if a.Time.Before(b.Time) { if a.Time.Before(b.Time) {
return 1 return 1

View File

@@ -12,6 +12,7 @@ import (
"github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/styles"
"github.com/kujtimiihoxha/termai/internal/tui/util" "github.com/kujtimiihoxha/termai/internal/tui/util"
"github.com/kujtimiihoxha/vimtea" "github.com/kujtimiihoxha/vimtea"
"golang.org/x/net/context"
) )
type EditorCmp interface { type EditorCmp interface {
@@ -30,11 +31,13 @@ type editorCmp struct {
focused bool focused bool
width int width int
height int height int
cancelMessage context.CancelFunc
} }
type editorKeyMap struct { type editorKeyMap struct {
SendMessage key.Binding SendMessage key.Binding
SendMessageI key.Binding SendMessageI key.Binding
CancelMessage key.Binding
InsertMode key.Binding InsertMode key.Binding
NormaMode key.Binding NormaMode key.Binding
VisualMode key.Binding VisualMode key.Binding
@@ -50,6 +53,10 @@ var editorKeyMapValue = editorKeyMap{
key.WithKeys("ctrl+s"), key.WithKeys("ctrl+s"),
key.WithHelp("ctrl+s", "send message insert mode"), key.WithHelp("ctrl+s", "send message insert mode"),
), ),
CancelMessage: key.NewBinding(
key.WithKeys("ctrl+x"),
key.WithHelp("ctrl+x", "cancel current message"),
),
InsertMode: key.NewBinding( InsertMode: key.NewBinding(
key.WithKeys("i"), key.WithKeys("i"),
key.WithHelp("i", "insert mode"), key.WithHelp("i", "insert mode"),
@@ -93,6 +100,8 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.editorMode == vimtea.ModeInsert { if m.editorMode == vimtea.ModeInsert {
return m, m.Send() return m, m.Send()
} }
case key.Matches(msg, editorKeyMapValue.CancelMessage):
return m, m.Cancel()
} }
} }
u, cmd := m.editor.Update(msg) u, cmd := m.editor.Update(msg)
@@ -136,6 +145,16 @@ func (m *editorCmp) SetSize(width int, height int) {
m.editor.SetSize(width, height) m.editor.SetSize(width, height)
} }
func (m *editorCmp) Cancel() tea.Cmd {
if m.cancelMessage == nil {
return util.ReportWarn("No message to cancel")
}
m.cancelMessage()
m.cancelMessage = nil
return util.ReportWarn("Message cancelled")
}
func (m *editorCmp) Send() tea.Cmd { func (m *editorCmp) Send() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
messages, err := m.app.Messages.List(m.sessionID) messages, err := m.app.Messages.List(m.sessionID)
@@ -151,7 +170,13 @@ func (m *editorCmp) Send() tea.Cmd {
} }
content := strings.Join(m.editor.GetBuffer().Lines(), "\n") content := strings.Join(m.editor.GetBuffer().Lines(), "\n")
go a.Generate(m.sessionID, content) ctx, cancel := context.WithCancel(m.app.Context)
m.cancelMessage = cancel
go func() {
defer cancel()
a.Generate(ctx, m.sessionID, content)
m.cancelMessage = nil
}()
return m.editor.Reset() return m.editor.Reset()
} }

View File

@@ -309,7 +309,7 @@ func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.
} }
for _, msg := range futureMessages { for _, msg := range futureMessages {
if msg.Content().String() != "" { if msg.Content().String() != "" || msg.FinishReason() == "canceled" {
break break
} }
@@ -345,13 +345,18 @@ func (m *messagesCmp) renderView() {
prevMessageWasUser := false prevMessageWasUser := false
for inx, msg := range m.messages { for inx, msg := range m.messages {
content := msg.Content().String() content := msg.Content().String()
if content != "" || prevMessageWasUser { if content != "" || prevMessageWasUser || msg.FinishReason() == "canceled" {
if msg.ReasoningContent().String() != "" && content == "" { if msg.ReasoningContent().String() != "" && content == "" {
content = msg.ReasoningContent().String() content = msg.ReasoningContent().String()
} else if content == "" { } else if content == "" {
content = "..." content = "..."
} }
if msg.FinishReason() == "canceled" {
content, _ = r.Render(content) content, _ = r.Render(content)
content += lipgloss.NewStyle().Padding(1, 0, 0, 1).Foreground(styles.Error).Render(styles.ErrorIcon + " Canceled")
} else {
content, _ = r.Render(content)
}
isSelected := inx == m.selectedMsgIdx isSelected := inx == m.selectedMsgIdx

View File

@@ -101,7 +101,8 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Status // Status
case util.InfoMsg: case util.InfoMsg:
a.status, cmd = a.status.Update(msg) a.status, cmd = a.status.Update(msg)
return a, cmd cmds = append(cmds, cmd)
return a, tea.Batch(cmds...)
case pubsub.Event[logging.LogMessage]: case pubsub.Event[logging.LogMessage]:
if msg.Payload.Persist { if msg.Payload.Persist {
switch msg.Payload.Level { switch msg.Payload.Level {