rework llm

This commit is contained in:
Kujtim Hoxha
2025-03-27 22:35:48 +01:00
parent 904061c243
commit afd9ad0560
61 changed files with 5882 additions and 2074 deletions

View File

@@ -4,13 +4,12 @@ import (
"context"
"database/sql"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/db"
"github.com/kujtimiihoxha/termai/internal/llm"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/kujtimiihoxha/termai/internal/permission"
"github.com/kujtimiihoxha/termai/internal/session"
"github.com/spf13/viper"
)
type App struct {
@@ -19,7 +18,6 @@ type App struct {
Sessions session.Service
Messages message.Service
Permissions permission.Service
LLM llm.Service
Logger logging.Interface
}
@@ -27,18 +25,17 @@ type App struct {
func New(ctx context.Context, conn *sql.DB) *App {
q := db.New(conn)
log := logging.NewLogger(logging.Options{
Level: viper.GetString("log.level"),
Level: config.Get().Log.Level,
})
sessions := session.NewService(ctx, q)
messages := message.NewService(ctx, q)
llm := llm.NewService(ctx, log, sessions, messages)
return &App{
Context: ctx,
Sessions: sessions,
Messages: messages,
Permissions: permission.Default,
LLM: llm,
Logger: log,
}
}

180
internal/config/config.go Normal file
View File

@@ -0,0 +1,180 @@
package config
import (
"fmt"
"os"
"strings"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/spf13/viper"
)
type MCPType string
const (
MCPStdio MCPType = "stdio"
MCPSse MCPType = "sse"
)
type MCPServer struct {
Command string `json:"command"`
Env []string `json:"env"`
Args []string `json:"args"`
Type MCPType `json:"type"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
// TODO: add permissions configuration
// TODO: add the ability to specify the tools to import
}
type Model struct {
Coder models.ModelID `json:"coder"`
CoderMaxTokens int64 `json:"coderMaxTokens"`
Task models.ModelID `json:"task"`
TaskMaxTokens int64 `json:"taskMaxTokens"`
// TODO: Maybe support multiple models for different purposes
}
type Provider struct {
APIKey string `json:"apiKey"`
Enabled bool `json:"enabled"`
}
type Data struct {
Directory string `json:"directory"`
}
type Log struct {
Level string `json:"level"`
}
type Config struct {
Data *Data `json:"data,omitempty"`
Log *Log `json:"log,omitempty"`
MCPServers map[string]MCPServer `json:"mcpServers,omitempty"`
Providers map[models.ModelProvider]Provider `json:"providers,omitempty"`
Model *Model `json:"model,omitempty"`
}
var cfg *Config
const (
defaultDataDirectory = ".termai"
defaultLogLevel = "info"
defaultMaxTokens = int64(5000)
termai = "termai"
)
func Load(debug bool) error {
if cfg != nil {
return nil
}
viper.SetConfigName(fmt.Sprintf(".%s", termai))
viper.SetConfigType("json")
viper.AddConfigPath("$HOME")
viper.AddConfigPath(fmt.Sprintf("$XDG_CONFIG_HOME/%s", termai))
viper.SetEnvPrefix(strings.ToUpper(termai))
// Add defaults
viper.SetDefault("data.directory", defaultDataDirectory)
if debug {
viper.Set("log.level", "debug")
} else {
viper.SetDefault("log.level", defaultLogLevel)
}
defaultModelSet := false
if os.Getenv("ANTHROPIC_API_KEY") != "" {
viper.SetDefault("providers.anthropic.apiKey", os.Getenv("ANTHROPIC_API_KEY"))
viper.SetDefault("providers.anthropic.enabled", true)
viper.SetDefault("model.coder", models.Claude37Sonnet)
viper.SetDefault("model.task", models.Claude37Sonnet)
defaultModelSet = true
}
if os.Getenv("OPENAI_API_KEY") != "" {
viper.SetDefault("providers.openai.apiKey", os.Getenv("OPENAI_API_KEY"))
viper.SetDefault("providers.openai.enabled", true)
if !defaultModelSet {
viper.SetDefault("model.coder", models.GPT4o)
viper.SetDefault("model.task", models.GPT4o)
defaultModelSet = true
}
}
if os.Getenv("GEMINI_API_KEY") != "" {
viper.SetDefault("providers.gemini.apiKey", os.Getenv("GEMINI_API_KEY"))
viper.SetDefault("providers.gemini.enabled", true)
if !defaultModelSet {
viper.SetDefault("model.coder", models.GRMINI20Flash)
viper.SetDefault("model.task", models.GRMINI20Flash)
defaultModelSet = true
}
}
if os.Getenv("GROQ_API_KEY") != "" {
viper.SetDefault("providers.groq.apiKey", os.Getenv("GROQ_API_KEY"))
viper.SetDefault("providers.groq.enabled", true)
if !defaultModelSet {
viper.SetDefault("model.coder", models.QWENQwq)
viper.SetDefault("model.task", models.QWENQwq)
defaultModelSet = true
}
}
// TODO: add more providers
cfg = &Config{}
err := viper.ReadInConfig()
if err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return err
}
}
local := viper.New()
local.SetConfigName(fmt.Sprintf(".%s", termai))
local.SetConfigType("json")
local.AddConfigPath(".")
// load local config, this will override the global config
if err = local.ReadInConfig(); err == nil {
viper.MergeConfigMap(local.AllSettings())
}
viper.Unmarshal(cfg)
if cfg.Model != nil && cfg.Model.CoderMaxTokens <= 0 {
cfg.Model.CoderMaxTokens = defaultMaxTokens
}
if cfg.Model != nil && cfg.Model.TaskMaxTokens <= 0 {
cfg.Model.TaskMaxTokens = defaultMaxTokens
}
for _, v := range cfg.MCPServers {
if v.Type == "" {
v.Type = MCPStdio
}
}
workdir, err := os.Getwd()
if err != nil {
return err
}
viper.Set("wd", workdir)
return nil
}
func Get() *Config {
if cfg == nil {
err := Load(false)
if err != nil {
panic(err)
}
}
return cfg
}
func WorkingDirectory() string {
return viper.GetString("wd")
}
func Write() error {
return viper.WriteConfig()
}

View File

@@ -0,0 +1,465 @@
package config
import (
"fmt"
"os"
"path/filepath"
"testing"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLoad(t *testing.T) {
setupTest(t)
t.Run("loads configuration successfully", func(t *testing.T) {
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
configPath := filepath.Join(homeDir, ".termai.json")
configContent := `{
"data": {
"directory": "custom-dir"
},
"log": {
"level": "debug"
},
"mcpServers": {
"test-server": {
"command": "test-command",
"env": ["TEST_ENV=value"],
"args": ["--arg1", "--arg2"],
"type": "stdio",
"url": "",
"headers": {}
},
"sse-server": {
"command": "",
"env": [],
"args": [],
"type": "sse",
"url": "https://api.example.com/events",
"headers": {
"Authorization": "Bearer token123",
"Content-Type": "application/json"
}
}
},
"providers": {
"anthropic": {
"apiKey": "test-api-key",
"enabled": true
}
},
"model": {
"coder": "claude-3-haiku",
"task": "claude-3-haiku"
}
}`
err := os.WriteFile(configPath, []byte(configContent), 0o644)
require.NoError(t, err)
cfg = nil
viper.Reset()
err = Load(false)
require.NoError(t, err)
config := Get()
assert.NotNil(t, config)
assert.Equal(t, "custom-dir", config.Data.Directory)
assert.Equal(t, "debug", config.Log.Level)
assert.Contains(t, config.MCPServers, "test-server")
stdioServer := config.MCPServers["test-server"]
assert.Equal(t, "test-command", stdioServer.Command)
assert.Equal(t, []string{"TEST_ENV=value"}, stdioServer.Env)
assert.Equal(t, []string{"--arg1", "--arg2"}, stdioServer.Args)
assert.Equal(t, MCPStdio, stdioServer.Type)
assert.Equal(t, "", stdioServer.URL)
assert.Empty(t, stdioServer.Headers)
assert.Contains(t, config.MCPServers, "sse-server")
sseServer := config.MCPServers["sse-server"]
assert.Equal(t, "", sseServer.Command)
assert.Empty(t, sseServer.Env)
assert.Empty(t, sseServer.Args)
assert.Equal(t, MCPSse, sseServer.Type)
assert.Equal(t, "https://api.example.com/events", sseServer.URL)
assert.Equal(t, map[string]string{
"authorization": "Bearer token123",
"content-type": "application/json",
}, sseServer.Headers)
assert.Contains(t, config.Providers, models.ModelProvider("anthropic"))
provider := config.Providers[models.ModelProvider("anthropic")]
assert.Equal(t, "test-api-key", provider.APIKey)
assert.True(t, provider.Enabled)
assert.NotNil(t, config.Model)
assert.Equal(t, models.Claude3Haiku, config.Model.Coder)
assert.Equal(t, models.Claude3Haiku, config.Model.Task)
assert.Equal(t, defaultMaxTokens, config.Model.CoderMaxTokens)
})
t.Run("loads configuration with environment variables", func(t *testing.T) {
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
configPath := filepath.Join(homeDir, ".termai.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
t.Setenv("ANTHROPIC_API_KEY", "env-anthropic-key")
t.Setenv("OPENAI_API_KEY", "env-openai-key")
t.Setenv("GEMINI_API_KEY", "env-gemini-key")
cfg = nil
viper.Reset()
err = Load(false)
require.NoError(t, err)
config := Get()
assert.NotNil(t, config)
assert.Equal(t, defaultDataDirectory, config.Data.Directory)
assert.Equal(t, defaultLogLevel, config.Log.Level)
assert.Contains(t, config.Providers, models.ModelProvider("anthropic"))
assert.Equal(t, "env-anthropic-key", config.Providers[models.ModelProvider("anthropic")].APIKey)
assert.True(t, config.Providers[models.ModelProvider("anthropic")].Enabled)
assert.Contains(t, config.Providers, models.ModelProvider("openai"))
assert.Equal(t, "env-openai-key", config.Providers[models.ModelProvider("openai")].APIKey)
assert.True(t, config.Providers[models.ModelProvider("openai")].Enabled)
assert.Contains(t, config.Providers, models.ModelProvider("gemini"))
assert.Equal(t, "env-gemini-key", config.Providers[models.ModelProvider("gemini")].APIKey)
assert.True(t, config.Providers[models.ModelProvider("gemini")].Enabled)
assert.Equal(t, models.Claude37Sonnet, config.Model.Coder)
})
t.Run("local config overrides global config", func(t *testing.T) {
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
globalConfigPath := filepath.Join(homeDir, ".termai.json")
globalConfig := `{
"data": {
"directory": "global-dir"
},
"log": {
"level": "info"
}
}`
err := os.WriteFile(globalConfigPath, []byte(globalConfig), 0o644)
require.NoError(t, err)
workDir := t.TempDir()
origDir, err := os.Getwd()
require.NoError(t, err)
defer os.Chdir(origDir)
err = os.Chdir(workDir)
require.NoError(t, err)
localConfigPath := filepath.Join(workDir, ".termai.json")
localConfig := `{
"data": {
"directory": "local-dir"
},
"log": {
"level": "debug"
}
}`
err = os.WriteFile(localConfigPath, []byte(localConfig), 0o644)
require.NoError(t, err)
cfg = nil
viper.Reset()
err = Load(false)
require.NoError(t, err)
config := Get()
assert.NotNil(t, config)
assert.Equal(t, "local-dir", config.Data.Directory)
assert.Equal(t, "debug", config.Log.Level)
})
t.Run("missing config file should not return error", func(t *testing.T) {
emptyDir := t.TempDir()
t.Setenv("HOME", emptyDir)
cfg = nil
viper.Reset()
err := Load(false)
assert.NoError(t, err)
})
t.Run("model priority and fallbacks", func(t *testing.T) {
testCases := []struct {
name string
anthropicKey string
openaiKey string
geminiKey string
expectedModel models.ModelID
explicitModel models.ModelID
useExplicitModel bool
}{
{
name: "anthropic has priority",
anthropicKey: "test-key",
openaiKey: "test-key",
geminiKey: "test-key",
expectedModel: models.Claude37Sonnet,
},
{
name: "fallback to openai when no anthropic",
anthropicKey: "",
openaiKey: "test-key",
geminiKey: "test-key",
expectedModel: models.GPT4o,
},
{
name: "fallback to gemini when no others",
anthropicKey: "",
openaiKey: "",
geminiKey: "test-key",
expectedModel: models.GRMINI20Flash,
},
{
name: "explicit model overrides defaults",
anthropicKey: "test-key",
openaiKey: "test-key",
geminiKey: "test-key",
explicitModel: models.GPT4o,
useExplicitModel: true,
expectedModel: models.GPT4o,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
configPath := filepath.Join(homeDir, ".termai.json")
configContent := "{}"
if tc.useExplicitModel {
configContent = fmt.Sprintf(`{"model":{"coder":"%s"}}`, tc.explicitModel)
}
err := os.WriteFile(configPath, []byte(configContent), 0o644)
require.NoError(t, err)
if tc.anthropicKey != "" {
t.Setenv("ANTHROPIC_API_KEY", tc.anthropicKey)
} else {
t.Setenv("ANTHROPIC_API_KEY", "")
}
if tc.openaiKey != "" {
t.Setenv("OPENAI_API_KEY", tc.openaiKey)
} else {
t.Setenv("OPENAI_API_KEY", "")
}
if tc.geminiKey != "" {
t.Setenv("GEMINI_API_KEY", tc.geminiKey)
} else {
t.Setenv("GEMINI_API_KEY", "")
}
cfg = nil
viper.Reset()
err = Load(false)
require.NoError(t, err)
config := Get()
assert.NotNil(t, config)
assert.Equal(t, tc.expectedModel, config.Model.Coder)
})
}
})
}
func TestGet(t *testing.T) {
t.Run("get returns same config instance", func(t *testing.T) {
setupTest(t)
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
configPath := filepath.Join(homeDir, ".termai.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
cfg = nil
viper.Reset()
config1 := Get()
require.NotNil(t, config1)
config2 := Get()
require.NotNil(t, config2)
assert.Same(t, config1, config2)
})
t.Run("get loads config if not loaded", func(t *testing.T) {
setupTest(t)
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
configPath := filepath.Join(homeDir, ".termai.json")
configContent := `{"data":{"directory":"test-dir"}}`
err := os.WriteFile(configPath, []byte(configContent), 0o644)
require.NoError(t, err)
cfg = nil
viper.Reset()
config := Get()
require.NotNil(t, config)
assert.Equal(t, "test-dir", config.Data.Directory)
})
}
func TestWorkingDirectory(t *testing.T) {
t.Run("returns current working directory", func(t *testing.T) {
setupTest(t)
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
configPath := filepath.Join(homeDir, ".termai.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
cfg = nil
viper.Reset()
err = Load(false)
require.NoError(t, err)
wd := WorkingDirectory()
expectedWd, err := os.Getwd()
require.NoError(t, err)
assert.Equal(t, expectedWd, wd)
})
}
func TestWrite(t *testing.T) {
t.Run("writes config to file", func(t *testing.T) {
setupTest(t)
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
configPath := filepath.Join(homeDir, ".termai.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
cfg = nil
viper.Reset()
err = Load(false)
require.NoError(t, err)
viper.Set("data.directory", "modified-dir")
err = Write()
require.NoError(t, err)
content, err := os.ReadFile(configPath)
require.NoError(t, err)
assert.Contains(t, string(content), "modified-dir")
})
}
func TestMCPType(t *testing.T) {
t.Run("MCPType constants", func(t *testing.T) {
assert.Equal(t, MCPType("stdio"), MCPStdio)
assert.Equal(t, MCPType("sse"), MCPSse)
})
t.Run("MCPType JSON unmarshaling", func(t *testing.T) {
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
configPath := filepath.Join(homeDir, ".termai.json")
configContent := `{
"mcpServers": {
"stdio-server": {
"type": "stdio"
},
"sse-server": {
"type": "sse"
},
"invalid-server": {
"type": "invalid"
}
}
}`
err := os.WriteFile(configPath, []byte(configContent), 0o644)
require.NoError(t, err)
cfg = nil
viper.Reset()
err = Load(false)
require.NoError(t, err)
config := Get()
assert.NotNil(t, config)
assert.Equal(t, MCPStdio, config.MCPServers["stdio-server"].Type)
assert.Equal(t, MCPSse, config.MCPServers["sse-server"].Type)
assert.Equal(t, MCPType("invalid"), config.MCPServers["invalid-server"].Type)
})
t.Run("default MCPType", func(t *testing.T) {
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
configPath := filepath.Join(homeDir, ".termai.json")
configContent := `{
"mcpServers": {
"test-server": {
"command": "test-command"
}
}
}`
err := os.WriteFile(configPath, []byte(configContent), 0o644)
require.NoError(t, err)
cfg = nil
viper.Reset()
err = Load(false)
require.NoError(t, err)
config := Get()
assert.NotNil(t, config)
assert.Equal(t, MCPType(""), config.MCPServers["test-server"].Type)
})
}
func setupTest(t *testing.T) {
origHome := os.Getenv("HOME")
origXdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
origAnthropicKey := os.Getenv("ANTHROPIC_API_KEY")
origOpenAIKey := os.Getenv("OPENAI_API_KEY")
origGeminiKey := os.Getenv("GEMINI_API_KEY")
t.Cleanup(func() {
t.Setenv("HOME", origHome)
t.Setenv("XDG_CONFIG_HOME", origXdgConfigHome)
t.Setenv("ANTHROPIC_API_KEY", origAnthropicKey)
t.Setenv("OPENAI_API_KEY", origOpenAIKey)
t.Setenv("GEMINI_API_KEY", origGeminiKey)
cfg = nil
viper.Reset()
})
}

View File

@@ -12,14 +12,14 @@ import (
"github.com/golang-migrate/migrate/v4/database/sqlite3"
_ "github.com/mattn/go-sqlite3"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/spf13/viper"
)
var log = logging.Get()
func Connect() (*sql.DB, error) {
dataDir := viper.GetString("data.dir")
dataDir := config.Get().Data.Directory
if dataDir == "" {
return nil, fmt.Errorf("data.dir is not set")
}

View File

@@ -51,6 +51,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.listSessionsStmt, err = db.PrepareContext(ctx, listSessions); err != nil {
return nil, fmt.Errorf("error preparing query ListSessions: %w", err)
}
if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil {
return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err)
}
if q.updateSessionStmt, err = db.PrepareContext(ctx, updateSession); err != nil {
return nil, fmt.Errorf("error preparing query UpdateSession: %w", err)
}
@@ -104,6 +107,11 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing listSessionsStmt: %w", cerr)
}
}
if q.updateMessageStmt != nil {
if cerr := q.updateMessageStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing updateMessageStmt: %w", cerr)
}
}
if q.updateSessionStmt != nil {
if cerr := q.updateSessionStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing updateSessionStmt: %w", cerr)
@@ -157,6 +165,7 @@ type Queries struct {
getSessionByIDStmt *sql.Stmt
listMessagesBySessionStmt *sql.Stmt
listSessionsStmt *sql.Stmt
updateMessageStmt *sql.Stmt
updateSessionStmt *sql.Stmt
}
@@ -173,6 +182,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
getSessionByIDStmt: q.getSessionByIDStmt,
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
listSessionsStmt: q.listSessionsStmt,
updateMessageStmt: q.updateMessageStmt,
updateSessionStmt: q.updateSessionStmt,
}
}

View File

@@ -7,34 +7,56 @@ package db
import (
"context"
"database/sql"
)
const createMessage = `-- name: CreateMessage :one
INSERT INTO messages (
id,
session_id,
message_data,
role,
finished,
content,
tool_calls,
tool_results,
created_at,
updated_at
) VALUES (
?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
)
RETURNING id, session_id, message_data, created_at, updated_at
RETURNING id, session_id, role, content, thinking, finished, tool_calls, tool_results, created_at, updated_at
`
type CreateMessageParams struct {
ID string `json:"id"`
SessionID string `json:"session_id"`
MessageData string `json:"message_data"`
ID string `json:"id"`
SessionID string `json:"session_id"`
Role string `json:"role"`
Finished bool `json:"finished"`
Content string `json:"content"`
ToolCalls sql.NullString `json:"tool_calls"`
ToolResults sql.NullString `json:"tool_results"`
}
func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) {
row := q.queryRow(ctx, q.createMessageStmt, createMessage, arg.ID, arg.SessionID, arg.MessageData)
row := q.queryRow(ctx, q.createMessageStmt, createMessage,
arg.ID,
arg.SessionID,
arg.Role,
arg.Finished,
arg.Content,
arg.ToolCalls,
arg.ToolResults,
)
var i Message
err := row.Scan(
&i.ID,
&i.SessionID,
&i.MessageData,
&i.Role,
&i.Content,
&i.Thinking,
&i.Finished,
&i.ToolCalls,
&i.ToolResults,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -62,7 +84,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e
}
const getMessage = `-- name: GetMessage :one
SELECT id, session_id, message_data, created_at, updated_at
SELECT id, session_id, role, content, thinking, finished, tool_calls, tool_results, created_at, updated_at
FROM messages
WHERE id = ? LIMIT 1
`
@@ -73,7 +95,12 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
err := row.Scan(
&i.ID,
&i.SessionID,
&i.MessageData,
&i.Role,
&i.Content,
&i.Thinking,
&i.Finished,
&i.ToolCalls,
&i.ToolResults,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -81,7 +108,7 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
}
const listMessagesBySession = `-- name: ListMessagesBySession :many
SELECT id, session_id, message_data, created_at, updated_at
SELECT id, session_id, role, content, thinking, finished, tool_calls, tool_results, created_at, updated_at
FROM messages
WHERE session_id = ?
ORDER BY created_at ASC
@@ -99,7 +126,12 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
if err := rows.Scan(
&i.ID,
&i.SessionID,
&i.MessageData,
&i.Role,
&i.Content,
&i.Thinking,
&i.Finished,
&i.ToolCalls,
&i.ToolResults,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -115,3 +147,36 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
}
return items, nil
}
const updateMessage = `-- name: UpdateMessage :exec
UPDATE messages
SET
content = ?,
thinking = ?,
tool_calls = ?,
tool_results = ?,
finished = ?,
updated_at = strftime('%s', 'now')
WHERE id = ?
`
type UpdateMessageParams struct {
Content string `json:"content"`
Thinking string `json:"thinking"`
ToolCalls sql.NullString `json:"tool_calls"`
ToolResults sql.NullString `json:"tool_results"`
Finished bool `json:"finished"`
ID string `json:"id"`
}
func (q *Queries) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error {
_, err := q.exec(ctx, q.updateMessageStmt, updateMessage,
arg.Content,
arg.Thinking,
arg.ToolCalls,
arg.ToolResults,
arg.Finished,
arg.ID,
)
return err
}

View File

@@ -1,6 +1,7 @@
-- Sessions
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
parent_session_id TEXT,
title TEXT NOT NULL,
message_count INTEGER NOT NULL DEFAULT 0 CHECK (message_count >= 0),
prompt_tokens INTEGER NOT NULL DEFAULT 0 CHECK (prompt_tokens >= 0),
@@ -21,7 +22,12 @@ END;
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
message_data TEXT NOT NULL, -- JSON string of message content
role TEXT NOT NULL,
content TEXT NOT NULL,
thinking Text NOT NULL DEFAULT '',
finished BOOLEAN NOT NULL DEFAULT 0,
tool_calls TEXT,
tool_results TEXT,
created_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE

View File

@@ -4,21 +4,31 @@
package db
import (
"database/sql"
)
type Message struct {
ID string `json:"id"`
SessionID string `json:"session_id"`
MessageData string `json:"message_data"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
ID string `json:"id"`
SessionID string `json:"session_id"`
Role string `json:"role"`
Content string `json:"content"`
Thinking string `json:"thinking"`
Finished bool `json:"finished"`
ToolCalls sql.NullString `json:"tool_calls"`
ToolResults sql.NullString `json:"tool_results"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
type Session struct {
ID string `json:"id"`
Title string `json:"title"`
MessageCount int64 `json:"message_count"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
Cost float64 `json:"cost"`
UpdatedAt int64 `json:"updated_at"`
CreatedAt int64 `json:"created_at"`
ID string `json:"id"`
ParentSessionID sql.NullString `json:"parent_session_id"`
Title string `json:"title"`
MessageCount int64 `json:"message_count"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
Cost float64 `json:"cost"`
UpdatedAt int64 `json:"updated_at"`
CreatedAt int64 `json:"created_at"`
}

View File

@@ -18,6 +18,7 @@ type Querier interface {
GetSessionByID(ctx context.Context, id string) (Session, error)
ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
ListSessions(ctx context.Context) ([]Session, error)
UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
}

View File

@@ -7,11 +7,13 @@ package db
import (
"context"
"database/sql"
)
const createSession = `-- name: CreateSession :one
INSERT INTO sessions (
id,
parent_session_id,
title,
message_count,
prompt_tokens,
@@ -26,23 +28,26 @@ INSERT INTO sessions (
?,
?,
?,
?,
strftime('%s', 'now'),
strftime('%s', 'now')
) RETURNING id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
`
type CreateSessionParams struct {
ID string `json:"id"`
Title string `json:"title"`
MessageCount int64 `json:"message_count"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
Cost float64 `json:"cost"`
ID string `json:"id"`
ParentSessionID sql.NullString `json:"parent_session_id"`
Title string `json:"title"`
MessageCount int64 `json:"message_count"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
Cost float64 `json:"cost"`
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
row := q.queryRow(ctx, q.createSessionStmt, createSession,
arg.ID,
arg.ParentSessionID,
arg.Title,
arg.MessageCount,
arg.PromptTokens,
@@ -52,6 +57,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
var i Session
err := row.Scan(
&i.ID,
&i.ParentSessionID,
&i.Title,
&i.MessageCount,
&i.PromptTokens,
@@ -74,7 +80,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error {
}
const getSessionByID = `-- name: GetSessionByID :one
SELECT id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
FROM sessions
WHERE id = ? LIMIT 1
`
@@ -84,6 +90,7 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
var i Session
err := row.Scan(
&i.ID,
&i.ParentSessionID,
&i.Title,
&i.MessageCount,
&i.PromptTokens,
@@ -96,8 +103,9 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
}
const listSessions = `-- name: ListSessions :many
SELECT id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
FROM sessions
WHERE parent_session_id is NULL
ORDER BY created_at DESC
`
@@ -112,6 +120,7 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
var i Session
if err := rows.Scan(
&i.ID,
&i.ParentSessionID,
&i.Title,
&i.MessageCount,
&i.PromptTokens,
@@ -141,7 +150,7 @@ SET
completion_tokens = ?,
cost = ?
WHERE id = ?
RETURNING id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
`
type UpdateSessionParams struct {
@@ -163,6 +172,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
var i Session
err := row.Scan(
&i.ID,
&i.ParentSessionID,
&i.Title,
&i.MessageCount,
&i.PromptTokens,

View File

@@ -13,14 +13,29 @@ ORDER BY created_at ASC;
INSERT INTO messages (
id,
session_id,
message_data,
role,
finished,
content,
tool_calls,
tool_results,
created_at,
updated_at
) VALUES (
?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
)
RETURNING *;
-- name: UpdateMessage :exec
UPDATE messages
SET
content = ?,
thinking = ?,
tool_calls = ?,
tool_results = ?,
finished = ?,
updated_at = strftime('%s', 'now')
WHERE id = ?;
-- name: DeleteMessage :exec
DELETE FROM messages
WHERE id = ?;

View File

@@ -1,6 +1,7 @@
-- name: CreateSession :one
INSERT INTO sessions (
id,
parent_session_id,
title,
message_count,
prompt_tokens,
@@ -15,6 +16,7 @@ INSERT INTO sessions (
?,
?,
?,
?,
strftime('%s', 'now'),
strftime('%s', 'now')
) RETURNING *;
@@ -27,6 +29,7 @@ WHERE id = ? LIMIT 1;
-- name: ListSessions :many
SELECT *
FROM sessions
WHERE parent_session_id is NULL
ORDER BY created_at DESC;
-- name: UpdateSession :one

View File

@@ -0,0 +1,102 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
type agentTool struct {
parentSessionID string
app *app.App
}
const (
AgentToolName = "agent"
)
type AgentParams struct {
Prompt string `json:"prompt"`
}
func (b *agentTool) Info() tools.ToolInfo {
return tools.ToolInfo{
Name: AgentToolName,
Description: "Launch a new agent that has access to the following tools: GlobTool, GrepTool, LS, View. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you. For example:\n\n- If you are searching for a keyword like \"config\" or \"logger\", or for questions like \"which file does X?\", the Agent tool is strongly recommended\n- If you want to read a specific file path, use the View or GlobTool tool instead of the Agent tool, to find the match more quickly\n- If you are searching for a specific class definition like \"class Foo\", use the GlobTool tool instead, to find the match more quickly\n\nUsage notes:\n1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses\n2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.\n3. Each agent invocation is stateless. You will not be able to send additional messages to the agent, nor will the agent be able to communicate with you outside of its final report. Therefore, your prompt should contain a highly detailed task description for the agent to perform autonomously and you should specify exactly what information the agent should return back to you in its final and only message to you.\n4. The agent's outputs should generally be trusted\n5. IMPORTANT: The agent can not use Bash, Replace, Edit, so can not modify files. If you want to use these tools, use them directly instead of going through the agent.",
Parameters: map[string]any{
"prompt": map[string]any{
"type": "string",
"description": "The task for the agent to perform",
},
},
Required: []string{"prompt"},
}
}
func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
var params AgentParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
if params.Prompt == "" {
return tools.NewTextErrorResponse("prompt is required"), nil
}
agent, err := NewTaskAgent(b.app)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil
}
session, err := b.app.Sessions.CreateTaskSession(call.ID, b.parentSessionID, "New Agent Session")
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil
}
err = agent.Generate(session.ID, params.Prompt)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil
}
messages, err := b.app.Messages.List(session.ID)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil
}
if len(messages) == 0 {
return tools.NewTextErrorResponse("no messages found"), nil
}
response := messages[len(messages)-1]
if response.Role != message.Assistant {
return tools.NewTextErrorResponse("no assistant message found"), nil
}
updatedSession, err := b.app.Sessions.Get(session.ID)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
}
parentSession, err := b.app.Sessions.Get(b.parentSessionID)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
}
parentSession.Cost += updatedSession.Cost
parentSession.PromptTokens += updatedSession.PromptTokens
parentSession.CompletionTokens += updatedSession.CompletionTokens
_, err = b.app.Sessions.Save(parentSession)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
}
return tools.NewTextResponse(response.Content), nil
}
func NewAgentTool(parentSessionID string, app *app.App) tools.BaseTool {
return &agentTool{
parentSessionID: parentSessionID,
app: app,
}
}

View File

@@ -2,16 +2,353 @@ package agent
import (
"context"
"errors"
"fmt"
"log"
"sync"
"github.com/cloudwego/eino/flow/agent/react"
"github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/prompt"
"github.com/kujtimiihoxha/termai/internal/llm/provider"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
func GetAgent(ctx context.Context, name string) (*react.Agent, string, error) {
switch name {
case "coder":
agent, err := NewCoderAgent(ctx)
return agent, CoderSystemPrompt(), err
}
return nil, "", fmt.Errorf("agent %s not found", name)
type Agent interface {
Generate(sessionID string, content string) error
}
type agent struct {
*app.App
model models.Model
tools []tools.BaseTool
agent provider.Provider
titleGenerator provider.Provider
}
func (c *agent) handleTitleGeneration(sessionID, content string) {
response, err := c.titleGenerator.SendMessages(
c.Context,
[]message.Message{
{
Role: message.User,
Content: content,
},
},
nil,
)
if err != nil {
return
}
session, err := c.Sessions.Get(sessionID)
if err != nil {
return
}
if response.Content != "" {
session.Title = response.Content
c.Sessions.Save(session)
}
}
func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
session, err := c.Sessions.Get(sessionID)
if err != nil {
return err
}
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
model.CostPer1MIn/1e6*float64(usage.InputTokens) +
model.CostPer1MOut/1e6*float64(usage.OutputTokens)
session.Cost += cost
session.CompletionTokens += usage.OutputTokens
session.PromptTokens += usage.InputTokens
_, err = c.Sessions.Save(session)
return err
}
func (c *agent) processEvent(
sessionID string,
assistantMsg *message.Message,
event provider.ProviderEvent,
) error {
switch event.Type {
case provider.EventThinkingDelta:
assistantMsg.Thinking += event.Thinking
return c.Messages.Update(*assistantMsg)
case provider.EventContentDelta:
assistantMsg.Content += event.Content
return c.Messages.Update(*assistantMsg)
case provider.EventError:
log.Println("error", event.Error)
return event.Error
case provider.EventComplete:
assistantMsg.ToolCalls = event.Response.ToolCalls
err := c.Messages.Update(*assistantMsg)
if err != nil {
return err
}
return c.TrackUsage(sessionID, c.model, event.Response.Usage)
}
return nil
}
func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
var wg sync.WaitGroup
toolResults := make([]message.ToolResult, len(toolCalls))
mutex := &sync.Mutex{}
for i, tc := range toolCalls {
wg.Add(1)
go func(index int, toolCall message.ToolCall) {
defer wg.Done()
response := ""
isError := false
found := false
for _, tool := range tls {
if tool.Info().Name == toolCall.Name {
found = true
toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
ID: toolCall.ID,
Name: toolCall.Name,
Input: toolCall.Input,
})
if toolErr != nil {
response = fmt.Sprintf("error running tool: %s", toolErr)
isError = true
} else {
response = toolResult.Content
isError = toolResult.IsError
}
break
}
}
if !found {
response = fmt.Sprintf("tool not found: %s", toolCall.Name)
isError = true
}
mutex.Lock()
defer mutex.Unlock()
toolResults[index] = message.ToolResult{
ToolCallID: toolCall.ID,
Content: response,
IsError: isError,
}
}(i, tc)
}
wg.Wait()
return toolResults, nil
}
func (c *agent) handleToolExecution(
ctx context.Context,
assistantMsg message.Message,
) (*message.Message, error) {
if len(assistantMsg.ToolCalls) == 0 {
return nil, nil
}
toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls, c.tools)
if err != nil {
return nil, err
}
msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
Role: message.Tool,
ToolResults: toolResults,
})
return &msg, err
}
func (c *agent) generate(sessionID string, content string) error {
messages, err := c.Messages.List(sessionID)
if err != nil {
return err
}
if len(messages) == 0 {
go c.handleTitleGeneration(sessionID, content)
}
userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
Role: message.User,
Content: content,
})
if err != nil {
return err
}
messages = append(messages, userMsg)
for {
eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
if err != nil {
return err
}
assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
Role: message.Assistant,
Content: "",
})
if err != nil {
return err
}
for event := range eventChan {
err = c.processEvent(sessionID, &assistantMsg, event)
if err != nil {
assistantMsg.Finished = true
c.Messages.Update(assistantMsg)
return err
}
}
msg, err := c.handleToolExecution(c.Context, assistantMsg)
assistantMsg.Finished = true
c.Messages.Update(assistantMsg)
if err != nil {
return err
}
if len(assistantMsg.ToolCalls) == 0 {
break
}
messages = append(messages, assistantMsg)
if msg != nil {
messages = append(messages, *msg)
}
}
return nil
}
func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
maxTokens := config.Get().Model.CoderMaxTokens
providerConfig, ok := config.Get().Providers[model.Provider]
if !ok || !providerConfig.Enabled {
return nil, nil, errors.New("provider is not enabled")
}
var agentProvider provider.Provider
var titleGenerator provider.Provider
switch model.Provider {
case models.ProviderOpenAI:
var err error
agentProvider, err = provider.NewOpenAIProvider(
provider.WithOpenAISystemMessage(
prompt.CoderOpenAISystemPrompt(),
),
provider.WithOpenAIMaxTokens(maxTokens),
provider.WithOpenAIModel(model),
provider.WithOpenAIKey(providerConfig.APIKey),
)
if err != nil {
return nil, nil, err
}
titleGenerator, err = provider.NewOpenAIProvider(
provider.WithOpenAISystemMessage(
prompt.TitlePrompt(),
),
provider.WithOpenAIMaxTokens(80),
provider.WithOpenAIModel(model),
provider.WithOpenAIKey(providerConfig.APIKey),
)
if err != nil {
return nil, nil, err
}
case models.ProviderAnthropic:
var err error
agentProvider, err = provider.NewAnthropicProvider(
provider.WithAnthropicSystemMessage(
prompt.CoderAnthropicSystemPrompt(),
),
provider.WithAnthropicMaxTokens(maxTokens),
provider.WithAnthropicKey(providerConfig.APIKey),
provider.WithAnthropicModel(model),
)
if err != nil {
return nil, nil, err
}
titleGenerator, err = provider.NewAnthropicProvider(
provider.WithAnthropicSystemMessage(
prompt.TitlePrompt(),
),
provider.WithAnthropicMaxTokens(80),
provider.WithAnthropicKey(providerConfig.APIKey),
provider.WithAnthropicModel(model),
)
if err != nil {
return nil, nil, err
}
case models.ProviderGemini:
var err error
agentProvider, err = provider.NewGeminiProvider(
ctx,
provider.WithGeminiSystemMessage(
prompt.CoderOpenAISystemPrompt(),
),
provider.WithGeminiMaxTokens(int32(maxTokens)),
provider.WithGeminiKey(providerConfig.APIKey),
provider.WithGeminiModel(model),
)
if err != nil {
return nil, nil, err
}
titleGenerator, err = provider.NewGeminiProvider(
ctx,
provider.WithGeminiSystemMessage(
prompt.TitlePrompt(),
),
provider.WithGeminiMaxTokens(80),
provider.WithGeminiKey(providerConfig.APIKey),
provider.WithGeminiModel(model),
)
if err != nil {
return nil, nil, err
}
case models.ProviderGROQ:
var err error
agentProvider, err = provider.NewOpenAIProvider(
provider.WithOpenAISystemMessage(
prompt.CoderAnthropicSystemPrompt(),
),
provider.WithOpenAIMaxTokens(maxTokens),
provider.WithOpenAIModel(model),
provider.WithOpenAIKey(providerConfig.APIKey),
provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
)
if err != nil {
return nil, nil, err
}
titleGenerator, err = provider.NewOpenAIProvider(
provider.WithOpenAISystemMessage(
prompt.TitlePrompt(),
),
provider.WithOpenAIMaxTokens(80),
provider.WithOpenAIModel(model),
provider.WithOpenAIKey(providerConfig.APIKey),
provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
)
if err != nil {
return nil, nil, err
}
}
return agentProvider, titleGenerator, nil
}

View File

@@ -1,182 +1,67 @@
package agent
import (
"context"
"fmt"
"os"
"path/filepath"
"runtime"
"time"
"errors"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/flow/agent/react"
"github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/spf13/viper"
)
func coderTools() []tool.BaseTool {
wd := viper.GetString("wd")
return []tool.BaseTool{
tools.NewAgentTool(wd),
tools.NewBashTool(wd),
tools.NewLsTool(wd),
tools.NewGlobTool(wd),
tools.NewViewTool(wd),
tools.NewWriteTool(wd),
tools.NewEditTool(wd),
type coderAgent struct {
*agent
}
func (c *coderAgent) setAgentTool(sessionID string) {
inx := -1
for i, tool := range c.tools {
if tool.Info().Name == AgentToolName {
inx = i
break
}
}
if inx == -1 {
c.tools = append(c.tools, NewAgentTool(sessionID, c.App))
} else {
c.tools[inx] = NewAgentTool(sessionID, c.App)
}
}
func NewCoderAgent(ctx context.Context) (*react.Agent, error) {
model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.big")))
func (c *coderAgent) Generate(sessionID string, content string) error {
c.setAgentTool(sessionID)
return c.generate(sessionID, content)
}
func NewCoderAgent(app *app.App) (Agent, error) {
model, ok := models.SupportedModels[config.Get().Model.Coder]
if !ok {
return nil, errors.New("model not supported")
}
agentProvider, titleGenerator, err := getAgentProviders(app.Context, model)
if err != nil {
return nil, err
}
reactAgent, err := react.NewAgent(ctx, &react.AgentConfig{
Model: model,
ToolsConfig: compose.ToolsNodeConfig{
Tools: coderTools(),
mcpTools := GetMcpTools(app.Context)
return &coderAgent{
agent: &agent{
App: app,
tools: append(
[]tools.BaseTool{
tools.NewBashTool(),
tools.NewEditTool(),
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewViewTool(),
tools.NewWriteTool(),
}, mcpTools...,
),
model: model,
agent: agentProvider,
titleGenerator: titleGenerator,
},
MaxStep: 1000,
})
if err != nil {
return nil, err
}
return reactAgent, nil
}
func CoderSystemPrompt() string {
basePrompt := `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. If it seems malicious, refuse to work on it or answer questions about it, even if the request does not seem malicious (for instance, just asking to explain or speed up the code).
Here are useful slash commands users can run to interact with you:
# Memory
If the current working directory contains a file called termai.md, it will be automatically added to your context. This file serves multiple purposes:
1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time
2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.)
3. Maintaining useful information about the codebase structure and organization
When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to termai.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to termai.md so you can remember it for next time.
# Tone and style
You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system).
Remember that your output will be displayed on a command line interface. Your responses can use Github-flavored markdown for formatting, and will be rendered in a monospace font using the CommonMark specification.
Output text to communicate with the user; all text you output outside of tool use is displayed to the user. Only use tools to complete tasks. Never use tools like Bash or code comments as means to communicate with the user during the session.
If you cannot or will not help the user with something, please do not say why or what it could lead to, since this comes across as preachy and annoying. Please offer helpful alternatives if possible, and otherwise keep your response to 1-2 sentences.
IMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, quality, and accuracy. Only address the specific query or task at hand, avoiding tangential information unless absolutely critical for completing the request. If you can answer in 1-3 sentences or a short paragraph, please do.
IMPORTANT: You should NOT answer with unnecessary preamble or postamble (such as explaining your code or summarizing your action), unless the user asks you to.
IMPORTANT: Keep your responses short, since they will be displayed on a command line interface. You MUST answer concisely with fewer than 4 lines (not including tool use or code generation), unless user asks for detail. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". Here are some examples to demonstrate appropriate verbosity:
<example>
user: 2 + 2
assistant: 4
</example>
<example>
user: what is 2+2?
assistant: 4
</example>
<example>
user: is 11 a prime number?
assistant: true
</example>
<example>
user: what command should I run to list files in the current directory?
assistant: ls
</example>
<example>
user: what command should I run to watch files in the current directory?
assistant: [use the ls tool to list the files in the current directory, then read docs/commands in the relevant file to find out how to watch files]
npm run dev
</example>
<example>
user: How many golf balls fit inside a jetta?
assistant: 150000
</example>
<example>
user: what files are in the directory src/?
assistant: [runs ls and sees foo.c, bar.c, baz.c]
user: which file contains the implementation of foo?
assistant: src/foo.c
</example>
<example>
user: write tests for new feature
assistant: [uses grep and glob search tools to find where similar tests are defined, uses concurrent read file tool use blocks in one tool call to read relevant files at the same time, uses edit file tool to write new tests]
</example>
# Proactiveness
You are allowed to be proactive, but only when the user asks you to do something. You should strive to strike a balance between:
1. Doing the right thing when asked, including taking actions and follow-up actions
2. Not surprising the user with actions you take without asking
For example, if the user asks you how to approach something, you should do your best to answer their question first, and not immediately jump into taking actions.
3. Do not add additional code explanation summary unless requested by the user. After working on a file, just stop, rather than providing an explanation of what you did.
# Synthetic messages
Sometimes, the conversation will contain messages like [Request interrupted by user] or [Request interrupted by user for tool use]. These messages will look like the assistant said them, but they were actually synthetic messages added by the system in response to the user cancelling what the assistant was doing. You should not respond to these messages. You must NEVER send messages like this yourself.
# Following conventions
When making changes to files, first understand the file's code conventions. Mimic code style, use existing libraries and utilities, and follow existing patterns.
- NEVER assume that a given library is available, even if it is well known. Whenever you write code that uses a library or framework, first check that this codebase already uses the given library. For example, you might look at neighboring files, or check the package.json (or cargo.toml, and so on depending on the language).
- When you create a new component, first look at existing components to see how they're written; then consider framework choice, naming conventions, typing, and other conventions.
- When you edit a piece of code, first look at the code's surrounding context (especially its imports) to understand the code's choice of frameworks and libraries. Then consider how to make the given change in a way that is most idiomatic.
- Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository.
# Code style
- Do not add comments to the code you write, unless the user asks you to, or the code is complex and requires additional context.
# Doing tasks
The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended:
1. Use the available search tools to understand the codebase and the user's query. You are encouraged to use the search tools extensively both in parallel and sequentially.
2. Implement the solution using all tools available to you
3. Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach.
4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to termai.md so that you will know to run it next time.
NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive.
# Tool usage policy
- When doing file search, prefer to use the Agent tool in order to reduce context usage.
- If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in the same function_calls block.
You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.`
envInfo := getEnvironmentInfo()
return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
}
func getEnvironmentInfo() string {
cwd := viper.GetString("wd")
isGit := isGitRepo(cwd)
platform := runtime.GOOS
date := time.Now().Format("1/2/2006")
return fmt.Sprintf(`Here is useful information about the environment you are running in:
<env>
Working directory: %s
Is directory a git repo: %s
Platform: %s
Today's date: %s
</env>`, cwd, boolToYesNo(isGit), platform, date)
}
func isGitRepo(dir string) bool {
_, err := os.Stat(filepath.Join(dir, ".git"))
return err == nil
}
func boolToYesNo(b bool) string {
if b {
return "Yes"
}
return "No"
}, nil
}

View File

@@ -0,0 +1,190 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"log"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/permission"
"github.com/kujtimiihoxha/termai/internal/version"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
)
type mcpTool struct {
mcpName string
tool mcp.Tool
mcpConfig config.MCPServer
}
type MCPClient interface {
Initialize(
ctx context.Context,
request mcp.InitializeRequest,
) (*mcp.InitializeResult, error)
ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
Close() error
}
func (b *mcpTool) Info() tools.ToolInfo {
return tools.ToolInfo{
Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
Description: b.tool.Description,
Parameters: b.tool.InputSchema.Properties,
Required: b.tool.InputSchema.Required,
}
}
func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
defer c.Close()
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "termai",
Version: version.Version,
}
_, err := c.Initialize(ctx, initRequest)
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
toolRequest := mcp.CallToolRequest{}
toolRequest.Params.Name = toolName
var args map[string]any
if err = json.Unmarshal([]byte(input), &input); err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
toolRequest.Params.Arguments = args
result, err := c.CallTool(ctx, toolRequest)
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
output := ""
for _, v := range result.Content {
if v, ok := v.(mcp.TextContent); ok {
output = v.Text
} else {
output = fmt.Sprintf("%v", v)
}
}
return tools.NewTextResponse(output), nil
}
func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
p := permission.Default.Request(
permission.CreatePermissionRequest{
Path: config.WorkingDirectory(),
ToolName: b.Info().Name,
Action: "execute",
Description: permissionDescription,
Params: params.Input,
},
)
if !p {
return tools.NewTextErrorResponse("permission denied"), nil
}
switch b.mcpConfig.Type {
case config.MCPStdio:
c, err := client.NewStdioMCPClient(
b.mcpConfig.Command,
b.mcpConfig.Env,
b.mcpConfig.Args...,
)
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
return runTool(ctx, c, b.tool.Name, params.Input)
case config.MCPSse:
c, err := client.NewSSEMCPClient(
b.mcpConfig.URL,
client.WithHeaders(b.mcpConfig.Headers),
)
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
return runTool(ctx, c, b.tool.Name, params.Input)
}
return tools.NewTextErrorResponse("invalid mcp type"), nil
}
func NewMcpTool(name string, tool mcp.Tool, mcpConfig config.MCPServer) tools.BaseTool {
return &mcpTool{
mcpName: name,
tool: tool,
mcpConfig: mcpConfig,
}
}
var mcpTools []tools.BaseTool
func getTools(ctx context.Context, name string, m config.MCPServer, c MCPClient) []tools.BaseTool {
var stdioTools []tools.BaseTool
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "termai",
Version: version.Version,
}
_, err := c.Initialize(ctx, initRequest)
if err != nil {
log.Printf("error initializing mcp client: %s", err)
return stdioTools
}
toolsRequest := mcp.ListToolsRequest{}
tools, err := c.ListTools(ctx, toolsRequest)
if err != nil {
log.Printf("error listing tools: %s", err)
return stdioTools
}
for _, t := range tools.Tools {
stdioTools = append(stdioTools, NewMcpTool(name, t, m))
}
defer c.Close()
return stdioTools
}
func GetMcpTools(ctx context.Context) []tools.BaseTool {
if len(mcpTools) > 0 {
return mcpTools
}
for name, m := range config.Get().MCPServers {
switch m.Type {
case config.MCPStdio:
c, err := client.NewStdioMCPClient(
m.Command,
m.Env,
m.Args...,
)
if err != nil {
log.Printf("error creating mcp client: %s", err)
continue
}
mcpTools = append(mcpTools, getTools(ctx, name, m, c)...)
case config.MCPSse:
c, err := client.NewSSEMCPClient(
m.URL,
client.WithHeaders(m.Headers),
)
if err != nil {
log.Printf("error creating mcp client: %s", err)
continue
}
mcpTools = append(mcpTools, getTools(ctx, name, m, c)...)
}
}
return mcpTools
}

View File

@@ -0,0 +1,44 @@
package agent
import (
"errors"
"github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
)
type taskAgent struct {
*agent
}
func (c *taskAgent) Generate(sessionID string, content string) error {
return c.generate(sessionID, content)
}
func NewTaskAgent(app *app.App) (Agent, error) {
model, ok := models.SupportedModels[config.Get().Model.Coder]
if !ok {
return nil, errors.New("model not supported")
}
agentProvider, titleGenerator, err := getAgentProviders(app.Context, model)
if err != nil {
return nil, err
}
return &taskAgent{
agent: &agent{
App: app,
tools: []tools.BaseTool{
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewViewTool(),
},
model: model,
agent: agentProvider,
titleGenerator: titleGenerator,
},
}, nil
}

View File

@@ -1,31 +0,0 @@
package agent
import (
"context"
"github.com/cloudwego/eino/schema"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/spf13/viper"
)
func GenerateTitle(ctx context.Context, content string) (string, error) {
model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.small")))
if err != nil {
return "", err
}
out, err := model.Generate(
ctx,
[]*schema.Message{
schema.SystemMessage(`- you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 80 characters long
- the title should be a summary of the user's message
- do not use quotes or colons
- the entire text you return will be used as the title`),
schema.UserMessage(content),
},
)
if err != nil {
return "", err
}
return out.Content, nil
}

View File

@@ -1,229 +0,0 @@
package llm
import (
"context"
"log"
"sync"
"time"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
"github.com/kujtimiihoxha/termai/internal/llm/agent"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/kujtimiihoxha/termai/internal/session"
eModel "github.com/cloudwego/eino/components/model"
enioAgent "github.com/cloudwego/eino/flow/agent"
"github.com/spf13/viper"
)
const (
AgentRequestoEvent pubsub.EventType = "agent_request"
AgentErrorEvent pubsub.EventType = "agent_error"
AgentResponseEvent pubsub.EventType = "agent_response"
)
type AgentMessageType int
const (
AgentMessageTypeNewUserMessage AgentMessageType = iota
AgentMessageTypeAgentResponse
AgentMessageTypeError
)
type agentID string
const (
RootAgent agentID = "root"
TaskAgent agentID = "task"
)
type AgentEvent struct {
ID string `json:"id"`
Type AgentMessageType `json:"type"`
AgentID agentID `json:"agent_id"`
MessageID string `json:"message_id"`
SessionID string `json:"session_id"`
Content string `json:"content"`
}
type Service interface {
pubsub.Suscriber[AgentEvent]
SendRequest(sessionID string, content string)
}
type service struct {
*pubsub.Broker[AgentEvent]
Requests sync.Map
ctx context.Context
activeRequests sync.Map
messages message.Service
sessions session.Service
logger logging.Interface
}
func (s *service) handleRequest(id string, sessionID string, content string) {
cancel, ok := s.activeRequests.Load(id)
if !ok {
return
}
defer cancel.(context.CancelFunc)()
defer s.activeRequests.Delete(id)
history, err := s.messages.List(sessionID)
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
Type: AgentMessageTypeError,
AgentID: RootAgent,
MessageID: "",
SessionID: sessionID,
Content: err.Error(),
})
return
}
currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
Type: AgentMessageTypeError,
AgentID: RootAgent,
MessageID: "",
SessionID: sessionID,
Content: err.Error(),
})
return
}
messages := []*schema.Message{
{
Role: schema.System,
Content: systemMessage,
},
}
for _, m := range history {
messages = append(messages, &m.MessageData)
}
builder := callbacks.NewHandlerBuilder()
builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
i, ok := input.(*eModel.CallbackInput)
if info.Component == "ChatModel" && ok {
if len(messages) < len(i.Messages) {
// find new messages
newMessages := i.Messages[len(messages):]
for _, m := range newMessages {
_, err = s.messages.Create(sessionID, *m)
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
Type: AgentMessageTypeError,
AgentID: RootAgent,
MessageID: "",
SessionID: sessionID,
Content: err.Error(),
})
}
messages = append(messages, m)
}
}
}
return ctx
})
builder.OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
return ctx
})
out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
Type: AgentMessageTypeError,
AgentID: RootAgent,
MessageID: "",
SessionID: sessionID,
Content: err.Error(),
})
return
}
usage := out.ResponseMeta.Usage
s.messages.Create(sessionID, *out)
if usage != nil {
log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
session, err := s.sessions.Get(sessionID)
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
Type: AgentMessageTypeError,
AgentID: RootAgent,
MessageID: "",
SessionID: sessionID,
Content: err.Error(),
})
return
}
session.PromptTokens += int64(usage.PromptTokens)
session.CompletionTokens += int64(usage.CompletionTokens)
model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
var newTitle string
if len(history) == 1 {
// first message generate the title
newTitle, err = agent.GenerateTitle(s.ctx, content)
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
Type: AgentMessageTypeError,
AgentID: RootAgent,
MessageID: "",
SessionID: sessionID,
Content: err.Error(),
})
return
}
}
if newTitle != "" {
session.Title = newTitle
}
_, err = s.sessions.Save(session)
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
Type: AgentMessageTypeError,
AgentID: RootAgent,
MessageID: "",
SessionID: sessionID,
Content: err.Error(),
})
return
}
}
}
func (s *service) SendRequest(sessionID string, content string) {
id := uuid.New().String()
_, cancel := context.WithTimeout(s.ctx, 5*time.Minute)
s.activeRequests.Store(id, cancel)
log.Printf("Request: %s", content)
go s.handleRequest(id, sessionID, content)
}
func NewService(ctx context.Context, logger logging.Interface, sessions session.Service, messages message.Service) Service {
return &service{
Broker: pubsub.NewBroker[AgentEvent](),
ctx: ctx,
sessions: sessions,
messages: messages,
logger: logger,
}
}

View File

@@ -1,230 +1,122 @@
package models
import (
"context"
"errors"
"log"
"github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/spf13/viper"
)
type (
ModelID string
ModelProvider string
)
type Model struct {
ID ModelID `json:"id"`
Name string `json:"name"`
Provider ModelProvider `json:"provider"`
APIModel string `json:"api_model"`
CostPer1MIn float64 `json:"cost_per_1m_in"`
CostPer1MOut float64 `json:"cost_per_1m_out"`
ID ModelID `json:"id"`
Name string `json:"name"`
Provider ModelProvider `json:"provider"`
APIModel string `json:"api_model"`
CostPer1MIn float64 `json:"cost_per_1m_in"`
CostPer1MOut float64 `json:"cost_per_1m_out"`
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
}
const (
DefaultBigModel = Claude37Sonnet
DefaultLittleModel = Claude37Sonnet
)
// Model IDs
const (
// OpenAI
GPT4o ModelID = "gpt-4o"
GPT4oMini ModelID = "gpt-4o-mini"
GPT45 ModelID = "gpt-4.5"
O1 ModelID = "o1"
O1Mini ModelID = "o1-mini"
// Anthropic
Claude35Sonnet ModelID = "claude-3.5-sonnet"
Claude3Haiku ModelID = "claude-3-haiku"
Claude37Sonnet ModelID = "claude-3.7-sonnet"
// Google
Gemini20Pro ModelID = "gemini-2.0-pro"
Gemini15Flash ModelID = "gemini-1.5-flash"
Gemini20Flash ModelID = "gemini-2.0-flash"
// xAI
Grok3 ModelID = "grok-3"
Grok2Mini ModelID = "grok-2-mini"
// DeepSeek
DeepSeekR1 ModelID = "deepseek-r1"
DeepSeekCoder ModelID = "deepseek-coder"
// Meta
Llama3 ModelID = "llama-3"
Llama270B ModelID = "llama-2-70b"
// OpenAI
GPT4o ModelID = "gpt-4o"
// GEMINI
GEMINI25 ModelID = "gemini-2.5"
GRMINI20Flash ModelID = "gemini-2.0-flash"
// GROQ
GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
GroqQwen32BCoder ModelID = "qwen-2.5-coder-32b"
QWENQwq ModelID = "qwen-qwq"
)
const (
ProviderOpenAI ModelProvider = "openai"
ProviderAnthropic ModelProvider = "anthropic"
ProviderGoogle ModelProvider = "google"
ProviderXAI ModelProvider = "xai"
ProviderDeepSeek ModelProvider = "deepseek"
ProviderMeta ModelProvider = "meta"
ProviderGroq ModelProvider = "groq"
ProviderGemini ModelProvider = "gemini"
ProviderGROQ ModelProvider = "groq"
)
var SupportedModels = map[ModelID]Model{
// OpenAI
GPT4o: {
ID: GPT4o,
Name: "GPT-4o",
Provider: ProviderOpenAI,
APIModel: "gpt-4o",
},
GPT4oMini: {
ID: GPT4oMini,
Name: "GPT-4o Mini",
Provider: ProviderOpenAI,
APIModel: "gpt-4o-mini",
CostPer1MIn: 0.150,
CostPer1MOut: 0.600,
},
GPT45: {
ID: GPT45,
Name: "GPT-4.5",
Provider: ProviderOpenAI,
APIModel: "gpt-4.5",
},
O1: {
ID: O1,
Name: "o1",
Provider: ProviderOpenAI,
APIModel: "o1",
},
O1Mini: {
ID: O1Mini,
Name: "o1 Mini",
Provider: ProviderOpenAI,
APIModel: "o1-mini",
},
// Anthropic
Claude35Sonnet: {
ID: Claude35Sonnet,
Name: "Claude 3.5 Sonnet",
Provider: ProviderAnthropic,
APIModel: "claude-3.5-sonnet",
ID: Claude35Sonnet,
Name: "Claude 3.5 Sonnet",
Provider: ProviderAnthropic,
APIModel: "claude-3-5-sonnet-latest",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
},
Claude3Haiku: {
ID: Claude3Haiku,
Name: "Claude 3 Haiku",
Provider: ProviderAnthropic,
APIModel: "claude-3-haiku",
ID: Claude3Haiku,
Name: "Claude 3 Haiku",
Provider: ProviderAnthropic,
APIModel: "claude-3-haiku-latest",
CostPer1MIn: 0.80,
CostPer1MInCached: 1,
CostPer1MOutCached: 0.08,
CostPer1MOut: 4,
},
Claude37Sonnet: {
ID: Claude37Sonnet,
Name: "Claude 3.7 Sonnet",
Provider: ProviderAnthropic,
APIModel: "claude-3-7-sonnet-20250219",
CostPer1MIn: 3.0,
CostPer1MOut: 15.0,
ID: Claude37Sonnet,
Name: "Claude 3.7 Sonnet",
Provider: ProviderAnthropic,
APIModel: "claude-3-7-sonnet-latest",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
},
// Google
Gemini20Pro: {
ID: Gemini20Pro,
Name: "Gemini 2.0 Pro",
Provider: ProviderGoogle,
APIModel: "gemini-2.0-pro",
// OpenAI
GPT4o: {
ID: GPT4o,
Name: "GPT-4o",
Provider: ProviderOpenAI,
APIModel: "gpt-4o",
CostPer1MIn: 2.50,
CostPer1MInCached: 1.25,
CostPer1MOutCached: 0,
CostPer1MOut: 10.00,
},
Gemini15Flash: {
ID: Gemini15Flash,
Name: "Gemini 1.5 Flash",
Provider: ProviderGoogle,
APIModel: "gemini-1.5-flash",
// GEMINI
GEMINI25: {
ID: GEMINI25,
Name: "Gemini 2.5 Pro",
Provider: ProviderGemini,
APIModel: "gemini-2.5-pro-exp-03-25",
CostPer1MIn: 0,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0,
},
Gemini20Flash: {
ID: Gemini20Flash,
Name: "Gemini 2.0 Flash",
Provider: ProviderGoogle,
APIModel: "gemini-2.0-flash",
},
// xAI
Grok3: {
ID: Grok3,
Name: "Grok 3",
Provider: ProviderXAI,
APIModel: "grok-3",
},
Grok2Mini: {
ID: Grok2Mini,
Name: "Grok 2 Mini",
Provider: ProviderXAI,
APIModel: "grok-2-mini",
},
// DeepSeek
DeepSeekR1: {
ID: DeepSeekR1,
Name: "DeepSeek R1",
Provider: ProviderDeepSeek,
APIModel: "deepseek-r1",
},
DeepSeekCoder: {
ID: DeepSeekCoder,
Name: "DeepSeek Coder",
Provider: ProviderDeepSeek,
APIModel: "deepseek-coder",
},
// Meta
Llama3: {
ID: Llama3,
Name: "LLaMA 3",
Provider: ProviderMeta,
APIModel: "llama-3",
},
Llama270B: {
ID: Llama270B,
Name: "LLaMA 2 70B",
Provider: ProviderMeta,
APIModel: "llama-2-70b",
GRMINI20Flash: {
ID: GRMINI20Flash,
Name: "Gemini 2.0 Flash",
Provider: ProviderGemini,
APIModel: "gemini-2.0-flash",
CostPer1MIn: 0.1,
CostPer1MInCached: 0,
CostPer1MOutCached: 0.025,
CostPer1MOut: 0.4,
},
// GROQ
GroqLlama3SpecDec: {
ID: GroqLlama3SpecDec,
Name: "GROQ LLaMA 3 SpecDec",
Provider: ProviderGroq,
APIModel: "llama-3.3-70b-specdec",
},
GroqQwen32BCoder: {
ID: GroqQwen32BCoder,
Name: "GROQ Qwen 2.5 Coder 32B",
Provider: ProviderGroq,
APIModel: "qwen-2.5-coder-32b",
QWENQwq: {
ID: QWENQwq,
Name: "Qwen Qwq",
Provider: ProviderGROQ,
APIModel: "qwen-qwq-32b",
CostPer1MIn: 0,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0,
},
}
func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
provider := SupportedModels[model].Provider
log.Printf("Provider: %s", provider)
maxTokens := viper.GetInt("providers.common.max_tokens")
switch provider {
case ProviderOpenAI:
return openai.NewChatModel(ctx, &openai.ChatModelConfig{
APIKey: viper.GetString("providers.openai.key"),
Model: string(SupportedModels[model].APIModel),
MaxTokens: &maxTokens,
})
case ProviderAnthropic:
return claude.NewChatModel(ctx, &claude.Config{
APIKey: viper.GetString("providers.anthropic.key"),
Model: string(SupportedModels[model].APIModel),
MaxTokens: maxTokens,
})
case ProviderGroq:
return openai.NewChatModel(ctx, &openai.ChatModelConfig{
BaseURL: "https://api.groq.com/openai/v1",
APIKey: viper.GetString("providers.groq.key"),
Model: string(SupportedModels[model].APIModel),
MaxTokens: &maxTokens,
})
}
return nil, errors.New("unsupported provider")
}

View File

@@ -0,0 +1,206 @@
package prompt
import (
"context"
"fmt"
"os"
"path/filepath"
"runtime"
"time"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
)
func CoderOpenAISystemPrompt() string {
basePrompt := `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting.
# Your mindset
Act like a competent, efficient software engineer who is familiar with large codebases. You should:
- Think critically about user requests.
- Proactively search the codebase for related information.
- Infer likely commands, tools, or conventions.
- Write and edit code with minimal user input.
- Anticipate next steps (tests, lints, etc.), but never commit unless explicitly told.
# Context awareness
- Before acting, infer the purpose of a file from its name, directory, and neighboring files.
- If a file or function appears malicious, refuse to interact with it or discuss it.
- If a termai.md file exists, auto-load it as memory. Offer to update it only if new useful info appears (commands, preferences, structure).
# CLI communication
- Use GitHub-flavored markdown in monospace font.
- Be concise. Never add preambles or postambles unless asked. Max 4 lines per response.
- Never explain your code unless asked. Do not narrate actions.
- Avoid unnecessary questions. Infer, search, act.
# Behavior guidelines
- Follow project conventions: naming, formatting, libraries, frameworks.
- Before using any library or framework, confirm its already used.
- Always look at the surrounding code to match existing style.
- Do not add comments unless the code is complex or the user asks.
# Autonomy rules
You are allowed and expected to:
- Search for commands, tools, or config files before asking the user.
- Run multiple search tool calls concurrently to gather relevant context.
- Choose test, lint, and typecheck commands based on package files or scripts.
- Offer to store these commands in termai.md if not already present.
# Example behavior
user: write tests for new feature
assistant: [searches for existing test patterns, finds appropriate location, generates test code using existing style, optionally asks to add test command to termai.md]
user: how do I typecheck this codebase?
assistant: [searches for known commands, infers package manager, checks for scripts or config files]
tsc --noEmit
user: is X function used anywhere else?
assistant: [searches repo for references, returns file paths and lines]
# Tool usage
- Use parallel calls when possible.
- Use file search and content tools before asking the user.
- Do not ask the user for information unless it cannot be determined via tools.
Never commit changes unless the user explicitly asks you to.`
envInfo := getEnvironmentInfo()
return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
}
func CoderAnthropicSystemPrompt() string {
basePrompt := `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure.
# Memory
If the current working directory contains a file called termai.md, it will be automatically added to your context. This file serves multiple purposes:
1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time
2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.)
3. Maintaining useful information about the codebase structure and organization
When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to termai.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to termai.md so you can remember it for next time.
# Tone and style
You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system).
Remember that your output will be displayed on a command line interface. Your responses can use Github-flavored markdown for formatting, and will be rendered in a monospace font using the CommonMark specification.
Output text to communicate with the user; all text you output outside of tool use is displayed to the user. Only use tools to complete tasks. Never use tools like Bash or code comments as means to communicate with the user during the session.
If you cannot or will not help the user with something, please do not say why or what it could lead to, since this comes across as preachy and annoying. Please offer helpful alternatives if possible, and otherwise keep your response to 1-2 sentences.
IMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, quality, and accuracy. Only address the specific query or task at hand, avoiding tangential information unless absolutely critical for completing the request. If you can answer in 1-3 sentences or a short paragraph, please do.
IMPORTANT: You should NOT answer with unnecessary preamble or postamble (such as explaining your code or summarizing your action), unless the user asks you to.
IMPORTANT: Keep your responses short, since they will be displayed on a command line interface. You MUST answer concisely with fewer than 4 lines (not including tool use or code generation), unless user asks for detail. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". Here are some examples to demonstrate appropriate verbosity:
<example>
user: 2 + 2
assistant: 4
</example>
<example>
user: what is 2+2?
assistant: 4
</example>
<example>
user: is 11 a prime number?
assistant: true
</example>
<example>
user: what command should I run to list files in the current directory?
assistant: ls
</example>
<example>
user: what command should I run to watch files in the current directory?
assistant: [use the ls tool to list the files in the current directory, then read docs/commands in the relevant file to find out how to watch files]
npm run dev
</example>
<example>
user: How many golf balls fit inside a jetta?
assistant: 150000
</example>
<example>
user: what files are in the directory src/?
assistant: [runs ls and sees foo.c, bar.c, baz.c]
user: which file contains the implementation of foo?
assistant: src/foo.c
</example>
<example>
user: write tests for new feature
assistant: [uses grep and glob search tools to find where similar tests are defined, uses concurrent read file tool use blocks in one tool call to read relevant files at the same time, uses edit file tool to write new tests]
</example>
# Proactiveness
You are allowed to be proactive, but only when the user asks you to do something. You should strive to strike a balance between:
1. Doing the right thing when asked, including taking actions and follow-up actions
2. Not surprising the user with actions you take without asking
For example, if the user asks you how to approach something, you should do your best to answer their question first, and not immediately jump into taking actions.
3. Do not add additional code explanation summary unless requested by the user. After working on a file, just stop, rather than providing an explanation of what you did.
# Following conventions
When making changes to files, first understand the file's code conventions. Mimic code style, use existing libraries and utilities, and follow existing patterns.
- NEVER assume that a given library is available, even if it is well known. Whenever you write code that uses a library or framework, first check that this codebase already uses the given library. For example, you might look at neighboring files, or check the package.json (or cargo.toml, and so on depending on the language).
- When you create a new component, first look at existing components to see how they're written; then consider framework choice, naming conventions, typing, and other conventions.
- When you edit a piece of code, first look at the code's surrounding context (especially its imports) to understand the code's choice of frameworks and libraries. Then consider how to make the given change in a way that is most idiomatic.
- Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository.
# Code style
- Do not add comments to the code you write, unless the user asks you to, or the code is complex and requires additional context.
# Doing tasks
The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended:
1. Use the available search tools to understand the codebase and the user's query. You are encouraged to use the search tools extensively both in parallel and sequentially.
2. Implement the solution using all tools available to you
3. Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach.
4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to termai.md so that you will know to run it next time.
NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive.
# Tool usage policy
- When doing file search, prefer to use the Agent tool in order to reduce context usage.
- If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in the same function_calls block.
You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.`
envInfo := getEnvironmentInfo()
return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
}
func getEnvironmentInfo() string {
cwd := config.WorkingDirectory()
isGit := isGitRepo(cwd)
platform := runtime.GOOS
date := time.Now().Format("1/2/2006")
ls := tools.NewLsTool()
r, _ := ls.Run(context.Background(), tools.ToolCall{
Input: `{"path":"."}`,
})
return fmt.Sprintf(`Here is useful information about the environment you are running in:
<env>
Working directory: %s
Is directory a git repo: %s
Platform: %s
Today's date: %s
</env>
<project>
%s
</project>
`, cwd, boolToYesNo(isGit), platform, date, r.Content)
}
func isGitRepo(dir string) bool {
_, err := os.Stat(filepath.Join(dir, ".git"))
return err == nil
}
func boolToYesNo(b bool) string {
if b {
return "Yes"
}
return "No"
}

View File

@@ -0,0 +1,16 @@
package prompt
import (
"fmt"
)
func TaskAgentSystemPrompt() string {
agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question.
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2. When relevant, share file names and code snippets relevant to the query
3. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.`
return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo())
}

View File

@@ -0,0 +1,9 @@
package prompt
func TitlePrompt() string {
return `you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 50 characters long
- the title should be a summary of the user's message
- do not use quotes or colons
- the entire text you return will be used as the title`
}

View File

@@ -0,0 +1,309 @@
package provider
import (
"context"
"encoding/json"
"errors"
"strings"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
type anthropicProvider struct {
client anthropic.Client
model models.Model
maxTokens int64
apiKey string
systemMessage string
}
type AnthropicOption func(*anthropicProvider)
func WithAnthropicSystemMessage(message string) AnthropicOption {
return func(a *anthropicProvider) {
a.systemMessage = message
}
}
func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
return func(a *anthropicProvider) {
a.maxTokens = maxTokens
}
}
func WithAnthropicModel(model models.Model) AnthropicOption {
return func(a *anthropicProvider) {
a.model = model
}
}
func WithAnthropicKey(apiKey string) AnthropicOption {
return func(a *anthropicProvider) {
a.apiKey = apiKey
}
}
func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
provider := &anthropicProvider{
maxTokens: 1024,
}
for _, opt := range opts {
opt(provider)
}
if provider.systemMessage == "" {
return nil, errors.New("system message is required")
}
provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
return provider, nil
}
func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
anthropicMessages := a.convertToAnthropicMessages(messages)
anthropicTools := a.convertToAnthropicTools(tools)
response, err := a.client.Messages.New(ctx, anthropic.MessageNewParams{
Model: anthropic.Model(a.model.APIModel),
MaxTokens: a.maxTokens,
Temperature: anthropic.Float(0),
Messages: anthropicMessages,
Tools: anthropicTools,
System: []anthropic.TextBlockParam{
{
Text: a.systemMessage,
CacheControl: anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
},
},
},
})
if err != nil {
return nil, err
}
content := ""
for _, block := range response.Content {
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
content += text.Text
}
}
toolCalls := a.extractToolCalls(response.Content)
tokenUsage := a.extractTokenUsage(response.Usage)
return &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: tokenUsage,
}, nil
}
func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
anthropicMessages := a.convertToAnthropicMessages(messages)
anthropicTools := a.convertToAnthropicTools(tools)
var thinkingParam anthropic.ThinkingConfigParamUnion
lastMessage := messages[len(messages)-1]
temperature := anthropic.Float(0)
if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content), "think") {
thinkingParam = anthropic.ThinkingConfigParamUnion{
OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
BudgetTokens: int64(float64(a.maxTokens) * 0.8),
Type: "enabled",
},
}
temperature = anthropic.Float(1)
}
stream := a.client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
Model: anthropic.Model(a.model.APIModel),
MaxTokens: a.maxTokens,
Temperature: temperature,
Messages: anthropicMessages,
Tools: anthropicTools,
Thinking: thinkingParam,
System: []anthropic.TextBlockParam{
{
Text: a.systemMessage,
CacheControl: anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
},
},
},
})
eventChan := make(chan ProviderEvent)
go func() {
defer close(eventChan)
accumulatedMessage := anthropic.Message{}
for stream.Next() {
event := stream.Current()
err := accumulatedMessage.Accumulate(event)
if err != nil {
eventChan <- ProviderEvent{Type: EventError, Error: err}
return
}
switch event := event.AsAny().(type) {
case anthropic.ContentBlockStartEvent:
eventChan <- ProviderEvent{Type: EventContentStart}
case anthropic.ContentBlockDeltaEvent:
if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
eventChan <- ProviderEvent{
Type: EventThinkingDelta,
Thinking: event.Delta.Thinking,
}
} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
Content: event.Delta.Text,
}
}
case anthropic.ContentBlockStopEvent:
eventChan <- ProviderEvent{Type: EventContentStop}
case anthropic.MessageStopEvent:
content := ""
for _, block := range accumulatedMessage.Content {
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
content += text.Text
}
}
toolCalls := a.extractToolCalls(accumulatedMessage.Content)
tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: tokenUsage,
},
}
}
}
if stream.Err() != nil {
eventChan <- ProviderEvent{Type: EventError, Error: stream.Err()}
}
}()
return eventChan, nil
}
func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
var toolCalls []message.ToolCall
for _, block := range content {
switch variant := block.AsAny().(type) {
case anthropic.ToolUseBlock:
toolCall := message.ToolCall{
ID: variant.ID,
Name: variant.Name,
Input: string(variant.Input),
Type: string(variant.Type),
}
toolCalls = append(toolCalls, toolCall)
}
}
return toolCalls
}
func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
return TokenUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheCreationTokens: usage.CacheCreationInputTokens,
CacheReadTokens: usage.CacheReadInputTokens,
}
}
func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
for i, tool := range tools {
info := tool.Info()
toolParam := anthropic.ToolParam{
Name: info.Name,
Description: anthropic.String(info.Description),
InputSchema: anthropic.ToolInputSchemaParam{
Properties: info.Parameters,
},
}
if i == len(tools)-1 {
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
}
anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
}
return anthropicTools
}
func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
anthropicMessages := make([]anthropic.MessageParam, len(messages))
cachedBlocks := 0
for i, msg := range messages {
switch msg.Role {
case message.User:
content := anthropic.NewTextBlock(msg.Content)
if cachedBlocks < 2 {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
cachedBlocks++
}
anthropicMessages[i] = anthropic.NewUserMessage(content)
case message.Assistant:
blocks := []anthropic.ContentBlockParamUnion{}
if msg.Content != "" {
content := anthropic.NewTextBlock(msg.Content)
if cachedBlocks < 2 {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
cachedBlocks++
}
blocks = append(blocks, content)
}
for _, toolCall := range msg.ToolCalls {
var inputMap map[string]any
err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
if err != nil {
continue
}
blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
}
anthropicMessages[i] = anthropic.NewAssistantMessage(blocks...)
case message.Tool:
results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults))
for i, toolResult := range msg.ToolResults {
results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
}
anthropicMessages[i] = anthropic.NewUserMessage(results...)
}
}
return anthropicMessages
}

