mirror of
https://github.com/aljazceru/opencode.git
synced 2026-01-07 18:04:54 +01:00
add initial lsp support
This commit is contained in:
@@ -91,7 +91,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
|
||||
}
|
||||
return tools.NewTextResponse(response.Content), nil
|
||||
return tools.NewTextResponse(response.Content().String()), nil
|
||||
}
|
||||
|
||||
func NewAgentTool(parentSessionID string, app *app.App) tools.BaseTool {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/app"
|
||||
@@ -33,8 +34,12 @@ func (c *agent) handleTitleGeneration(sessionID, content string) {
|
||||
c.Context,
|
||||
[]message.Message{
|
||||
{
|
||||
Role: message.User,
|
||||
Content: content,
|
||||
Role: message.User,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{
|
||||
Text: content,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
@@ -49,6 +54,8 @@ func (c *agent) handleTitleGeneration(sessionID, content string) {
|
||||
}
|
||||
if response.Content != "" {
|
||||
session.Title = response.Content
|
||||
session.Title = strings.TrimSpace(session.Title)
|
||||
session.Title = strings.ReplaceAll(session.Title, "\n", " ")
|
||||
c.Sessions.Save(session)
|
||||
}
|
||||
}
|
||||
@@ -79,17 +86,18 @@ func (c *agent) processEvent(
|
||||
) error {
|
||||
switch event.Type {
|
||||
case provider.EventThinkingDelta:
|
||||
assistantMsg.Thinking += event.Thinking
|
||||
assistantMsg.AppendReasoningContent(event.Content)
|
||||
return c.Messages.Update(*assistantMsg)
|
||||
case provider.EventContentDelta:
|
||||
assistantMsg.Content += event.Content
|
||||
assistantMsg.AppendContent(event.Content)
|
||||
return c.Messages.Update(*assistantMsg)
|
||||
case provider.EventError:
|
||||
log.Println("error", event.Error)
|
||||
return event.Error
|
||||
|
||||
case provider.EventComplete:
|
||||
assistantMsg.ToolCalls = event.Response.ToolCalls
|
||||
assistantMsg.SetToolCalls(event.Response.ToolCalls)
|
||||
assistantMsg.AddFinish(event.Response.FinishReason)
|
||||
err := c.Messages.Update(*assistantMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -157,18 +165,21 @@ func (c *agent) handleToolExecution(
|
||||
ctx context.Context,
|
||||
assistantMsg message.Message,
|
||||
) (*message.Message, error) {
|
||||
if len(assistantMsg.ToolCalls) == 0 {
|
||||
if len(assistantMsg.ToolCalls()) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls, c.tools)
|
||||
toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parts := make([]message.ContentPart, 0)
|
||||
for _, toolResult := range toolResults {
|
||||
parts = append(parts, toolResult)
|
||||
}
|
||||
msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
|
||||
Role: message.Tool,
|
||||
ToolResults: toolResults,
|
||||
Role: message.Tool,
|
||||
Parts: parts,
|
||||
})
|
||||
|
||||
return &msg, err
|
||||
@@ -185,8 +196,12 @@ func (c *agent) generate(sessionID string, content string) error {
|
||||
}
|
||||
|
||||
userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
|
||||
Role: message.User,
|
||||
Content: content,
|
||||
Role: message.User,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{
|
||||
Text: content,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -201,8 +216,8 @@ func (c *agent) generate(sessionID string, content string) error {
|
||||
}
|
||||
|
||||
assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
|
||||
Role: message.Assistant,
|
||||
Content: "",
|
||||
Role: message.Assistant,
|
||||
Parts: []message.ContentPart{},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -210,20 +225,20 @@ func (c *agent) generate(sessionID string, content string) error {
|
||||
for event := range eventChan {
|
||||
err = c.processEvent(sessionID, &assistantMsg, event)
|
||||
if err != nil {
|
||||
assistantMsg.Finished = true
|
||||
assistantMsg.AddFinish("error:" + err.Error())
|
||||
c.Messages.Update(assistantMsg)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
msg, err := c.handleToolExecution(c.Context, assistantMsg)
|
||||
assistantMsg.Finished = true
|
||||
|
||||
c.Messages.Update(assistantMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(assistantMsg.ToolCalls) == 0 {
|
||||
if len(assistantMsg.ToolCalls()) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
@@ -44,20 +44,23 @@ func NewCoderAgent(app *app.App) (Agent, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mcpTools := GetMcpTools(app.Context)
|
||||
otherTools := GetMcpTools(app.Context)
|
||||
if len(app.LSPClients) > 0 {
|
||||
otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients))
|
||||
}
|
||||
return &coderAgent{
|
||||
agent: &agent{
|
||||
App: app,
|
||||
tools: append(
|
||||
[]tools.BaseTool{
|
||||
tools.NewBashTool(),
|
||||
tools.NewEditTool(),
|
||||
tools.NewEditTool(app.LSPClients),
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewViewTool(),
|
||||
tools.NewWriteTool(),
|
||||
}, mcpTools...,
|
||||
tools.NewViewTool(app.LSPClients),
|
||||
tools.NewWriteTool(app.LSPClients),
|
||||
}, otherTools...,
|
||||
),
|
||||
model: model,
|
||||
agent: agentProvider,
|
||||
|
||||
@@ -34,7 +34,7 @@ func NewTaskAgent(app *app.App) (Agent, error) {
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewViewTool(),
|
||||
tools.NewViewTool(app.LSPClients),
|
||||
},
|
||||
model: model,
|
||||
agent: agentProvider,
|
||||
|
||||
@@ -67,7 +67,7 @@ Never commit changes unless the user explicitly asks you to.`
|
||||
|
||||
envInfo := getEnvironmentInfo()
|
||||
|
||||
return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
|
||||
return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
|
||||
}
|
||||
|
||||
func CoderAnthropicSystemPrompt() string {
|
||||
@@ -168,7 +168,7 @@ You MUST answer concisely with fewer than 4 lines of text (not including tool us
|
||||
|
||||
envInfo := getEnvironmentInfo()
|
||||
|
||||
return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
|
||||
return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
|
||||
}
|
||||
|
||||
func getEnvironmentInfo() string {
|
||||
@@ -198,6 +198,25 @@ func isGitRepo(dir string) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func lspInformation() string {
|
||||
cfg := config.Get()
|
||||
hasLSP := false
|
||||
for _, v := range cfg.LSP {
|
||||
if !v.Disabled {
|
||||
hasLSP = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasLSP {
|
||||
return ""
|
||||
}
|
||||
return `# LSP Information
|
||||
Tools that support it will also include useful diagnostics such as linting and typechecking.
|
||||
These diagnostics will be automatically enabled when you run the tool, and will be displayed in the output at the bottom within the <file_diagnostics></file_diagnostics> and <project_diagnostics></project_diagnostics> tags.
|
||||
Take necessary actions to fix the issues.
|
||||
`
|
||||
}
|
||||
|
||||
func boolToYesNo(b bool) string {
|
||||
if b {
|
||||
return "Yes"
|
||||
|
||||
@@ -111,7 +111,7 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
|
||||
var thinkingParam anthropic.ThinkingConfigParamUnion
|
||||
lastMessage := messages[len(messages)-1]
|
||||
temperature := anthropic.Float(0)
|
||||
if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content), "think") {
|
||||
if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
|
||||
thinkingParam = anthropic.ThinkingConfigParamUnion{
|
||||
OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
|
||||
BudgetTokens: int64(float64(a.maxTokens) * 0.8),
|
||||
@@ -187,9 +187,10 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
FinishReason: string(accumulatedMessage.StopReason),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -263,7 +264,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
|
||||
for i, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
content := anthropic.NewTextBlock(msg.Content)
|
||||
content := anthropic.NewTextBlock(msg.Content().String())
|
||||
if cachedBlocks < 2 {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
@@ -274,8 +275,8 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
|
||||
|
||||
case message.Assistant:
|
||||
blocks := []anthropic.ContentBlockParamUnion{}
|
||||
if msg.Content != "" {
|
||||
content := anthropic.NewTextBlock(msg.Content)
|
||||
if msg.Content().String() != "" {
|
||||
content := anthropic.NewTextBlock(msg.Content().String())
|
||||
if cachedBlocks < 2 {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
@@ -285,7 +286,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
|
||||
blocks = append(blocks, content)
|
||||
}
|
||||
|
||||
for _, toolCall := range msg.ToolCalls {
|
||||
for _, toolCall := range msg.ToolCalls() {
|
||||
var inputMap map[string]any
|
||||
err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
|
||||
if err != nil {
|
||||
@@ -297,8 +298,8 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
|
||||
anthropicMessages[i] = anthropic.NewAssistantMessage(blocks...)
|
||||
|
||||
case message.Tool:
|
||||
results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults))
|
||||
for i, toolResult := range msg.ToolResults {
|
||||
results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
|
||||
for i, toolResult := range msg.ToolResults() {
|
||||
results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
|
||||
}
|
||||
anthropicMessages[i] = anthropic.NewUserMessage(results...)
|
||||
|
||||
@@ -78,7 +78,6 @@ func (p *geminiProvider) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
// convertToGeminiHistory converts the message history to Gemini's format
|
||||
func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
|
||||
var history []*genai.Content
|
||||
|
||||
@@ -86,7 +85,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
history = append(history, &genai.Content{
|
||||
Parts: []genai.Part{genai.Text(msg.Content)},
|
||||
Parts: []genai.Part{genai.Text(msg.Content().String())},
|
||||
Role: "user",
|
||||
})
|
||||
case message.Assistant:
|
||||
@@ -95,14 +94,12 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
|
||||
Parts: []genai.Part{},
|
||||
}
|
||||
|
||||
// Handle regular content
|
||||
if msg.Content != "" {
|
||||
content.Parts = append(content.Parts, genai.Text(msg.Content))
|
||||
if msg.Content().String() != "" {
|
||||
content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
|
||||
}
|
||||
|
||||
// Handle tool calls if any
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
for _, call := range msg.ToolCalls {
|
||||
if len(msg.ToolCalls()) > 0 {
|
||||
for _, call := range msg.ToolCalls() {
|
||||
args, _ := parseJsonToMap(call.Input)
|
||||
content.Parts = append(content.Parts, genai.FunctionCall{
|
||||
Name: call.Name,
|
||||
@@ -113,8 +110,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
|
||||
|
||||
history = append(history, content)
|
||||
case message.Tool:
|
||||
for _, result := range msg.ToolResults {
|
||||
// Parse response content to map if possible
|
||||
for _, result := range msg.ToolResults() {
|
||||
response := map[string]interface{}{"result": result.Content}
|
||||
parsed, err := parseJsonToMap(result.Content)
|
||||
if err == nil {
|
||||
@@ -123,7 +119,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
|
||||
var toolCall message.ToolCall
|
||||
for _, msg := range messages {
|
||||
if msg.Role == message.Assistant {
|
||||
for _, call := range msg.ToolCalls {
|
||||
for _, call := range msg.ToolCalls() {
|
||||
if call.ID == result.ToolCallID {
|
||||
toolCall = call
|
||||
break
|
||||
@@ -146,108 +142,6 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
|
||||
return history
|
||||
}
|
||||
|
||||
// convertToolsToGeminiFunctionDeclarations converts tool definitions to Gemini's function declarations
|
||||
func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
|
||||
declarations := make([]*genai.FunctionDeclaration, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
|
||||
// Convert parameters to genai.Schema format
|
||||
properties := make(map[string]*genai.Schema)
|
||||
for name, param := range info.Parameters {
|
||||
// Try to extract type and description from the parameter
|
||||
paramMap, ok := param.(map[string]interface{})
|
||||
if !ok {
|
||||
// Default to string if unable to determine type
|
||||
properties[name] = &genai.Schema{Type: genai.TypeString}
|
||||
continue
|
||||
}
|
||||
|
||||
schemaType := genai.TypeString // Default
|
||||
var description string
|
||||
var itemsTypeSchema *genai.Schema
|
||||
if typeVal, found := paramMap["type"]; found {
|
||||
if typeStr, ok := typeVal.(string); ok {
|
||||
switch typeStr {
|
||||
case "string":
|
||||
schemaType = genai.TypeString
|
||||
case "number":
|
||||
schemaType = genai.TypeNumber
|
||||
case "integer":
|
||||
schemaType = genai.TypeInteger
|
||||
case "boolean":
|
||||
schemaType = genai.TypeBoolean
|
||||
case "array":
|
||||
schemaType = genai.TypeArray
|
||||
items, found := paramMap["items"]
|
||||
if found {
|
||||
itemsMap, ok := items.(map[string]interface{})
|
||||
if ok {
|
||||
itemsType, found := itemsMap["type"]
|
||||
if found {
|
||||
itemsTypeStr, ok := itemsType.(string)
|
||||
if ok {
|
||||
switch itemsTypeStr {
|
||||
case "string":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeString,
|
||||
}
|
||||
case "number":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeNumber,
|
||||
}
|
||||
case "integer":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeInteger,
|
||||
}
|
||||
case "boolean":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeBoolean,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case "object":
|
||||
schemaType = genai.TypeObject
|
||||
if _, found := paramMap["properties"]; !found {
|
||||
continue
|
||||
}
|
||||
// TODO: Add support for other types
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if desc, found := paramMap["description"]; found {
|
||||
if descStr, ok := desc.(string); ok {
|
||||
description = descStr
|
||||
}
|
||||
}
|
||||
|
||||
properties[name] = &genai.Schema{
|
||||
Type: schemaType,
|
||||
Description: description,
|
||||
Items: itemsTypeSchema,
|
||||
}
|
||||
}
|
||||
|
||||
declarations[i] = &genai.FunctionDeclaration{
|
||||
Name: info.Name,
|
||||
Description: info.Description,
|
||||
Parameters: &genai.Schema{
|
||||
Type: genai.TypeObject,
|
||||
Properties: properties,
|
||||
Required: info.Required,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return declarations
|
||||
}
|
||||
|
||||
// extractTokenUsage extracts token usage information from Gemini's response
|
||||
func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
|
||||
if resp == nil || resp.UsageMetadata == nil {
|
||||
return TokenUsage{}
|
||||
@@ -261,41 +155,28 @@ func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse)
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessages sends a batch of messages to Gemini and returns the response
|
||||
func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
// Create a generative model
|
||||
model := p.client.GenerativeModel(p.model.APIModel)
|
||||
model.SetMaxOutputTokens(p.maxTokens)
|
||||
|
||||
// Set system instruction
|
||||
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
|
||||
|
||||
// Set up tools if provided
|
||||
if len(tools) > 0 {
|
||||
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
|
||||
model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}}
|
||||
}
|
||||
|
||||
// Create chat session and set history
|
||||
chat := model.StartChat()
|
||||
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
|
||||
|
||||
// Get the most recent user message
|
||||
var lastUserMsg message.Message
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == message.User {
|
||||
lastUserMsg = messages[i]
|
||||
break
|
||||
for _, declaration := range declarations {
|
||||
model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
|
||||
}
|
||||
}
|
||||
|
||||
// Send the message
|
||||
resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content))
|
||||
chat := model.StartChat()
|
||||
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
|
||||
|
||||
lastUserMsg := messages[len(messages)-1]
|
||||
resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process the response
|
||||
var content string
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
@@ -317,7 +198,6 @@ func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Me
|
||||
}
|
||||
}
|
||||
|
||||
// Extract token usage
|
||||
tokenUsage := p.extractTokenUsage(resp)
|
||||
|
||||
return &ProviderResponse{
|
||||
@@ -327,16 +207,12 @@ func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Me
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StreamResponse streams the response from Gemini
|
||||
func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
// Create a generative model
|
||||
model := p.client.GenerativeModel(p.model.APIModel)
|
||||
model.SetMaxOutputTokens(p.maxTokens)
|
||||
|
||||
// Set system instruction
|
||||
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
|
||||
|
||||
// Set up tools if provided
|
||||
if len(tools) > 0 {
|
||||
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
|
||||
for _, declaration := range declarations {
|
||||
@@ -344,14 +220,12 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
|
||||
}
|
||||
}
|
||||
|
||||
// Create chat session and set history
|
||||
chat := model.StartChat()
|
||||
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
|
||||
|
||||
lastUserMsg := messages[len(messages)-1]
|
||||
|
||||
// Start streaming
|
||||
iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content))
|
||||
iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
|
||||
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
@@ -392,7 +266,6 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
|
||||
}
|
||||
currentContent += newText
|
||||
case genai.FunctionCall:
|
||||
// For function calls, we assume they come complete, not streamed in parts
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
newCall := message.ToolCall{
|
||||
@@ -402,7 +275,6 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
|
||||
Type: "function",
|
||||
}
|
||||
|
||||
// Check if this is a new tool call
|
||||
isNew := true
|
||||
for _, existing := range toolCalls {
|
||||
if existing.Name == newCall.Name && existing.Input == newCall.Input {
|
||||
@@ -419,15 +291,15 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
|
||||
}
|
||||
}
|
||||
|
||||
// Extract token usage from the final response
|
||||
tokenUsage := p.extractTokenUsage(finalResp)
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
|
||||
},
|
||||
}
|
||||
}()
|
||||
@@ -435,7 +307,99 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
|
||||
return eventChan, nil
|
||||
}
|
||||
|
||||
// Helper function to parse JSON string into map
|
||||
func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
|
||||
declarations := make([]*genai.FunctionDeclaration, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
declarations[i] = &genai.FunctionDeclaration{
|
||||
Name: info.Name,
|
||||
Description: info.Description,
|
||||
Parameters: &genai.Schema{
|
||||
Type: genai.TypeObject,
|
||||
Properties: convertSchemaProperties(info.Parameters),
|
||||
Required: info.Required,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return declarations
|
||||
}
|
||||
|
||||
func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
|
||||
properties := make(map[string]*genai.Schema)
|
||||
|
||||
for name, param := range parameters {
|
||||
properties[name] = convertToSchema(param)
|
||||
}
|
||||
|
||||
return properties
|
||||
}
|
||||
|
||||
func convertToSchema(param interface{}) *genai.Schema {
|
||||
schema := &genai.Schema{Type: genai.TypeString}
|
||||
|
||||
paramMap, ok := param.(map[string]interface{})
|
||||
if !ok {
|
||||
return schema
|
||||
}
|
||||
|
||||
if desc, ok := paramMap["description"].(string); ok {
|
||||
schema.Description = desc
|
||||
}
|
||||
|
||||
typeVal, hasType := paramMap["type"]
|
||||
if !hasType {
|
||||
return schema
|
||||
}
|
||||
|
||||
typeStr, ok := typeVal.(string)
|
||||
if !ok {
|
||||
return schema
|
||||
}
|
||||
|
||||
schema.Type = mapJSONTypeToGenAI(typeStr)
|
||||
|
||||
switch typeStr {
|
||||
case "array":
|
||||
schema.Items = processArrayItems(paramMap)
|
||||
case "object":
|
||||
if props, ok := paramMap["properties"].(map[string]interface{}); ok {
|
||||
schema.Properties = convertSchemaProperties(props)
|
||||
}
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
|
||||
func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
|
||||
items, ok := paramMap["items"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return convertToSchema(items)
|
||||
}
|
||||
|
||||
func mapJSONTypeToGenAI(jsonType string) genai.Type {
|
||||
switch jsonType {
|
||||
case "string":
|
||||
return genai.TypeString
|
||||
case "number":
|
||||
return genai.TypeNumber
|
||||
case "integer":
|
||||
return genai.TypeInteger
|
||||
case "boolean":
|
||||
return genai.TypeBoolean
|
||||
case "array":
|
||||
return genai.TypeArray
|
||||
case "object":
|
||||
return genai.TypeObject
|
||||
default:
|
||||
return genai.TypeString // Default to string for unknown types
|
||||
}
|
||||
}
|
||||
|
||||
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &result)
|
||||
|
||||
@@ -84,22 +84,22 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
chatMessages = append(chatMessages, openai.UserMessage(msg.Content))
|
||||
chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String()))
|
||||
|
||||
case message.Assistant:
|
||||
assistantMsg := openai.ChatCompletionAssistantMessageParam{
|
||||
Role: "assistant",
|
||||
}
|
||||
|
||||
if msg.Content != "" {
|
||||
if msg.Content().String() != "" {
|
||||
assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
|
||||
OfString: openai.String(msg.Content),
|
||||
OfString: openai.String(msg.Content().String()),
|
||||
}
|
||||
}
|
||||
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls))
|
||||
for i, call := range msg.ToolCalls {
|
||||
if len(msg.ToolCalls()) > 0 {
|
||||
assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
|
||||
for i, call := range msg.ToolCalls() {
|
||||
assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
|
||||
ID: call.ID,
|
||||
Type: "function",
|
||||
@@ -116,7 +116,7 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o
|
||||
})
|
||||
|
||||
case message.Tool:
|
||||
for _, result := range msg.ToolResults {
|
||||
for _, result := range msg.ToolResults() {
|
||||
chatMessages = append(chatMessages,
|
||||
openai.ToolMessage(result.Content, result.ToolCallID),
|
||||
)
|
||||
@@ -276,3 +276,4 @@ func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.
|
||||
|
||||
return eventChan, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -27,9 +27,10 @@ type TokenUsage struct {
|
||||
}
|
||||
|
||||
type ProviderResponse struct {
|
||||
Content string
|
||||
ToolCalls []message.ToolCall
|
||||
Usage TokenUsage
|
||||
Content string
|
||||
ToolCalls []message.ToolCall
|
||||
Usage TokenUsage
|
||||
FinishReason string
|
||||
}
|
||||
|
||||
type ProviderEvent struct {
|
||||
|
||||
229
internal/llm/tools/diagnostics.go
Normal file
229
internal/llm/tools/diagnostics.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp"
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
type diagnosticsTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
DiagnosticsToolName = "diagnostics"
|
||||
)
|
||||
|
||||
type DiagnosticsParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
func (b *diagnosticsTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: DiagnosticsToolName,
|
||||
Description: "Get diagnostics for a file and/or project.",
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
|
||||
},
|
||||
},
|
||||
Required: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params DiagnosticsParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
return NewTextErrorResponse("no LSP clients available"), nil
|
||||
}
|
||||
|
||||
if params.FilePath == "" {
|
||||
notifyLspOpenFile(ctx, params.FilePath, lsps)
|
||||
}
|
||||
|
||||
output := appendDiagnostics(params.FilePath, lsps)
|
||||
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
|
||||
for _, client := range lsps {
|
||||
err := client.OpenFile(ctx, filePath)
|
||||
if err != nil {
|
||||
// Wait for the file to be opened and diagnostics to be received
|
||||
// TODO: see if we can do this in a more efficient way
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
|
||||
fileDiagnostics := []string{}
|
||||
projectDiagnostics := []string{}
|
||||
|
||||
// Enhanced format function that includes more diagnostic information
|
||||
formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
|
||||
// Base components
|
||||
severity := "Info"
|
||||
switch diagnostic.Severity {
|
||||
case protocol.SeverityError:
|
||||
severity = "Error"
|
||||
case protocol.SeverityWarning:
|
||||
severity = "Warn"
|
||||
case protocol.SeverityHint:
|
||||
severity = "Hint"
|
||||
}
|
||||
|
||||
// Location information
|
||||
location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
|
||||
|
||||
// Source information (LSP name)
|
||||
sourceInfo := ""
|
||||
if diagnostic.Source != "" {
|
||||
sourceInfo = diagnostic.Source
|
||||
} else if source != "" {
|
||||
sourceInfo = source
|
||||
}
|
||||
|
||||
// Code information
|
||||
codeInfo := ""
|
||||
if diagnostic.Code != nil {
|
||||
codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
|
||||
}
|
||||
|
||||
// Tags information
|
||||
tagsInfo := ""
|
||||
if len(diagnostic.Tags) > 0 {
|
||||
tags := []string{}
|
||||
for _, tag := range diagnostic.Tags {
|
||||
switch tag {
|
||||
case protocol.Unnecessary:
|
||||
tags = append(tags, "unnecessary")
|
||||
case protocol.Deprecated:
|
||||
tags = append(tags, "deprecated")
|
||||
}
|
||||
}
|
||||
if len(tags) > 0 {
|
||||
tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
// Assemble the full diagnostic message
|
||||
return fmt.Sprintf("%s: %s [%s]%s%s %s",
|
||||
severity,
|
||||
location,
|
||||
sourceInfo,
|
||||
codeInfo,
|
||||
tagsInfo,
|
||||
diagnostic.Message)
|
||||
}
|
||||
|
||||
for lspName, client := range lsps {
|
||||
diagnostics := client.GetDiagnostics()
|
||||
if len(diagnostics) > 0 {
|
||||
for location, diags := range diagnostics {
|
||||
isCurrentFile := location.Path() == filePath
|
||||
|
||||
// Group diagnostics by severity for better organization
|
||||
for _, diag := range diags {
|
||||
formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
|
||||
|
||||
if isCurrentFile {
|
||||
fileDiagnostics = append(fileDiagnostics, formattedDiag)
|
||||
} else {
|
||||
projectDiagnostics = append(projectDiagnostics, formattedDiag)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort diagnostics by severity (errors first) and then by location
|
||||
sort.Slice(fileDiagnostics, func(i, j int) bool {
|
||||
iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
|
||||
jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
|
||||
if iIsError != jIsError {
|
||||
return iIsError // Errors come first
|
||||
}
|
||||
return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
|
||||
})
|
||||
|
||||
sort.Slice(projectDiagnostics, func(i, j int) bool {
|
||||
iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
|
||||
jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
|
||||
if iIsError != jIsError {
|
||||
return iIsError
|
||||
}
|
||||
return projectDiagnostics[i] < projectDiagnostics[j]
|
||||
})
|
||||
|
||||
output := ""
|
||||
|
||||
if len(fileDiagnostics) > 0 {
|
||||
output += "\n<file_diagnostics>\n"
|
||||
if len(fileDiagnostics) > 10 {
|
||||
output += strings.Join(fileDiagnostics[:10], "\n")
|
||||
output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
|
||||
} else {
|
||||
output += strings.Join(fileDiagnostics, "\n")
|
||||
}
|
||||
output += "\n</file_diagnostics>\n"
|
||||
}
|
||||
|
||||
if len(projectDiagnostics) > 0 {
|
||||
output += "\n<project_diagnostics>\n"
|
||||
if len(projectDiagnostics) > 10 {
|
||||
output += strings.Join(projectDiagnostics[:10], "\n")
|
||||
output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
|
||||
} else {
|
||||
output += strings.Join(projectDiagnostics, "\n")
|
||||
}
|
||||
output += "\n</project_diagnostics>\n"
|
||||
}
|
||||
|
||||
// Add summary counts
|
||||
if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
|
||||
fileErrors := countSeverity(fileDiagnostics, "Error")
|
||||
fileWarnings := countSeverity(fileDiagnostics, "Warn")
|
||||
projectErrors := countSeverity(projectDiagnostics, "Error")
|
||||
projectWarnings := countSeverity(projectDiagnostics, "Warn")
|
||||
|
||||
output += "\n<diagnostic_summary>\n"
|
||||
output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
|
||||
output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
|
||||
output += "</diagnostic_summary>\n"
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
// Helper function to count diagnostics by severity
|
||||
func countSeverity(diagnostics []string, severity string) int {
|
||||
count := 0
|
||||
for _, diag := range diagnostics {
|
||||
if strings.HasPrefix(diag, severity) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &diagnosticsTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
@@ -10,11 +10,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/sergi/go-diff/diffmatchpatch"
|
||||
)
|
||||
|
||||
type editTool struct{}
|
||||
type editTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
EditToolName = "edit"
|
||||
@@ -71,6 +74,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
params.FilePath = filepath.Join(wd, params.FilePath)
|
||||
}
|
||||
|
||||
notifyLspOpenFile(ctx, params.FilePath, e.lspClients)
|
||||
if params.OldString == "" {
|
||||
result, err := createNewFile(params.FilePath, params.NewString)
|
||||
if err != nil {
|
||||
@@ -91,6 +95,9 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil
|
||||
}
|
||||
|
||||
result = fmt.Sprintf("<result>\n%s\n</result>\n", result)
|
||||
result += appendDiagnostics(params.FilePath, e.lspClients)
|
||||
return NewTextResponse(result), nil
|
||||
}
|
||||
|
||||
@@ -296,18 +303,18 @@ func GenerateDiff(oldContent, newContent string) string {
|
||||
|
||||
switch diff.Type {
|
||||
case diffmatchpatch.DiffInsert:
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
for line := range strings.SplitSeq(text, "\n") {
|
||||
_, _ = buff.WriteString("+ " + line + "\n")
|
||||
}
|
||||
case diffmatchpatch.DiffDelete:
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
for line := range strings.SplitSeq(text, "\n") {
|
||||
_, _ = buff.WriteString("- " + line + "\n")
|
||||
}
|
||||
case diffmatchpatch.DiffEqual:
|
||||
if len(text) > 40 {
|
||||
_, _ = buff.WriteString(" " + text[:20] + "..." + text[len(text)-20:] + "\n")
|
||||
} else {
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
for line := range strings.SplitSeq(text, "\n") {
|
||||
_, _ = buff.WriteString(" " + line + "\n")
|
||||
}
|
||||
}
|
||||
@@ -366,6 +373,8 @@ When making edits:
|
||||
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`
|
||||
}
|
||||
|
||||
func NewEditTool() BaseTool {
|
||||
return &editTool{}
|
||||
func NewEditTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &editTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,7 +221,7 @@ func (s *PersistentShell) killChildren() {
|
||||
return
|
||||
}
|
||||
|
||||
for _, pidStr := range strings.Split(string(output), "\n") {
|
||||
for pidStr := range strings.SplitSeq(string(output), "\n") {
|
||||
if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
|
||||
var pid int
|
||||
fmt.Sscanf(pidStr, "%d", &pid)
|
||||
|
||||
@@ -11,9 +11,12 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp"
|
||||
)
|
||||
|
||||
type viewTool struct{}
|
||||
type viewTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
ViewToolName = "view"
|
||||
@@ -127,15 +130,18 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
return NewTextErrorResponse(fmt.Sprintf("Failed to read file: %s", err)), nil
|
||||
}
|
||||
|
||||
notifyLspOpenFile(ctx, filePath, v.lspClients)
|
||||
output := "<file>\n"
|
||||
// Format the output with line numbers
|
||||
output := addLineNumbers(content, params.Offset+1)
|
||||
output += addLineNumbers(content, params.Offset+1)
|
||||
|
||||
// Add a note if the content was truncated
|
||||
if lineCount > params.Offset+len(strings.Split(content, "\n")) {
|
||||
output += fmt.Sprintf("\n\n(File has more lines. Use 'offset' parameter to read beyond line %d)",
|
||||
params.Offset+len(strings.Split(content, "\n")))
|
||||
}
|
||||
|
||||
output += "\n</file>\n"
|
||||
output += appendDiagnostics(filePath, v.lspClients)
|
||||
recordFileRead(filePath)
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
@@ -155,10 +161,10 @@ func addLineNumbers(content string, startLine int) string {
|
||||
numStr := fmt.Sprintf("%d", lineNum)
|
||||
|
||||
if len(numStr) >= 6 {
|
||||
result = append(result, fmt.Sprintf("%s\t%s", numStr, line))
|
||||
result = append(result, fmt.Sprintf("%s|%s", numStr, line))
|
||||
} else {
|
||||
paddedNum := fmt.Sprintf("%6s", numStr)
|
||||
result = append(result, fmt.Sprintf("%s\t|%s", paddedNum, line))
|
||||
result = append(result, fmt.Sprintf("%s|%s", paddedNum, line))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,8 +179,9 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) {
|
||||
defer file.Close()
|
||||
|
||||
lineCount := 0
|
||||
|
||||
scanner := NewLineScanner(file)
|
||||
if offset > 0 {
|
||||
scanner := NewLineScanner(file)
|
||||
for lineCount < offset && scanner.Scan() {
|
||||
lineCount++
|
||||
}
|
||||
@@ -192,7 +199,6 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) {
|
||||
|
||||
var lines []string
|
||||
lineCount = offset
|
||||
scanner := NewLineScanner(file)
|
||||
|
||||
for scanner.Scan() && len(lines) < limit {
|
||||
lineCount++
|
||||
@@ -290,6 +296,8 @@ TIPS:
|
||||
- When viewing large files, use the offset parameter to read specific sections`
|
||||
}
|
||||
|
||||
func NewViewTool() BaseTool {
|
||||
return &viewTool{}
|
||||
func NewViewTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &viewTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,10 +9,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
)
|
||||
|
||||
type writeTool struct{}
|
||||
type writeTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
WriteToolName = "write"
|
||||
@@ -96,6 +99,8 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
if err = os.MkdirAll(dir, 0o755); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Failed to create parent directories: %s", err)), nil
|
||||
}
|
||||
|
||||
notifyLspOpenFile(ctx, filePath, w.lspClients)
|
||||
p := permission.Default.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
Path: filePath,
|
||||
@@ -122,7 +127,10 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
recordFileWrite(filePath)
|
||||
recordFileRead(filePath)
|
||||
|
||||
return NewTextResponse(fmt.Sprintf("File successfully written: %s", filePath)), nil
|
||||
result := fmt.Sprintf("File successfully written: %s", filePath)
|
||||
result = fmt.Sprintf("<result>\n%s\n</result>", result)
|
||||
result += appendDiagnostics(filePath, w.lspClients)
|
||||
return NewTextResponse(result), nil
|
||||
}
|
||||
|
||||
func writeDescription() string {
|
||||
@@ -156,6 +164,8 @@ TIPS:
|
||||
- Always include descriptive comments when making changes to existing code`
|
||||
}
|
||||
|
||||
func NewWriteTool() BaseTool {
|
||||
return &writeTool{}
|
||||
func NewWriteTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &writeTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,13 +8,14 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteTool_Info(t *testing.T) {
|
||||
tool := NewWriteTool()
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client))
|
||||
info := tool.Info()
|
||||
|
||||
assert.Equal(t, WriteToolName, info.Name)
|
||||
@@ -40,11 +41,11 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
|
||||
t.Run("creates a new file successfully", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client))
|
||||
|
||||
filePath := filepath.Join(tempDir, "new_file.txt")
|
||||
content := "This is a test content"
|
||||
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: filePath,
|
||||
Content: content,
|
||||
@@ -70,11 +71,11 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
|
||||
t.Run("creates file with nested directories", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client))
|
||||
|
||||
filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
|
||||
content := "Content in nested directory"
|
||||
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: filePath,
|
||||
Content: content,
|
||||
@@ -100,17 +101,17 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
|
||||
t.Run("updates existing file", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client))
|
||||
|
||||
// Create a file first
|
||||
filePath := filepath.Join(tempDir, "existing_file.txt")
|
||||
initialContent := "Initial content"
|
||||
err := os.WriteFile(filePath, []byte(initialContent), 0644)
|
||||
err := os.WriteFile(filePath, []byte(initialContent), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// Record the file read to avoid modification time check failure
|
||||
recordFileRead(filePath)
|
||||
|
||||
|
||||
// Update the file
|
||||
updatedContent := "Updated content"
|
||||
params := WriteParams{
|
||||
@@ -138,8 +139,8 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
|
||||
t.Run("handles invalid parameters", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client))
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
Input: "invalid json",
|
||||
@@ -152,8 +153,8 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
|
||||
t.Run("handles missing file_path", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client))
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: "",
|
||||
Content: "Some content",
|
||||
@@ -174,8 +175,8 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
|
||||
t.Run("handles missing content", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client))
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: filepath.Join(tempDir, "file.txt"),
|
||||
Content: "",
|
||||
@@ -196,13 +197,13 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
|
||||
t.Run("handles writing to a directory path", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client))
|
||||
|
||||
// Create a directory
|
||||
dirPath := filepath.Join(tempDir, "test_dir")
|
||||
err := os.Mkdir(dirPath, 0755)
|
||||
err := os.Mkdir(dirPath, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: dirPath,
|
||||
Content: "Some content",
|
||||
@@ -223,8 +224,8 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
|
||||
t.Run("handles permission denied", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(false)
|
||||
tool := NewWriteTool()
|
||||
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client))
|
||||
|
||||
filePath := filepath.Join(tempDir, "permission_denied.txt")
|
||||
params := WriteParams{
|
||||
FilePath: filePath,
|
||||
@@ -242,7 +243,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "Permission denied")
|
||||
|
||||
|
||||
// Verify file was not created
|
||||
_, err = os.Stat(filePath)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
@@ -250,14 +251,14 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
|
||||
t.Run("detects file modified since last read", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client))
|
||||
|
||||
// Create a file
|
||||
filePath := filepath.Join(tempDir, "modified_file.txt")
|
||||
initialContent := "Initial content"
|
||||
err := os.WriteFile(filePath, []byte(initialContent), 0644)
|
||||
err := os.WriteFile(filePath, []byte(initialContent), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// Record an old read time
|
||||
fileRecordMutex.Lock()
|
||||
fileRecords[filePath] = fileRecord{
|
||||
@@ -265,7 +266,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
readTime: time.Now().Add(-1 * time.Hour),
|
||||
}
|
||||
fileRecordMutex.Unlock()
|
||||
|
||||
|
||||
// Try to update the file
|
||||
params := WriteParams{
|
||||
FilePath: filePath,
|
||||
@@ -283,7 +284,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "has been modified since it was last read")
|
||||
|
||||
|
||||
// Verify file was not modified
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
require.NoError(t, err)
|
||||
@@ -292,17 +293,17 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
|
||||
t.Run("skips writing when content is identical", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client))
|
||||
|
||||
// Create a file
|
||||
filePath := filepath.Join(tempDir, "identical_content.txt")
|
||||
content := "Content that won't change"
|
||||
err := os.WriteFile(filePath, []byte(content), 0644)
|
||||
err := os.WriteFile(filePath, []byte(content), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// Record a read time
|
||||
recordFileRead(filePath)
|
||||
|
||||
|
||||
// Try to write the same content
|
||||
params := WriteParams{
|
||||
FilePath: filePath,
|
||||
@@ -321,4 +322,5 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "already contains the exact content")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user