add initial lsp support

This commit is contained in:
Kujtim Hoxha
2025-04-03 15:20:15 +02:00
parent afd9ad0560
commit cfdd687216
47 changed files with 13996 additions and 456 deletions

View File

@@ -91,7 +91,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
}
return tools.NewTextResponse(response.Content), nil
return tools.NewTextResponse(response.Content().String()), nil
}
func NewAgentTool(parentSessionID string, app *app.App) tools.BaseTool {

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"strings"
"sync"
"github.com/kujtimiihoxha/termai/internal/app"
@@ -33,8 +34,12 @@ func (c *agent) handleTitleGeneration(sessionID, content string) {
c.Context,
[]message.Message{
{
Role: message.User,
Content: content,
Role: message.User,
Parts: []message.ContentPart{
message.TextContent{
Text: content,
},
},
},
},
nil,
@@ -49,6 +54,8 @@ func (c *agent) handleTitleGeneration(sessionID, content string) {
}
if response.Content != "" {
session.Title = response.Content
session.Title = strings.TrimSpace(session.Title)
session.Title = strings.ReplaceAll(session.Title, "\n", " ")
c.Sessions.Save(session)
}
}
@@ -79,17 +86,18 @@ func (c *agent) processEvent(
) error {
switch event.Type {
case provider.EventThinkingDelta:
assistantMsg.Thinking += event.Thinking
assistantMsg.AppendReasoningContent(event.Content)
return c.Messages.Update(*assistantMsg)
case provider.EventContentDelta:
assistantMsg.Content += event.Content
assistantMsg.AppendContent(event.Content)
return c.Messages.Update(*assistantMsg)
case provider.EventError:
log.Println("error", event.Error)
return event.Error
case provider.EventComplete:
assistantMsg.ToolCalls = event.Response.ToolCalls
assistantMsg.SetToolCalls(event.Response.ToolCalls)
assistantMsg.AddFinish(event.Response.FinishReason)
err := c.Messages.Update(*assistantMsg)
if err != nil {
return err
@@ -157,18 +165,21 @@ func (c *agent) handleToolExecution(
ctx context.Context,
assistantMsg message.Message,
) (*message.Message, error) {
if len(assistantMsg.ToolCalls) == 0 {
if len(assistantMsg.ToolCalls()) == 0 {
return nil, nil
}
toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls, c.tools)
toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
if err != nil {
return nil, err
}
parts := make([]message.ContentPart, 0)
for _, toolResult := range toolResults {
parts = append(parts, toolResult)
}
msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
Role: message.Tool,
ToolResults: toolResults,
Role: message.Tool,
Parts: parts,
})
return &msg, err
@@ -185,8 +196,12 @@ func (c *agent) generate(sessionID string, content string) error {
}
userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
Role: message.User,
Content: content,
Role: message.User,
Parts: []message.ContentPart{
message.TextContent{
Text: content,
},
},
})
if err != nil {
return err
@@ -201,8 +216,8 @@ func (c *agent) generate(sessionID string, content string) error {
}
assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
Role: message.Assistant,
Content: "",
Role: message.Assistant,
Parts: []message.ContentPart{},
})
if err != nil {
return err
@@ -210,20 +225,20 @@ func (c *agent) generate(sessionID string, content string) error {
for event := range eventChan {
err = c.processEvent(sessionID, &assistantMsg, event)
if err != nil {
assistantMsg.Finished = true
assistantMsg.AddFinish("error:" + err.Error())
c.Messages.Update(assistantMsg)
return err
}
}
msg, err := c.handleToolExecution(c.Context, assistantMsg)
assistantMsg.Finished = true
c.Messages.Update(assistantMsg)
if err != nil {
return err
}
if len(assistantMsg.ToolCalls) == 0 {
if len(assistantMsg.ToolCalls()) == 0 {
break
}

View File

@@ -44,20 +44,23 @@ func NewCoderAgent(app *app.App) (Agent, error) {
return nil, err
}
mcpTools := GetMcpTools(app.Context)
otherTools := GetMcpTools(app.Context)
if len(app.LSPClients) > 0 {
otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients))
}
return &coderAgent{
agent: &agent{
App: app,
tools: append(
[]tools.BaseTool{
tools.NewBashTool(),
tools.NewEditTool(),
tools.NewEditTool(app.LSPClients),
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewViewTool(),
tools.NewWriteTool(),
}, mcpTools...,
tools.NewViewTool(app.LSPClients),
tools.NewWriteTool(app.LSPClients),
}, otherTools...,
),
model: model,
agent: agentProvider,

View File

@@ -34,7 +34,7 @@ func NewTaskAgent(app *app.App) (Agent, error) {
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewViewTool(),
tools.NewViewTool(app.LSPClients),
},
model: model,
agent: agentProvider,

View File

@@ -67,7 +67,7 @@ Never commit changes unless the user explicitly asks you to.`
envInfo := getEnvironmentInfo()
return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
}
func CoderAnthropicSystemPrompt() string {
@@ -168,7 +168,7 @@ You MUST answer concisely with fewer than 4 lines of text (not including tool us
envInfo := getEnvironmentInfo()
return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
}
func getEnvironmentInfo() string {
@@ -198,6 +198,25 @@ func isGitRepo(dir string) bool {
return err == nil
}
func lspInformation() string {
cfg := config.Get()
hasLSP := false
for _, v := range cfg.LSP {
if !v.Disabled {
hasLSP = true
break
}
}
if !hasLSP {
return ""
}
return `# LSP Information
Tools that support it will also include useful diagnostics such as linting and typechecking.
These diagnostics will be automatically enabled when you run the tool, and will be displayed in the output at the bottom within the <file_diagnostics></file_diagnostics> and <project_diagnostics></project_diagnostics> tags.
Take necessary actions to fix the issues.
`
}
func boolToYesNo(b bool) string {
if b {
return "Yes"

View File

@@ -111,7 +111,7 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
var thinkingParam anthropic.ThinkingConfigParamUnion
lastMessage := messages[len(messages)-1]
temperature := anthropic.Float(0)
if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content), "think") {
if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
thinkingParam = anthropic.ThinkingConfigParamUnion{
OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
BudgetTokens: int64(float64(a.maxTokens) * 0.8),
@@ -187,9 +187,10 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: tokenUsage,
Content: content,
ToolCalls: toolCalls,
Usage: tokenUsage,
FinishReason: string(accumulatedMessage.StopReason),
},
}
}
@@ -263,7 +264,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
for i, msg := range messages {
switch msg.Role {
case message.User:
content := anthropic.NewTextBlock(msg.Content)
content := anthropic.NewTextBlock(msg.Content().String())
if cachedBlocks < 2 {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
@@ -274,8 +275,8 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
case message.Assistant:
blocks := []anthropic.ContentBlockParamUnion{}
if msg.Content != "" {
content := anthropic.NewTextBlock(msg.Content)
if msg.Content().String() != "" {
content := anthropic.NewTextBlock(msg.Content().String())
if cachedBlocks < 2 {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
@@ -285,7 +286,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
blocks = append(blocks, content)
}
for _, toolCall := range msg.ToolCalls {
for _, toolCall := range msg.ToolCalls() {
var inputMap map[string]any
err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
if err != nil {
@@ -297,8 +298,8 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
anthropicMessages[i] = anthropic.NewAssistantMessage(blocks...)
case message.Tool:
results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults))
for i, toolResult := range msg.ToolResults {
results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
for i, toolResult := range msg.ToolResults() {
results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
}
anthropicMessages[i] = anthropic.NewUserMessage(results...)

View File

@@ -78,7 +78,6 @@ func (p *geminiProvider) Close() {
}
}
// convertToGeminiHistory converts the message history to Gemini's format
func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
var history []*genai.Content
@@ -86,7 +85,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
switch msg.Role {
case message.User:
history = append(history, &genai.Content{
Parts: []genai.Part{genai.Text(msg.Content)},
Parts: []genai.Part{genai.Text(msg.Content().String())},
Role: "user",
})
case message.Assistant:
@@ -95,14 +94,12 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
Parts: []genai.Part{},
}
// Handle regular content
if msg.Content != "" {
content.Parts = append(content.Parts, genai.Text(msg.Content))
if msg.Content().String() != "" {
content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
}
// Handle tool calls if any
if len(msg.ToolCalls) > 0 {
for _, call := range msg.ToolCalls {
if len(msg.ToolCalls()) > 0 {
for _, call := range msg.ToolCalls() {
args, _ := parseJsonToMap(call.Input)
content.Parts = append(content.Parts, genai.FunctionCall{
Name: call.Name,
@@ -113,8 +110,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
history = append(history, content)
case message.Tool:
for _, result := range msg.ToolResults {
// Parse response content to map if possible
for _, result := range msg.ToolResults() {
response := map[string]interface{}{"result": result.Content}
parsed, err := parseJsonToMap(result.Content)
if err == nil {
@@ -123,7 +119,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
var toolCall message.ToolCall
for _, msg := range messages {
if msg.Role == message.Assistant {
for _, call := range msg.ToolCalls {
for _, call := range msg.ToolCalls() {
if call.ID == result.ToolCallID {
toolCall = call
break
@@ -146,108 +142,6 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
return history
}
// convertToolsToGeminiFunctionDeclarations converts tool definitions to Gemini's function declarations
func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
declarations := make([]*genai.FunctionDeclaration, len(tools))
for i, tool := range tools {
info := tool.Info()
// Convert parameters to genai.Schema format
properties := make(map[string]*genai.Schema)
for name, param := range info.Parameters {
// Try to extract type and description from the parameter
paramMap, ok := param.(map[string]interface{})
if !ok {
// Default to string if unable to determine type
properties[name] = &genai.Schema{Type: genai.TypeString}
continue
}
schemaType := genai.TypeString // Default
var description string
var itemsTypeSchema *genai.Schema
if typeVal, found := paramMap["type"]; found {
if typeStr, ok := typeVal.(string); ok {
switch typeStr {
case "string":
schemaType = genai.TypeString
case "number":
schemaType = genai.TypeNumber
case "integer":
schemaType = genai.TypeInteger
case "boolean":
schemaType = genai.TypeBoolean
case "array":
schemaType = genai.TypeArray
items, found := paramMap["items"]
if found {
itemsMap, ok := items.(map[string]interface{})
if ok {
itemsType, found := itemsMap["type"]
if found {
itemsTypeStr, ok := itemsType.(string)
if ok {
switch itemsTypeStr {
case "string":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeString,
}
case "number":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeNumber,
}
case "integer":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeInteger,
}
case "boolean":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeBoolean,
}
}
}
}
}
}
case "object":
schemaType = genai.TypeObject
if _, found := paramMap["properties"]; !found {
continue
}
// TODO: Add support for other types
}
}
}
if desc, found := paramMap["description"]; found {
if descStr, ok := desc.(string); ok {
description = descStr
}
}
properties[name] = &genai.Schema{
Type: schemaType,
Description: description,
Items: itemsTypeSchema,
}
}
declarations[i] = &genai.FunctionDeclaration{
Name: info.Name,
Description: info.Description,
Parameters: &genai.Schema{
Type: genai.TypeObject,
Properties: properties,
Required: info.Required,
},
}
}
return declarations
}
// extractTokenUsage extracts token usage information from Gemini's response
func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
if resp == nil || resp.UsageMetadata == nil {
return TokenUsage{}
@@ -261,41 +155,28 @@ func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse)
}
}
// SendMessages sends a batch of messages to Gemini and returns the response
func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
// Create a generative model
model := p.client.GenerativeModel(p.model.APIModel)
model.SetMaxOutputTokens(p.maxTokens)
// Set system instruction
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
// Set up tools if provided
if len(tools) > 0 {
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}}
}
// Create chat session and set history
chat := model.StartChat()
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
// Get the most recent user message
var lastUserMsg message.Message
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == message.User {
lastUserMsg = messages[i]
break
for _, declaration := range declarations {
model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
}
}
// Send the message
resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content))
chat := model.StartChat()
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
lastUserMsg := messages[len(messages)-1]
resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
if err != nil {
return nil, err
}
// Process the response
var content string
var toolCalls []message.ToolCall
@@ -317,7 +198,6 @@ func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Me
}
}
// Extract token usage
tokenUsage := p.extractTokenUsage(resp)
return &ProviderResponse{
@@ -327,16 +207,12 @@ func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Me
}, nil
}
// StreamResponse streams the response from Gemini
func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
// Create a generative model
model := p.client.GenerativeModel(p.model.APIModel)
model.SetMaxOutputTokens(p.maxTokens)
// Set system instruction
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
// Set up tools if provided
if len(tools) > 0 {
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
for _, declaration := range declarations {
@@ -344,14 +220,12 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
}
}
// Create chat session and set history
chat := model.StartChat()
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
lastUserMsg := messages[len(messages)-1]
// Start streaming
iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content))
iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
eventChan := make(chan ProviderEvent)
@@ -392,7 +266,6 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
}
currentContent += newText
case genai.FunctionCall:
// For function calls, we assume they come complete, not streamed in parts
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
newCall := message.ToolCall{
@@ -402,7 +275,6 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
Type: "function",
}
// Check if this is a new tool call
isNew := true
for _, existing := range toolCalls {
if existing.Name == newCall.Name && existing.Input == newCall.Input {
@@ -419,15 +291,15 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
}
}
// Extract token usage from the final response
tokenUsage := p.extractTokenUsage(finalResp)
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: currentContent,
ToolCalls: toolCalls,
Usage: tokenUsage,
Content: currentContent,
ToolCalls: toolCalls,
Usage: tokenUsage,
FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
},
}
}()
@@ -435,7 +307,99 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
return eventChan, nil
}
// Helper function to parse JSON string into map
func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
declarations := make([]*genai.FunctionDeclaration, len(tools))
for i, tool := range tools {
info := tool.Info()
declarations[i] = &genai.FunctionDeclaration{
Name: info.Name,
Description: info.Description,
Parameters: &genai.Schema{
Type: genai.TypeObject,
Properties: convertSchemaProperties(info.Parameters),
Required: info.Required,
},
}
}
return declarations
}
func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
properties := make(map[string]*genai.Schema)
for name, param := range parameters {
properties[name] = convertToSchema(param)
}
return properties
}
func convertToSchema(param interface{}) *genai.Schema {
schema := &genai.Schema{Type: genai.TypeString}
paramMap, ok := param.(map[string]interface{})
if !ok {
return schema
}
if desc, ok := paramMap["description"].(string); ok {
schema.Description = desc
}
typeVal, hasType := paramMap["type"]
if !hasType {
return schema
}
typeStr, ok := typeVal.(string)
if !ok {
return schema
}
schema.Type = mapJSONTypeToGenAI(typeStr)
switch typeStr {
case "array":
schema.Items = processArrayItems(paramMap)
case "object":
if props, ok := paramMap["properties"].(map[string]interface{}); ok {
schema.Properties = convertSchemaProperties(props)
}
}
return schema
}
func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
items, ok := paramMap["items"].(map[string]interface{})
if !ok {
return nil
}
return convertToSchema(items)
}
func mapJSONTypeToGenAI(jsonType string) genai.Type {
switch jsonType {
case "string":
return genai.TypeString
case "number":
return genai.TypeNumber
case "integer":
return genai.TypeInteger
case "boolean":
return genai.TypeBoolean
case "array":
return genai.TypeArray
case "object":
return genai.TypeObject
default:
return genai.TypeString // Default to string for unknown types
}
}
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
var result map[string]interface{}
err := json.Unmarshal([]byte(jsonStr), &result)

View File

@@ -84,22 +84,22 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o
for _, msg := range messages {
switch msg.Role {
case message.User:
chatMessages = append(chatMessages, openai.UserMessage(msg.Content))
chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String()))
case message.Assistant:
assistantMsg := openai.ChatCompletionAssistantMessageParam{
Role: "assistant",
}
if msg.Content != "" {
if msg.Content().String() != "" {
assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
OfString: openai.String(msg.Content),
OfString: openai.String(msg.Content().String()),
}
}
if len(msg.ToolCalls) > 0 {
assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls))
for i, call := range msg.ToolCalls {
if len(msg.ToolCalls()) > 0 {
assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
for i, call := range msg.ToolCalls() {
assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
ID: call.ID,
Type: "function",
@@ -116,7 +116,7 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o
})
case message.Tool:
for _, result := range msg.ToolResults {
for _, result := range msg.ToolResults() {
chatMessages = append(chatMessages,
openai.ToolMessage(result.Content, result.ToolCallID),
)
@@ -276,3 +276,4 @@ func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.
return eventChan, nil
}

View File

@@ -27,9 +27,10 @@ type TokenUsage struct {
}
type ProviderResponse struct {
Content string
ToolCalls []message.ToolCall
Usage TokenUsage
Content string
ToolCalls []message.ToolCall
Usage TokenUsage
FinishReason string
}
type ProviderEvent struct {

View File

@@ -0,0 +1,229 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"sort"
"strings"
"time"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
)
type diagnosticsTool struct {
lspClients map[string]*lsp.Client
}
const (
DiagnosticsToolName = "diagnostics"
)
type DiagnosticsParams struct {
FilePath string `json:"file_path"`
}
func (b *diagnosticsTool) Info() ToolInfo {
return ToolInfo{
Name: DiagnosticsToolName,
Description: "Get diagnostics for a file and/or project.",
Parameters: map[string]any{
"file_path": map[string]any{
"type": "string",
"description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
},
},
Required: []string{},
}
}
func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params DiagnosticsParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
lsps := b.lspClients
if len(lsps) == 0 {
return NewTextErrorResponse("no LSP clients available"), nil
}
if params.FilePath == "" {
notifyLspOpenFile(ctx, params.FilePath, lsps)
}
output := appendDiagnostics(params.FilePath, lsps)
return NewTextResponse(output), nil
}
func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
for _, client := range lsps {
err := client.OpenFile(ctx, filePath)
if err != nil {
// Wait for the file to be opened and diagnostics to be received
// TODO: see if we can do this in a more efficient way
time.Sleep(2 * time.Second)
}
}
}
func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
fileDiagnostics := []string{}
projectDiagnostics := []string{}
// Enhanced format function that includes more diagnostic information
formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
// Base components
severity := "Info"
switch diagnostic.Severity {
case protocol.SeverityError:
severity = "Error"
case protocol.SeverityWarning:
severity = "Warn"
case protocol.SeverityHint:
severity = "Hint"
}
// Location information
location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
// Source information (LSP name)
sourceInfo := ""
if diagnostic.Source != "" {
sourceInfo = diagnostic.Source
} else if source != "" {
sourceInfo = source
}
// Code information
codeInfo := ""
if diagnostic.Code != nil {
codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
}
// Tags information
tagsInfo := ""
if len(diagnostic.Tags) > 0 {
tags := []string{}
for _, tag := range diagnostic.Tags {
switch tag {
case protocol.Unnecessary:
tags = append(tags, "unnecessary")
case protocol.Deprecated:
tags = append(tags, "deprecated")
}
}
if len(tags) > 0 {
tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
}
}
// Assemble the full diagnostic message
return fmt.Sprintf("%s: %s [%s]%s%s %s",
severity,
location,
sourceInfo,
codeInfo,
tagsInfo,
diagnostic.Message)
}
for lspName, client := range lsps {
diagnostics := client.GetDiagnostics()
if len(diagnostics) > 0 {
for location, diags := range diagnostics {
isCurrentFile := location.Path() == filePath
// Group diagnostics by severity for better organization
for _, diag := range diags {
formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
if isCurrentFile {
fileDiagnostics = append(fileDiagnostics, formattedDiag)
} else {
projectDiagnostics = append(projectDiagnostics, formattedDiag)
}
}
}
}
}
// Sort diagnostics by severity (errors first) and then by location
sort.Slice(fileDiagnostics, func(i, j int) bool {
iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
if iIsError != jIsError {
return iIsError // Errors come first
}
return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
})
sort.Slice(projectDiagnostics, func(i, j int) bool {
iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
if iIsError != jIsError {
return iIsError
}
return projectDiagnostics[i] < projectDiagnostics[j]
})
output := ""
if len(fileDiagnostics) > 0 {
output += "\n<file_diagnostics>\n"
if len(fileDiagnostics) > 10 {
output += strings.Join(fileDiagnostics[:10], "\n")
output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
} else {
output += strings.Join(fileDiagnostics, "\n")
}
output += "\n</file_diagnostics>\n"
}
if len(projectDiagnostics) > 0 {
output += "\n<project_diagnostics>\n"
if len(projectDiagnostics) > 10 {
output += strings.Join(projectDiagnostics[:10], "\n")
output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
} else {
output += strings.Join(projectDiagnostics, "\n")
}
output += "\n</project_diagnostics>\n"
}
// Add summary counts
if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
fileErrors := countSeverity(fileDiagnostics, "Error")
fileWarnings := countSeverity(fileDiagnostics, "Warn")
projectErrors := countSeverity(projectDiagnostics, "Error")
projectWarnings := countSeverity(projectDiagnostics, "Warn")
output += "\n<diagnostic_summary>\n"
output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
output += "</diagnostic_summary>\n"
}
return output
}
// Helper function to count diagnostics by severity
func countSeverity(diagnostics []string, severity string) int {
count := 0
for _, diag := range diagnostics {
if strings.HasPrefix(diag, severity) {
count++
}
}
return count
}
func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
return &diagnosticsTool{
lspClients,
}
}

View File

@@ -10,11 +10,14 @@ import (
"time"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
"github.com/sergi/go-diff/diffmatchpatch"
)
type editTool struct{}
type editTool struct {
lspClients map[string]*lsp.Client
}
const (
EditToolName = "edit"
@@ -71,6 +74,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
params.FilePath = filepath.Join(wd, params.FilePath)
}
notifyLspOpenFile(ctx, params.FilePath, e.lspClients)
if params.OldString == "" {
result, err := createNewFile(params.FilePath, params.NewString)
if err != nil {
@@ -91,6 +95,9 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil
}
result = fmt.Sprintf("<result>\n%s\n</result>\n", result)
result += appendDiagnostics(params.FilePath, e.lspClients)
return NewTextResponse(result), nil
}
@@ -296,18 +303,18 @@ func GenerateDiff(oldContent, newContent string) string {
switch diff.Type {
case diffmatchpatch.DiffInsert:
for _, line := range strings.Split(text, "\n") {
for line := range strings.SplitSeq(text, "\n") {
_, _ = buff.WriteString("+ " + line + "\n")
}
case diffmatchpatch.DiffDelete:
for _, line := range strings.Split(text, "\n") {
for line := range strings.SplitSeq(text, "\n") {
_, _ = buff.WriteString("- " + line + "\n")
}
case diffmatchpatch.DiffEqual:
if len(text) > 40 {
_, _ = buff.WriteString(" " + text[:20] + "..." + text[len(text)-20:] + "\n")
} else {
for _, line := range strings.Split(text, "\n") {
for line := range strings.SplitSeq(text, "\n") {
_, _ = buff.WriteString(" " + line + "\n")
}
}
@@ -366,6 +373,8 @@ When making edits:
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`
}
func NewEditTool() BaseTool {
return &editTool{}
func NewEditTool(lspClients map[string]*lsp.Client) BaseTool {
return &editTool{
lspClients,
}
}

View File

@@ -221,7 +221,7 @@ func (s *PersistentShell) killChildren() {
return
}
for _, pidStr := range strings.Split(string(output), "\n") {
for pidStr := range strings.SplitSeq(string(output), "\n") {
if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
var pid int
fmt.Sscanf(pidStr, "%d", &pid)

View File

@@ -11,9 +11,12 @@ import (
"strings"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/lsp"
)
type viewTool struct{}
type viewTool struct {
lspClients map[string]*lsp.Client
}
const (
ViewToolName = "view"
@@ -127,15 +130,18 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
return NewTextErrorResponse(fmt.Sprintf("Failed to read file: %s", err)), nil
}
notifyLspOpenFile(ctx, filePath, v.lspClients)
output := "<file>\n"
// Format the output with line numbers
output := addLineNumbers(content, params.Offset+1)
output += addLineNumbers(content, params.Offset+1)
// Add a note if the content was truncated
if lineCount > params.Offset+len(strings.Split(content, "\n")) {
output += fmt.Sprintf("\n\n(File has more lines. Use 'offset' parameter to read beyond line %d)",
params.Offset+len(strings.Split(content, "\n")))
}
output += "\n</file>\n"
output += appendDiagnostics(filePath, v.lspClients)
recordFileRead(filePath)
return NewTextResponse(output), nil
}
@@ -155,10 +161,10 @@ func addLineNumbers(content string, startLine int) string {
numStr := fmt.Sprintf("%d", lineNum)
if len(numStr) >= 6 {
result = append(result, fmt.Sprintf("%s\t%s", numStr, line))
result = append(result, fmt.Sprintf("%s|%s", numStr, line))
} else {
paddedNum := fmt.Sprintf("%6s", numStr)
result = append(result, fmt.Sprintf("%s\t|%s", paddedNum, line))
result = append(result, fmt.Sprintf("%s|%s", paddedNum, line))
}
}
@@ -173,8 +179,9 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) {
defer file.Close()
lineCount := 0
scanner := NewLineScanner(file)
if offset > 0 {
scanner := NewLineScanner(file)
for lineCount < offset && scanner.Scan() {
lineCount++
}
@@ -192,7 +199,6 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) {
var lines []string
lineCount = offset
scanner := NewLineScanner(file)
for scanner.Scan() && len(lines) < limit {
lineCount++
@@ -290,6 +296,8 @@ TIPS:
- When viewing large files, use the offset parameter to read specific sections`
}
func NewViewTool() BaseTool {
return &viewTool{}
func NewViewTool(lspClients map[string]*lsp.Client) BaseTool {
return &viewTool{
lspClients,
}
}

View File

@@ -9,10 +9,13 @@ import (
"time"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
)
type writeTool struct{}
type writeTool struct {
lspClients map[string]*lsp.Client
}
const (
WriteToolName = "write"
@@ -96,6 +99,8 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
if err = os.MkdirAll(dir, 0o755); err != nil {
return NewTextErrorResponse(fmt.Sprintf("Failed to create parent directories: %s", err)), nil
}
notifyLspOpenFile(ctx, filePath, w.lspClients)
p := permission.Default.Request(
permission.CreatePermissionRequest{
Path: filePath,
@@ -122,7 +127,10 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
recordFileWrite(filePath)
recordFileRead(filePath)
return NewTextResponse(fmt.Sprintf("File successfully written: %s", filePath)), nil
result := fmt.Sprintf("File successfully written: %s", filePath)
result = fmt.Sprintf("<result>\n%s\n</result>", result)
result += appendDiagnostics(filePath, w.lspClients)
return NewTextResponse(result), nil
}
func writeDescription() string {
@@ -156,6 +164,8 @@ TIPS:
- Always include descriptive comments when making changes to existing code`
}
func NewWriteTool() BaseTool {
return &writeTool{}
func NewWriteTool(lspClients map[string]*lsp.Client) BaseTool {
return &writeTool{
lspClients,
}
}

View File

@@ -8,13 +8,14 @@ import (
"testing"
"time"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWriteTool_Info(t *testing.T) {
tool := NewWriteTool()
tool := NewWriteTool(make(map[string]*lsp.Client))
info := tool.Info()
assert.Equal(t, WriteToolName, info.Name)
@@ -40,11 +41,11 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("creates a new file successfully", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
tool := NewWriteTool(make(map[string]*lsp.Client))
filePath := filepath.Join(tempDir, "new_file.txt")
content := "This is a test content"
params := WriteParams{
FilePath: filePath,
Content: content,
@@ -70,11 +71,11 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("creates file with nested directories", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
tool := NewWriteTool(make(map[string]*lsp.Client))
filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
content := "Content in nested directory"
params := WriteParams{
FilePath: filePath,
Content: content,
@@ -100,17 +101,17 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("updates existing file", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
tool := NewWriteTool(make(map[string]*lsp.Client))
// Create a file first
filePath := filepath.Join(tempDir, "existing_file.txt")
initialContent := "Initial content"
err := os.WriteFile(filePath, []byte(initialContent), 0644)
err := os.WriteFile(filePath, []byte(initialContent), 0o644)
require.NoError(t, err)
// Record the file read to avoid modification time check failure
recordFileRead(filePath)
// Update the file
updatedContent := "Updated content"
params := WriteParams{
@@ -138,8 +139,8 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("handles invalid parameters", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
tool := NewWriteTool(make(map[string]*lsp.Client))
call := ToolCall{
Name: WriteToolName,
Input: "invalid json",
@@ -152,8 +153,8 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("handles missing file_path", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
tool := NewWriteTool(make(map[string]*lsp.Client))
params := WriteParams{
FilePath: "",
Content: "Some content",
@@ -174,8 +175,8 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("handles missing content", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
tool := NewWriteTool(make(map[string]*lsp.Client))
params := WriteParams{
FilePath: filepath.Join(tempDir, "file.txt"),
Content: "",
@@ -196,13 +197,13 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("handles writing to a directory path", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
tool := NewWriteTool(make(map[string]*lsp.Client))
// Create a directory
dirPath := filepath.Join(tempDir, "test_dir")
err := os.Mkdir(dirPath, 0755)
err := os.Mkdir(dirPath, 0o755)
require.NoError(t, err)
params := WriteParams{
FilePath: dirPath,
Content: "Some content",
@@ -223,8 +224,8 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("handles permission denied", func(t *testing.T) {
permission.Default = newMockPermissionService(false)
tool := NewWriteTool()
tool := NewWriteTool(make(map[string]*lsp.Client))
filePath := filepath.Join(tempDir, "permission_denied.txt")
params := WriteParams{
FilePath: filePath,
@@ -242,7 +243,7 @@ func TestWriteTool_Run(t *testing.T) {
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "Permission denied")
// Verify file was not created
_, err = os.Stat(filePath)
assert.True(t, os.IsNotExist(err))
@@ -250,14 +251,14 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("detects file modified since last read", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
tool := NewWriteTool(make(map[string]*lsp.Client))
// Create a file
filePath := filepath.Join(tempDir, "modified_file.txt")
initialContent := "Initial content"
err := os.WriteFile(filePath, []byte(initialContent), 0644)
err := os.WriteFile(filePath, []byte(initialContent), 0o644)
require.NoError(t, err)
// Record an old read time
fileRecordMutex.Lock()
fileRecords[filePath] = fileRecord{
@@ -265,7 +266,7 @@ func TestWriteTool_Run(t *testing.T) {
readTime: time.Now().Add(-1 * time.Hour),
}
fileRecordMutex.Unlock()
// Try to update the file
params := WriteParams{
FilePath: filePath,
@@ -283,7 +284,7 @@ func TestWriteTool_Run(t *testing.T) {
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "has been modified since it was last read")
// Verify file was not modified
fileContent, err := os.ReadFile(filePath)
require.NoError(t, err)
@@ -292,17 +293,17 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("skips writing when content is identical", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
tool := NewWriteTool()
tool := NewWriteTool(make(map[string]*lsp.Client))
// Create a file
filePath := filepath.Join(tempDir, "identical_content.txt")
content := "Content that won't change"
err := os.WriteFile(filePath, []byte(content), 0644)
err := os.WriteFile(filePath, []byte(content), 0o644)
require.NoError(t, err)
// Record a read time
recordFileRead(filePath)
// Try to write the same content
params := WriteParams{
FilePath: filePath,
@@ -321,4 +322,5 @@ func TestWriteTool_Run(t *testing.T) {
require.NoError(t, err)
assert.Contains(t, response.Content, "already contains the exact content")
})
}
}