View File

@@ -0,0 +1,443 @@
package provider
import (
"context"
"encoding/json"
"errors"
"log"
"github.com/google/generative-ai-go/genai"
"github.com/google/uuid"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
"google.golang.org/api/googleapi"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
type geminiProvider struct {
client *genai.Client
model models.Model
maxTokens int32
apiKey string
systemMessage string
}
type GeminiOption func(*geminiProvider)
func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
provider := &geminiProvider{
maxTokens: 5000,
}
for _, opt := range opts {
opt(provider)
}
if provider.systemMessage == "" {
return nil, errors.New("system message is required")
}
client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
if err != nil {
return nil, err
}
provider.client = client
return provider, nil
}
func WithGeminiSystemMessage(message string) GeminiOption {
return func(p *geminiProvider) {
p.systemMessage = message
}
}
func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
return func(p *geminiProvider) {
p.maxTokens = maxTokens
}
}
func WithGeminiModel(model models.Model) GeminiOption {
return func(p *geminiProvider) {
p.model = model
}
}
func WithGeminiKey(apiKey string) GeminiOption {
return func(p *geminiProvider) {
p.apiKey = apiKey
}
}
func (p *geminiProvider) Close() {
if p.client != nil {
p.client.Close()
}
}
// convertToGeminiHistory converts the message history to Gemini's format
func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
var history []*genai.Content
for _, msg := range messages {
switch msg.Role {
case message.User:
history = append(history, &genai.Content{
Parts: []genai.Part{genai.Text(msg.Content)},
Role: "user",
})
case message.Assistant:
content := &genai.Content{
Role: "model",
Parts: []genai.Part{},
}
// Handle regular content
if msg.Content != "" {
content.Parts = append(content.Parts, genai.Text(msg.Content))
}
// Handle tool calls if any
if len(msg.ToolCalls) > 0 {
for _, call := range msg.ToolCalls {
args, _ := parseJsonToMap(call.Input)
content.Parts = append(content.Parts, genai.FunctionCall{
Name: call.Name,
Args: args,
})
}
}
history = append(history, content)
case message.Tool:
for _, result := range msg.ToolResults {
// Parse response content to map if possible
response := map[string]interface{}{"result": result.Content}
parsed, err := parseJsonToMap(result.Content)
if err == nil {
response = parsed
}
var toolCall message.ToolCall
for _, msg := range messages {
if msg.Role == message.Assistant {
for _, call := range msg.ToolCalls {
if call.ID == result.ToolCallID {
toolCall = call
break
}
}
}
}
history = append(history, &genai.Content{
Parts: []genai.Part{genai.FunctionResponse{
Name: toolCall.Name,
Response: response,
}},
Role: "function",
})
}
}
}
return history
}
// convertToolsToGeminiFunctionDeclarations converts tool definitions to Gemini's function declarations
func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
declarations := make([]*genai.FunctionDeclaration, len(tools))
for i, tool := range tools {
info := tool.Info()
// Convert parameters to genai.Schema format
properties := make(map[string]*genai.Schema)
for name, param := range info.Parameters {
// Try to extract type and description from the parameter
paramMap, ok := param.(map[string]interface{})
if !ok {
// Default to string if unable to determine type
properties[name] = &genai.Schema{Type: genai.TypeString}
continue
}
schemaType := genai.TypeString // Default
var description string
var itemsTypeSchema *genai.Schema
if typeVal, found := paramMap["type"]; found {
if typeStr, ok := typeVal.(string); ok {
switch typeStr {
case "string":
schemaType = genai.TypeString
case "number":
schemaType = genai.TypeNumber
case "integer":
schemaType = genai.TypeInteger
case "boolean":
schemaType = genai.TypeBoolean
case "array":
schemaType = genai.TypeArray
items, found := paramMap["items"]
if found {
itemsMap, ok := items.(map[string]interface{})
if ok {
itemsType, found := itemsMap["type"]
if found {
itemsTypeStr, ok := itemsType.(string)
if ok {
switch itemsTypeStr {
case "string":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeString,
}
case "number":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeNumber,
}
case "integer":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeInteger,
}
case "boolean":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeBoolean,
}
}
}
}
}
}
case "object":
schemaType = genai.TypeObject
if _, found := paramMap["properties"]; !found {
continue
}
// TODO: Add support for other types
}
}
}
if desc, found := paramMap["description"]; found {
if descStr, ok := desc.(string); ok {
description = descStr
}
}
properties[name] = &genai.Schema{
Type: schemaType,
Description: description,
Items: itemsTypeSchema,
}
}
declarations[i] = &genai.FunctionDeclaration{
Name: info.Name,
Description: info.Description,
Parameters: &genai.Schema{
Type: genai.TypeObject,
Properties: properties,
Required: info.Required,
},
}
}
return declarations
}
// extractTokenUsage extracts token usage information from Gemini's response
func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
if resp == nil || resp.UsageMetadata == nil {
return TokenUsage{}
}
return TokenUsage{
InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
CacheCreationTokens: 0, // Not directly provided by Gemini
CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
}
}
// SendMessages sends a batch of messages to Gemini and returns the response
func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
// Create a generative model
model := p.client.GenerativeModel(p.model.APIModel)
model.SetMaxOutputTokens(p.maxTokens)
// Set system instruction
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
// Set up tools if provided
if len(tools) > 0 {
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}}
}
// Create chat session and set history
chat := model.StartChat()
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
// Get the most recent user message
var lastUserMsg message.Message
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == message.User {
lastUserMsg = messages[i]
break
}
}
// Send the message
resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content))
if err != nil {
return nil, err
}
// Process the response
var content string
var toolCalls []message.ToolCall
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch p := part.(type) {
case genai.Text:
content = string(p)
case genai.FunctionCall:
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
Name: p.Name,
Input: string(args),
Type: "function",
})
}
}
}
// Extract token usage
tokenUsage := p.extractTokenUsage(resp)
return &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: tokenUsage,
}, nil
}
// StreamResponse streams the response from Gemini
func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
// Create a generative model
model := p.client.GenerativeModel(p.model.APIModel)
model.SetMaxOutputTokens(p.maxTokens)
// Set system instruction
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
// Set up tools if provided
if len(tools) > 0 {
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
for _, declaration := range declarations {
model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
}
}
// Create chat session and set history
chat := model.StartChat()
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
lastUserMsg := messages[len(messages)-1]
// Start streaming
iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content))
eventChan := make(chan ProviderEvent)
go func() {
defer close(eventChan)
var finalResp *genai.GenerateContentResponse
currentContent := ""
toolCalls := []message.ToolCall{}
for {
resp, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
var apiErr *googleapi.Error
if errors.As(err, &apiErr) {
log.Printf("%s", apiErr.Body)
}
eventChan <- ProviderEvent{
Type: EventError,
Error: err,
}
return
}
finalResp = resp
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch p := part.(type) {
case genai.Text:
newText := string(p)
eventChan <- ProviderEvent{
Type: EventContentDelta,
Content: newText,
}
currentContent += newText
case genai.FunctionCall:
// For function calls, we assume they come complete, not streamed in parts
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
newCall := message.ToolCall{
ID: id,
Name: p.Name,
Input: string(args),
Type: "function",
}
// Check if this is a new tool call
isNew := true
for _, existing := range toolCalls {
if existing.Name == newCall.Name && existing.Input == newCall.Input {
isNew = false
break
}
}
if isNew {
toolCalls = append(toolCalls, newCall)
}
}
}
}
}
// Extract token usage from the final response
tokenUsage := p.extractTokenUsage(finalResp)
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: currentContent,
ToolCalls: toolCalls,
Usage: tokenUsage,
},
}
}()
return eventChan, nil
}
// Helper function to parse JSON string into map
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
var result map[string]interface{}
err := json.Unmarshal([]byte(jsonStr), &result)
return result, err
}

View File

@@ -0,0 +1,278 @@
package provider
import (
"context"
"errors"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
)
type openaiProvider struct {
client openai.Client
model models.Model
maxTokens int64
baseURL string
apiKey string
systemMessage string
}
type OpenAIOption func(*openaiProvider)
func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) {
provider := &openaiProvider{
maxTokens: 5000,
}
for _, opt := range opts {
opt(provider)
}
clientOpts := []option.RequestOption{
option.WithAPIKey(provider.apiKey),
}
if provider.baseURL != "" {
clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL))
}
provider.client = openai.NewClient(clientOpts...)
if provider.systemMessage == "" {
return nil, errors.New("system message is required")
}
return provider, nil
}
func WithOpenAISystemMessage(message string) OpenAIOption {
return func(p *openaiProvider) {
p.systemMessage = message
}
}
func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption {
return func(p *openaiProvider) {
p.maxTokens = maxTokens
}
}
func WithOpenAIModel(model models.Model) OpenAIOption {
return func(p *openaiProvider) {
p.model = model
}
}
func WithOpenAIBaseURL(baseURL string) OpenAIOption {
return func(p *openaiProvider) {
p.baseURL = baseURL
}
}
func WithOpenAIKey(apiKey string) OpenAIOption {
return func(p *openaiProvider) {
p.apiKey = apiKey
}
}
func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion {
var chatMessages []openai.ChatCompletionMessageParamUnion
chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage))
for _, msg := range messages {
switch msg.Role {
case message.User:
chatMessages = append(chatMessages, openai.UserMessage(msg.Content))
case message.Assistant:
assistantMsg := openai.ChatCompletionAssistantMessageParam{
Role: "assistant",
}
if msg.Content != "" {
assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
OfString: openai.String(msg.Content),
}
}
if len(msg.ToolCalls) > 0 {
assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls))
for i, call := range msg.ToolCalls {
assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
ID: call.ID,
Type: "function",
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: call.Name,
Arguments: call.Input,
},
}
}
}
chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{
OfAssistant: &assistantMsg,
})
case message.Tool:
for _, result := range msg.ToolResults {
chatMessages = append(chatMessages,
openai.ToolMessage(result.Content, result.ToolCallID),
)
}
}
}
return chatMessages
}
func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
for i, tool := range tools {
info := tool.Info()
openaiTools[i] = openai.ChatCompletionToolParam{
Function: openai.FunctionDefinitionParam{
Name: info.Name,
Description: openai.String(info.Description),
Parameters: openai.FunctionParameters{
"type": "object",
"properties": info.Parameters,
"required": info.Required,
},
},
}
}
return openaiTools
}
func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage {
cachedTokens := int64(0)
cachedTokens = usage.PromptTokensDetails.CachedTokens
inputTokens := usage.PromptTokens - cachedTokens
return TokenUsage{
InputTokens: inputTokens,
OutputTokens: usage.CompletionTokens,
CacheCreationTokens: 0, // OpenAI doesn't provide this directly
CacheReadTokens: cachedTokens,
}
}
func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
chatMessages := p.convertToOpenAIMessages(messages)
openaiTools := p.convertToOpenAITools(tools)
params := openai.ChatCompletionNewParams{
Model: openai.ChatModel(p.model.APIModel),
Messages: chatMessages,
MaxTokens: openai.Int(p.maxTokens),
Tools: openaiTools,
}
response, err := p.client.Chat.Completions.New(ctx, params)
if err != nil {
return nil, err
}
content := ""
if response.Choices[0].Message.Content != "" {
content = response.Choices[0].Message.Content
}
var toolCalls []message.ToolCall
if len(response.Choices[0].Message.ToolCalls) > 0 {
toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls))
for i, call := range response.Choices[0].Message.ToolCalls {
toolCalls[i] = message.ToolCall{
ID: call.ID,
Name: call.Function.Name,
Input: call.Function.Arguments,
Type: "function",
}
}
}
tokenUsage := p.extractTokenUsage(response.Usage)
return &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: tokenUsage,
}, nil
}
func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
chatMessages := p.convertToOpenAIMessages(messages)
openaiTools := p.convertToOpenAITools(tools)
params := openai.ChatCompletionNewParams{
Model: openai.ChatModel(p.model.APIModel),
Messages: chatMessages,
MaxTokens: openai.Int(p.maxTokens),
Tools: openaiTools,
StreamOptions: openai.ChatCompletionStreamOptionsParam{
IncludeUsage: openai.Bool(true),
},
}
stream := p.client.Chat.Completions.NewStreaming(ctx, params)
eventChan := make(chan ProviderEvent)
toolCalls := make([]message.ToolCall, 0)
go func() {
defer close(eventChan)
acc := openai.ChatCompletionAccumulator{}
currentContent := ""
for stream.Next() {
chunk := stream.Current()
acc.AddChunk(chunk)
if tool, ok := acc.JustFinishedToolCall(); ok {
toolCalls = append(toolCalls, message.ToolCall{
ID: tool.Id,
Name: tool.Name,
Input: tool.Arguments,
Type: "function",
})
}
for _, choice := range chunk.Choices {
if choice.Delta.Content != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
Content: choice.Delta.Content,
}
currentContent += choice.Delta.Content
}
}
}
if err := stream.Err(); err != nil {
eventChan <- ProviderEvent{
Type: EventError,
Error: err,
}
return
}
tokenUsage := p.extractTokenUsage(acc.Usage)
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: currentContent,
ToolCalls: toolCalls,
Usage: tokenUsage,
},
}
}()
return eventChan, nil
}

View File

@@ -0,0 +1,48 @@
package provider
import (
"context"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
// EventType represents the type of streaming event
type EventType string
const (
EventContentStart EventType = "content_start"
EventContentDelta EventType = "content_delta"
EventThinkingDelta EventType = "thinking_delta"
EventContentStop EventType = "content_stop"
EventComplete EventType = "complete"
EventError EventType = "error"
)
type TokenUsage struct {
InputTokens int64
OutputTokens int64
CacheCreationTokens int64
CacheReadTokens int64
}
type ProviderResponse struct {
Content string
ToolCalls []message.ToolCall
Usage TokenUsage
}
type ProviderEvent struct {
Type EventType
Content string
Thinking string
ToolCall *message.ToolCall
Error error
Response *ProviderResponse
}
type Provider interface {
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
}

View File

@@ -1,141 +0,0 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"os"
"runtime"
"time"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/flow/agent/react"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/spf13/viper"
)
type agentTool struct {
workingDir string
}
const (
AgentToolName = "agent"
)
type AgentParams struct {
Prompt string `json:"prompt"`
}
func taskAgentTools() []tool.BaseTool {
wd := viper.GetString("wd")
return []tool.BaseTool{
NewBashTool(wd),
NewLsTool(wd),
NewGlobTool(wd),
NewViewTool(wd),
NewWriteTool(wd),
NewEditTool(wd),
}
}
func NewTaskAgent(ctx context.Context) (*react.Agent, error) {
model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.big")))
if err != nil {
return nil, err
}
reactAgent, err := react.NewAgent(ctx, &react.AgentConfig{
Model: model,
ToolsConfig: compose.ToolsNodeConfig{
Tools: taskAgentTools(),
},
MaxStep: 1000,
})
if err != nil {
return nil, err
}
return reactAgent, nil
}
func TaskAgentSystemPrompt() string {
agentPrompt := `You are an agent for Orbitowl. Given the user's prompt, you should use the tools available to you to answer the user's question.
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2. When relevant, share file names and code snippets relevant to the query
3. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.
Here is useful information about the environment you are running in:
<env>
Working directory: %s
Platform: %s
Today's date: %s
</env>`
cwd, err := os.Getwd()
if err != nil {
cwd = "unknown"
}
platform := runtime.GOOS
switch platform {
case "darwin":
platform = "macos"
case "windows":
platform = "windows"
case "linux":
platform = "linux"
}
return fmt.Sprintf(agentPrompt, cwd, platform, time.Now().Format("1/2/2006"))
}
func (b *agentTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
return &schema.ToolInfo{
Name: AgentToolName,
Desc: "Launch a new agent that has access to the following tools: GlobTool, GrepTool, LS, View, ReadNotebook. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you. For example:\n\n- If you are searching for a keyword like \"config\" or \"logger\", or for questions like \"which file does X?\", the Agent tool is strongly recommended\n- If you want to read a specific file path, use the View or GlobTool tool instead of the Agent tool, to find the match more quickly\n- If you are searching for a specific class definition like \"class Foo\", use the GlobTool tool instead, to find the match more quickly\n\nUsage notes:\n1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses\n2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.\n3. Each agent invocation is stateless. You will not be able to send additional messages to the agent, nor will the agent be able to communicate with you outside of its final report. Therefore, your prompt should contain a highly detailed task description for the agent to perform autonomously and you should specify exactly what information the agent should return back to you in its final and only message to you.\n4. The agent's outputs should generally be trusted\n5. IMPORTANT: The agent can not use Bash, Replace, Edit, NotebookEditCell, so can not modify files. If you want to use these tools, use them directly instead of going through the agent.",
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"prompt": {
Type: "string",
Desc: "The task for the agent to perform",
Required: true,
},
}),
}, nil
}
func (b *agentTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
var params AgentParams
if err := json.Unmarshal([]byte(args), &params); err != nil {
return "", err
}
if params.Prompt == "" {
return "prompt is required", nil
}
a, err := NewTaskAgent(ctx)
if err != nil {
return "", err
}
out, err := a.Generate(
ctx,
[]*schema.Message{
schema.SystemMessage(TaskAgentSystemPrompt()),
schema.UserMessage(params.Prompt),
},
)
if err != nil {
return "", err
}
return out.Content, nil
}
func NewAgentTool(wd string) tool.InvokableTool {
return &agentTool{
workingDir: wd,
}
}

View File

@@ -6,20 +6,17 @@ import (
"fmt"
"strings"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/tools/shell"
"github.com/kujtimiihoxha/termai/internal/permission"
)
type bashTool struct {
workingDir string
}
type bashTool struct{}
const (
BashToolName = "bash"
DefaultTimeout = 30 * 60 * 1000 // 30 minutes in milliseconds
DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
MaxOutputLength = 30000
)
@@ -29,6 +26,11 @@ type BashParams struct {
Timeout int `json:"timeout"`
}
type BashPermissionsParams struct {
Command string `json:"command"`
Timeout int `json:"timeout"`
}
var BannedCommands = []string{
"alias", "curl", "curlie", "wget", "axel", "aria2c",
"nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
@@ -40,29 +42,29 @@ var SafeReadOnlyCommands = []string{
"whatis", //...
}
func (b *bashTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
return &schema.ToolInfo{
Name: BashToolName,
Desc: bashDescription(),
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"command": {
Type: "string",
Desc: "The command to execute",
Required: true,
func (b *bashTool) Info() ToolInfo {
return ToolInfo{
Name: BashToolName,
Description: bashDescription(),
Parameters: map[string]any{
"command": map[string]any{
"type": "string",
"description": "The command to execute",
},
"timeout": {
Type: "number",
Desc: "Optional timeout in milliseconds (max 600000)",
"timeout": map[string]any{
"type": "number",
"desription": "Optional timeout in milliseconds (max 600000)",
},
}),
}, nil
},
Required: []string{"command"},
}
}
// Handle implements Tool.
func (b *bashTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params BashParams
if err := json.Unmarshal([]byte(args), &params); err != nil {
return "", err
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse("invalid parameters"), nil
}
if params.Timeout > MaxTimeout {
@@ -72,13 +74,13 @@ func (b *bashTool) InvokableRun(ctx context.Context, args string, opts ...tool.O
}
if params.Command == "" {
return "missing command", nil
return NewTextErrorResponse("missing command"), nil
}
baseCmd := strings.Fields(params.Command)[0]
for _, banned := range BannedCommands {
if strings.EqualFold(baseCmd, banned) {
return fmt.Sprintf("command '%s' is not allowed", baseCmd), nil
return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
}
}
isSafeReadOnly := false
@@ -91,39 +93,21 @@ func (b *bashTool) InvokableRun(ctx context.Context, args string, opts ...tool.O
if !isSafeReadOnly {
p := permission.Default.Request(
permission.CreatePermissionRequest{
Path: b.workingDir,
Path: config.WorkingDirectory(),
ToolName: BashToolName,
Action: "execute",
Description: fmt.Sprintf("Execute command: %s", params.Command),
Params: map[string]interface{}{
"command": params.Command,
"timeout": params.Timeout,
},
Params: BashPermissionsParams(params),
},
)
if !p {
return "", fmt.Errorf("permission denied for command: %s", params.Command)
return NewTextErrorResponse("permission denied"), nil
}
}
// p := b.permission.Request(permission.CreatePermissionRequest{
// Path: b.workingDir,
// ToolName: BashToolName,
// Action: "execute",
// Description: fmt.Sprintf("Execute command: %s", params.Command),
// Params: map[string]any{
// "command": params.Command,
// "timeout": params.Timeout,
// },
// })
// if !p {
// return "", errors.New("permission denied")
// }
shell := shell.GetPersistentShell(b.workingDir)
shell := shell.GetPersistentShell(config.WorkingDirectory())
stdout, stderr, exitCode, interrupted, err := shell.Exec(ctx, params.Command, params.Timeout)
if err != nil {
return "", err
return NewTextErrorResponse(fmt.Sprintf("error executing command: %s", err)), nil
}
stdout = truncateOutput(stdout)
@@ -153,9 +137,9 @@ func (b *bashTool) InvokableRun(ctx context.Context, args string, opts ...tool.O
}
if stdout == "" {
return "no output", nil
return NewTextResponse("no output"), nil
}
return stdout, nil
return NewTextResponse(stdout), nil
}
func truncateOutput(content string) string {
@@ -327,8 +311,6 @@ Important:
- Never update git config`, bannedCommandsStr, MaxOutputLength)
}
func NewBashTool(workingDir string) tool.InvokableTool {
return &bashTool{
workingDir: workingDir,
}
func NewBashTool() BaseTool {
return &bashTool{}
}

View File

@@ -0,0 +1,389 @@
package tools
import (
"context"
"encoding/json"
"os"
"strings"
"testing"
"time"
"github.com/kujtimiihoxha/termai/internal/permission"
"github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBashTool_Info(t *testing.T) {
tool := NewBashTool()
info := tool.Info()
assert.Equal(t, BashToolName, info.Name)
assert.NotEmpty(t, info.Description)
assert.Contains(t, info.Parameters, "command")
assert.Contains(t, info.Parameters, "timeout")
assert.Contains(t, info.Required, "command")
}
func TestBashTool_Run(t *testing.T) {
// Setup a mock permission handler that always allows
origPermission := permission.Default
defer func() {
permission.Default = origPermission
}()
permission.Default = newMockPermissionService(true)
// Save original working directory
origWd, err := os.Getwd()
require.NoError(t, err)
defer func() {
os.Chdir(origWd)
}()
t.Run("executes command successfully", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewBashTool()
params := BashParams{
Command: "echo 'Hello World'",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: BashToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Equal(t, "Hello World\n", response.Content)
})
t.Run("handles invalid parameters", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewBashTool()
call := ToolCall{
Name: BashToolName,
Input: "invalid json",
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "invalid parameters")
})
t.Run("handles missing command", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewBashTool()
params := BashParams{
Command: "",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: BashToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "missing command")
})
t.Run("handles banned commands", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewBashTool()
for _, bannedCmd := range BannedCommands {
params := BashParams{
Command: bannedCmd + " arg1 arg2",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: BashToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
}
})
t.Run("handles safe read-only commands without permission check", func(t *testing.T) {
permission.Default = newMockPermissionService(false)
tool := NewBashTool()
// Test with a safe read-only command
params := BashParams{
Command: "echo 'test'",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: BashToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Equal(t, "test\n", response.Content)
})
t.Run("handles permission denied", func(t *testing.T) {
permission.Default = newMockPermissionService(false)
tool := NewBashTool()
// Test with a command that requires permission
params := BashParams{
Command: "mkdir test_dir",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: BashToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "permission denied")
})
t.Run("handles command timeout", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewBashTool()
params := BashParams{
Command: "sleep 2",
Timeout: 100, // 100ms timeout
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: BashToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "aborted")
})
t.Run("handles command with stderr output", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewBashTool()
params := BashParams{
Command: "echo 'error message' >&2",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: BashToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "error message")
})
t.Run("handles command with both stdout and stderr", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewBashTool()
params := BashParams{
Command: "echo 'stdout message' && echo 'stderr message' >&2",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: BashToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "stdout message")
assert.Contains(t, response.Content, "stderr message")
})
t.Run("handles context cancellation", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewBashTool()
params := BashParams{
Command: "sleep 5",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: BashToolName,
Input: string(paramsJSON),
}
ctx, cancel := context.WithCancel(context.Background())
// Cancel the context after a short delay
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
response, err := tool.Run(ctx, call)
require.NoError(t, err)
assert.Contains(t, response.Content, "aborted")
})
t.Run("respects max timeout", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewBashTool()
params := BashParams{
Command: "echo 'test'",
Timeout: MaxTimeout + 1000, // Exceeds max timeout
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: BashToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Equal(t, "test\n", response.Content)
})
t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewBashTool()
params := BashParams{
Command: "echo 'test'",
Timeout: -100, // Negative timeout
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: BashToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Equal(t, "test\n", response.Content)
})
}
func TestTruncateOutput(t *testing.T) {
t.Run("does not truncate short output", func(t *testing.T) {
output := "short output"
result := truncateOutput(output)
assert.Equal(t, output, result)
})
t.Run("truncates long output", func(t *testing.T) {
// Create a string longer than MaxOutputLength
longOutput := strings.Repeat("a\n", MaxOutputLength)
result := truncateOutput(longOutput)
// Check that the result is shorter than the original
assert.Less(t, len(result), len(longOutput))
// Check that the truncation message is included
assert.Contains(t, result, "lines truncated")
// Check that we have the beginning and end of the original string
assert.True(t, strings.HasPrefix(result, "a\n"))
assert.True(t, strings.HasSuffix(result, "a\n"))
})
}
func TestCountLines(t *testing.T) {
testCases := []struct {
name string
input string
expected int
}{
{
name: "empty string",
input: "",
expected: 0,
},
{
name: "single line",
input: "line1",
expected: 1,
},
{
name: "multiple lines",
input: "line1\nline2\nline3",
expected: 3,
},
{
name: "trailing newline",
input: "line1\nline2\n",
expected: 3, // Empty string after last newline counts as a line
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := countLines(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// Mock permission service for testing
type mockPermissionService struct {
*pubsub.Broker[permission.PermissionRequest]
allow bool
}
func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
// Not needed for tests
}
func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
// Not needed for tests
}
func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
// Not needed for tests
}
func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
return m.allow
}
func newMockPermissionService(allow bool) permission.Service {
return &mockPermissionService{
Broker: pubsub.NewBroker[permission.PermissionRequest](),
allow: allow,
}
}

View File

@@ -3,22 +3,18 @@ package tools
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/permission"
"github.com/sergi/go-diff/diffmatchpatch"
)
type editTool struct {
workingDir string
}
type editTool struct{}
const (
EditToolName = "edit"
@@ -30,100 +26,72 @@ type EditParams struct {
NewString string `json:"new_string"`
}
func (b *editTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
return &schema.ToolInfo{
Name: EditToolName,
Desc: `This is a tool for editing files. For moving or renaming files, you should generally use the Bash tool with the 'mv' command instead. For larger edits, use the Write tool to overwrite files. F.
Before using this tool:
1. Use the View tool to understand the file's contents and context
2. Verify the directory path is correct (only applicable when creating new files):
- Use the LS tool to verify the parent directory exists and is the correct location
To make a file edit, provide the following:
1. file_path: The absolute path to the file to modify (must be absolute, not relative)
2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation)
3. new_string: The edited text to replace the old_string
The tool will replace ONE occurrence of old_string with new_string in the specified file.
CRITICAL REQUIREMENTS FOR USING THIS TOOL:
1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means:
- Include AT LEAST 3-5 lines of context BEFORE the change point
- Include AT LEAST 3-5 lines of context AFTER the change point
- Include all whitespace, indentation, and surrounding code exactly as it appears in the file
2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances:
- Make separate calls to this tool for each instance
- Each call must uniquely identify its specific instance using extensive context
3. VERIFICATION: Before using this tool:
- Check how many instances of the target text exist in the file
- If multiple instances exist, gather enough context to uniquely identify each one
- Plan separate tool calls for each instance
WARNING: If you do not follow these requirements:
- The tool will fail if old_string matches multiple locations
- The tool will fail if old_string doesn't match exactly (including whitespace)
- You may change the wrong instance if you don't include enough context
When making edits:
- Ensure the edit results in idiomatic, correct code
- Do not leave the code in a broken state
- Always use absolute file paths (starting with /)
If you want to create a new file, use:
- A new file path, including dir name if needed
- An empty old_string
- The new file's contents as new_string
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`,
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"file_path": {
Type: "string",
Desc: "The absolute path to the file to modify",
Required: true,
},
"old_string": {
Type: "string",
Desc: "The text to replace",
Required: true,
},
"new_string": {
Type: "string",
Desc: "The text to replace it with",
Required: true,
},
}),
}, nil
type EditPermissionsParams struct {
FilePath string `json:"file_path"`
OldString string `json:"old_string"`
NewString string `json:"new_string"`
Diff string `json:"diff"`
}
func (b *editTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
func (e *editTool) Info() ToolInfo {
return ToolInfo{
Name: EditToolName,
Description: editDescription(),
Parameters: map[string]any{
"file_path": map[string]any{
"type": "string",
"description": "The absolute path to the file to modify",
},
"old_string": map[string]any{
"type": "string",
"description": "The text to replace",
},
"new_string": map[string]any{
"type": "string",
"description": "The text to replace it with",
},
},
Required: []string{"file_path", "old_string", "new_string"},
}
}
// Run implements Tool.
func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params EditParams
if err := json.Unmarshal([]byte(args), &params); err != nil {
return "", err
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse("invalid parameters"), nil
}
if params.FilePath == "" {
return "", errors.New("file_path is required")
return NewTextErrorResponse("file_path is required"), nil
}
if !filepath.IsAbs(params.FilePath) {
return "", fmt.Errorf("file path must be absolute, got: %s", params.FilePath)
wd := config.WorkingDirectory()
params.FilePath = filepath.Join(wd, params.FilePath)
}
if params.OldString == "" {
return createNewFile(params.FilePath, params.NewString)
result, err := createNewFile(params.FilePath, params.NewString)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("error creating file: %s", err)), nil
}
return NewTextErrorResponse(result), nil
}
if params.NewString == "" {
return deleteContent(params.FilePath, params.OldString)
result, err := deleteContent(params.FilePath, params.OldString)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("error deleting content: %s", err)), nil
}
return NewTextErrorResponse(result), nil
}
return replaceContent(params.FilePath, params.OldString, params.NewString)
result, err := replaceContent(params.FilePath, params.OldString, params.NewString)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil
}
return NewTextResponse(result), nil
}
func createNewFile(filePath, content string) (string, error) {
@@ -148,9 +116,11 @@ func createNewFile(filePath, content string) (string, error) {
ToolName: EditToolName,
Action: "create",
Description: fmt.Sprintf("Create file %s", filePath),
Params: map[string]interface{}{
"file_path": filePath,
"content": content,
Params: EditPermissionsParams{
FilePath: filePath,
OldString: "",
NewString: content,
Diff: GenerateDiff("", content),
},
},
)
@@ -166,19 +136,6 @@ func createNewFile(filePath, content string) (string, error) {
recordFileWrite(filePath)
recordFileRead(filePath)
// result := FileEditResult{
// FilePath: filePath,
// Created: true,
// Updated: false,
// Deleted: false,
// Diff: generateDiff("", content),
// }
//
// resultJSON, err := json.Marshal(result)
// if err != nil {
// return "", fmt.Errorf("failed to serialize result: %w", err)
// }
//
return "File created: " + filePath, nil
}
@@ -231,9 +188,11 @@ func deleteContent(filePath, oldString string) (string, error) {
ToolName: EditToolName,
Action: "delete",
Description: fmt.Sprintf("Delete content from file %s", filePath),
Params: map[string]interface{}{
"file_path": filePath,
"content": content,
Params: EditPermissionsParams{
FilePath: filePath,
OldString: oldString,
NewString: "",
Diff: GenerateDiff(oldContent, newContent),
},
},
)
@@ -247,21 +206,7 @@ func deleteContent(filePath, oldString string) (string, error) {
}
recordFileWrite(filePath)
// result := FileEditResult{
// FilePath: filePath,
// Created: false,
// Updated: true,
// Deleted: true,
// Diff: generateDiff(oldContent, newContent),
// SnippetBefore: getContextSnippet(oldContent, index, len(oldString)),
// SnippetAfter: getContextSnippet(newContent, index, 0),
// }
//
// resultJSON, err := json.Marshal(result)
// if err != nil {
// return "", fmt.Errorf("failed to serialize result: %w", err)
// }
recordFileRead(filePath)
return "Content deleted from file: " + filePath, nil
}
@@ -270,44 +215,45 @@ func replaceContent(filePath, oldString, newString string) (string, error) {
fileInfo, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
return fmt.Sprintf("file not found: %s", filePath), nil
return "", fmt.Errorf("file not found: %s", filePath)
}
return fmt.Sprintf("failed to access file: %s", err), nil
return "", fmt.Errorf("failed to access file: %w", err)
}
if fileInfo.IsDir() {
return fmt.Sprintf("path is a directory, not a file: %s", filePath), nil
return "", fmt.Errorf("path is a directory, not a file: %s", filePath)
}
if getLastReadTime(filePath).IsZero() {
return "you must read the file before editing it. Use the View tool first", nil
return "", fmt.Errorf("you must read the file before editing it. Use the View tool first")
}
modTime := fileInfo.ModTime()
lastRead := getLastReadTime(filePath)
if modTime.After(lastRead) {
return fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)), nil
return "", fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))
}
content, err := os.ReadFile(filePath)
if err != nil {
return fmt.Sprintf("failed to read file: %s", err), nil
return "", fmt.Errorf("failed to read file: %w", err)
}
oldContent := string(content)
index := strings.Index(oldContent, oldString)
if index == -1 {
return "old_string not found in file. Make sure it matches exactly, including whitespace and line breaks", nil
return "", fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks")
}
lastIndex := strings.LastIndex(oldContent, oldString)
if index != lastIndex {
return "old_string appears multiple times in the file. Please provide more context to ensure a unique match", nil
return "", fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match")
}
newContent := oldContent[:index] + newString + oldContent[index+len(oldString):]
diff := GenerateDiff(oldString, newContent)
p := permission.Default.Request(
permission.CreatePermissionRequest{
@@ -315,10 +261,11 @@ func replaceContent(filePath, oldString, newString string) (string, error) {
ToolName: EditToolName,
Action: "replace",
Description: fmt.Sprintf("Replace content in file %s", filePath),
Params: map[string]interface{}{
"file_path": filePath,
"old_string": oldString,
"new_string": newString,
Params: EditPermissionsParams{
FilePath: filePath,
OldString: oldString,
NewString: newString,
Diff: diff,
},
},
)
@@ -328,93 +275,97 @@ func replaceContent(filePath, oldString, newString string) (string, error) {
err = os.WriteFile(filePath, []byte(newContent), 0o644)
if err != nil {
return fmt.Sprintf("failed to write file: %s", err), nil
return "", fmt.Errorf("failed to write file: %w", err)
}
recordFileWrite(filePath)
// result := FileEditResult{
// FilePath: filePath,
// Created: false,
// Updated: true,
// Deleted: false,
// Diff: generateDiff(oldContent, newContent),
// SnippetBefore: getContextSnippet(oldContent, index, len(oldString)),
// SnippetAfter: getContextSnippet(newContent, index, len(newString)),
// }
//
// resultJSON, err := json.Marshal(result)
// if err != nil {
// return "", fmt.Errorf("failed to serialize result: %w", err)
// }
recordFileRead(filePath)
return "Content replaced in file: " + filePath, nil
}
func getContextSnippet(content string, position, length int) string {
contextLines := 3
lines := strings.Split(content, "\n")
lineIndex := 0
currentPos := 0
for i, line := range lines {
if currentPos <= position && position < currentPos+len(line)+1 {
lineIndex = i
break
}
currentPos += len(line) + 1 // +1 for the newline
}
startLine := max(0, lineIndex-contextLines)
endLine := min(len(lines), lineIndex+contextLines+1)
var snippetBuilder strings.Builder
for i := startLine; i < endLine; i++ {
if i == lineIndex {
snippetBuilder.WriteString(fmt.Sprintf("> %s\n", lines[i]))
} else {
snippetBuilder.WriteString(fmt.Sprintf(" %s\n", lines[i]))
}
}
return snippetBuilder.String()
}
func generateDiff(oldContent, newContent string) string {
func GenerateDiff(oldContent, newContent string) string {
dmp := diffmatchpatch.New()
fileAdmp, fileBdmp, dmpStrings := dmp.DiffLinesToChars(oldContent, newContent)
diffs := dmp.DiffMain(fileAdmp, fileBdmp, false)
diffs = dmp.DiffCharsToLines(diffs, dmpStrings)
diffs = dmp.DiffCleanupSemantic(diffs)
buff := strings.Builder{}
for _, diff := range diffs {
text := diff.Text
diffs := dmp.DiffMain(oldContent, newContent, false)
patches := dmp.PatchMake(oldContent, diffs)
patchText := dmp.PatchToText(patches)
if patchText == "" && (oldContent != newContent) {
var result strings.Builder
result.WriteString("@@ Diff @@\n")
for _, diff := range diffs {
switch diff.Type {
case diffmatchpatch.DiffInsert:
result.WriteString("+ " + diff.Text + "\n")
case diffmatchpatch.DiffDelete:
result.WriteString("- " + diff.Text + "\n")
case diffmatchpatch.DiffEqual:
if len(diff.Text) > 40 {
result.WriteString(" " + diff.Text[:20] + "..." + diff.Text[len(diff.Text)-20:] + "\n")
} else {
result.WriteString(" " + diff.Text + "\n")
switch diff.Type {
case diffmatchpatch.DiffInsert:
for _, line := range strings.Split(text, "\n") {
_, _ = buff.WriteString("+ " + line + "\n")
}
case diffmatchpatch.DiffDelete:
for _, line := range strings.Split(text, "\n") {
_, _ = buff.WriteString("- " + line + "\n")
}
case diffmatchpatch.DiffEqual:
if len(text) > 40 {
_, _ = buff.WriteString(" " + text[:20] + "..." + text[len(text)-20:] + "\n")
} else {
for _, line := range strings.Split(text, "\n") {
_, _ = buff.WriteString(" " + line + "\n")
}
}
}
return result.String()
}
return patchText
return buff.String()
}
func NewEditTool(workingDir string) tool.InvokableTool {
return &editTool{
workingDir: workingDir,
}
func editDescription() string {
return `Edits files by replacing text, creating new files, or deleting content. For moving or renaming files, use the Bash tool with the 'mv' command instead. For larger file edits, use the FileWrite tool to overwrite files.
Before using this tool:
1. Use the FileRead tool to understand the file's contents and context
2. Verify the directory path is correct (only applicable when creating new files):
- Use the LS tool to verify the parent directory exists and is the correct location
To make a file edit, provide the following:
1. file_path: The absolute path to the file to modify (must be absolute, not relative)
2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation)
3. new_string: The edited text to replace the old_string
Special cases:
- To create a new file: provide file_path and new_string, leave old_string empty
- To delete content: provide file_path and old_string, leave new_string empty
The tool will replace ONE occurrence of old_string with new_string in the specified file.
CRITICAL REQUIREMENTS FOR USING THIS TOOL:
1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means:
- Include AT LEAST 3-5 lines of context BEFORE the change point
- Include AT LEAST 3-5 lines of context AFTER the change point
- Include all whitespace, indentation, and surrounding code exactly as it appears in the file
2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances:
- Make separate calls to this tool for each instance
- Each call must uniquely identify its specific instance using extensive context
3. VERIFICATION: Before using this tool:
- Check how many instances of the target text exist in the file
- If multiple instances exist, gather enough context to uniquely identify each one
- Plan separate tool calls for each instance
WARNING: If you do not follow these requirements:
- The tool will fail if old_string matches multiple locations
- The tool will fail if old_string doesn't match exactly (including whitespace)
- You may change the wrong instance if you don't include enough context
When making edits:
- Ensure the edit results in idiomatic, correct code
- Do not leave the code in a broken state
- Always use absolute file paths (starting with /)
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`
}
func NewEditTool() BaseTool {
return &editTool{}
}

View File

@@ -11,15 +11,11 @@ import (
"strings"
"time"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
"github.com/bmatcuk/doublestar/v4"
"github.com/kujtimiihoxha/termai/internal/config"
)
type globTool struct {
workingDir string
}
type globTool struct{}
const (
GlobToolName = "glob"
@@ -35,43 +31,44 @@ type GlobParams struct {
Path string `json:"path"`
}
func (b *globTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
return &schema.ToolInfo{
Name: GlobToolName,
Desc: `- Fast file pattern matching tool that works with any codebase size
- Supports glob patterns like "**/*.js" or "src/**/*.ts"
- Returns matching file paths sorted by modification time
- Use this tool when you need to find files by name patterns
- When you are doing an open ended search that may require multiple rounds of globbing and grepping, use the Agent tool instead`,
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"pattern": {
Type: "string",
Desc: "The glob pattern to match files against",
Required: true,
func (g *globTool) Info() ToolInfo {
return ToolInfo{
Name: GlobToolName,
Description: globDescription(),
Parameters: map[string]any{
"pattern": map[string]any{
"type": "string",
"description": "The glob pattern to match files against",
},
"path": {
Type: "string",
Desc: "The directory to search in. Defaults to the current working directory.",
"path": map[string]any{
"type": "string",
"description": "The directory to search in. Defaults to the current working directory.",
},
}),
}, nil
},
Required: []string{"pattern"},
}
}
func (b *globTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
// Run implements Tool.
func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params GlobParams
if err := json.Unmarshal([]byte(args), &params); err != nil {
return fmt.Sprintf("error parsing parameters: %s", err), nil
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
if params.Pattern == "" {
return NewTextErrorResponse("pattern is required"), nil
}
// If path is empty, use current working directory
searchPath := params.Path
if searchPath == "" {
searchPath = b.workingDir
searchPath = config.WorkingDirectory()
}
files, truncated, err := globFiles(params.Pattern, searchPath, 100)
if err != nil {
return fmt.Sprintf("error performing glob search: %s", err), nil
return NewTextErrorResponse(fmt.Sprintf("error performing glob search: %s", err)), nil
}
// Format the output for the assistant
@@ -81,11 +78,11 @@ func (b *globTool) InvokableRun(ctx context.Context, args string, opts ...tool.O
} else {
output = strings.Join(files, "\n")
if truncated {
output += "\n(Results are truncated. Consider using a more specific path or pattern.)"
output += "\n\n(Results are truncated. Consider using a more specific path or pattern.)"
}
}
return output, nil
return NewTextResponse(output), nil
}
func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) {
@@ -167,8 +164,43 @@ func skipHidden(path string) bool {
return base != "." && strings.HasPrefix(base, ".")
}
func NewGlobTool(workingDir string) tool.InvokableTool {
return &globTool{
workingDir,
}
func globDescription() string {
return `Fast file pattern matching tool that finds files by name and pattern, returning matching paths sorted by modification time (newest first).
WHEN TO USE THIS TOOL:
- Use when you need to find files by name patterns or extensions
- Great for finding specific file types across a directory structure
- Useful for discovering files that match certain naming conventions
HOW TO USE:
- Provide a glob pattern to match against file paths
- Optionally specify a starting directory (defaults to current working directory)
- Results are sorted with most recently modified files first
GLOB PATTERN SYNTAX:
- '*' matches any sequence of non-separator characters
- '**' matches any sequence of characters, including separators
- '?' matches any single non-separator character
- '[...]' matches any character in the brackets
- '[!...]' matches any character not in the brackets
COMMON PATTERN EXAMPLES:
- '*.js' - Find all JavaScript files in the current directory
- '**/*.js' - Find all JavaScript files in any subdirectory
- 'src/**/*.{ts,tsx}' - Find all TypeScript files in the src directory
- '*.{html,css,js}' - Find all HTML, CSS, and JS files
LIMITATIONS:
- Results are limited to 100 files (newest first)
- Does not search file contents (use Grep tool for that)
- Hidden files (starting with '.') are skipped
TIPS:
- For the most useful results, combine with the Grep tool: first find files with Glob, then search their contents with Grep
- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead
- Always check if results are truncated and refine your search pattern if needed`
}
func NewGlobTool() BaseTool {
return &globTool{}
}

View File

@@ -13,18 +13,13 @@ import (
"strings"
"time"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
"github.com/kujtimiihoxha/termai/internal/config"
)
type grepTool struct {
workingDir string
}
type grepTool struct{}
const (
GrepToolName = "grep"
MaxGrepResults = 100
)
type GrepParams struct {
@@ -38,83 +33,66 @@ type grepMatch struct {
modTime time.Time
}
func (b *grepTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
return &schema.ToolInfo{
Name: GrepToolName,
Desc: `- Fast content search tool that works with any codebase size
- Searches file contents using regular expressions
- Supports full regex syntax (eg. "log.*Error", "function\\s+\\w+", etc.)
- Filter files by pattern with the include parameter (eg. "*.js", "*.{ts,tsx}")
- Returns matching file paths sorted by modification time
- Use this tool when you need to find files containing specific patterns
- When you are doing an open ended search that may require multiple rounds of globbing and grepping, use the Agent tool instead`,
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"command": {
Type: "string",
Desc: "The command to execute",
Required: true,
func (g *grepTool) Info() ToolInfo {
return ToolInfo{
Name: GrepToolName,
Description: grepDescription(),
Parameters: map[string]any{
"pattern": map[string]any{
"type": "string",
"description": "The regex pattern to search for in file contents",
},
"timeout": {
Type: "number",
Desc: "Optional timeout in milliseconds (max 600000)",
"path": map[string]any{
"type": "string",
"description": "The directory to search in. Defaults to the current working directory.",
},
}),
}, nil
"include": map[string]any{
"type": "string",
"description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
},
},
Required: []string{"pattern"},
}
}
func (b *grepTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
// Run implements Tool.
func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params GrepParams
if err := json.Unmarshal([]byte(args), &params); err != nil {
return "", err
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
if params.Pattern == "" {
return NewTextErrorResponse("pattern is required"), nil
}
// If path is empty, use current working directory
searchPath := params.Path
if searchPath == "" {
var err error
searchPath, err = os.Getwd()
if err != nil {
return fmt.Sprintf("unable to get current working directory: %s", err), nil
}
searchPath = config.WorkingDirectory()
}
matches, err := searchWithRipgrep(params.Pattern, searchPath, params.Include)
matches, truncated, err := searchFiles(params.Pattern, searchPath, params.Include, 100)
if err != nil {
matches, err = searchFilesWithRegex(params.Pattern, searchPath, params.Include)
if err != nil {
return fmt.Sprintf("error searching files: %s", err), nil
}
}
sort.Slice(matches, func(i, j int) bool {
return matches[i].modTime.After(matches[j].modTime)
})
truncated := false
if len(matches) > MaxGrepResults {
truncated = true
matches = matches[:MaxGrepResults]
}
filenames := make([]string, len(matches))
for i, m := range matches {
filenames[i] = m.path
return NewTextErrorResponse(fmt.Sprintf("error searching files: %s", err)), nil
}
// Format the output for the assistant
var output string
if len(filenames) == 0 {
if len(matches) == 0 {
output = "No files found"
} else {
output = fmt.Sprintf("Found %d file%s\n%s",
len(filenames),
pluralize(len(filenames)),
strings.Join(filenames, "\n"))
len(matches),
pluralize(len(matches)),
strings.Join(matches, "\n"))
if truncated {
output += "\n(Results are truncated. Consider using a more specific path or pattern.)"
output += "\n\n(Results are truncated. Consider using a more specific path or pattern.)"
}
}
return output, nil
return NewTextResponse(output), nil
}
func pluralize(count int) string {
@@ -124,6 +102,37 @@ func pluralize(count int) string {
return "s"
}
func searchFiles(pattern, rootPath, include string, limit int) ([]string, bool, error) {
// First try using ripgrep if available for better performance
matches, err := searchWithRipgrep(pattern, rootPath, include)
if err != nil {
// Fall back to manual regex search if ripgrep is not available
matches, err = searchFilesWithRegex(pattern, rootPath, include)
if err != nil {
return nil, false, err
}
}
// Sort files by modification time (newest first)
sort.Slice(matches, func(i, j int) bool {
return matches[i].modTime.After(matches[j].modTime)
})
// Check if we need to truncate the results
truncated := len(matches) > limit
if truncated {
matches = matches[:limit]
}
// Extract just the paths
results := make([]string, len(matches))
for i, m := range matches {
results[i] = m.path
}
return results, truncated, nil
}
func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) {
_, err := exec.LookPath("rg")
if err != nil {
@@ -140,6 +149,7 @@ func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) {
output, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
// Exit code 1 means no matches, which isn't an error for our purposes
return []grepMatch{}, nil
}
return nil, err
@@ -155,7 +165,7 @@ func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) {
fileInfo, err := os.Stat(line)
if err != nil {
continue
continue // Skip files we can't access
}
matches = append(matches, grepMatch{
@@ -186,20 +196,27 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error
err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil
return nil // Skip errors
}
if info.IsDir() {
return nil // Skip directories
}
// Skip hidden files
if skipHidden(path) {
return nil
}
// Check include pattern if provided
if includePattern != nil && !includePattern.MatchString(path) {
return nil
}
// Check file contents for the pattern
match, err := fileContainsPattern(path, regex)
if err != nil {
return nil
return nil // Skip files we can't read
}
if match {
@@ -207,6 +224,11 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error
path: path,
modTime: info.ModTime(),
})
// Check if we've hit the limit (collect double for sorting)
if len(matches) >= 200 {
return filepath.SkipAll
}
}
return nil
@@ -232,11 +254,7 @@ func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, error)
}
}
if err := scanner.Err(); err != nil {
return false, err
}
return false, nil
return false, scanner.Err()
}
func globToRegex(glob string) string {
@@ -250,11 +268,46 @@ func globToRegex(glob string) string {
return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
})
return "^" + regexPattern + "$"
return regexPattern
}
func NewGrepTool(workingDir string) tool.InvokableTool {
return &grepTool{
workingDir,
}
func grepDescription() string {
return `Fast content search tool that finds files containing specific text or patterns, returning matching file paths sorted by modification time (newest first).
WHEN TO USE THIS TOOL:
- Use when you need to find files containing specific text or patterns
- Great for searching code bases for function names, variable declarations, or error messages
- Useful for finding all files that use a particular API or pattern
HOW TO USE:
- Provide a regex pattern to search for within file contents
- Optionally specify a starting directory (defaults to current working directory)
- Optionally provide an include pattern to filter which files to search
- Results are sorted with most recently modified files first
REGEX PATTERN SYNTAX:
- Supports standard regular expression syntax
- 'function' searches for the literal text "function"
- 'log\..*Error' finds text starting with "log." and ending with "Error"
- 'import\s+.*\s+from' finds import statements in JavaScript/TypeScript
COMMON INCLUDE PATTERN EXAMPLES:
- '*.js' - Only search JavaScript files
- '*.{ts,tsx}' - Only search TypeScript files
- '*.go' - Only search Go files
LIMITATIONS:
- Results are limited to 100 files (newest first)
- Performance depends on the number of files being searched
- Very large binary files may be skipped
- Hidden files (starting with '.') are skipped
TIPS:
- For faster, more targeted searches, first use Glob to find relevant files, then use Grep
- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead
- Always check if results are truncated and refine your search pattern if needed`
}
func NewGrepTool() BaseTool {
return &grepTool{}
}

View File

@@ -8,19 +8,14 @@ import (
"path/filepath"
"strings"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
"github.com/kujtimiihoxha/termai/internal/config"
)
type lsTool struct {
workingDir string
}
type lsTool struct{}
const (
LSToolName = "ls"
MaxFiles = 1000
TruncatedMessage = "There are more than 1000 files in the repository. Use the LS tool (passing a specific path), Bash tool, and other tools to explore nested directories. The first 1000 files and directories are included below:\n\n"
MaxLSFiles = 1000
)
type LSParams struct {
@@ -28,61 +23,82 @@ type LSParams struct {
Ignore []string `json:"ignore"`
}
func (b *lsTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
return &schema.ToolInfo{
Name: LSToolName,
Desc: "Lists files and directories in a given path. The path parameter must be an absolute path, not a relative path. You can optionally provide an array of glob patterns to ignore with the ignore parameter. You should generally prefer the Glob and Grep tools, if you know which directories to search.",
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"path": {
Type: "string",
Desc: "The absolute path to the directory to list (must be absolute, not relative)",
Required: true,
},
"ignore": {
Type: "array",
ElemInfo: &schema.ParameterInfo{
Type: schema.String,
Desc: "List of glob patterns to ignore",
},
},
}),
}, nil
type TreeNode struct {
Name string `json:"name"`
Path string `json:"path"`
Type string `json:"type"` // "file" or "directory"
Children []*TreeNode `json:"children,omitempty"`
}
func (b *lsTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
func (l *lsTool) Info() ToolInfo {
return ToolInfo{
Name: LSToolName,
Description: lsDescription(),
Parameters: map[string]any{
"path": map[string]any{
"type": "string",
"description": "The path to the directory to list (defaults to current working directory)",
},
"ignore": map[string]any{
"type": "array",
"description": "List of glob patterns to ignore",
"items": map[string]any{
"type": "string",
},
},
},
Required: []string{"path"},
}
}
// Run implements Tool.
func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params LSParams
if err := json.Unmarshal([]byte(args), &params); err != nil {
return "", err
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
if !filepath.IsAbs(params.Path) {
return fmt.Sprintf("path must be absolute, got: %s", params.Path), nil
// If path is empty, use current working directory
searchPath := params.Path
if searchPath == "" {
searchPath = config.WorkingDirectory()
}
files, err := b.listDirectory(params.Path)
// Ensure the path is absolute
if !filepath.IsAbs(searchPath) {
searchPath = filepath.Join(config.WorkingDirectory(), searchPath)
}
// Check if the path exists
if _, err := os.Stat(searchPath); os.IsNotExist(err) {
return NewTextErrorResponse(fmt.Sprintf("path does not exist: %s", searchPath)), nil
}
files, truncated, err := listDirectory(searchPath, params.Ignore, MaxLSFiles)
if err != nil {
return fmt.Sprintf("error listing directory: %s", err), nil
return NewTextErrorResponse(fmt.Sprintf("error listing directory: %s", err)), nil
}
tree := createFileTree(files)
output := printTree(tree, params.Path)
output := printTree(tree, searchPath)
if len(files) >= MaxFiles {
output = TruncatedMessage + output
if truncated {
output = fmt.Sprintf("There are more than %d files in the directory. Use a more specific path or use the Glob tool to find specific files. The first %d files and directories are included below:\n\n%s", MaxLSFiles, MaxLSFiles, output)
}
return output, nil
return NewTextResponse(output), nil
}
func (b *lsTool) listDirectory(initialPath string) ([]string, error) {
func listDirectory(initialPath string, ignorePatterns []string, limit int) ([]string, bool, error) {
var results []string
truncated := false
err := filepath.Walk(initialPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil // Skip files we don't have permission to access
}
if shouldSkip(path) {
if shouldSkip(path, ignorePatterns) {
if info.IsDir() {
return filepath.SkipDir
}
@@ -93,137 +109,212 @@ func (b *lsTool) listDirectory(initialPath string) ([]string, error) {
if info.IsDir() {
path = path + string(filepath.Separator)
}
relPath, err := filepath.Rel(b.workingDir, path)
if err == nil {
results = append(results, relPath)
} else {
results = append(results, path)
}
results = append(results, path)
}
if len(results) >= MaxFiles {
return fmt.Errorf("max files reached")
if len(results) >= limit {
truncated = true
return filepath.SkipAll
}
return nil
})
if err != nil && err.Error() != "max files reached" {
return nil, err
if err != nil {
return nil, truncated, err
}
return results, nil
return results, truncated, nil
}
func shouldSkip(path string) bool {
func shouldSkip(path string, ignorePatterns []string) bool {
base := filepath.Base(path)
// Skip hidden files and directories
if base != "." && strings.HasPrefix(base, ".") {
return true
}
// Skip common directories and files
commonIgnored := []string{
"__pycache__",
"node_modules",
"dist",
"build",
"target",
"vendor",
"bin",
"obj",
".git",
".idea",
".vscode",
".DS_Store",
"*.pyc",
"*.pyo",
"*.pyd",
"*.so",
"*.dll",
"*.exe",
}
// Skip __pycache__ directories
if strings.Contains(path, filepath.Join("__pycache__", "")) {
return true
}
// Check against common ignored patterns
for _, ignored := range commonIgnored {
if strings.HasSuffix(ignored, "/") {
// Directory pattern
if strings.Contains(path, filepath.Join(ignored[:len(ignored)-1], "")) {
return true
}
} else if strings.HasPrefix(ignored, "*.") {
// File extension pattern
if strings.HasSuffix(base, ignored[1:]) {
return true
}
} else {
// Exact match
if base == ignored {
return true
}
}
}
// Check against ignore patterns
for _, pattern := range ignorePatterns {
matched, err := filepath.Match(pattern, base)
if err == nil && matched {
return true
}
}
return false
}
type TreeNode struct {
Name string `json:"name"`
Path string `json:"path"`
Type string `json:"type"` // "file" or "directory"
Children []TreeNode `json:"children,omitempty"`
}
func createFileTree(sortedPaths []string) []TreeNode {
root := []TreeNode{}
func createFileTree(sortedPaths []string) []*TreeNode {
root := []*TreeNode{}
pathMap := make(map[string]*TreeNode)
for _, path := range sortedPaths {
parts := strings.Split(path, string(filepath.Separator))
currentLevel := &root
currentPath := ""
var parentPath string
var cleanParts []string
for _, part := range parts {
if part != "" {
cleanParts = append(cleanParts, part)
}
}
parts = cleanParts
if len(parts) == 0 {
continue
}
for i, part := range parts {
if part == "" {
continue
}
if currentPath == "" {
currentPath = part
} else {
currentPath = filepath.Join(currentPath, part)
}
if _, exists := pathMap[currentPath]; exists {
parentPath = currentPath
continue
}
isLastPart := i == len(parts)-1
isDir := !isLastPart || strings.HasSuffix(path, string(filepath.Separator))
found := false
for i := range *currentLevel {
if (*currentLevel)[i].Name == part {
found = true
if (*currentLevel)[i].Children != nil {
currentLevel = &(*currentLevel)[i].Children
}
break
}
nodeType := "file"
if isDir {
nodeType = "directory"
}
newNode := &TreeNode{
Name: part,
Path: currentPath,
Type: nodeType,
Children: []*TreeNode{},
}
if !found {
nodeType := "file"
if isDir {
nodeType = "directory"
}
pathMap[currentPath] = newNode
newNode := TreeNode{
Name: part,
Path: currentPath,
Type: nodeType,
}
if isDir {
newNode.Children = []TreeNode{}
*currentLevel = append(*currentLevel, newNode)
currentLevel = &(*currentLevel)[len(*currentLevel)-1].Children
} else {
*currentLevel = append(*currentLevel, newNode)
if i > 0 && parentPath != "" {
if parent, ok := pathMap[parentPath]; ok {
parent.Children = append(parent.Children, newNode)
}
} else {
root = append(root, newNode)
}
parentPath = currentPath
}
}
return root
}
func printTree(tree []TreeNode, rootPath string) string {
func printTree(tree []*TreeNode, rootPath string) string {
var result strings.Builder
result.WriteString(fmt.Sprintf("- %s%s\n", rootPath, string(filepath.Separator)))
printTreeRecursive(&result, tree, 0, " ")
for _, node := range tree {
printNode(&result, node, 1)
}
return result.String()
}
func printTreeRecursive(builder *strings.Builder, tree []TreeNode, level int, prefix string) {
for _, node := range tree {
linePrefix := prefix + "- "
func printNode(builder *strings.Builder, node *TreeNode, level int) {
indent := strings.Repeat(" ", level)
nodeName := node.Name
if node.Type == "directory" {
nodeName += string(filepath.Separator)
}
fmt.Fprintf(builder, "%s%s\n", linePrefix, nodeName)
nodeName := node.Name
if node.Type == "directory" {
nodeName += string(filepath.Separator)
}
if node.Type == "directory" && len(node.Children) > 0 {
printTreeRecursive(builder, node.Children, level+1, prefix+" ")
fmt.Fprintf(builder, "%s- %s\n", indent, nodeName)
if node.Type == "directory" && len(node.Children) > 0 {
for _, child := range node.Children {
printNode(builder, child, level+1)
}
}
}
func NewLsTool(workingDir string) tool.InvokableTool {
return &lsTool{
workingDir,
}
func lsDescription() string {
return `Directory listing tool that shows files and subdirectories in a tree structure, helping you explore and understand the project organization.
WHEN TO USE THIS TOOL:
- Use when you need to explore the structure of a directory
- Helpful for understanding the organization of a project
- Good first step when getting familiar with a new codebase
HOW TO USE:
- Provide a path to list (defaults to current working directory)
- Optionally specify glob patterns to ignore
- Results are displayed in a tree structure
FEATURES:
- Displays a hierarchical view of files and directories
- Automatically skips hidden files/directories (starting with '.')
- Skips common system directories like __pycache__
- Can filter out files matching specific patterns
LIMITATIONS:
- Results are limited to 1000 files
- Very large directories will be truncated
- Does not show file sizes or permissions
- Cannot recursively list all directories in a large project
TIPS:
- Use Glob tool for finding files by name patterns instead of browsing
- Use Grep tool for searching file contents
- Combine with other tools for more effective exploration`
}
func NewLsTool() BaseTool {
return &lsTool{}
}

View File

@@ -0,0 +1,457 @@
package tools
import (
"context"
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLsTool_Info(t *testing.T) {
tool := NewLsTool()
info := tool.Info()
assert.Equal(t, LSToolName, info.Name)
assert.NotEmpty(t, info.Description)
assert.Contains(t, info.Parameters, "path")
assert.Contains(t, info.Parameters, "ignore")
assert.Contains(t, info.Required, "path")
}
func TestLsTool_Run(t *testing.T) {
// Create a temporary directory for testing
tempDir, err := os.MkdirTemp("", "ls_tool_test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create a test directory structure
testDirs := []string{
"dir1",
"dir2",
"dir2/subdir1",
"dir2/subdir2",
"dir3",
"dir3/.hidden_dir",
"__pycache__",
}
testFiles := []string{
"file1.txt",
"file2.txt",
"dir1/file3.txt",
"dir2/file4.txt",
"dir2/subdir1/file5.txt",
"dir2/subdir2/file6.txt",
"dir3/file7.txt",
"dir3/.hidden_file.txt",
"__pycache__/cache.pyc",
".hidden_root_file.txt",
}
// Create directories
for _, dir := range testDirs {
dirPath := filepath.Join(tempDir, dir)
err := os.MkdirAll(dirPath, 0755)
require.NoError(t, err)
}
// Create files
for _, file := range testFiles {
filePath := filepath.Join(tempDir, file)
err := os.WriteFile(filePath, []byte("test content"), 0644)
require.NoError(t, err)
}
t.Run("lists directory successfully", func(t *testing.T) {
tool := NewLsTool()
params := LSParams{
Path: tempDir,
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: LSToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
// Check that visible directories and files are included
assert.Contains(t, response.Content, "dir1")
assert.Contains(t, response.Content, "dir2")
assert.Contains(t, response.Content, "dir3")
assert.Contains(t, response.Content, "file1.txt")
assert.Contains(t, response.Content, "file2.txt")
// Check that hidden files and directories are not included
assert.NotContains(t, response.Content, ".hidden_dir")
assert.NotContains(t, response.Content, ".hidden_file.txt")
assert.NotContains(t, response.Content, ".hidden_root_file.txt")
// Check that __pycache__ is not included
assert.NotContains(t, response.Content, "__pycache__")
})
t.Run("handles non-existent path", func(t *testing.T) {
tool := NewLsTool()
params := LSParams{
Path: filepath.Join(tempDir, "non_existent_dir"),
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: LSToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "path does not exist")
})
t.Run("handles empty path parameter", func(t *testing.T) {
// For this test, we need to mock the config.WorkingDirectory function
// Since we can't easily do that, we'll just check that the response doesn't contain an error message
tool := NewLsTool()
params := LSParams{
Path: "",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: LSToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
// The response should either contain a valid directory listing or an error
// We'll just check that it's not empty
assert.NotEmpty(t, response.Content)
})
t.Run("handles invalid parameters", func(t *testing.T) {
tool := NewLsTool()
call := ToolCall{
Name: LSToolName,
Input: "invalid json",
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "error parsing parameters")
})
t.Run("respects ignore patterns", func(t *testing.T) {
tool := NewLsTool()
params := LSParams{
Path: tempDir,
Ignore: []string{"file1.txt", "dir1"},
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: LSToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
// The output format is a tree, so we need to check for specific patterns
// Check that file1.txt is not directly mentioned
assert.NotContains(t, response.Content, "- file1.txt")
// Check that dir1/ is not directly mentioned
assert.NotContains(t, response.Content, "- dir1/")
})
t.Run("handles relative path", func(t *testing.T) {
// Save original working directory
origWd, err := os.Getwd()
require.NoError(t, err)
defer func() {
os.Chdir(origWd)
}()
// Change to a directory above the temp directory
parentDir := filepath.Dir(tempDir)
err = os.Chdir(parentDir)
require.NoError(t, err)
tool := NewLsTool()
params := LSParams{
Path: filepath.Base(tempDir),
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: LSToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
// Should list the temp directory contents
assert.Contains(t, response.Content, "dir1")
assert.Contains(t, response.Content, "file1.txt")
})
}
func TestShouldSkip(t *testing.T) {
testCases := []struct {
name string
path string
ignorePatterns []string
expected bool
}{
{
name: "hidden file",
path: "/path/to/.hidden_file",
ignorePatterns: []string{},
expected: true,
},
{
name: "hidden directory",
path: "/path/to/.hidden_dir",
ignorePatterns: []string{},
expected: true,
},
{
name: "pycache directory",
path: "/path/to/__pycache__/file.pyc",
ignorePatterns: []string{},
expected: true,
},
{
name: "node_modules directory",
path: "/path/to/node_modules/package",
ignorePatterns: []string{},
expected: false, // The shouldSkip function doesn't directly check for node_modules in the path
},
{
name: "normal file",
path: "/path/to/normal_file.txt",
ignorePatterns: []string{},
expected: false,
},
{
name: "normal directory",
path: "/path/to/normal_dir",
ignorePatterns: []string{},
expected: false,
},
{
name: "ignored by pattern",
path: "/path/to/ignore_me.txt",
ignorePatterns: []string{"ignore_*.txt"},
expected: true,
},
{
name: "not ignored by pattern",
path: "/path/to/keep_me.txt",
ignorePatterns: []string{"ignore_*.txt"},
expected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := shouldSkip(tc.path, tc.ignorePatterns)
assert.Equal(t, tc.expected, result)
})
}
}
func TestCreateFileTree(t *testing.T) {
paths := []string{
"/path/to/file1.txt",
"/path/to/dir1/file2.txt",
"/path/to/dir1/subdir/file3.txt",
"/path/to/dir2/file4.txt",
}
tree := createFileTree(paths)
// Check the structure of the tree
assert.Len(t, tree, 1) // Should have one root node
// Check the root node
rootNode := tree[0]
assert.Equal(t, "path", rootNode.Name)
assert.Equal(t, "directory", rootNode.Type)
assert.Len(t, rootNode.Children, 1)
// Check the "to" node
toNode := rootNode.Children[0]
assert.Equal(t, "to", toNode.Name)
assert.Equal(t, "directory", toNode.Type)
assert.Len(t, toNode.Children, 3) // file1.txt, dir1, dir2
// Find the dir1 node
var dir1Node *TreeNode
for _, child := range toNode.Children {
if child.Name == "dir1" {
dir1Node = child
break
}
}
require.NotNil(t, dir1Node)
assert.Equal(t, "directory", dir1Node.Type)
assert.Len(t, dir1Node.Children, 2) // file2.txt and subdir
}
func TestPrintTree(t *testing.T) {
// Create a simple tree
tree := []*TreeNode{
{
Name: "dir1",
Path: "dir1",
Type: "directory",
Children: []*TreeNode{
{
Name: "file1.txt",
Path: "dir1/file1.txt",
Type: "file",
},
{
Name: "subdir",
Path: "dir1/subdir",
Type: "directory",
Children: []*TreeNode{
{
Name: "file2.txt",
Path: "dir1/subdir/file2.txt",
Type: "file",
},
},
},
},
},
{
Name: "file3.txt",
Path: "file3.txt",
Type: "file",
},
}
result := printTree(tree, "/root")
// Check the output format
assert.Contains(t, result, "- /root/")
assert.Contains(t, result, " - dir1/")
assert.Contains(t, result, " - file1.txt")
assert.Contains(t, result, " - subdir/")
assert.Contains(t, result, " - file2.txt")
assert.Contains(t, result, " - file3.txt")
}
func TestListDirectory(t *testing.T) {
// Create a temporary directory for testing
tempDir, err := os.MkdirTemp("", "list_directory_test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create a test directory structure
testDirs := []string{
"dir1",
"dir1/subdir1",
".hidden_dir",
}
testFiles := []string{
"file1.txt",
"file2.txt",
"dir1/file3.txt",
"dir1/subdir1/file4.txt",
".hidden_file.txt",
}
// Create directories
for _, dir := range testDirs {
dirPath := filepath.Join(tempDir, dir)
err := os.MkdirAll(dirPath, 0755)
require.NoError(t, err)
}
// Create files
for _, file := range testFiles {
filePath := filepath.Join(tempDir, file)
err := os.WriteFile(filePath, []byte("test content"), 0644)
require.NoError(t, err)
}
t.Run("lists files with no limit", func(t *testing.T) {
files, truncated, err := listDirectory(tempDir, []string{}, 1000)
require.NoError(t, err)
assert.False(t, truncated)
// Check that visible files and directories are included
containsPath := func(paths []string, target string) bool {
targetPath := filepath.Join(tempDir, target)
for _, path := range paths {
if strings.HasPrefix(path, targetPath) {
return true
}
}
return false
}
assert.True(t, containsPath(files, "dir1"))
assert.True(t, containsPath(files, "file1.txt"))
assert.True(t, containsPath(files, "file2.txt"))
assert.True(t, containsPath(files, "dir1/file3.txt"))
// Check that hidden files and directories are not included
assert.False(t, containsPath(files, ".hidden_dir"))
assert.False(t, containsPath(files, ".hidden_file.txt"))
})
t.Run("respects limit and returns truncated flag", func(t *testing.T) {
files, truncated, err := listDirectory(tempDir, []string{}, 2)
require.NoError(t, err)
assert.True(t, truncated)
assert.Len(t, files, 2)
})
t.Run("respects ignore patterns", func(t *testing.T) {
files, truncated, err := listDirectory(tempDir, []string{"*.txt"}, 1000)
require.NoError(t, err)
assert.False(t, truncated)
// Check that no .txt files are included
for _, file := range files {
assert.False(t, strings.HasSuffix(file, ".txt"), "Found .txt file: %s", file)
}
// But directories should still be included
containsDir := false
for _, file := range files {
if strings.Contains(file, "dir1") {
containsDir = true
break
}
}
assert.True(t, containsDir)
})
}

View File

@@ -116,10 +116,10 @@ func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx
}
tempDir := os.TempDir()
stdoutFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-stdout-%d", time.Now().UnixNano()))
stderrFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-stderr-%d", time.Now().UnixNano()))
statusFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-status-%d", time.Now().UnixNano()))
cwdFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-cwd-%d", time.Now().UnixNano()))
stdoutFile := filepath.Join(tempDir, fmt.Sprintf("termai-stdout-%d", time.Now().UnixNano()))
stderrFile := filepath.Join(tempDir, fmt.Sprintf("termai-stderr-%d", time.Now().UnixNano()))
statusFile := filepath.Join(tempDir, fmt.Sprintf("termai-status-%d", time.Now().UnixNano()))
cwdFile := filepath.Join(tempDir, fmt.Sprintf("termai-cwd-%d", time.Now().UnixNano()))
defer func() {
os.Remove(stdoutFile)

View File

@@ -0,0 +1,49 @@
package tools
import "context"
type ToolInfo struct {
Name string
Description string
Parameters map[string]any
Required []string
}
type toolResponseType string
const (
ToolResponseTypeText toolResponseType = "text"
ToolResponseTypeImage toolResponseType = "image"
)
type ToolResponse struct {
Type toolResponseType `json:"type"`
Content string `json:"content"`
IsError bool `json:"is_error"`
}
func NewTextResponse(content string) ToolResponse {
return ToolResponse{
Type: ToolResponseTypeText,
Content: content,
}
}
func NewTextErrorResponse(content string) ToolResponse {
return ToolResponse{
Type: ToolResponseTypeText,
Content: content,
IsError: true,
}
}
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Input string `json:"input"`
}
type BaseTool interface {
Info() ToolInfo
Run(ctx context.Context, params ToolCall) (ToolResponse, error)
}

View File

@@ -10,77 +10,77 @@ import (
"path/filepath"
"strings"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
"github.com/kujtimiihoxha/termai/internal/config"
)
type viewTool struct {
workingDir string
}
type viewTool struct{}
const (
ViewToolName = "view"
MaxReadSize = 250 * 1024
ViewToolName = "view"
MaxReadSize = 250 * 1024
DefaultReadLimit = 2000
MaxLineLength = 2000
MaxLineLength = 2000
)
type ViewPatams struct {
type ViewParams struct {
FilePath string `json:"file_path"`
Offset int `json:"offset"`
Limit int `json:"limit"`
}
func (b *viewTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
return &schema.ToolInfo{
Name: ViewToolName,
Desc: `Reads a file from the local filesystem. The file_path parameter must be an absolute path, not a relative path. By default, it reads up to 2000 lines starting from the beginning of the file. You can optionally specify a line offset and limit (especially handy for long files), but it's recommended to read the whole file by not providing these parameters. Any lines longer than 2000 characters will be truncated. For image files, the tool will display the image for you.`,
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"file_path": {
Type: "string",
Desc: "The absolute path to the file to read",
Required: true,
func (v *viewTool) Info() ToolInfo {
return ToolInfo{
Name: ViewToolName,
Description: viewDescription(),
Parameters: map[string]any{
"file_path": map[string]any{
"type": "string",
"description": "The path to the file to read",
},
"offset": {
Type: "int",
Desc: "The line number to start reading from. Only provide if the file is too large to read at once",
"offset": map[string]any{
"type": "integer",
"description": "The line number to start reading from (0-based)",
},
"limit": {
Type: "int",
Desc: "The number of lines to read. Only provide if the file is too large to read at once.",
"limit": map[string]any{
"type": "integer",
"description": "The number of lines to read (defaults to 2000)",
},
}),
}, nil
},
Required: []string{"file_path"},
}
}
func (b *viewTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
var params ViewPatams
if err := json.Unmarshal([]byte(args), &params); err != nil {
return fmt.Sprintf("failed to parse parameters: %s", err), nil
// Run implements Tool.
func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params ViewParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
if params.FilePath == "" {
return "file_path is required", nil
return NewTextErrorResponse("file_path is required"), nil
}
if !filepath.IsAbs(params.FilePath) {
return fmt.Sprintf("file path must be absolute, got: %s", params.FilePath), nil
// Handle relative paths
filePath := params.FilePath
if !filepath.IsAbs(filePath) {
filePath = filepath.Join(config.WorkingDirectory(), filePath)
}
fileInfo, err := os.Stat(params.FilePath)
// Check if file exists
fileInfo, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
dir := filepath.Dir(params.FilePath)
base := filepath.Base(params.FilePath)
// Try to offer suggestions for similarly named files
dir := filepath.Dir(filePath)
base := filepath.Base(filePath)
dirEntries, dirErr := os.ReadDir(dir)
if dirErr == nil {
var suggestions []string
for _, entry := range dirEntries {
if strings.Contains(entry.Name(), base) || strings.Contains(base, entry.Name()) {
if strings.Contains(strings.ToLower(entry.Name()), strings.ToLower(base)) ||
strings.Contains(strings.ToLower(base), strings.ToLower(entry.Name())) {
suggestions = append(suggestions, filepath.Join(dir, entry.Name()))
if len(suggestions) >= 3 {
break
@@ -89,43 +89,55 @@ func (b *viewTool) InvokableRun(ctx context.Context, args string, opts ...tool.O
}
if len(suggestions) > 0 {
return fmt.Sprintf("file not found: %s. Did you mean one of these?\n%s",
params.FilePath, strings.Join(suggestions, "\n")), nil
return NewTextErrorResponse(fmt.Sprintf("File not found: %s\n\nDid you mean one of these?\n%s",
filePath, strings.Join(suggestions, "\n"))), nil
}
}
return fmt.Sprintf("file not found: %s", params.FilePath), nil
return NewTextErrorResponse(fmt.Sprintf("File not found: %s", filePath)), nil
}
return fmt.Sprintf("failed to access file: %s", err), nil
return NewTextErrorResponse(fmt.Sprintf("Failed to access file: %s", err)), nil
}
// Check if it's a directory
if fileInfo.IsDir() {
return fmt.Sprintf("path is a directory, not a file: %s", params.FilePath), nil
return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
}
// Check file size
if fileInfo.Size() > MaxReadSize {
return fmt.Sprintf("file is too large (%d bytes). Maximum size is %d bytes",
fileInfo.Size(), MaxReadSize), nil
return NewTextErrorResponse(fmt.Sprintf("File is too large (%d bytes). Maximum size is %d bytes",
fileInfo.Size(), MaxReadSize)), nil
}
// Set default limit if not provided
if params.Limit <= 0 {
params.Limit = DefaultReadLimit
}
isImage, _ := isImageFile(params.FilePath)
// Check if it's an image file
isImage, imageType := isImageFile(filePath)
if isImage {
// TODO: Implement image reading
return "reading images is not supported", nil
return NewTextErrorResponse(fmt.Sprintf("This is an image file of type: %s\nUse a different tool to process images", imageType)), nil
}
content, _, err := readTextFile(params.FilePath, params.Offset, params.Limit)
// Read the file content
content, lineCount, err := readTextFile(filePath, params.Offset, params.Limit)
if err != nil {
return fmt.Sprintf("failed to read file: %s", err), nil
return NewTextErrorResponse(fmt.Sprintf("Failed to read file: %s", err)), nil
}
recordFileRead(params.FilePath)
// Format the output with line numbers
output := addLineNumbers(content, params.Offset+1)
return addLineNumbers(content, params.Offset+1), nil
// Add a note if the content was truncated
if lineCount > params.Offset+len(strings.Split(content, "\n")) {
output += fmt.Sprintf("\n\n(File has more lines. Use 'offset' parameter to read beyond line %d)",
params.Offset+len(strings.Split(content, "\n")))
}
recordFileRead(filePath)
return NewTextResponse(output), nil
}
func addLineNumbers(content string, startLine int) string {
@@ -191,6 +203,11 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) {
lines = append(lines, lineText)
}
// Continue scanning to get total line count
for scanner.Scan() {
lineCount++
}
if err := scanner.Err(); err != nil {
return "", 0, err
}
@@ -202,17 +219,17 @@ func isImageFile(filePath string) (bool, string) {
ext := strings.ToLower(filepath.Ext(filePath))
switch ext {
case ".jpg", ".jpeg":
return true, "jpeg"
return true, "JPEG"
case ".png":
return true, "png"
return true, "PNG"
case ".gif":
return true, "gif"
return true, "GIF"
case ".bmp":
return true, "bmp"
return true, "BMP"
case ".svg":
return true, "svg"
return true, "SVG"
case ".webp":
return true, "webp"
return true, "WebP"
default:
return false, ""
}
@@ -240,8 +257,39 @@ func (s *LineScanner) Err() error {
return s.scanner.Err()
}
func NewViewTool(workingDir string) tool.InvokableTool {
return &viewTool{
workingDir,
}
func viewDescription() string {
return `File viewing tool that reads and displays the contents of files with line numbers, allowing you to examine code, logs, or text data.
WHEN TO USE THIS TOOL:
- Use when you need to read the contents of a specific file
- Helpful for examining source code, configuration files, or log files
- Perfect for looking at text-based file formats
HOW TO USE:
- Provide the path to the file you want to view
- Optionally specify an offset to start reading from a specific line
- Optionally specify a limit to control how many lines are read
FEATURES:
- Displays file contents with line numbers for easy reference
- Can read from any position in a file using the offset parameter
- Handles large files by limiting the number of lines read
- Automatically truncates very long lines for better display
- Suggests similar file names when the requested file isn't found
LIMITATIONS:
- Maximum file size is 250KB
- Default reading limit is 2000 lines
- Lines longer than 2000 characters are truncated
- Cannot display binary files or images
- Images can be identified but not displayed
TIPS:
- Use with Glob tool to first find files you want to view
- For code exploration, first use Grep to find relevant files, then View to examine them
- When viewing large files, use the offset parameter to read specific sections`
}
func NewViewTool() BaseTool {
return &viewTool{}
}

View File

@@ -6,17 +6,13 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/permission"
)
type writeTool struct {
workingDir string
}
type writeTool struct{}
const (
WriteToolName = "write"
@@ -27,139 +23,139 @@ type WriteParams struct {
Content string `json:"content"`
}
func (b *writeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
return &schema.ToolInfo{
Name: WriteToolName,
Desc: "Write a file to the local filesystem. Overwrites the existing file if there is one.\n\nBefore using this tool:\n\n1. Use the ReadFile tool to understand the file's contents and context\n\n2. Directory Verification (only applicable when creating new files):\n - Use the LS tool to verify the parent directory exists and is the correct location",
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"file_path": {
Type: "string",
Desc: "The absolute path to the file to write (must be absolute, not relative)",
Required: true,
},
"content": {
Type: "string",
Desc: "The content to write to the file",
Required: true,
},
}),
}, nil
type WritePermissionsParams struct {
FilePath string `json:"file_path"`
Content string `json:"content"`
}
func (b *writeTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
func (w *writeTool) Info() ToolInfo {
return ToolInfo{
Name: WriteToolName,
Description: writeDescription(),
Parameters: map[string]any{
"file_path": map[string]any{
"type": "string",
"description": "The path to the file to write",
},
"content": map[string]any{
"type": "string",
"description": "The content to write to the file",
},
},
Required: []string{"file_path", "content"},
}
}
// Run implements Tool.
func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params WriteParams
if err := json.Unmarshal([]byte(args), &params); err != nil {
return "", fmt.Errorf("failed to parse parameters: %w", err)
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
if params.FilePath == "" {
return "file_path is required", nil
return NewTextErrorResponse("file_path is required"), nil
}
if !filepath.IsAbs(params.FilePath) {
return fmt.Sprintf("file path must be absolute, got: %s", params.FilePath), nil
if params.Content == "" {
return NewTextErrorResponse("content is required"), nil
}
// fileExists := false
// oldContent := ""
fileInfo, err := os.Stat(params.FilePath)
// Handle relative paths
filePath := params.FilePath
if !filepath.IsAbs(filePath) {
filePath = filepath.Join(config.WorkingDirectory(), filePath)
}
// Check if file exists and is a directory
fileInfo, err := os.Stat(filePath)
if err == nil {
if fileInfo.IsDir() {
return fmt.Sprintf("path is a directory, not a file: %s", params.FilePath), nil
return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
}
// Check if file was modified since last read
modTime := fileInfo.ModTime()
lastRead := getLastReadTime(params.FilePath)
lastRead := getLastReadTime(filePath)
if modTime.After(lastRead) {
return fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)), nil
return NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
}
// oldContentBytes, readErr := os.ReadFile(params.FilePath)
// if readErr != nil {
// oldContent = string(oldContentBytes)
// }
// Optional: Get old content for diff
oldContent, readErr := os.ReadFile(filePath)
if readErr == nil && string(oldContent) == params.Content {
return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
}
} else if !os.IsNotExist(err) {
return fmt.Sprintf("failed to access file: %s", err), nil
return NewTextErrorResponse(fmt.Sprintf("Failed to access file: %s", err)), nil
}
// Create parent directories if needed
dir := filepath.Dir(filePath)
if err = os.MkdirAll(dir, 0o755); err != nil {
return NewTextErrorResponse(fmt.Sprintf("Failed to create parent directories: %s", err)), nil
}
p := permission.Default.Request(
permission.CreatePermissionRequest{
Path: b.workingDir,
Path: filePath,
ToolName: WriteToolName,
Action: "write",
Description: fmt.Sprintf("Write to file %s", params.FilePath),
Params: map[string]interface{}{
"file_path": params.FilePath,
"contnet": params.Content,
Action: "create",
Description: fmt.Sprintf("Create file %s", filePath),
Params: WritePermissionsParams{
FilePath: filePath,
Content: GenerateDiff("", params.Content),
},
},
)
if !p {
return "", fmt.Errorf("permission denied")
}
dir := filepath.Dir(params.FilePath)
if err = os.MkdirAll(dir, 0o755); err != nil {
return fmt.Sprintf("failed to create parent directories: %s", err), nil
return NewTextErrorResponse(fmt.Sprintf("Permission denied to create file: %s", filePath)), nil
}
err = os.WriteFile(params.FilePath, []byte(params.Content), 0o644)
// Write the file
err = os.WriteFile(filePath, []byte(params.Content), 0o644)
if err != nil {
return fmt.Sprintf("failed to write file: %s", err), nil
return NewTextErrorResponse(fmt.Sprintf("Failed to write file: %s", err)), nil
}
recordFileWrite(params.FilePath)
// Record the file write
recordFileWrite(filePath)
recordFileRead(filePath)
output := "File written: " + params.FilePath
// if fileExists && oldContent != params.Content {
// output = generateSimpleDiff(oldContent, params.Content)
// }
return output, nil
return NewTextResponse(fmt.Sprintf("File successfully written: %s", filePath)), nil
}
func generateSimpleDiff(oldContent, newContent string) string {
if oldContent == newContent {
return "[No changes]"
}
func writeDescription() string {
return `File writing tool that creates or updates files in the filesystem, allowing you to save or modify text content.
oldLines := strings.Split(oldContent, "\n")
newLines := strings.Split(newContent, "\n")
WHEN TO USE THIS TOOL:
- Use when you need to create a new file
- Helpful for updating existing files with modified content
- Perfect for saving generated code, configurations, or text data
var diffBuilder strings.Builder
diffBuilder.WriteString(fmt.Sprintf("@@ -%d,+%d @@\n", len(oldLines), len(newLines)))
HOW TO USE:
- Provide the path to the file you want to write
- Include the content to be written to the file
- The tool will create any necessary parent directories
maxLines := max(len(oldLines), len(newLines))
for i := range maxLines {
oldLine := ""
newLine := ""
FEATURES:
- Can create new files or overwrite existing ones
- Creates parent directories automatically if they don't exist
- Checks if the file has been modified since last read for safety
- Avoids unnecessary writes when content hasn't changed
if i < len(oldLines) {
oldLine = oldLines[i]
}
LIMITATIONS:
- You should read a file before writing to it to avoid conflicts
- Cannot append to files (rewrites the entire file)
if i < len(newLines) {
newLine = newLines[i]
}
if oldLine != newLine {
if i < len(oldLines) {
diffBuilder.WriteString(fmt.Sprintf("- %s\n", oldLine))
}
if i < len(newLines) {
diffBuilder.WriteString(fmt.Sprintf("+ %s\n", newLine))
}
} else {
diffBuilder.WriteString(fmt.Sprintf(" %s\n", oldLine))
}
}
return diffBuilder.String()
TIPS:
- Use the View tool first to examine existing files before modifying them
- Use the LS tool to verify the correct location when creating new files
- Combine with Glob and Grep tools to find and modify multiple files
- Always include descriptive comments when making changes to existing code`
}
func NewWriteTool(workingDir string) tool.InvokableTool {
return &writeTool{
workingDir: workingDir,
}
func NewWriteTool() BaseTool {
return &writeTool{}
}

View File

@@ -0,0 +1,324 @@
package tools
import (
"context"
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/kujtimiihoxha/termai/internal/permission"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWriteTool_Info(t *testing.T) {
tool := NewWriteTool()
info := tool.Info()
assert.Equal(t, WriteToolName, info.Name)
assert.NotEmpty(t, info.Description)
assert.Contains(t, info.Parameters, "file_path")
assert.Contains(t, info.Parameters, "content")
assert.Contains(t, info.Required, "file_path")
assert.Contains(t, info.Required, "content")
}
func TestWriteTool_Run(t *testing.T) {
// Setup a mock permission handler that always allows
origPermission := permission.Default
defer func() {
permission.Default = origPermission
}()
permission.Default = newMockPermissionService(true)
// Create a temporary directory for testing
tempDir, err := os.MkdirTemp("", "write_tool_test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
t.Run("creates a new file successfully", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
filePath := filepath.Join(tempDir, "new_file.txt")
content := "This is a test content"
params := WriteParams{
FilePath: filePath,
Content: content,
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: WriteToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "successfully written")
// Verify file was created with correct content
fileContent, err := os.ReadFile(filePath)
require.NoError(t, err)
assert.Equal(t, content, string(fileContent))
})
t.Run("creates file with nested directories", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
content := "Content in nested directory"
params := WriteParams{
FilePath: filePath,
Content: content,
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: WriteToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "successfully written")
// Verify file was created with correct content
fileContent, err := os.ReadFile(filePath)
require.NoError(t, err)
assert.Equal(t, content, string(fileContent))
})
t.Run("updates existing file", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
// Create a file first
filePath := filepath.Join(tempDir, "existing_file.txt")
initialContent := "Initial content"
err := os.WriteFile(filePath, []byte(initialContent), 0644)
require.NoError(t, err)
// Record the file read to avoid modification time check failure
recordFileRead(filePath)
// Update the file
updatedContent := "Updated content"
params := WriteParams{
FilePath: filePath,
Content: updatedContent,
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: WriteToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "successfully written")
// Verify file was updated with correct content
fileContent, err := os.ReadFile(filePath)
require.NoError(t, err)
assert.Equal(t, updatedContent, string(fileContent))
})
t.Run("handles invalid parameters", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
call := ToolCall{
Name: WriteToolName,
Input: "invalid json",
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "error parsing parameters")
})
t.Run("handles missing file_path", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
params := WriteParams{
FilePath: "",
Content: "Some content",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: WriteToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "file_path is required")
})
t.Run("handles missing content", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
params := WriteParams{
FilePath: filepath.Join(tempDir, "file.txt"),
Content: "",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: WriteToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "content is required")
})
t.Run("handles writing to a directory path", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
// Create a directory
dirPath := filepath.Join(tempDir, "test_dir")
err := os.Mkdir(dirPath, 0755)
require.NoError(t, err)
params := WriteParams{
FilePath: dirPath,
Content: "Some content",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: WriteToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "Path is a directory")
})
t.Run("handles permission denied", func(t *testing.T) {
permission.Default = newMockPermissionService(false)
tool := NewWriteTool()
filePath := filepath.Join(tempDir, "permission_denied.txt")
params := WriteParams{
FilePath: filePath,
Content: "Content that should not be written",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: WriteToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "Permission denied")
// Verify file was not created
_, err = os.Stat(filePath)
assert.True(t, os.IsNotExist(err))
})
t.Run("detects file modified since last read", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
// Create a file
filePath := filepath.Join(tempDir, "modified_file.txt")
initialContent := "Initial content"
err := os.WriteFile(filePath, []byte(initialContent), 0644)
require.NoError(t, err)
// Record an old read time
fileRecordMutex.Lock()
fileRecords[filePath] = fileRecord{
path: filePath,
readTime: time.Now().Add(-1 * time.Hour),
}
fileRecordMutex.Unlock()
// Try to update the file
params := WriteParams{
FilePath: filePath,
Content: "Updated content",
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: WriteToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "has been modified since it was last read")
// Verify file was not modified
fileContent, err := os.ReadFile(filePath)
require.NoError(t, err)
assert.Equal(t, initialContent, string(fileContent))
})
t.Run("skips writing when content is identical", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
// Create a file
filePath := filepath.Join(tempDir, "identical_content.txt")
content := "Content that won't change"
err := os.WriteFile(filePath, []byte(content), 0644)
require.NoError(t, err)
// Record a read time
recordFileRead(filePath)
// Try to write the same content
params := WriteParams{
FilePath: filePath,
Content: content,
}
paramsJSON, err := json.Marshal(params)
require.NoError(t, err)
call := ToolCall{
Name: WriteToolName,
Input: string(paramsJSON),
}
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "already contains the exact content")
})
}

View File

@@ -2,26 +2,65 @@ package message
import (
"context"
"database/sql"
"encoding/json"
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
"github.com/kujtimiihoxha/termai/internal/db"
"github.com/kujtimiihoxha/termai/internal/pubsub"
)
type Message struct {
ID string
SessionID string
MessageData schema.Message
type MessageRole string
CreatedAt int64
UpdatedAt int64
const (
Assistant MessageRole = "assistant"
User MessageRole = "user"
System MessageRole = "system"
Tool MessageRole = "tool"
)
type ToolResult struct {
ToolCallID string
Content string
IsError bool
// TODO: support for images
}
type ToolCall struct {
ID string
Name string
Input string
Type string
}
type Message struct {
ID string
SessionID string
// NEW
Role MessageRole
Content string
Thinking string
Finished bool
ToolResults []ToolResult
ToolCalls []ToolCall
CreatedAt int64
UpdatedAt int64
}
type CreateMessageParams struct {
Role MessageRole
Content string
ToolCalls []ToolCall
ToolResults []ToolResult
}
type Service interface {
pubsub.Suscriber[Message]
Create(sessionID string, messageData schema.Message) (Message, error)
Create(sessionID string, params CreateMessageParams) (Message, error)
Update(message Message) error
Get(id string) (Message, error)
List(sessionID string) ([]Message, error)
Delete(id string) error
@@ -34,24 +73,6 @@ type service struct {
ctx context.Context
}
func (s *service) Create(sessionID string, messageData schema.Message) (Message, error) {
messageDataJSON, err := json.Marshal(messageData)
if err != nil {
return Message{}, err
}
dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
ID: uuid.New().String(),
SessionID: sessionID,
MessageData: string(messageDataJSON),
})
if err != nil {
return Message{}, err
}
message := s.fromDBItem(dbMessage)
s.Publish(pubsub.CreatedEvent, message)
return message, nil
}
func (s *service) Delete(id string) error {
message, err := s.Get(id)
if err != nil {
@@ -65,6 +86,35 @@ func (s *service) Delete(id string) error {
return nil
}
func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
toolCallsStr, err := json.Marshal(params.ToolCalls)
if err != nil {
return Message{}, err
}
toolResultsStr, err := json.Marshal(params.ToolResults)
if err != nil {
return Message{}, err
}
dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
ID: uuid.New().String(),
SessionID: sessionID,
Role: string(params.Role),
Finished: params.Role != Assistant,
Content: params.Content,
ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true},
ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
})
if err != nil {
return Message{}, err
}
message, err := s.fromDBItem(dbMessage)
if err != nil {
return Message{}, err
}
s.Publish(pubsub.CreatedEvent, message)
return message, nil
}
func (s *service) DeleteSessionMessages(sessionID string) error {
messages, err := s.List(sessionID)
if err != nil {
@@ -81,12 +131,36 @@ func (s *service) DeleteSessionMessages(sessionID string) error {
return nil
}
func (s *service) Update(message Message) error {
toolCallsStr, err := json.Marshal(message.ToolCalls)
if err != nil {
return err
}
toolResultsStr, err := json.Marshal(message.ToolResults)
if err != nil {
return err
}
err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
ID: message.ID,
Content: message.Content,
Thinking: message.Thinking,
Finished: message.Finished,
ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true},
ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
})
if err != nil {
return err
}
s.Publish(pubsub.UpdatedEvent, message)
return nil
}
func (s *service) Get(id string) (Message, error) {
dbMessage, err := s.q.GetMessage(s.ctx, id)
if err != nil {
return Message{}, err
}
return s.fromDBItem(dbMessage), nil
return s.fromDBItem(dbMessage)
}
func (s *service) List(sessionID string) ([]Message, error) {
@@ -96,21 +170,43 @@ func (s *service) List(sessionID string) ([]Message, error) {
}
messages := make([]Message, len(dbMessages))
for i, dbMessage := range dbMessages {
messages[i] = s.fromDBItem(dbMessage)
messages[i], err = s.fromDBItem(dbMessage)
if err != nil {
return nil, err
}
}
return messages, nil
}
func (s *service) fromDBItem(item db.Message) Message {
var messageData schema.Message
json.Unmarshal([]byte(item.MessageData), &messageData)
func (s *service) fromDBItem(item db.Message) (Message, error) {
toolCalls := make([]ToolCall, 0)
if item.ToolCalls.Valid {
err := json.Unmarshal([]byte(item.ToolCalls.String), &toolCalls)
if err != nil {
return Message{}, err
}
}
toolResults := make([]ToolResult, 0)
if item.ToolResults.Valid {
err := json.Unmarshal([]byte(item.ToolResults.String), &toolResults)
if err != nil {
return Message{}, err
}
}
return Message{
ID: item.ID,
SessionID: item.SessionID,
MessageData: messageData,
Role: MessageRole(item.Role),
Content: item.Content,
Thinking: item.Thinking,
Finished: item.Finished,
ToolCalls: toolCalls,
ToolResults: toolResults,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
}
}, nil
}
func NewService(ctx context.Context, q db.Querier) Service {

View File

@@ -65,6 +65,7 @@ func (s *permissionService) Deny(permission PermissionRequest) {
func (s *permissionService) Request(opts CreatePermissionRequest) bool {
permission := PermissionRequest{
ID: uuid.New().String(),
Path: opts.Path,
ToolName: opts.ToolName,
Description: opts.Description,
Action: opts.Action,

View File

@@ -2,6 +2,7 @@ package session
import (
"context"
"database/sql"
"github.com/google/uuid"
"github.com/kujtimiihoxha/termai/internal/db"
@@ -10,6 +11,7 @@ import (
type Session struct {
ID string
ParentSessionID string
Title string
MessageCount int64
PromptTokens int64
@@ -22,6 +24,7 @@ type Session struct {
type Service interface {
pubsub.Suscriber[Session]
Create(title string) (Session, error)
CreateTaskSession(toolCallID, parentSessionID, title string) (Session, error)
Get(id string) (Session, error)
List() ([]Session, error)
Save(session Session) (Session, error)
@@ -47,6 +50,20 @@ func (s *service) Create(title string) (Session, error) {
return session, nil
}
func (s *service) CreateTaskSession(toolCallID, parentSessionID, title string) (Session, error) {
dbSession, err := s.q.CreateSession(s.ctx, db.CreateSessionParams{
ID: toolCallID,
ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
Title: title,
})
if err != nil {
return Session{}, err
}
session := s.fromDBItem(dbSession)
s.Publish(pubsub.CreatedEvent, session)
return session, nil
}
func (s *service) Delete(id string) error {
session, err := s.Get(id)
if err != nil {
@@ -99,6 +116,7 @@ func (s *service) List() ([]Session, error) {
func (s service) fromDBItem(item db.Session) Session {
return Session{
ID: item.ID,
ParentSessionID: item.ParentSessionID.String,
Title: item.Title,
MessageCount: item.MessageCount,
PromptTokens: item.PromptTokens,

View File

@@ -14,7 +14,12 @@ type SizeableModel interface {
}
type DialogMsg struct {
Content SizeableModel
Content SizeableModel
WidthRatio float64
HeightRatio float64
MinWidth int
MinHeight int
}
type DialogCloseMsg struct{}
@@ -36,7 +41,18 @@ type DialogCmp interface {
}
type dialogCmp struct {
content SizeableModel
content SizeableModel
screenWidth int
screenHeight int
widthRatio float64
heightRatio float64
minWidth int
minHeight int
width int
height int
}
func (d *dialogCmp) Init() tea.Cmd {
@@ -45,8 +61,26 @@ func (d *dialogCmp) Init() tea.Cmd {
func (d *dialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
d.screenWidth = msg.Width
d.screenHeight = msg.Height
d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth)
d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight)
if d.content != nil {
d.content.SetSize(d.width, d.height)
}
return d, nil
case DialogMsg:
d.content = msg.Content
d.widthRatio = msg.WidthRatio
d.heightRatio = msg.HeightRatio
d.minWidth = msg.MinWidth
d.minHeight = msg.MinHeight
d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth)
d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight)
if d.content != nil {
d.content.SetSize(d.width, d.height)
}
case DialogCloseMsg:
d.content = nil
return d, nil
@@ -75,8 +109,7 @@ func (d *dialogCmp) BindingKeys() []key.Binding {
}
func (d *dialogCmp) View() string {
w, h := d.content.GetSize()
return lipgloss.NewStyle().Width(w).Height(h).Render(d.content.View())
return lipgloss.NewStyle().Width(d.width).Height(d.height).Render(d.content.View())
}
func NewDialogCmp() DialogCmp {

View File

@@ -3,6 +3,8 @@ package core
import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/tui/styles"
"github.com/kujtimiihoxha/termai/internal/tui/util"
"github.com/kujtimiihoxha/termai/internal/version"
@@ -57,14 +59,19 @@ func (m statusCmp) View() string {
Width(m.availableFooterMsgWidth()).
Render(m.info)
}
status += m.model()
status += versionWidget
return status
}
func (m statusCmp) availableFooterMsgWidth() int {
// -2 to accommodate padding
return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(versionWidget))
return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(versionWidget)-lipgloss.Width(m.model()))
}
func (m statusCmp) model() string {
model := models.SupportedModels[config.Get().Model.Coder]
return styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render(model.Name)
}
func NewStatusCmp() tea.Model {

View File

@@ -1,9 +1,14 @@
package dialog
import (
"fmt"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/glamour"
"github.com/charmbracelet/lipgloss"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/permission"
"github.com/kujtimiihoxha/termai/internal/tui/components/core"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
@@ -28,12 +33,6 @@ type PermissionResponseMsg struct {
Action PermissionAction
}
// Width and height constants for the dialog
var (
permissionWidth = 60
permissionHeight = 10
)
// PermissionDialog interface for permission dialog component
type PermissionDialog interface {
tea.Model
@@ -41,13 +40,28 @@ type PermissionDialog interface {
layout.Bindings
}
type keyMap struct {
ChangeFocus key.Binding
}
var keyMapValue = keyMap{
ChangeFocus: key.NewBinding(
key.WithKeys("tab"),
key.WithHelp("tab", "change focus"),
),
}
// permissionDialogCmp is the implementation of PermissionDialog
type permissionDialogCmp struct {
form *huh.Form
content string
width int
height int
permission permission.PermissionRequest
form *huh.Form
width int
height int
permission permission.PermissionRequest
windowSize tea.WindowSizeMsg
r *glamour.TermRenderer
contentViewPort viewport.Model
isViewportFocus bool
selectOption *huh.Select[string]
}
func (p *permissionDialogCmp) Init() tea.Cmd {
@@ -57,41 +71,101 @@ func (p *permissionDialogCmp) Init() tea.Cmd {
func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
// Process the form
form, cmd := p.form.Update(msg)
if f, ok := form.(*huh.Form); ok {
p.form = f
switch msg := msg.(type) {
case tea.WindowSizeMsg:
p.windowSize = msg
case tea.KeyMsg:
if key.Matches(msg, keyMapValue.ChangeFocus) {
p.isViewportFocus = !p.isViewportFocus
if p.isViewportFocus {
p.selectOption.Blur()
} else {
p.selectOption.Focus()
}
return p, nil
}
}
if p.isViewportFocus {
viewPort, cmd := p.contentViewPort.Update(msg)
p.contentViewPort = viewPort
cmds = append(cmds, cmd)
} else {
form, cmd := p.form.Update(msg)
if f, ok := form.(*huh.Form); ok {
p.form = f
cmds = append(cmds, cmd)
}
if p.form.State == huh.StateCompleted {
// Get the selected action
action := p.form.GetString("action")
// Close the dialog and return the response
return p, tea.Batch(
util.CmdHandler(core.DialogCloseMsg{}),
util.CmdHandler(PermissionResponseMsg{Action: PermissionAction(action), Permission: p.permission}),
)
}
}
if p.form.State == huh.StateCompleted {
// Get the selected action
action := p.form.GetString("action")
// Close the dialog and return the response
return p, tea.Batch(
util.CmdHandler(core.DialogCloseMsg{}),
util.CmdHandler(PermissionResponseMsg{Action: PermissionAction(action), Permission: p.permission}),
)
}
return p, tea.Batch(cmds...)
}
func (p *permissionDialogCmp) View() string {
contentStyle := lipgloss.NewStyle().
Width(p.width).
Padding(1, 0).
Foreground(styles.Text).
Align(lipgloss.Center)
func (p *permissionDialogCmp) render() string {
form := p.form.View()
keyStyle := lipgloss.NewStyle().Bold(true).Foreground(styles.Rosewater)
valueStyle := lipgloss.NewStyle().Foreground(styles.Peach)
headerParts := []string{
lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Tool:"), " ", valueStyle.Render(p.permission.ToolName)),
" ",
lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Path:"), " ", valueStyle.Render(p.permission.Path)),
" ",
}
r, _ := glamour.NewTermRenderer(
glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
glamour.WithWordWrap(p.width-10),
glamour.WithEmoji(),
)
content := ""
switch p.permission.ToolName {
case tools.BashToolName:
pr := p.permission.Params.(tools.BashPermissionsParams)
headerParts = append(headerParts, keyStyle.Render("Command:"))
content, _ = r.Render(fmt.Sprintf("```bash\n%s\n```", pr.Command))
case tools.EditToolName:
pr := p.permission.Params.(tools.EditPermissionsParams)
headerParts = append(headerParts, keyStyle.Render("Update:"))
content, _ = r.Render(fmt.Sprintf("```diff\n%s\n```", pr.Diff))
case tools.WriteToolName:
pr := p.permission.Params.(tools.WritePermissionsParams)
headerParts = append(headerParts, keyStyle.Render("Content:"))
content, _ = r.Render(fmt.Sprintf("```diff\n%s\n```", pr.Content))
default:
content, _ = r.Render(p.permission.Description)
}
headerContent := lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...))
p.contentViewPort.Width = p.width - 2 - 2
p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1
p.contentViewPort.SetContent(content)
contentBorder := lipgloss.RoundedBorder()
if p.isViewportFocus {
contentBorder = lipgloss.DoubleBorder()
}
cotentStyle := lipgloss.NewStyle().MarginTop(1).Padding(0, 1).Border(contentBorder).BorderForeground(styles.Flamingo)
return lipgloss.JoinVertical(
lipgloss.Center,
contentStyle.Render(p.content),
p.form.View(),
lipgloss.Top,
headerContent,
cotentStyle.Render(p.contentViewPort.View()),
form,
)
}
func (p *permissionDialogCmp) View() string {
return p.render()
}
func (p *permissionDialogCmp) GetSize() (int, int) {
return p.width, p.height
}
@@ -99,13 +173,14 @@ func (p *permissionDialogCmp) GetSize() (int, int) {
func (p *permissionDialogCmp) SetSize(width int, height int) {
p.width = width
p.height = height
p.form = p.form.WithWidth(width)
}
func (p *permissionDialogCmp) BindingKeys() []key.Binding {
return p.form.KeyBinds()
}
func newPermissionDialogCmp(permission permission.PermissionRequest, content string) PermissionDialog {
func newPermissionDialogCmp(permission permission.PermissionRequest) PermissionDialog {
// Create a note field for displaying the content
// Create select field for the permission options
@@ -116,14 +191,13 @@ func newPermissionDialogCmp(permission permission.PermissionRequest, content str
huh.NewOption("Allow for this session", string(PermissionAllowForSession)),
huh.NewOption("Deny", string(PermissionDeny)),
).
Title("Permission Request")
Title("Select an action")
// Apply theme
theme := styles.HuhTheme()
// Setup form width and height
form := huh.NewForm(huh.NewGroup(selectOption)).
WithWidth(permissionWidth - 2).
WithShowHelp(false).
WithTheme(theme).
WithShowErrors(false)
@@ -132,25 +206,22 @@ func newPermissionDialogCmp(permission permission.PermissionRequest, content str
selectOption.Focus()
return &permissionDialogCmp{
permission: permission,
form: form,
content: content,
width: permissionWidth,
height: permissionHeight,
permission: permission,
form: form,
selectOption: selectOption,
}
}
// NewPermissionDialogCmd creates a new permission dialog command
func NewPermissionDialogCmd(permission permission.PermissionRequest, content string) tea.Cmd {
permDialog := newPermissionDialogCmp(permission, content)
func NewPermissionDialogCmd(permission permission.PermissionRequest) tea.Cmd {
permDialog := newPermissionDialogCmp(permission)
// Create the dialog layout
dialogPane := layout.NewSinglePane(
permDialog.(*permissionDialogCmp),
layout.WithSignlePaneSize(permissionWidth+2, permissionHeight+2),
layout.WithSinglePaneBordered(true),
layout.WithSinglePaneFocusable(true),
layout.WithSinglePaneActiveColor(styles.Blue),
layout.WithSinglePaneActiveColor(styles.Warning),
layout.WithSignlePaneBorderText(map[layout.BorderPosition]string{
layout.TopMiddleBorder: " Permission Required ",
}),
@@ -158,10 +229,24 @@ func NewPermissionDialogCmd(permission permission.PermissionRequest, content str
// Focus the dialog
dialogPane.Focus()
widthRatio := 0.7
heightRatio := 0.6
minWidth := 100
minHeight := 30
switch permission.ToolName {
case tools.BashToolName:
widthRatio = 0.5
heightRatio = 0.3
minWidth = 80
minHeight = 20
}
// Return the dialog command
return util.CmdHandler(core.DialogMsg{
Content: dialogPane,
Content: dialogPane,
WidthRatio: widthRatio,
HeightRatio: heightRatio,
MinWidth: minWidth,
MinHeight: minHeight,
})
}

View File

@@ -3,7 +3,6 @@ package dialog
import (
"github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/kujtimiihoxha/termai/internal/tui/components/core"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
"github.com/kujtimiihoxha/termai/internal/tui/styles"
@@ -14,11 +13,6 @@ import (
const question = "Are you sure you want to quit?"
var (
width = lipgloss.Width(question) + 6
height = 3
)
type QuitDialog interface {
tea.Model
layout.Sizeable
@@ -37,8 +31,6 @@ func (q *quitDialogCmp) Init() tea.Cmd {
func (q *quitDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
// Process the form
form, cmd := q.form.Update(msg)
if f, ok := form.(*huh.Form); ok {
q.form = f
@@ -67,6 +59,7 @@ func (q *quitDialogCmp) GetSize() (int, int) {
func (q *quitDialogCmp) SetSize(width int, height int) {
q.width = width
q.height = height
q.form = q.form.WithWidth(width).WithHeight(height)
}
func (q *quitDialogCmp) BindingKeys() []key.Binding {
@@ -84,28 +77,30 @@ func newQuitDialogCmp() QuitDialog {
theme.Focused.FocusedButton = theme.Focused.FocusedButton.Background(styles.Warning)
theme.Blurred.FocusedButton = theme.Blurred.FocusedButton.Background(styles.Warning)
form := huh.NewForm(huh.NewGroup(confirm)).
WithWidth(width).
WithHeight(height).
WithShowHelp(false).
WithWidth(0).
WithHeight(0).
WithTheme(theme).
WithShowErrors(false)
confirm.Focus()
return &quitDialogCmp{
form: form,
width: width,
form: form,
}
}
func NewQuitDialogCmd() tea.Cmd {
content := layout.NewSinglePane(
newQuitDialogCmp().(*quitDialogCmp),
layout.WithSignlePaneSize(width+2, height+2),
layout.WithSinglePaneBordered(true),
layout.WithSinglePaneFocusable(true),
layout.WithSinglePaneActiveColor(styles.Warning),
)
content.Focus()
return util.CmdHandler(core.DialogMsg{
Content: content,
Content: content,
WidthRatio: 0.2,
HeightRatio: 0.1,
MinWidth: 40,
MinHeight: 5,
})
}

View File

@@ -1,108 +0,0 @@
package messages
import (
"fmt"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/cloudwego/eino/schema"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
"github.com/kujtimiihoxha/termai/internal/tui/styles"
)
const (
maxHeight = 10
)
type MessagesCmp interface {
tea.Model
layout.Focusable
layout.Bordered
layout.Sizeable
}
type messageCmp struct {
message message.Message
width int
height int
focused bool
expanded bool
}
func (m *messageCmp) Init() tea.Cmd {
return nil
}
func (m *messageCmp) Update(tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
}
func (m *messageCmp) View() string {
wrapper := layout.NewSinglePane(
m,
layout.WithSinglePaneBordered(true),
layout.WithSinglePaneFocusable(true),
layout.WithSinglePanePadding(1),
layout.WithSinglePaneActiveColor(m.borderColor()),
)
if m.focused {
wrapper.Focus()
}
wrapper.SetSize(m.width, m.height)
return wrapper.View()
}
func (m *messageCmp) Blur() tea.Cmd {
m.focused = false
return nil
}
func (m *messageCmp) borderColor() lipgloss.TerminalColor {
switch m.message.MessageData.Role {
case schema.Assistant:
return styles.Mauve
case schema.User:
return styles.Flamingo
}
return styles.Blue
}
func (m *messageCmp) BorderText() map[layout.BorderPosition]string {
role := ""
icon := ""
switch m.message.MessageData.Role {
case schema.Assistant:
role = "Assistant"
icon = styles.BotIcon
case schema.User:
role = "User"
icon = styles.UserIcon
}
return map[layout.BorderPosition]string{
layout.TopLeftBorder: fmt.Sprintf("%s %s ", role, icon),
}
}
func (m *messageCmp) Focus() tea.Cmd {
m.focused = true
return nil
}
func (m *messageCmp) IsFocused() bool {
return m.focused
}
func (m *messageCmp) GetSize() (int, int) {
return m.width, 0
}
func (m *messageCmp) SetSize(width int, height int) {
m.width = width
}
func NewMessageCmp(msg message.Message) MessagesCmp {
return &messageCmp{
message: msg,
}
}

View File

@@ -6,10 +6,11 @@ import (
"github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/cloudwego/eino/schema"
"github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/llm/agent"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
"github.com/kujtimiihoxha/termai/internal/tui/styles"
"github.com/kujtimiihoxha/termai/internal/tui/util"
"github.com/kujtimiihoxha/vimtea"
)
@@ -112,7 +113,7 @@ func (m *editorCmp) BorderText() map[layout.BorderPosition]string {
title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
}
return map[layout.BorderPosition]string{
layout.TopLeftBorder: title,
layout.BottomLeftBorder: title,
}
}
@@ -137,9 +138,15 @@ func (m *editorCmp) SetSize(width int, height int) {
func (m *editorCmp) Send() tea.Cmd {
return func() tea.Msg {
messages, _ := m.app.Messages.List(m.sessionID)
if hasUnfinishedMessages(messages) {
return util.InfoMsg("Assistant is still working on the previous message")
}
a, _ := agent.NewCoderAgent(m.app)
content := strings.Join(m.editor.GetBuffer().Lines(), "\n")
m.app.Messages.Create(m.sessionID, *schema.UserMessage(content))
m.app.LLM.SendRequest(m.sessionID, content)
go a.Generate(m.sessionID, content)
return m.editor.Reset()
}
}
@@ -153,10 +160,11 @@ func (m *editorCmp) BindingKeys() []key.Binding {
}
func NewEditorCmp(app *app.App) EditorCmp {
editor := vimtea.NewEditor(
vimtea.WithFileName("message.md"),
)
return &editorCmp{
app: app,
editor: vimtea.NewEditor(
vimtea.WithFileName("message.md"),
),
app: app,
editor: editor,
}
}

View File

@@ -1,8 +1,9 @@
package repl
import (
"encoding/json"
"fmt"
"slices"
"sort"
"strings"
"github.com/charmbracelet/bubbles/key"
@@ -10,8 +11,8 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/glamour"
"github.com/charmbracelet/lipgloss"
"github.com/cloudwego/eino/schema"
"github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/llm/agent"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/kujtimiihoxha/termai/internal/session"
@@ -28,30 +29,50 @@ type MessagesCmp interface {
}
type messagesCmp struct {
app *app.App
messages []message.Message
session session.Session
viewport viewport.Model
mdRenderer *glamour.TermRenderer
width int
height int
focused bool
cachedView string
app *app.App
messages []message.Message
selectedMsgIdx int // Index of the selected message
session session.Session
viewport viewport.Model
mdRenderer *glamour.TermRenderer
width int
height int
focused bool
cachedView string
}
func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case pubsub.Event[message.Message]:
if msg.Type == pubsub.CreatedEvent {
m.messages = append(m.messages, msg.Payload)
m.renderView()
m.viewport.GotoBottom()
if msg.Payload.SessionID == m.session.ID {
m.messages = append(m.messages, msg.Payload)
m.renderView()
m.viewport.GotoBottom()
}
for _, v := range m.messages {
for _, c := range v.ToolCalls {
if c.ID == msg.Payload.SessionID {
m.renderView()
m.viewport.GotoBottom()
}
}
}
} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
for i, v := range m.messages {
if v.ID == msg.Payload.ID {
m.messages[i] = msg.Payload
m.renderView()
if i == len(m.messages)-1 {
m.viewport.GotoBottom()
}
break
}
}
}
case pubsub.Event[session.Session]:
if msg.Type == pubsub.UpdatedEvent {
if m.session.ID == msg.Payload.ID {
m.session = msg.Payload
}
if msg.Type == pubsub.UpdatedEvent && m.session.ID == msg.Payload.ID {
m.session = msg.Payload
}
case SelectedSessionMsg:
m.session, _ = m.app.Sessions.Get(msg.SessionID)
@@ -67,26 +88,24 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
}
func borderColor(role schema.RoleType) lipgloss.TerminalColor {
func borderColor(role message.MessageRole) lipgloss.TerminalColor {
switch role {
case schema.Assistant:
case message.Assistant:
return styles.Mauve
case schema.User:
case message.User:
return styles.Rosewater
case schema.Tool:
return styles.Peach
}
return styles.Blue
}
func borderText(msgRole schema.RoleType, currentMessage int) map[layout.BorderPosition]string {
func borderText(msgRole message.MessageRole, currentMessage int) map[layout.BorderPosition]string {
role := ""
icon := ""
switch msgRole {
case schema.Assistant:
case message.Assistant:
role = "Assistant"
icon = styles.BotIcon
case schema.User:
case message.User:
role = "User"
icon = styles.UserIcon
}
@@ -106,81 +125,259 @@ func borderText(msgRole schema.RoleType, currentMessage int) map[layout.BorderPo
}
}
func hasUnfinishedMessages(messages []message.Message) bool {
if len(messages) == 0 {
return false
}
for _, msg := range messages {
if !msg.Finished {
return true
}
}
lastMessage := messages[len(messages)-1]
return lastMessage.Role != message.Assistant
}
func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.ToolCall, futureMessages []message.Message) string {
allParts := []string{content}
leftPaddingValue := 4
connectorStyle := lipgloss.NewStyle().
Foreground(styles.Peach).
Bold(true)
toolCallStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(styles.Peach).
Width(m.width-leftPaddingValue-5).
Padding(0, 1)
toolResultStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(styles.Green).
Width(m.width-leftPaddingValue-5).
Padding(0, 1)
leftPadding := lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue)
runningStyle := lipgloss.NewStyle().
Foreground(styles.Peach).
Bold(true)
renderTool := func(toolCall message.ToolCall) string {
toolHeader := lipgloss.NewStyle().
Bold(true).
Foreground(styles.Blue).
Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
var paramLines []string
var args map[string]interface{}
var paramOrder []string
json.Unmarshal([]byte(toolCall.Input), &args)
for key := range args {
paramOrder = append(paramOrder, key)
}
sort.Strings(paramOrder)
for _, name := range paramOrder {
value := args[name]
paramName := lipgloss.NewStyle().
Foreground(styles.Peach).
Bold(true).
Render(name)
truncate := m.width - leftPaddingValue*2 - 10
if len(fmt.Sprintf("%v", value)) > truncate {
value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
}
paramValue := fmt.Sprintf("%v", value)
paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue))
}
paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
return toolCallStyle.Render(toolContent)
}
findToolResult := func(toolCallID string, messages []message.Message) *message.ToolResult {
for _, msg := range messages {
if msg.Role == message.Tool {
for _, result := range msg.ToolResults {
if result.ToolCallID == toolCallID {
return &result
}
}
}
}
return nil
}
renderToolResult := func(result message.ToolResult) string {
resultHeader := lipgloss.NewStyle().
Bold(true).
Foreground(styles.Green).
Render(fmt.Sprintf("%s Result", styles.CheckIcon))
if result.IsError {
resultHeader = lipgloss.NewStyle().
Bold(true).
Foreground(styles.Red).
Render(fmt.Sprintf("%s Error", styles.ErrorIcon))
}
truncate := 200
content := result.Content
if len(content) > truncate {
content = content[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
}
resultContent := lipgloss.JoinVertical(lipgloss.Left, resultHeader, content)
return toolResultStyle.Render(resultContent)
}
connector := connectorStyle.Render("└─> Tool Calls:")
allParts = append(allParts, connector)
for _, toolCall := range tools {
toolOutput := renderTool(toolCall)
allParts = append(allParts, leftPadding.Render(toolOutput))
result := findToolResult(toolCall.ID, futureMessages)
if result != nil {
resultOutput := renderToolResult(*result)
allParts = append(allParts, leftPadding.Render(resultOutput))
} else if toolCall.Name == agent.AgentToolName {
runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
allParts = append(allParts, leftPadding.Render(runningIndicator))
taskSessionMessages, _ := m.app.Messages.List(toolCall.ID)
for _, msg := range taskSessionMessages {
if msg.Role == message.Assistant {
for _, toolCall := range msg.ToolCalls {
toolHeader := lipgloss.NewStyle().
Bold(true).
Foreground(styles.Blue).
Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
var paramLines []string
var args map[string]interface{}
var paramOrder []string
json.Unmarshal([]byte(toolCall.Input), &args)
for key := range args {
paramOrder = append(paramOrder, key)
}
sort.Strings(paramOrder)
for _, name := range paramOrder {
value := args[name]
paramName := lipgloss.NewStyle().
Foreground(styles.Peach).
Bold(true).
Render(name)
truncate := 50
if len(fmt.Sprintf("%v", value)) > truncate {
value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
}
paramValue := fmt.Sprintf("%v", value)
paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue))
}
paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
toolOutput := toolCallStyle.BorderForeground(styles.Teal).MaxWidth(m.width - leftPaddingValue*2 - 2).Render(toolContent)
allParts = append(allParts, lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue*2).Render(toolOutput))
}
}
}
} else {
runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
allParts = append(allParts, " "+runningIndicator)
}
}
for _, msg := range futureMessages {
if msg.Content != "" {
break
}
for _, toolCall := range msg.ToolCalls {
toolOutput := renderTool(toolCall)
allParts = append(allParts, " "+strings.ReplaceAll(toolOutput, "\n", "\n "))
result := findToolResult(toolCall.ID, futureMessages)
if result != nil {
resultOutput := renderToolResult(*result)
allParts = append(allParts, " "+strings.ReplaceAll(resultOutput, "\n", "\n "))
} else {
runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
allParts = append(allParts, " "+runningIndicator)
}
}
}
return lipgloss.JoinVertical(lipgloss.Left, allParts...)
}
func (m *messagesCmp) renderView() {
stringMessages := make([]string, 0)
r, _ := glamour.NewTermRenderer(
glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
glamour.WithWordWrap(m.width-10),
glamour.WithWordWrap(m.width-20),
glamour.WithEmoji(),
)
textStyle := lipgloss.NewStyle().Width(m.width - 4)
currentMessage := 1
for _, msg := range m.messages {
if msg.MessageData.Role == schema.Tool {
continue
}
content := msg.MessageData.Content
if content != "" {
content, _ = r.Render(msg.MessageData.Content)
stringMessages = append(stringMessages, layout.Borderize(
textStyle.Render(content),
layout.BorderOptions{
InactiveBorder: lipgloss.DoubleBorder(),
ActiveBorder: lipgloss.DoubleBorder(),
ActiveColor: borderColor(msg.MessageData.Role),
InactiveColor: borderColor(msg.MessageData.Role),
EmbeddedText: borderText(msg.MessageData.Role, currentMessage),
},
))
currentMessage++
}
for _, toolCall := range msg.MessageData.ToolCalls {
resultInx := slices.IndexFunc(m.messages, func(m message.Message) bool {
return m.MessageData.ToolCallID == toolCall.ID
})
content := fmt.Sprintf("**Arguments**\n```json\n%s\n```\n", toolCall.Function.Arguments)
if resultInx == -1 {
content += "Running..."
} else {
result := m.messages[resultInx].MessageData.Content
if result != "" {
lines := strings.Split(result, "\n")
if len(lines) > 15 {
result = strings.Join(lines[:15], "\n")
}
content += fmt.Sprintf("**Result**\n```\n%s\n```\n", result)
if len(lines) > 15 {
content += fmt.Sprintf("\n\n *...%d lines are truncated* ", len(lines)-15)
}
}
displayedMsgCount := 0 // Track the actual displayed messages count
prevMessageWasUser := false
for inx, msg := range m.messages {
content := msg.Content
if content != "" || prevMessageWasUser {
if msg.Thinking != "" && content == "" {
content = msg.Thinking
} else if content == "" {
content = "..."
}
content, _ = r.Render(content)
stringMessages = append(stringMessages, layout.Borderize(
isSelected := inx == m.selectedMsgIdx
border := lipgloss.DoubleBorder()
activeColor := borderColor(msg.Role)
if isSelected {
activeColor = styles.Primary // Use primary color for selected message
}
content = layout.Borderize(
textStyle.Render(content),
layout.BorderOptions{
InactiveBorder: lipgloss.DoubleBorder(),
ActiveBorder: lipgloss.DoubleBorder(),
ActiveColor: borderColor(schema.Tool),
InactiveColor: borderColor(schema.Tool),
EmbeddedText: map[layout.BorderPosition]string{
layout.TopLeftBorder: lipgloss.NewStyle().
Padding(0, 1).
Bold(true).
Foreground(styles.Crust).
Background(borderColor(schema.Tool)).
Render(
fmt.Sprintf("Tool [%s] %s ", toolCall.Function.Name, styles.ToolIcon),
),
layout.TopRightBorder: lipgloss.NewStyle().
Padding(0, 1).
Bold(true).
Foreground(styles.Crust).
Background(borderColor(schema.Tool)).
Render(fmt.Sprintf("#%d ", currentMessage)),
},
InactiveBorder: border,
ActiveBorder: border,
ActiveColor: activeColor,
InactiveColor: borderColor(msg.Role),
EmbeddedText: borderText(msg.Role, currentMessage),
},
))
)
if len(msg.ToolCalls) > 0 {
content = m.renderMessageWithToolCall(content, msg.ToolCalls, m.messages[inx+1:])
}
stringMessages = append(stringMessages, content)
currentMessage++
displayedMsgCount++
}
if msg.Role == message.User && msg.Content != "" {
prevMessageWasUser = true
} else {
prevMessageWasUser = false
}
}
m.viewport.SetContent(lipgloss.JoinVertical(lipgloss.Top, stringMessages...))
@@ -191,7 +388,9 @@ func (m *messagesCmp) View() string {
}
func (m *messagesCmp) BindingKeys() []key.Binding {
return layout.KeyMapToSlice(m.viewport.KeyMap)
keys := layout.KeyMapToSlice(m.viewport.KeyMap)
return keys
}
func (m *messagesCmp) Blur() tea.Cmd {
@@ -208,10 +407,17 @@ func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
if m.focused {
title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
}
return map[layout.BorderPosition]string{
borderTest := map[layout.BorderPosition]string{
layout.TopLeftBorder: title,
layout.BottomRightBorder: formatTokensAndCost(m.session.CompletionTokens+m.session.PromptTokens, m.session.Cost),
}
if hasUnfinishedMessages(m.messages) {
borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Peach).Render("Thinking...")
} else {
borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Text).Render("Sleeping " + styles.SleepIcon + " ")
}
return borderTest
}
func (m *messagesCmp) Focus() tea.Cmd {
@@ -232,6 +438,7 @@ func (m *messagesCmp) SetSize(width int, height int) {
m.height = height
m.viewport.Width = width - 2 // padding
m.viewport.Height = height - 2 // padding
m.renderView()
}
func (m *messagesCmp) Init() tea.Cmd {

View File

@@ -89,7 +89,23 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
return i, i.list.SetItems(items)
case pubsub.Event[session.Session]:
if msg.Type == pubsub.UpdatedEvent {
if msg.Type == pubsub.CreatedEvent && msg.Payload.ParentSessionID == "" {
// Check if the session is already in the list
items := i.list.Items()
for _, item := range items {
s := item.(listItem)
if s.id == msg.Payload.ID {
return i, nil
}
}
// insert the new session at the top of the list
items = append([]list.Item{listItem{
id: msg.Payload.ID,
title: msg.Payload.Title,
desc: formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost),
}}, items...)
return i, i.list.SetItems(items)
} else if msg.Type == pubsub.UpdatedEvent {
// update the session in the list
items := i.list.Items()
for idx, item := range items {
@@ -229,3 +245,4 @@ func NewSessionsCmp(app *app.App) SessionsCmp {
focused: false,
}
}

View File

@@ -78,12 +78,12 @@ func (i *initPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Save configuration to file
configPath := filepath.Join(os.Getenv("HOME"), ".termai.yaml")
maxTokens, _ := strconv.Atoi(i.maxTokens)
config := map[string]interface{}{
config := map[string]any{
"models": map[string]string{
"big": i.bigModel,
"small": i.smallModel,
},
"providers": map[string]interface{}{
"providers": map[string]any{
"openai": map[string]string{
"key": i.openAIKey,
},
@@ -192,8 +192,8 @@ func NewInitPage() tea.Model {
// Init page with form
initModel := &initPage{
modelOpts: modelOpts,
bigModel: string(models.DefaultBigModel),
smallModel: string(models.DefaultLittleModel),
bigModel: string(models.Claude37Sonnet),
smallModel: string(models.Claude37Sonnet),
maxTokens: "4000",
dataDir: ".termai",
agent: "coder",

View File

@@ -8,5 +8,9 @@ const (
ToolIcon string = ""
UserIcon string = ""
CheckIcon string = "✓"
ErrorIcon string = "✗"
SpinnerIcon string = "..."
SleepIcon string = "󰒲"
)

View File

@@ -5,7 +5,7 @@ import (
"github.com/charmbracelet/lipgloss"
)
const defaultMargin = 2
const defaultMargin = 1
// Helper functions for style pointers
func boolPtr(b bool) *bool { return &b }
@@ -25,7 +25,7 @@ var catppuccinDark = ansi.StyleConfig{
Document: ansi.StyleBlock{
StylePrimitive: ansi.StylePrimitive{
BlockPrefix: "\n",
BlockSuffix: "\n",
BlockSuffix: "",
Color: stringPtr(dark.Text().Hex),
},
Margin: uintPtr(defaultMargin),
@@ -153,7 +153,7 @@ var catppuccinDark = ansi.StyleConfig{
CodeBlock: ansi.StyleCodeBlock{
StyleBlock: ansi.StyleBlock{
StylePrimitive: ansi.StylePrimitive{
Prefix: " ",
Prefix: " ",
Color: stringPtr(dark.Text().Hex),
},

View File

@@ -20,8 +20,7 @@ var (
DoubleBorder = Regular.Border(lipgloss.DoubleBorder())
// Colors
White = lipgloss.Color("#ffffff")
White = lipgloss.Color("#ffffff")
Surface0 = lipgloss.AdaptiveColor{
Dark: dark.Surface0().Hex,
Light: light.Surface0().Hex,

View File

@@ -1,20 +1,15 @@
package tui
import (
"fmt"
"log"
"os"
"path/filepath"
"github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/llm"
"github.com/kujtimiihoxha/termai/internal/permission"
"github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/kujtimiihoxha/termai/internal/tui/components/core"
"github.com/kujtimiihoxha/termai/internal/tui/components/dialog"
"github.com/kujtimiihoxha/termai/internal/tui/components/repl"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
"github.com/kujtimiihoxha/termai/internal/tui/page"
"github.com/kujtimiihoxha/termai/internal/tui/util"
@@ -52,9 +47,9 @@ var keys = keyMap{
),
}
var editorKeyMap = key.NewBinding(
key.WithKeys("i"),
key.WithHelp("i", "insert mode"),
var replKeyMap = key.NewBinding(
key.WithKeys("N"),
key.WithHelp("N", "new session"),
)
type appModel struct {
@@ -66,6 +61,7 @@ type appModel struct {
status tea.Model
help core.HelpCmp
dialog core.DialogCmp
app *app.App
dialogVisible bool
editorMode vimtea.EditorMode
showHelp bool
@@ -79,19 +75,8 @@ func (a appModel) Init() tea.Cmd {
func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case pubsub.Event[llm.AgentEvent]:
log.Println("AgentEvent")
log.Println(msg)
case pubsub.Event[permission.PermissionRequest]:
return a, dialog.NewPermissionDialogCmd(
msg.Payload,
fmt.Sprintf(
"Tool: %s\nAction: %s\nParams: %v",
msg.Payload.ToolName,
msg.Payload.Action,
msg.Payload.Params,
),
)
return a, dialog.NewPermissionDialogCmd(msg.Payload)
case dialog.PermissionResponseMsg:
switch msg.Action {
case dialog.PermissionAllow:
@@ -104,6 +89,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case vimtea.EditorModeMsg:
a.editorMode = msg.Mode
case tea.WindowSizeMsg:
var cmds []tea.Cmd
msg.Height -= 1 // Make space for the status bar
a.width, a.height = msg.Width, msg.Height
@@ -113,8 +99,14 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.help = uh.(core.HelpCmp)
p, cmd := a.pages[a.currentPage].Update(msg)
cmds = append(cmds, cmd)
a.pages[a.currentPage] = p
return a, cmd
d, cmd := a.dialog.Update(msg)
cmds = append(cmds, cmd)
a.dialog = d.(core.DialogCmp)
return a, tea.Batch(cmds...)
case core.DialogMsg:
d, cmd := a.dialog.Update(msg)
a.dialog = d.(core.DialogCmp)
@@ -145,6 +137,22 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.ToggleHelp()
return a, nil
}
case key.Matches(msg, replKeyMap):
if a.currentPage == page.ReplPage {
sessions, err := a.app.Sessions.List()
if err != nil {
return a, util.CmdHandler(util.ErrorMsg(err))
}
lastSession := sessions[0]
if lastSession.MessageCount == 0 {
return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: lastSession.ID})
}
s, err := a.app.Sessions.Create("New Session")
if err != nil {
return a, util.CmdHandler(util.ErrorMsg(err))
}
return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: s.ID})
}
case key.Matches(msg, keys.Logs):
return a, a.moveToPage(page.LogsPage)
case key.Matches(msg, keys.Help):
@@ -205,6 +213,9 @@ func (a appModel) View() string {
if a.dialogVisible {
bindings = append(bindings, a.dialog.BindingKeys()...)
}
if a.currentPage == page.ReplPage {
bindings = append(bindings, replKeyMap)
}
a.help.SetBindings(bindings)
components = append(components, a.help.View())
}
@@ -231,14 +242,13 @@ func (a appModel) View() string {
}
func New(app *app.App) tea.Model {
// Check if config file exists, if not, start with init page
homedir, _ := os.UserHomeDir()
configPath := filepath.Join(homedir, ".termai.yaml")
// homedir, _ := os.UserHomeDir()
// configPath := filepath.Join(homedir, ".termai.yaml")
//
startPage := page.ReplPage
if _, err := os.Stat(configPath); os.IsNotExist(err) {
startPage = page.InitPage
}
// if _, err := os.Stat(configPath); os.IsNotExist(err) {
// startPage = page.InitPage
// }
return &appModel{
currentPage: startPage,
@@ -246,6 +256,7 @@ func New(app *app.App) tea.Model {
status: core.NewStatusCmp(),
help: core.NewHelpCmp(),
dialog: core.NewDialogCmp(),
app: app,
pages: map[page.PageID]tea.Model{
page.LogsPage: page.NewLogsPage(),
page.InitPage: page.NewInitPage(),