mirror of
https://github.com/aljazceru/opencode.git
synced 2026-01-07 01:44:56 +01:00
reimplement agent,provider and add file history
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
@@ -53,7 +54,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
|
||||
return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required")
|
||||
}
|
||||
|
||||
agent, err := NewTaskAgent(b.messages, b.sessions, b.lspClients)
|
||||
agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients))
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err)
|
||||
}
|
||||
@@ -63,21 +64,16 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
|
||||
return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
|
||||
}
|
||||
|
||||
err = agent.Generate(ctx, session.ID, params.Prompt)
|
||||
done, err := agent.Run(ctx, session.ID, params.Prompt)
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err)
|
||||
}
|
||||
|
||||
messages, err := b.messages.List(ctx, session.ID)
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error listing messages: %s", err)
|
||||
result := <-done
|
||||
if result.Err() != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", result.Err())
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return tools.NewTextErrorResponse("no response"), nil
|
||||
}
|
||||
|
||||
response := messages[len(messages)-1]
|
||||
response := result.Response()
|
||||
if response.Role != message.Assistant {
|
||||
return tools.NewTextErrorResponse("no response"), nil
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,63 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/kujtimiihoxha/termai/internal/session"
|
||||
)
|
||||
|
||||
type coderAgent struct {
|
||||
Service
|
||||
}
|
||||
|
||||
func NewCoderAgent(
|
||||
permissions permission.Service,
|
||||
sessions session.Service,
|
||||
messages message.Service,
|
||||
lspClients map[string]*lsp.Client,
|
||||
) (Service, error) {
|
||||
model, ok := models.SupportedModels[config.Get().Model.Coder]
|
||||
if !ok {
|
||||
return nil, errors.New("model not supported")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
otherTools := GetMcpTools(ctx, permissions)
|
||||
if len(lspClients) > 0 {
|
||||
otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
|
||||
}
|
||||
agent, err := NewAgent(
|
||||
ctx,
|
||||
sessions,
|
||||
messages,
|
||||
model,
|
||||
append(
|
||||
[]tools.BaseTool{
|
||||
tools.NewBashTool(permissions),
|
||||
tools.NewEditTool(lspClients, permissions),
|
||||
tools.NewFetchTool(permissions),
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewSourcegraphTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
tools.NewWriteTool(lspClients, permissions),
|
||||
NewAgentTool(sessions, messages, lspClients),
|
||||
}, otherTools...,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &coderAgent{
|
||||
agent,
|
||||
}, nil
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t
|
||||
initRequest := mcp.InitializeRequest{}
|
||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initRequest.Params.ClientInfo = mcp.Implementation{
|
||||
Name: "termai",
|
||||
Name: "OpenCode",
|
||||
Version: version.Version,
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions
|
||||
initRequest := mcp.InitializeRequest{}
|
||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initRequest.Params.ClientInfo = mcp.Implementation{
|
||||
Name: "termai",
|
||||
Name: "OpenCode",
|
||||
Version: version.Version,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/kujtimiihoxha/termai/internal/session"
|
||||
)
|
||||
|
||||
type taskAgent struct {
|
||||
Service
|
||||
}
|
||||
|
||||
func NewTaskAgent(messages message.Service, sessions session.Service, lspClients map[string]*lsp.Client) (Service, error) {
|
||||
model, ok := models.SupportedModels[config.Get().Model.Coder]
|
||||
if !ok {
|
||||
return nil, errors.New("model not supported")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
agent, err := NewAgent(
|
||||
ctx,
|
||||
sessions,
|
||||
messages,
|
||||
model,
|
||||
[]tools.BaseTool{
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewSourcegraphTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &taskAgent{
|
||||
agent,
|
||||
}, nil
|
||||
}
|
||||
50
internal/llm/agent/tools.go
Normal file
50
internal/llm/agent/tools.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/history"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/kujtimiihoxha/termai/internal/session"
|
||||
)
|
||||
|
||||
func CoderAgentTools(
|
||||
permissions permission.Service,
|
||||
sessions session.Service,
|
||||
messages message.Service,
|
||||
history history.Service,
|
||||
lspClients map[string]*lsp.Client,
|
||||
) []tools.BaseTool {
|
||||
ctx := context.Background()
|
||||
otherTools := GetMcpTools(ctx, permissions)
|
||||
if len(lspClients) > 0 {
|
||||
otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
|
||||
}
|
||||
return append(
|
||||
[]tools.BaseTool{
|
||||
tools.NewBashTool(permissions),
|
||||
tools.NewEditTool(lspClients, permissions, history),
|
||||
tools.NewFetchTool(permissions),
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewSourcegraphTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
tools.NewWriteTool(lspClients, permissions, history),
|
||||
NewAgentTool(sessions, messages, lspClients),
|
||||
}, otherTools...,
|
||||
)
|
||||
}
|
||||
|
||||
func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
|
||||
return []tools.BaseTool{
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewSourcegraphTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
}
|
||||
}
|
||||
71
internal/llm/models/anthropic.go
Normal file
71
internal/llm/models/anthropic.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderAnthropic ModelProvider = "anthropic"
|
||||
|
||||
// Models
|
||||
Claude35Sonnet ModelID = "claude-3.5-sonnet"
|
||||
Claude3Haiku ModelID = "claude-3-haiku"
|
||||
Claude37Sonnet ModelID = "claude-3.7-sonnet"
|
||||
Claude35Haiku ModelID = "claude-3.5-haiku"
|
||||
Claude3Opus ModelID = "claude-3-opus"
|
||||
)
|
||||
|
||||
var AnthropicModels = map[ModelID]Model{
|
||||
// Anthropic
|
||||
Claude35Sonnet: {
|
||||
ID: Claude35Sonnet,
|
||||
Name: "Claude 3.5 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200000,
|
||||
},
|
||||
Claude3Haiku: {
|
||||
ID: Claude3Haiku,
|
||||
Name: "Claude 3 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-haiku-latest",
|
||||
CostPer1MIn: 0.25,
|
||||
CostPer1MInCached: 0.30,
|
||||
CostPer1MOutCached: 0.03,
|
||||
CostPer1MOut: 1.25,
|
||||
ContextWindow: 200000,
|
||||
},
|
||||
Claude37Sonnet: {
|
||||
ID: Claude37Sonnet,
|
||||
Name: "Claude 3.7 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-7-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200000,
|
||||
},
|
||||
Claude35Haiku: {
|
||||
ID: Claude35Haiku,
|
||||
Name: "Claude 3.5 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-haiku-latest",
|
||||
CostPer1MIn: 0.80,
|
||||
CostPer1MInCached: 1.0,
|
||||
CostPer1MOutCached: 0.08,
|
||||
CostPer1MOut: 4.0,
|
||||
ContextWindow: 200000,
|
||||
},
|
||||
Claude3Opus: {
|
||||
ID: Claude3Opus,
|
||||
Name: "Claude 3 Opus",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-opus-latest",
|
||||
CostPer1MIn: 15.0,
|
||||
CostPer1MInCached: 18.75,
|
||||
CostPer1MOutCached: 1.50,
|
||||
CostPer1MOut: 75.0,
|
||||
ContextWindow: 200000,
|
||||
},
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package models
|
||||
|
||||
import "maps"
|
||||
|
||||
type (
|
||||
ModelID string
|
||||
ModelProvider string
|
||||
@@ -14,15 +16,13 @@ type Model struct {
|
||||
CostPer1MOut float64 `json:"cost_per_1m_out"`
|
||||
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
|
||||
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
|
||||
ContextWindow int64 `json:"context_window"`
|
||||
}
|
||||
|
||||
// Model IDs
|
||||
const (
|
||||
// Anthropic
|
||||
Claude35Sonnet ModelID = "claude-3.5-sonnet"
|
||||
Claude3Haiku ModelID = "claude-3-haiku"
|
||||
Claude37Sonnet ModelID = "claude-3.7-sonnet"
|
||||
// OpenAI
|
||||
GPT4o ModelID = "gpt-4o"
|
||||
GPT41 ModelID = "gpt-4.1"
|
||||
|
||||
// GEMINI
|
||||
@@ -37,47 +37,59 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderOpenAI ModelProvider = "openai"
|
||||
ProviderAnthropic ModelProvider = "anthropic"
|
||||
ProviderBedrock ModelProvider = "bedrock"
|
||||
ProviderGemini ModelProvider = "gemini"
|
||||
ProviderGROQ ModelProvider = "groq"
|
||||
ProviderOpenAI ModelProvider = "openai"
|
||||
ProviderBedrock ModelProvider = "bedrock"
|
||||
ProviderGemini ModelProvider = "gemini"
|
||||
ProviderGROQ ModelProvider = "groq"
|
||||
|
||||
// ForTests
|
||||
ProviderMock ModelProvider = "__mock"
|
||||
)
|
||||
|
||||
var SupportedModels = map[ModelID]Model{
|
||||
// Anthropic
|
||||
Claude35Sonnet: {
|
||||
ID: Claude35Sonnet,
|
||||
Name: "Claude 3.5 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
// // Anthropic
|
||||
// Claude35Sonnet: {
|
||||
// ID: Claude35Sonnet,
|
||||
// Name: "Claude 3.5 Sonnet",
|
||||
// Provider: ProviderAnthropic,
|
||||
// APIModel: "claude-3-5-sonnet-latest",
|
||||
// CostPer1MIn: 3.0,
|
||||
// CostPer1MInCached: 3.75,
|
||||
// CostPer1MOutCached: 0.30,
|
||||
// CostPer1MOut: 15.0,
|
||||
// },
|
||||
// Claude3Haiku: {
|
||||
// ID: Claude3Haiku,
|
||||
// Name: "Claude 3 Haiku",
|
||||
// Provider: ProviderAnthropic,
|
||||
// APIModel: "claude-3-haiku-latest",
|
||||
// CostPer1MIn: 0.80,
|
||||
// CostPer1MInCached: 1,
|
||||
// CostPer1MOutCached: 0.08,
|
||||
// CostPer1MOut: 4,
|
||||
// },
|
||||
// Claude37Sonnet: {
|
||||
// ID: Claude37Sonnet,
|
||||
// Name: "Claude 3.7 Sonnet",
|
||||
// Provider: ProviderAnthropic,
|
||||
// APIModel: "claude-3-7-sonnet-latest",
|
||||
// CostPer1MIn: 3.0,
|
||||
// CostPer1MInCached: 3.75,
|
||||
// CostPer1MOutCached: 0.30,
|
||||
// CostPer1MOut: 15.0,
|
||||
// },
|
||||
//
|
||||
// // OpenAI
|
||||
GPT4o: {
|
||||
ID: GPT4o,
|
||||
Name: "GPT-4o",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: 2.00,
|
||||
CostPer1MInCached: 0.50,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 8.00,
|
||||
},
|
||||
Claude3Haiku: {
|
||||
ID: Claude3Haiku,
|
||||
Name: "Claude 3 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-haiku-latest",
|
||||
CostPer1MIn: 0.80,
|
||||
CostPer1MInCached: 1,
|
||||
CostPer1MOutCached: 0.08,
|
||||
CostPer1MOut: 4,
|
||||
},
|
||||
Claude37Sonnet: {
|
||||
ID: Claude37Sonnet,
|
||||
Name: "Claude 3.7 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-7-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
},
|
||||
|
||||
// OpenAI
|
||||
GPT41: {
|
||||
ID: GPT41,
|
||||
Name: "GPT-4.1",
|
||||
@@ -88,51 +100,55 @@ var SupportedModels = map[ModelID]Model{
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 8.00,
|
||||
},
|
||||
|
||||
// GEMINI
|
||||
GEMINI25: {
|
||||
ID: GEMINI25,
|
||||
Name: "Gemini 2.5 Pro",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.5-pro-exp-03-25",
|
||||
CostPer1MIn: 0,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0,
|
||||
},
|
||||
|
||||
GRMINI20Flash: {
|
||||
ID: GRMINI20Flash,
|
||||
Name: "Gemini 2.0 Flash",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.0-flash",
|
||||
CostPer1MIn: 0.1,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0.025,
|
||||
CostPer1MOut: 0.4,
|
||||
},
|
||||
|
||||
// GROQ
|
||||
QWENQwq: {
|
||||
ID: QWENQwq,
|
||||
Name: "Qwen Qwq",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "qwen-qwq-32b",
|
||||
CostPer1MIn: 0,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0,
|
||||
},
|
||||
|
||||
// Bedrock
|
||||
BedrockClaude37Sonnet: {
|
||||
ID: BedrockClaude37Sonnet,
|
||||
Name: "Bedrock: Claude 3.7 Sonnet",
|
||||
Provider: ProviderBedrock,
|
||||
APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
},
|
||||
//
|
||||
// // GEMINI
|
||||
// GEMINI25: {
|
||||
// ID: GEMINI25,
|
||||
// Name: "Gemini 2.5 Pro",
|
||||
// Provider: ProviderGemini,
|
||||
// APIModel: "gemini-2.5-pro-exp-03-25",
|
||||
// CostPer1MIn: 0,
|
||||
// CostPer1MInCached: 0,
|
||||
// CostPer1MOutCached: 0,
|
||||
// CostPer1MOut: 0,
|
||||
// },
|
||||
//
|
||||
// GRMINI20Flash: {
|
||||
// ID: GRMINI20Flash,
|
||||
// Name: "Gemini 2.0 Flash",
|
||||
// Provider: ProviderGemini,
|
||||
// APIModel: "gemini-2.0-flash",
|
||||
// CostPer1MIn: 0.1,
|
||||
// CostPer1MInCached: 0,
|
||||
// CostPer1MOutCached: 0.025,
|
||||
// CostPer1MOut: 0.4,
|
||||
// },
|
||||
//
|
||||
// // GROQ
|
||||
// QWENQwq: {
|
||||
// ID: QWENQwq,
|
||||
// Name: "Qwen Qwq",
|
||||
// Provider: ProviderGROQ,
|
||||
// APIModel: "qwen-qwq-32b",
|
||||
// CostPer1MIn: 0,
|
||||
// CostPer1MInCached: 0,
|
||||
// CostPer1MOutCached: 0,
|
||||
// CostPer1MOut: 0,
|
||||
// },
|
||||
//
|
||||
// // Bedrock
|
||||
// BedrockClaude37Sonnet: {
|
||||
// ID: BedrockClaude37Sonnet,
|
||||
// Name: "Bedrock: Claude 3.7 Sonnet",
|
||||
// Provider: ProviderBedrock,
|
||||
// APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
// CostPer1MIn: 3.0,
|
||||
// CostPer1MInCached: 3.75,
|
||||
// CostPer1MOutCached: 0.30,
|
||||
// CostPer1MOut: 15.0,
|
||||
// },
|
||||
}
|
||||
|
||||
func init() {
|
||||
maps.Copy(SupportedModels, AnthropicModels)
|
||||
}
|
||||
|
||||
@@ -9,11 +9,22 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
)
|
||||
|
||||
func CoderOpenAISystemPrompt() string {
|
||||
basePrompt := `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting.
|
||||
func CoderPrompt(provider models.ModelProvider) string {
|
||||
basePrompt := baseAnthropicCoderPrompt
|
||||
switch provider {
|
||||
case models.ProviderOpenAI:
|
||||
basePrompt = baseOpenAICoderPrompt
|
||||
}
|
||||
envInfo := getEnvironmentInfo()
|
||||
|
||||
return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
|
||||
}
|
||||
|
||||
const baseOpenAICoderPrompt = `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting.
|
||||
|
||||
# Your mindset
|
||||
Act like a competent, efficient software engineer who is familiar with large codebases. You should:
|
||||
@@ -65,13 +76,7 @@ assistant: [searches repo for references, returns file paths and lines]
|
||||
|
||||
Never commit changes unless the user explicitly asks you to.`
|
||||
|
||||
envInfo := getEnvironmentInfo()
|
||||
|
||||
return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
|
||||
}
|
||||
|
||||
func CoderAnthropicSystemPrompt() string {
|
||||
basePrompt := `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
|
||||
const baseAnthropicCoderPrompt = `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
|
||||
|
||||
IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure.
|
||||
|
||||
@@ -166,11 +171,6 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN
|
||||
|
||||
You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.`
|
||||
|
||||
envInfo := getEnvironmentInfo()
|
||||
|
||||
return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
|
||||
}
|
||||
|
||||
func getEnvironmentInfo() string {
|
||||
cwd := config.WorkingDirectory()
|
||||
isGit := isGitRepo(cwd)
|
||||
|
||||
19
internal/llm/prompt/prompt.go
Normal file
19
internal/llm/prompt/prompt.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
)
|
||||
|
||||
func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string {
|
||||
switch agentName {
|
||||
case config.AgentCoder:
|
||||
return CoderPrompt(provider)
|
||||
case config.AgentTitle:
|
||||
return TitlePrompt(provider)
|
||||
case config.AgentTask:
|
||||
return TaskPrompt(provider)
|
||||
default:
|
||||
return "You are a helpful assistant"
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,12 @@ package prompt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
)
|
||||
|
||||
func TaskAgentSystemPrompt() string {
|
||||
func TaskPrompt(_ models.ModelProvider) string {
|
||||
agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question.
|
||||
|
||||
Notes:
|
||||
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
|
||||
2. When relevant, share file names and code snippets relevant to the query
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package prompt
|
||||
|
||||
func TitlePrompt() string {
|
||||
import "github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
|
||||
func TitlePrompt(_ models.ModelProvider) string {
|
||||
return `you will generate a short title based on the first message a user begins a conversation with
|
||||
- ensure it is not more than 50 characters long
|
||||
- the title should be a summary of the user's message
|
||||
|
||||
@@ -12,187 +12,257 @@ import (
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/bedrock"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/logging"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
)
|
||||
|
||||
type anthropicProvider struct {
|
||||
client anthropic.Client
|
||||
model models.Model
|
||||
maxTokens int64
|
||||
apiKey string
|
||||
systemMessage string
|
||||
useBedrock bool
|
||||
disableCache bool
|
||||
type anthropicOptions struct {
|
||||
useBedrock bool
|
||||
disableCache bool
|
||||
shouldThink func(userMessage string) bool
|
||||
}
|
||||
|
||||
type AnthropicOption func(*anthropicProvider)
|
||||
type AnthropicOption func(*anthropicOptions)
|
||||
|
||||
func WithAnthropicSystemMessage(message string) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.systemMessage = message
|
||||
type anthropicClient struct {
|
||||
providerOptions providerClientOptions
|
||||
options anthropicOptions
|
||||
client anthropic.Client
|
||||
}
|
||||
|
||||
type AnthropicClient ProviderClient
|
||||
|
||||
func newAnthropicClient(opts providerClientOptions) AnthropicClient {
|
||||
anthropicOpts := anthropicOptions{}
|
||||
for _, o := range opts.anthropicOptions {
|
||||
o(&anthropicOpts)
|
||||
}
|
||||
|
||||
anthropicClientOptions := []option.RequestOption{}
|
||||
if opts.apiKey != "" {
|
||||
anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
|
||||
}
|
||||
if anthropicOpts.useBedrock {
|
||||
anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
|
||||
}
|
||||
|
||||
client := anthropic.NewClient(anthropicClientOptions...)
|
||||
return &anthropicClient{
|
||||
providerOptions: opts,
|
||||
options: anthropicOpts,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
|
||||
cachedBlocks := 0
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
content := anthropic.NewTextBlock(msg.Content().String())
|
||||
if cachedBlocks < 2 && !a.options.disableCache {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
cachedBlocks++
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
|
||||
|
||||
func WithAnthropicModel(model models.Model) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicKey(apiKey string) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicBedrock() AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.useBedrock = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicDisableCache() AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.disableCache = true
|
||||
}
|
||||
}
|
||||
|
||||
func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
|
||||
provider := &anthropicProvider{
|
||||
maxTokens: 1024,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
|
||||
if provider.systemMessage == "" {
|
||||
return nil, errors.New("system message is required")
|
||||
}
|
||||
|
||||
anthropicOptions := []option.RequestOption{}
|
||||
|
||||
if provider.apiKey != "" {
|
||||
anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
|
||||
}
|
||||
if provider.useBedrock {
|
||||
anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
|
||||
}
|
||||
|
||||
provider.client = anthropic.NewClient(anthropicOptions...)
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
messages = cleanupMessages(messages)
|
||||
anthropicMessages := a.convertToAnthropicMessages(messages)
|
||||
anthropicTools := a.convertToAnthropicTools(tools)
|
||||
|
||||
response, err := a.client.Messages.New(
|
||||
ctx,
|
||||
anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(a.model.APIModel),
|
||||
MaxTokens: a.maxTokens,
|
||||
Temperature: anthropic.Float(0),
|
||||
Messages: anthropicMessages,
|
||||
Tools: anthropicTools,
|
||||
System: []anthropic.TextBlockParam{
|
||||
{
|
||||
Text: a.systemMessage,
|
||||
CacheControl: anthropic.CacheControlEphemeralParam{
|
||||
case message.Assistant:
|
||||
blocks := []anthropic.ContentBlockParamUnion{}
|
||||
if msg.Content().String() != "" {
|
||||
content := anthropic.NewTextBlock(msg.Content().String())
|
||||
if cachedBlocks < 2 && !a.options.disableCache {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
}
|
||||
cachedBlocks++
|
||||
}
|
||||
blocks = append(blocks, content)
|
||||
}
|
||||
|
||||
for _, toolCall := range msg.ToolCalls() {
|
||||
var inputMap map[string]any
|
||||
err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
logging.Warn("There is a message without content, investigate")
|
||||
// This should never happend but we log this because we might have a bug in our cleanup method
|
||||
continue
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||
|
||||
case message.Tool:
|
||||
results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
|
||||
for i, toolResult := range msg.ToolResults() {
|
||||
results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
|
||||
anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
toolParam := anthropic.ToolParam{
|
||||
Name: info.Name,
|
||||
Description: anthropic.String(info.Description),
|
||||
InputSchema: anthropic.ToolInputSchemaParam{
|
||||
Properties: info.Parameters,
|
||||
// TODO: figure out how we can tell claude the required fields?
|
||||
},
|
||||
}
|
||||
|
||||
if i == len(tools)-1 && !a.options.disableCache {
|
||||
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
}
|
||||
|
||||
anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
|
||||
}
|
||||
|
||||
return anthropicTools
|
||||
}
|
||||
|
||||
func (a *anthropicClient) finishReason(reason string) message.FinishReason {
|
||||
switch reason {
|
||||
case "end_turn":
|
||||
return message.FinishReasonEndTurn
|
||||
case "max_tokens":
|
||||
return message.FinishReasonMaxTokens
|
||||
case "tool_use":
|
||||
return message.FinishReasonToolUse
|
||||
case "stop_sequence":
|
||||
return message.FinishReasonEndTurn
|
||||
default:
|
||||
return message.FinishReasonUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
|
||||
var thinkingParam anthropic.ThinkingConfigParamUnion
|
||||
lastMessage := messages[len(messages)-1]
|
||||
isUser := lastMessage.Role == anthropic.MessageParamRoleUser
|
||||
messageContent := ""
|
||||
temperature := anthropic.Float(0)
|
||||
if isUser {
|
||||
for _, m := range lastMessage.Content {
|
||||
if m.OfRequestTextBlock != nil && m.OfRequestTextBlock.Text != "" {
|
||||
messageContent = m.OfRequestTextBlock.Text
|
||||
}
|
||||
}
|
||||
if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) {
|
||||
thinkingParam = anthropic.ThinkingConfigParamUnion{
|
||||
OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
|
||||
BudgetTokens: int64(float64(a.providerOptions.maxTokens) * 0.8),
|
||||
Type: "enabled",
|
||||
},
|
||||
}
|
||||
temperature = anthropic.Float(1)
|
||||
}
|
||||
}
|
||||
|
||||
return anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(a.providerOptions.model.APIModel),
|
||||
MaxTokens: a.providerOptions.maxTokens,
|
||||
Temperature: temperature,
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
Thinking: thinkingParam,
|
||||
System: []anthropic.TextBlockParam{
|
||||
{
|
||||
Text: a.providerOptions.systemMessage,
|
||||
CacheControl: anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, block := range response.Content {
|
||||
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
|
||||
content += text.Text
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := a.extractToolCalls(response.Content)
|
||||
tokenUsage := a.extractTokenUsage(response.Usage)
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
messages = cleanupMessages(messages)
|
||||
anthropicMessages := a.convertToAnthropicMessages(messages)
|
||||
anthropicTools := a.convertToAnthropicTools(tools)
|
||||
|
||||
var thinkingParam anthropic.ThinkingConfigParamUnion
|
||||
lastMessage := messages[len(messages)-1]
|
||||
temperature := anthropic.Float(0)
|
||||
if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
|
||||
thinkingParam = anthropic.ThinkingConfigParamUnion{
|
||||
OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
|
||||
BudgetTokens: int64(float64(a.maxTokens) * 0.8),
|
||||
Type: "enabled",
|
||||
},
|
||||
}
|
||||
temperature = anthropic.Float(1)
|
||||
func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (resposne *ProviderResponse, err error) {
|
||||
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(preparedMessages)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
attempts := 0
|
||||
for {
|
||||
attempts++
|
||||
anthropicResponse, err := a.client.Messages.New(
|
||||
ctx,
|
||||
preparedMessages,
|
||||
)
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
if err != nil {
|
||||
retry, after, retryErr := a.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, retryErr
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, block := range anthropicResponse.Content {
|
||||
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
|
||||
content += text.Text
|
||||
}
|
||||
}
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: a.toolCalls(*anthropicResponse),
|
||||
Usage: a.usage(*anthropicResponse),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(preparedMessages)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
attempts := 0
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
|
||||
const maxRetries = 8
|
||||
attempts := 0
|
||||
|
||||
for {
|
||||
|
||||
attempts++
|
||||
|
||||
stream := a.client.Messages.NewStreaming(
|
||||
anthropicStream := a.client.Messages.NewStreaming(
|
||||
ctx,
|
||||
anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(a.model.APIModel),
|
||||
MaxTokens: a.maxTokens,
|
||||
Temperature: temperature,
|
||||
Messages: anthropicMessages,
|
||||
Tools: anthropicTools,
|
||||
Thinking: thinkingParam,
|
||||
System: []anthropic.TextBlockParam{
|
||||
{
|
||||
Text: a.systemMessage,
|
||||
CacheControl: anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
preparedMessages,
|
||||
)
|
||||
|
||||
accumulatedMessage := anthropic.Message{}
|
||||
|
||||
for stream.Next() {
|
||||
event := stream.Current()
|
||||
for anthropicStream.Next() {
|
||||
event := anthropicStream.Current()
|
||||
err := accumulatedMessage.Accumulate(event)
|
||||
if err != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: err}
|
||||
return // Don't retry on accumulation errors
|
||||
continue
|
||||
}
|
||||
|
||||
switch event := event.AsAny().(type) {
|
||||
@@ -211,6 +281,7 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
|
||||
Content: event.Delta.Text,
|
||||
}
|
||||
}
|
||||
// TODO: check if we can somehow stream tool calls
|
||||
|
||||
case anthropic.ContentBlockStopEvent:
|
||||
eventChan <- ProviderEvent{Type: EventContentStop}
|
||||
@@ -223,84 +294,87 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := a.extractToolCalls(accumulatedMessage.Content)
|
||||
tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
FinishReason: string(accumulatedMessage.StopReason),
|
||||
ToolCalls: a.toolCalls(accumulatedMessage),
|
||||
Usage: a.usage(accumulatedMessage),
|
||||
FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := stream.Err()
|
||||
err := anthropicStream.Err()
|
||||
if err == nil || errors.Is(err, io.EOF) {
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
|
||||
var apierr *anthropic.Error
|
||||
if !errors.As(err, &apierr) {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: err}
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
retry, after, retryErr := a.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
|
||||
if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: err}
|
||||
return
|
||||
}
|
||||
|
||||
if attempts > maxRetries {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventError,
|
||||
Error: errors.New("maximum retry attempts reached for rate limit (429)"),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
retryMs := 0
|
||||
retryAfterValues := apierr.Response.Header.Values("Retry-After")
|
||||
if len(retryAfterValues) > 0 {
|
||||
var retryAfterSec int
|
||||
if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
|
||||
retryMs = retryAfterSec * 1000
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventWarning,
|
||||
Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
|
||||
if retry {
|
||||
logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// context cancelled
|
||||
if ctx.Err() != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
||||
}
|
||||
close(eventChan)
|
||||
return
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventWarning,
|
||||
Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
|
||||
}
|
||||
|
||||
backoffMs := 2000 * (1 << (attempts - 1))
|
||||
jitterMs := int(float64(backoffMs) * 0.2)
|
||||
retryMs = backoffMs + jitterMs
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
||||
return
|
||||
case <-time.After(time.Duration(retryMs) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan, nil
|
||||
return eventChan
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
|
||||
func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
|
||||
var apierr *anthropic.Error
|
||||
if !errors.As(err, &apierr) {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
if attempts > maxRetries {
|
||||
return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
|
||||
}
|
||||
|
||||
retryMs := 0
|
||||
retryAfterValues := apierr.Response.Header.Values("Retry-After")
|
||||
|
||||
backoffMs := 2000 * (1 << (attempts - 1))
|
||||
jitterMs := int(float64(backoffMs) * 0.2)
|
||||
retryMs = backoffMs + jitterMs
|
||||
if len(retryAfterValues) > 0 {
|
||||
if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
|
||||
retryMs = retryMs * 1000
|
||||
}
|
||||
}
|
||||
return true, int64(retryMs), nil
|
||||
}
|
||||
|
||||
func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
for _, block := range content {
|
||||
for _, block := range msg.Content {
|
||||
switch variant := block.AsAny().(type) {
|
||||
case anthropic.ToolUseBlock:
|
||||
toolCall := message.ToolCall{
|
||||
@@ -316,90 +390,33 @@ func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUni
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
|
||||
func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
|
||||
return TokenUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheCreationTokens: usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: usage.CacheReadInputTokens,
|
||||
InputTokens: msg.Usage.InputTokens,
|
||||
OutputTokens: msg.Usage.OutputTokens,
|
||||
CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: msg.Usage.CacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
|
||||
anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
toolParam := anthropic.ToolParam{
|
||||
Name: info.Name,
|
||||
Description: anthropic.String(info.Description),
|
||||
InputSchema: anthropic.ToolInputSchemaParam{
|
||||
Properties: info.Parameters,
|
||||
},
|
||||
}
|
||||
|
||||
if i == len(tools)-1 && !a.disableCache {
|
||||
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
}
|
||||
|
||||
anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
|
||||
func WithAnthropicBedrock(useBedrock bool) AnthropicOption {
|
||||
return func(options *anthropicOptions) {
|
||||
options.useBedrock = useBedrock
|
||||
}
|
||||
|
||||
return anthropicTools
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
|
||||
anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
|
||||
cachedBlocks := 0
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
content := anthropic.NewTextBlock(msg.Content().String())
|
||||
if cachedBlocks < 2 && !a.disableCache {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
cachedBlocks++
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
|
||||
|
||||
case message.Assistant:
|
||||
blocks := []anthropic.ContentBlockParamUnion{}
|
||||
if msg.Content().String() != "" {
|
||||
content := anthropic.NewTextBlock(msg.Content().String())
|
||||
if cachedBlocks < 2 && !a.disableCache {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
cachedBlocks++
|
||||
}
|
||||
blocks = append(blocks, content)
|
||||
}
|
||||
|
||||
for _, toolCall := range msg.ToolCalls() {
|
||||
var inputMap map[string]any
|
||||
err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
|
||||
}
|
||||
|
||||
if len(blocks) > 0 {
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||
}
|
||||
|
||||
case message.Tool:
|
||||
results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
|
||||
for i, toolResult := range msg.ToolResults() {
|
||||
results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
|
||||
}
|
||||
func WithAnthropicDisableCache() AnthropicOption {
|
||||
return func(options *anthropicOptions) {
|
||||
options.disableCache = true
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultShouldThinkFn(s string) bool {
|
||||
return strings.Contains(strings.ToLower(s), "think")
|
||||
}
|
||||
|
||||
func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption {
|
||||
return func(options *anthropicOptions) {
|
||||
options.shouldThink = fn
|
||||
}
|
||||
|
||||
return anthropicMessages
|
||||
}
|
||||
|
||||
@@ -7,33 +7,29 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
)
|
||||
|
||||
type bedrockProvider struct {
|
||||
childProvider Provider
|
||||
model models.Model
|
||||
maxTokens int64
|
||||
systemMessage string
|
||||
type bedrockOptions struct {
|
||||
// Bedrock specific options can be added here
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
return b.childProvider.SendMessages(ctx, messages, tools)
|
||||
type BedrockOption func(*bedrockOptions)
|
||||
|
||||
type bedrockClient struct {
|
||||
providerOptions providerClientOptions
|
||||
options bedrockOptions
|
||||
childProvider ProviderClient
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
return b.childProvider.StreamResponse(ctx, messages, tools)
|
||||
}
|
||||
type BedrockClient ProviderClient
|
||||
|
||||
func NewBedrockProvider(opts ...BedrockOption) (Provider, error) {
|
||||
provider := &bedrockProvider{}
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
func newBedrockClient(opts providerClientOptions) BedrockClient {
|
||||
bedrockOpts := bedrockOptions{}
|
||||
// Apply bedrock specific options if they are added in the future
|
||||
|
||||
// based on the AWS region prefix the model name with, us, eu, ap, sa, etc.
|
||||
// Get AWS region from environment
|
||||
region := os.Getenv("AWS_REGION")
|
||||
if region == "" {
|
||||
region = os.Getenv("AWS_DEFAULT_REGION")
|
||||
@@ -43,45 +39,62 @@ func NewBedrockProvider(opts ...BedrockOption) (Provider, error) {
|
||||
region = "us-east-1" // default region
|
||||
}
|
||||
if len(region) < 2 {
|
||||
return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid")
|
||||
return &bedrockClient{
|
||||
providerOptions: opts,
|
||||
options: bedrockOpts,
|
||||
childProvider: nil, // Will cause an error when used
|
||||
}
|
||||
}
|
||||
regionPrefix := region[:2]
|
||||
provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel)
|
||||
|
||||
if strings.Contains(string(provider.model.APIModel), "anthropic") {
|
||||
anthropic, err := NewAnthropicProvider(
|
||||
WithAnthropicModel(provider.model),
|
||||
WithAnthropicMaxTokens(provider.maxTokens),
|
||||
WithAnthropicSystemMessage(provider.systemMessage),
|
||||
WithAnthropicBedrock(),
|
||||
// Prefix the model name with region
|
||||
regionPrefix := region[:2]
|
||||
modelName := opts.model.APIModel
|
||||
opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
|
||||
|
||||
// Determine which provider to use based on the model
|
||||
if strings.Contains(string(opts.model.APIModel), "anthropic") {
|
||||
// Create Anthropic client with Bedrock configuration
|
||||
anthropicOpts := opts
|
||||
anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
|
||||
WithAnthropicBedrock(true),
|
||||
WithAnthropicDisableCache(),
|
||||
)
|
||||
provider.childProvider = anthropic
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return &bedrockClient{
|
||||
providerOptions: opts,
|
||||
options: bedrockOpts,
|
||||
childProvider: newAnthropicClient(anthropicOpts),
|
||||
}
|
||||
} else {
|
||||
}
|
||||
|
||||
// Return client with nil childProvider if model is not supported
|
||||
// This will cause an error when used
|
||||
return &bedrockClient{
|
||||
providerOptions: opts,
|
||||
options: bedrockOpts,
|
||||
childProvider: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
if b.childProvider == nil {
|
||||
return nil, errors.New("unsupported model for bedrock provider")
|
||||
}
|
||||
return provider, nil
|
||||
return b.childProvider.send(ctx, messages, tools)
|
||||
}
|
||||
|
||||
type BedrockOption func(*bedrockProvider)
|
||||
|
||||
func WithBedrockSystemMessage(message string) BedrockOption {
|
||||
return func(a *bedrockProvider) {
|
||||
a.systemMessage = message
|
||||
func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
if b.childProvider == nil {
|
||||
go func() {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventError,
|
||||
Error: errors.New("unsupported model for bedrock provider"),
|
||||
}
|
||||
close(eventChan)
|
||||
}()
|
||||
return eventChan
|
||||
}
|
||||
}
|
||||
|
||||
func WithBedrockMaxTokens(maxTokens int64) BedrockOption {
|
||||
return func(a *bedrockProvider) {
|
||||
a.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithBedrockModel(model models.Model) BedrockOption {
|
||||
return func(a *bedrockProvider) {
|
||||
a.model = model
|
||||
}
|
||||
}
|
||||
|
||||
return b.childProvider.stream(ctx, messages, tools)
|
||||
}
|
||||
@@ -4,81 +4,69 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
"github.com/google/uuid"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/logging"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
type geminiProvider struct {
|
||||
client *genai.Client
|
||||
model models.Model
|
||||
maxTokens int32
|
||||
apiKey string
|
||||
systemMessage string
|
||||
type geminiOptions struct {
|
||||
disableCache bool
|
||||
}
|
||||
|
||||
type GeminiOption func(*geminiProvider)
|
||||
type GeminiOption func(*geminiOptions)
|
||||
|
||||
func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
|
||||
provider := &geminiProvider{
|
||||
maxTokens: 5000,
|
||||
type geminiClient struct {
|
||||
providerOptions providerClientOptions
|
||||
options geminiOptions
|
||||
client *genai.Client
|
||||
}
|
||||
|
||||
type GeminiClient ProviderClient
|
||||
|
||||
func newGeminiClient(opts providerClientOptions) GeminiClient {
|
||||
geminiOpts := geminiOptions{}
|
||||
for _, o := range opts.geminiOptions {
|
||||
o(&geminiOpts)
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
|
||||
if provider.systemMessage == "" {
|
||||
return nil, errors.New("system message is required")
|
||||
}
|
||||
|
||||
client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
|
||||
client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
logging.Error("Failed to create Gemini client", "error", err)
|
||||
return nil
|
||||
}
|
||||
provider.client = client
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func WithGeminiSystemMessage(message string) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.systemMessage = message
|
||||
return &geminiClient{
|
||||
providerOptions: opts,
|
||||
options: geminiOpts,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiModel(model models.Model) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiKey(apiKey string) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func (p *geminiProvider) Close() {
|
||||
if p.client != nil {
|
||||
p.client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
|
||||
func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
|
||||
var history []*genai.Content
|
||||
|
||||
// Add system message first
|
||||
history = append(history, &genai.Content{
|
||||
Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)},
|
||||
Role: "user",
|
||||
})
|
||||
|
||||
// Add a system response to acknowledge the system message
|
||||
history = append(history, &genai.Content{
|
||||
Parts: []genai.Part{genai.Text("I'll help you with that.")},
|
||||
Role: "model",
|
||||
})
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
@@ -86,6 +74,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
|
||||
Parts: []genai.Part{genai.Text(msg.Content().String())},
|
||||
Role: "user",
|
||||
})
|
||||
|
||||
case message.Assistant:
|
||||
content := &genai.Content{
|
||||
Role: "model",
|
||||
@@ -107,6 +96,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
|
||||
}
|
||||
|
||||
history = append(history, content)
|
||||
|
||||
case message.Tool:
|
||||
for _, result := range msg.ToolResults() {
|
||||
response := map[string]interface{}{"result": result.Content}
|
||||
@@ -114,10 +104,11 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
|
||||
if err == nil {
|
||||
response = parsed
|
||||
}
|
||||
|
||||
var toolCall message.ToolCall
|
||||
for _, msg := range messages {
|
||||
if msg.Role == message.Assistant {
|
||||
for _, call := range msg.ToolCalls() {
|
||||
for _, m := range messages {
|
||||
if m.Role == message.Assistant {
|
||||
for _, call := range m.ToolCalls() {
|
||||
if call.ID == result.ToolCallID {
|
||||
toolCall = call
|
||||
break
|
||||
@@ -140,7 +131,335 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
|
||||
return history
|
||||
}
|
||||
|
||||
func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
|
||||
func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
|
||||
geminiTools := make([]*genai.Tool, 0, len(tools))
|
||||
|
||||
for _, tool := range tools {
|
||||
info := tool.Info()
|
||||
declaration := &genai.FunctionDeclaration{
|
||||
Name: info.Name,
|
||||
Description: info.Description,
|
||||
Parameters: &genai.Schema{
|
||||
Type: genai.TypeObject,
|
||||
Properties: convertSchemaProperties(info.Parameters),
|
||||
Required: info.Required,
|
||||
},
|
||||
}
|
||||
|
||||
geminiTools = append(geminiTools, &genai.Tool{
|
||||
FunctionDeclarations: []*genai.FunctionDeclaration{declaration},
|
||||
})
|
||||
}
|
||||
|
||||
return geminiTools
|
||||
}
|
||||
|
||||
func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
|
||||
reasonStr := reason.String()
|
||||
switch {
|
||||
case reasonStr == "STOP":
|
||||
return message.FinishReasonEndTurn
|
||||
case reasonStr == "MAX_TOKENS":
|
||||
return message.FinishReasonMaxTokens
|
||||
case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"):
|
||||
return message.FinishReasonToolUse
|
||||
default:
|
||||
return message.FinishReasonUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
||||
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
||||
|
||||
// Convert tools
|
||||
if len(tools) > 0 {
|
||||
model.Tools = g.convertTools(tools)
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
geminiMessages := g.convertMessages(messages)
|
||||
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(geminiMessages)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
for {
|
||||
attempts++
|
||||
chat := model.StartChat()
|
||||
chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
var lastText string
|
||||
for _, part := range lastMsg.Parts {
|
||||
if text, ok := part.(genai.Text); ok {
|
||||
lastText = string(text)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := chat.SendMessage(ctx, genai.Text(lastText))
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
if err != nil {
|
||||
retry, after, retryErr := g.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, retryErr
|
||||
}
|
||||
|
||||
content := ""
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
content = string(p)
|
||||
case genai.FunctionCall:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: g.usage(resp),
|
||||
FinishReason: g.finishReason(resp.Candidates[0].FinishReason),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
||||
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
||||
|
||||
// Convert tools
|
||||
if len(tools) > 0 {
|
||||
model.Tools = g.convertTools(tools)
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
geminiMessages := g.convertMessages(messages)
|
||||
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(geminiMessages)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
|
||||
for {
|
||||
attempts++
|
||||
chat := model.StartChat()
|
||||
chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
var lastText string
|
||||
for _, part := range lastMsg.Parts {
|
||||
if text, ok := part.(genai.Text); ok {
|
||||
lastText = string(text)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
iter := chat.SendMessageStream(ctx, genai.Text(lastText))
|
||||
|
||||
currentContent := ""
|
||||
toolCalls := []message.ToolCall{}
|
||||
var finalResp *genai.GenerateContentResponse
|
||||
|
||||
eventChan <- ProviderEvent{Type: EventContentStart}
|
||||
|
||||
for {
|
||||
resp, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
retry, after, retryErr := g.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
|
||||
return
|
||||
}
|
||||
if retry {
|
||||
logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
||||
}
|
||||
|
||||
return
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
break
|
||||
}
|
||||
} else {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: err}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
finalResp = resp
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
newText := string(p)
|
||||
delta := newText[len(currentContent):]
|
||||
if delta != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: delta,
|
||||
}
|
||||
currentContent = newText
|
||||
}
|
||||
case genai.FunctionCall:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
newCall := message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
}
|
||||
|
||||
isNew := true
|
||||
for _, existing := range toolCalls {
|
||||
if existing.Name == newCall.Name && existing.Input == newCall.Input {
|
||||
isNew = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNew {
|
||||
toolCalls = append(toolCalls, newCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eventChan <- ProviderEvent{Type: EventContentStop}
|
||||
|
||||
if finalResp != nil {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: g.usage(finalResp),
|
||||
FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason),
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If we get here, we need to retry
|
||||
if attempts > maxRetries {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventError,
|
||||
Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Wait before retrying
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
||||
}
|
||||
return
|
||||
case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan
|
||||
}
|
||||
|
||||
func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
|
||||
// Check if error is a rate limit error
|
||||
if attempts > maxRetries {
|
||||
return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
|
||||
}
|
||||
|
||||
// Gemini doesn't have a standard error type we can check against
|
||||
// So we'll check the error message for rate limit indicators
|
||||
if errors.Is(err, io.EOF) {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
isRateLimit := false
|
||||
|
||||
// Check for common rate limit error messages
|
||||
if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
|
||||
isRateLimit = true
|
||||
}
|
||||
|
||||
if !isRateLimit {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
// Calculate backoff with jitter
|
||||
backoffMs := 2000 * (1 << (attempts - 1))
|
||||
jitterMs := int(float64(backoffMs) * 0.2)
|
||||
retryMs := backoffMs + jitterMs
|
||||
|
||||
return true, int64(retryMs), nil
|
||||
}
|
||||
|
||||
func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
if funcCall, ok := part.(genai.FunctionCall); ok {
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(funcCall.Args)
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: id,
|
||||
Name: funcCall.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
|
||||
if resp == nil || resp.UsageMetadata == nil {
|
||||
return TokenUsage{}
|
||||
}
|
||||
@@ -153,173 +472,17 @@ func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
messages = cleanupMessages(messages)
|
||||
model := p.client.GenerativeModel(p.model.APIModel)
|
||||
model.SetMaxOutputTokens(p.maxTokens)
|
||||
|
||||
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
|
||||
|
||||
if len(tools) > 0 {
|
||||
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
|
||||
for _, declaration := range declarations {
|
||||
model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
|
||||
}
|
||||
func WithGeminiDisableCache() GeminiOption {
|
||||
return func(options *geminiOptions) {
|
||||
options.disableCache = true
|
||||
}
|
||||
|
||||
chat := model.StartChat()
|
||||
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
|
||||
|
||||
lastUserMsg := messages[len(messages)-1]
|
||||
resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var content string
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
content = string(p)
|
||||
case genai.FunctionCall:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenUsage := p.extractTokenUsage(resp)
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
messages = cleanupMessages(messages)
|
||||
model := p.client.GenerativeModel(p.model.APIModel)
|
||||
model.SetMaxOutputTokens(p.maxTokens)
|
||||
|
||||
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
|
||||
|
||||
if len(tools) > 0 {
|
||||
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
|
||||
for _, declaration := range declarations {
|
||||
model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
|
||||
}
|
||||
}
|
||||
|
||||
chat := model.StartChat()
|
||||
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
|
||||
|
||||
lastUserMsg := messages[len(messages)-1]
|
||||
|
||||
iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
|
||||
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
|
||||
var finalResp *genai.GenerateContentResponse
|
||||
currentContent := ""
|
||||
toolCalls := []message.ToolCall{}
|
||||
|
||||
for {
|
||||
resp, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventError,
|
||||
Error: err,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
finalResp = resp
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
newText := string(p)
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: newText,
|
||||
}
|
||||
currentContent += newText
|
||||
case genai.FunctionCall:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
newCall := message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
}
|
||||
|
||||
isNew := true
|
||||
for _, existing := range toolCalls {
|
||||
if existing.Name == newCall.Name && existing.Input == newCall.Input {
|
||||
isNew = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNew {
|
||||
toolCalls = append(toolCalls, newCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenUsage := p.extractTokenUsage(finalResp)
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
|
||||
},
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan, nil
|
||||
}
|
||||
|
||||
func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
|
||||
declarations := make([]*genai.FunctionDeclaration, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
declarations[i] = &genai.FunctionDeclaration{
|
||||
Name: info.Name,
|
||||
Description: info.Description,
|
||||
Parameters: &genai.Schema{
|
||||
Type: genai.TypeObject,
|
||||
Properties: convertSchemaProperties(info.Parameters),
|
||||
Required: info.Required,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return declarations
|
||||
// Helper functions
|
||||
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
|
||||
@@ -396,8 +559,12 @@ func mapJSONTypeToGenAI(jsonType string) genai.Type {
|
||||
}
|
||||
}
|
||||
|
||||
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &result)
|
||||
return result, err
|
||||
func contains(s string, substrs ...string) bool {
|
||||
for _, substr := range substrs {
|
||||
if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -2,89 +2,65 @@ package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/logging"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
)
|
||||
|
||||
type openaiProvider struct {
|
||||
client openai.Client
|
||||
model models.Model
|
||||
maxTokens int64
|
||||
baseURL string
|
||||
apiKey string
|
||||
systemMessage string
|
||||
type openaiOptions struct {
|
||||
baseURL string
|
||||
disableCache bool
|
||||
}
|
||||
|
||||
type OpenAIOption func(*openaiProvider)
|
||||
type OpenAIOption func(*openaiOptions)
|
||||
|
||||
func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) {
|
||||
provider := &openaiProvider{
|
||||
maxTokens: 5000,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
|
||||
clientOpts := []option.RequestOption{
|
||||
option.WithAPIKey(provider.apiKey),
|
||||
}
|
||||
if provider.baseURL != "" {
|
||||
clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL))
|
||||
}
|
||||
|
||||
provider.client = openai.NewClient(clientOpts...)
|
||||
if provider.systemMessage == "" {
|
||||
return nil, errors.New("system message is required")
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
type openaiClient struct {
|
||||
providerOptions providerClientOptions
|
||||
options openaiOptions
|
||||
client openai.Client
|
||||
}
|
||||
|
||||
func WithOpenAISystemMessage(message string) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.systemMessage = message
|
||||
type OpenAIClient ProviderClient
|
||||
|
||||
func newOpenAIClient(opts providerClientOptions) OpenAIClient {
|
||||
openaiOpts := openaiOptions{}
|
||||
for _, o := range opts.openaiOptions {
|
||||
o(&openaiOpts)
|
||||
}
|
||||
|
||||
openaiClientOptions := []option.RequestOption{}
|
||||
if opts.apiKey != "" {
|
||||
openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
|
||||
}
|
||||
if openaiOpts.baseURL != "" {
|
||||
openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
|
||||
}
|
||||
|
||||
client := openai.NewClient(openaiClientOptions...)
|
||||
return &openaiClient{
|
||||
providerOptions: opts,
|
||||
options: openaiOpts,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIModel(model models.Model) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIBaseURL(baseURL string) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.baseURL = baseURL
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIKey(apiKey string) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion {
|
||||
var chatMessages []openai.ChatCompletionMessageParamUnion
|
||||
|
||||
chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage))
|
||||
func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
|
||||
// Add system message first
|
||||
openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String()))
|
||||
openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
|
||||
|
||||
case message.Assistant:
|
||||
assistantMsg := openai.ChatCompletionAssistantMessageParam{
|
||||
@@ -111,23 +87,23 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o
|
||||
}
|
||||
}
|
||||
|
||||
chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{
|
||||
openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
|
||||
OfAssistant: &assistantMsg,
|
||||
})
|
||||
|
||||
case message.Tool:
|
||||
for _, result := range msg.ToolResults() {
|
||||
chatMessages = append(chatMessages,
|
||||
openaiMessages = append(openaiMessages,
|
||||
openai.ToolMessage(result.Content, result.ToolCallID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return chatMessages
|
||||
return
|
||||
}
|
||||
|
||||
func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
|
||||
func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
|
||||
openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
@@ -148,133 +124,238 @@ func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.C
|
||||
return openaiTools
|
||||
}
|
||||
|
||||
func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage {
|
||||
cachedTokens := int64(0)
|
||||
|
||||
cachedTokens = usage.PromptTokensDetails.CachedTokens
|
||||
inputTokens := usage.PromptTokens - cachedTokens
|
||||
|
||||
return TokenUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: usage.CompletionTokens,
|
||||
CacheCreationTokens: 0, // OpenAI doesn't provide this directly
|
||||
CacheReadTokens: cachedTokens,
|
||||
func (o *openaiClient) finishReason(reason string) message.FinishReason {
|
||||
switch reason {
|
||||
case "stop":
|
||||
return message.FinishReasonEndTurn
|
||||
case "length":
|
||||
return message.FinishReasonMaxTokens
|
||||
case "tool_calls":
|
||||
return message.FinishReasonToolUse
|
||||
default:
|
||||
return message.FinishReasonUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
messages = cleanupMessages(messages)
|
||||
chatMessages := p.convertToOpenAIMessages(messages)
|
||||
openaiTools := p.convertToOpenAITools(tools)
|
||||
func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
|
||||
return openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel(o.providerOptions.model.APIModel),
|
||||
Messages: messages,
|
||||
MaxTokens: openai.Int(o.providerOptions.maxTokens),
|
||||
Tools: tools,
|
||||
}
|
||||
}
|
||||
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel(p.model.APIModel),
|
||||
Messages: chatMessages,
|
||||
MaxTokens: openai.Int(p.maxTokens),
|
||||
Tools: openaiTools,
|
||||
func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
|
||||
params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(params)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
attempts := 0
|
||||
for {
|
||||
attempts++
|
||||
openaiResponse, err := o.client.Chat.Completions.New(
|
||||
ctx,
|
||||
params,
|
||||
)
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
if err != nil {
|
||||
retry, after, retryErr := o.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, retryErr
|
||||
}
|
||||
|
||||
content := ""
|
||||
if openaiResponse.Choices[0].Message.Content != "" {
|
||||
content = openaiResponse.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: o.toolCalls(*openaiResponse),
|
||||
Usage: o.usage(*openaiResponse),
|
||||
FinishReason: o.finishReason(string(openaiResponse.Choices[0].FinishReason)),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
|
||||
params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: openai.Bool(true),
|
||||
}
|
||||
|
||||
response, err := p.client.Chat.Completions.New(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(params)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
content := ""
|
||||
if response.Choices[0].Message.Content != "" {
|
||||
content = response.Choices[0].Message.Content
|
||||
attempts := 0
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
attempts++
|
||||
openaiStream := o.client.Chat.Completions.NewStreaming(
|
||||
ctx,
|
||||
params,
|
||||
)
|
||||
|
||||
acc := openai.ChatCompletionAccumulator{}
|
||||
currentContent := ""
|
||||
toolCalls := make([]message.ToolCall, 0)
|
||||
|
||||
for openaiStream.Next() {
|
||||
chunk := openaiStream.Current()
|
||||
acc.AddChunk(chunk)
|
||||
|
||||
if tool, ok := acc.JustFinishedToolCall(); ok {
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: tool.Id,
|
||||
Name: tool.Name,
|
||||
Input: tool.Arguments,
|
||||
Type: "function",
|
||||
})
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: choice.Delta.Content,
|
||||
}
|
||||
currentContent += choice.Delta.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := openaiStream.Err()
|
||||
if err == nil || errors.Is(err, io.EOF) {
|
||||
// Stream completed successfully
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: o.usage(acc.ChatCompletion),
|
||||
FinishReason: o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)),
|
||||
},
|
||||
}
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
retry, after, retryErr := o.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
if retry {
|
||||
logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// context cancelled
|
||||
if ctx.Err() == nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
||||
}
|
||||
close(eventChan)
|
||||
return
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan
|
||||
}
|
||||
|
||||
func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
|
||||
var apierr *openai.Error
|
||||
if !errors.As(err, &apierr) {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
if attempts > maxRetries {
|
||||
return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
|
||||
}
|
||||
|
||||
retryMs := 0
|
||||
retryAfterValues := apierr.Response.Header.Values("Retry-After")
|
||||
|
||||
backoffMs := 2000 * (1 << (attempts - 1))
|
||||
jitterMs := int(float64(backoffMs) * 0.2)
|
||||
retryMs = backoffMs + jitterMs
|
||||
if len(retryAfterValues) > 0 {
|
||||
if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
|
||||
retryMs = retryMs * 1000
|
||||
}
|
||||
}
|
||||
return true, int64(retryMs), nil
|
||||
}
|
||||
|
||||
func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
|
||||
var toolCalls []message.ToolCall
|
||||
if len(response.Choices[0].Message.ToolCalls) > 0 {
|
||||
toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls))
|
||||
for i, call := range response.Choices[0].Message.ToolCalls {
|
||||
toolCalls[i] = message.ToolCall{
|
||||
|
||||
if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
|
||||
for _, call := range completion.Choices[0].Message.ToolCalls {
|
||||
toolCall := message.ToolCall{
|
||||
ID: call.ID,
|
||||
Name: call.Function.Name,
|
||||
Input: call.Function.Arguments,
|
||||
Type: "function",
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
tokenUsage := p.extractTokenUsage(response.Usage)
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
}, nil
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
messages = cleanupMessages(messages)
|
||||
chatMessages := p.convertToOpenAIMessages(messages)
|
||||
openaiTools := p.convertToOpenAITools(tools)
|
||||
func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
|
||||
cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
|
||||
inputTokens := completion.Usage.PromptTokens - cachedTokens
|
||||
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel(p.model.APIModel),
|
||||
Messages: chatMessages,
|
||||
MaxTokens: openai.Int(p.maxTokens),
|
||||
Tools: openaiTools,
|
||||
StreamOptions: openai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: openai.Bool(true),
|
||||
},
|
||||
return TokenUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: completion.Usage.CompletionTokens,
|
||||
CacheCreationTokens: 0, // OpenAI doesn't provide this directly
|
||||
CacheReadTokens: cachedTokens,
|
||||
}
|
||||
|
||||
stream := p.client.Chat.Completions.NewStreaming(ctx, params)
|
||||
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
toolCalls := make([]message.ToolCall, 0)
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
|
||||
acc := openai.ChatCompletionAccumulator{}
|
||||
currentContent := ""
|
||||
|
||||
for stream.Next() {
|
||||
chunk := stream.Current()
|
||||
acc.AddChunk(chunk)
|
||||
|
||||
if tool, ok := acc.JustFinishedToolCall(); ok {
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: tool.Id,
|
||||
Name: tool.Name,
|
||||
Input: tool.Arguments,
|
||||
Type: "function",
|
||||
})
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: choice.Delta.Content,
|
||||
}
|
||||
currentContent += choice.Delta.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventError,
|
||||
Error: err,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tokenUsage := p.extractTokenUsage(acc.Usage)
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
},
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan, nil
|
||||
}
|
||||
|
||||
func WithOpenAIBaseURL(baseURL string) OpenAIOption {
|
||||
return func(options *openaiOptions) {
|
||||
options.baseURL = baseURL
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIDisableCache() OpenAIOption {
|
||||
return func(options *openaiOptions) {
|
||||
options.disableCache = true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,14 +2,17 @@ package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
)
|
||||
|
||||
// EventType represents the type of streaming event
|
||||
type EventType string
|
||||
|
||||
const maxRetries = 8
|
||||
|
||||
const (
|
||||
EventContentStart EventType = "content_start"
|
||||
EventContentDelta EventType = "content_delta"
|
||||
@@ -18,7 +21,6 @@ const (
|
||||
EventComplete EventType = "complete"
|
||||
EventError EventType = "error"
|
||||
EventWarning EventType = "warning"
|
||||
EventInfo EventType = "info"
|
||||
)
|
||||
|
||||
type TokenUsage struct {
|
||||
@@ -32,61 +34,152 @@ type ProviderResponse struct {
|
||||
Content string
|
||||
ToolCalls []message.ToolCall
|
||||
Usage TokenUsage
|
||||
FinishReason string
|
||||
FinishReason message.FinishReason
|
||||
}
|
||||
|
||||
type ProviderEvent struct {
|
||||
Type EventType
|
||||
Type EventType
|
||||
|
||||
Content string
|
||||
Thinking string
|
||||
ToolCall *message.ToolCall
|
||||
Error error
|
||||
Response *ProviderResponse
|
||||
|
||||
// Used for giving users info on e.x retry
|
||||
Info string
|
||||
Error error
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
|
||||
|
||||
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
|
||||
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
|
||||
Model() models.Model
|
||||
}
|
||||
|
||||
func cleanupMessages(messages []message.Message) []message.Message {
|
||||
// First pass: filter out canceled messages
|
||||
var cleanedMessages []message.Message
|
||||
type providerClientOptions struct {
|
||||
apiKey string
|
||||
model models.Model
|
||||
maxTokens int64
|
||||
systemMessage string
|
||||
|
||||
anthropicOptions []AnthropicOption
|
||||
openaiOptions []OpenAIOption
|
||||
geminiOptions []GeminiOption
|
||||
bedrockOptions []BedrockOption
|
||||
}
|
||||
|
||||
type ProviderClientOption func(*providerClientOptions)
|
||||
|
||||
type ProviderClient interface {
|
||||
send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
|
||||
stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
}
|
||||
|
||||
type baseProvider[C ProviderClient] struct {
|
||||
options providerClientOptions
|
||||
client C
|
||||
}
|
||||
|
||||
func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
|
||||
clientOptions := providerClientOptions{}
|
||||
for _, o := range opts {
|
||||
o(&clientOptions)
|
||||
}
|
||||
switch providerName {
|
||||
case models.ProviderAnthropic:
|
||||
return &baseProvider[AnthropicClient]{
|
||||
options: clientOptions,
|
||||
client: newAnthropicClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderOpenAI:
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderGemini:
|
||||
return &baseProvider[GeminiClient]{
|
||||
options: clientOptions,
|
||||
client: newGeminiClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderBedrock:
|
||||
return &baseProvider[BedrockClient]{
|
||||
options: clientOptions,
|
||||
client: newBedrockClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderMock:
|
||||
// TODO: implement mock client for test
|
||||
panic("not implemented")
|
||||
}
|
||||
return nil, fmt.Errorf("provider not supported: %s", providerName)
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
|
||||
for _, msg := range messages {
|
||||
if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 {
|
||||
// if there are toolCalls this means we want to return it to the LLM telling it that those tools have been
|
||||
// cancelled
|
||||
cleanedMessages = append(cleanedMessages, msg)
|
||||
// The message has no content
|
||||
if len(msg.Parts) == 0 {
|
||||
continue
|
||||
}
|
||||
cleaned = append(cleaned, msg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
messages = p.cleanMessages(messages)
|
||||
return p.client.send(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) Model() models.Model {
|
||||
return p.options.model
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
messages = p.cleanMessages(messages)
|
||||
return p.client.stream(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func WithAPIKey(apiKey string) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func WithModel(model models.Model) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithMaxTokens(maxTokens int64) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithSystemMessage(systemMessage string) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.systemMessage = systemMessage
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.anthropicOptions = anthropicOptions
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.openaiOptions = openaiOptions
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.geminiOptions = geminiOptions
|
||||
}
|
||||
}
|
||||
|
||||
func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.bedrockOptions = bedrockOptions
|
||||
}
|
||||
|
||||
// Second pass: filter out tool messages without a corresponding tool call
|
||||
var result []message.Message
|
||||
toolMessageIDs := make(map[string]bool)
|
||||
|
||||
for _, msg := range cleanedMessages {
|
||||
if msg.Role == message.Assistant {
|
||||
for _, toolCall := range msg.ToolCalls() {
|
||||
toolMessageIDs[toolCall.ID] = true // Mark as referenced
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keep only messages that aren't unreferenced tool messages
|
||||
for _, msg := range cleanedMessages {
|
||||
if msg.Role == message.Tool {
|
||||
for _, toolCall := range msg.ToolResults() {
|
||||
if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced {
|
||||
result = append(result, msg)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result = append(result, msg)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -23,7 +23,8 @@ type BashPermissionsParams struct {
|
||||
}
|
||||
|
||||
type BashResponseMetadata struct {
|
||||
Took int64 `json:"took"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
EndTime int64 `json:"end_time"`
|
||||
}
|
||||
type bashTool struct {
|
||||
permissions permission.Service
|
||||
@@ -282,7 +283,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("error executing command: %w", err)
|
||||
}
|
||||
took := time.Since(startTime).Milliseconds()
|
||||
|
||||
stdout = truncateOutput(stdout)
|
||||
stderr = truncateOutput(stderr)
|
||||
@@ -311,7 +311,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
}
|
||||
|
||||
metadata := BashResponseMetadata{
|
||||
Took: took,
|
||||
StartTime: startTime.UnixMilli(),
|
||||
EndTime: time.Now().UnixMilli(),
|
||||
}
|
||||
if stdout == "" {
|
||||
return WithResponseMetadata(NewTextResponse("no output"), metadata), nil
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -340,32 +338,3 @@ func TestCountLines(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Mock permission service for testing
|
||||
type mockPermissionService struct {
|
||||
*pubsub.Broker[permission.PermissionRequest]
|
||||
allow bool
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
|
||||
// Not needed for tests
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
|
||||
// Not needed for tests
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
|
||||
// Not needed for tests
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
|
||||
return m.allow
|
||||
}
|
||||
|
||||
func newMockPermissionService(allow bool) permission.Service {
|
||||
return &mockPermissionService{
|
||||
Broker: pubsub.NewBroker[permission.PermissionRequest](),
|
||||
allow: allow,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/diff"
|
||||
"github.com/kujtimiihoxha/termai/internal/history"
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
)
|
||||
@@ -35,6 +36,7 @@ type EditResponseMetadata struct {
|
||||
type editTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
permissions permission.Service
|
||||
files history.Service
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -88,10 +90,11 @@ When making edits:
|
||||
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`
|
||||
)
|
||||
|
||||
func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool {
|
||||
func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool {
|
||||
return &editTool{
|
||||
lspClients: lspClients,
|
||||
permissions: permissions,
|
||||
files: files,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,6 +156,11 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
if err != nil {
|
||||
return response, nil
|
||||
}
|
||||
if response.IsError {
|
||||
// Return early if there was an error during content replacement
|
||||
// This prevents unnecessary LSP diagnostics processing
|
||||
return response, nil
|
||||
}
|
||||
|
||||
waitForLspDiagnostics(ctx, params.FilePath, e.lspClients)
|
||||
text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
|
||||
@@ -208,6 +216,20 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
|
||||
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
// File can't be in the history so we create a new file history
|
||||
_, err = e.files.Create(ctx, sessionID, filePath, "")
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
}
|
||||
|
||||
// Add the new content to the file history
|
||||
_, err = e.files.CreateVersion(ctx, sessionID, filePath, content)
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
fmt.Printf("Error creating file history version: %v\n", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
recordFileRead(filePath)
|
||||
|
||||
@@ -298,6 +320,29 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
// Check if file exists in history
|
||||
file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
|
||||
if err != nil {
|
||||
_, err = e.files.Create(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
}
|
||||
}
|
||||
if file.Content != oldContent {
|
||||
// User Manually changed the content store an intermediate version
|
||||
_, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating file history version: %v\n", err)
|
||||
}
|
||||
}
|
||||
// Store the new version
|
||||
_, err = e.files.CreateVersion(ctx, sessionID, filePath, "")
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating file history version: %v\n", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
recordFileRead(filePath)
|
||||
|
||||
@@ -356,6 +401,9 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
|
||||
|
||||
newContent := oldContent[:index] + newString + oldContent[index+len(oldString):]
|
||||
|
||||
if oldContent == newContent {
|
||||
return NewTextErrorResponse("new content is the same as old content. No changes made."), nil
|
||||
}
|
||||
sessionID, messageID := GetContextValues(ctx)
|
||||
|
||||
if sessionID == "" || messageID == "" {
|
||||
@@ -374,8 +422,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
|
||||
Description: fmt.Sprintf("Replace content in file %s", filePath),
|
||||
Params: EditPermissionsParams{
|
||||
FilePath: filePath,
|
||||
|
||||
Diff: diff,
|
||||
Diff: diff,
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -388,6 +435,28 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
|
||||
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
// Check if file exists in history
|
||||
file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
|
||||
if err != nil {
|
||||
_, err = e.files.Create(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
}
|
||||
}
|
||||
if file.Content != oldContent {
|
||||
// User Manually changed the content store an intermediate version
|
||||
_, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating file history version: %v\n", err)
|
||||
}
|
||||
}
|
||||
// Store the new version
|
||||
_, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating file history version: %v\n", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
recordFileRead(filePath)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func TestEditTool_Info(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
info := tool.Info()
|
||||
|
||||
assert.Equal(t, EditToolName, info.Name)
|
||||
@@ -34,7 +34,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
t.Run("creates a new file successfully", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
filePath := filepath.Join(tempDir, "new_file.txt")
|
||||
content := "This is a test content"
|
||||
@@ -64,7 +64,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("creates file with nested directories", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
|
||||
content := "Content in nested directory"
|
||||
@@ -94,7 +94,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("fails to create file that already exists", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a file first
|
||||
filePath := filepath.Join(tempDir, "existing_file.txt")
|
||||
@@ -123,7 +123,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("fails to create file when path is a directory", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a directory
|
||||
dirPath := filepath.Join(tempDir, "test_dir")
|
||||
@@ -151,7 +151,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("replaces content successfully", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a file first
|
||||
filePath := filepath.Join(tempDir, "replace_content.txt")
|
||||
@@ -191,7 +191,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("deletes content successfully", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a file first
|
||||
filePath := filepath.Join(tempDir, "delete_content.txt")
|
||||
@@ -230,7 +230,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles invalid parameters", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
call := ToolCall{
|
||||
Name: EditToolName,
|
||||
@@ -243,7 +243,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles missing file_path", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
params := EditParams{
|
||||
FilePath: "",
|
||||
@@ -265,7 +265,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles file not found", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
filePath := filepath.Join(tempDir, "non_existent_file.txt")
|
||||
params := EditParams{
|
||||
@@ -288,7 +288,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles old_string not found in file", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a file first
|
||||
filePath := filepath.Join(tempDir, "content_not_found.txt")
|
||||
@@ -320,7 +320,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles multiple occurrences of old_string", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a file with duplicate content
|
||||
filePath := filepath.Join(tempDir, "duplicate_content.txt")
|
||||
@@ -352,7 +352,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles file modified since last read", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a file
|
||||
filePath := filepath.Join(tempDir, "modified_file.txt")
|
||||
@@ -394,7 +394,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles file not read before editing", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a file
|
||||
filePath := filepath.Join(tempDir, "not_read_file.txt")
|
||||
@@ -423,7 +423,7 @@ func TestEditTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles permission denied", func(t *testing.T) {
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false))
|
||||
tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService())
|
||||
|
||||
// Create a file
|
||||
filePath := filepath.Join(tempDir, "permission_denied.txt")
|
||||
|
||||
@@ -3,8 +3,6 @@ package tools
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
)
|
||||
|
||||
// File record to track when files were read/written
|
||||
@@ -19,14 +17,6 @@ var (
|
||||
fileRecordMutex sync.RWMutex
|
||||
)
|
||||
|
||||
func removeWorkingDirectoryPrefix(path string) string {
|
||||
wd := config.WorkingDirectory()
|
||||
if len(path) > len(wd) && path[:len(wd)] == wd {
|
||||
return path[len(wd)+1:]
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func recordFileRead(path string) {
|
||||
fileRecordMutex.Lock()
|
||||
defer fileRecordMutex.Unlock()
|
||||
|
||||
@@ -63,7 +63,7 @@ type GlobParams struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type GlobMetadata struct {
|
||||
type GlobResponseMetadata struct {
|
||||
NumberOfFiles int `json:"number_of_files"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
@@ -124,7 +124,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse(output),
|
||||
GlobMetadata{
|
||||
GlobResponseMetadata{
|
||||
NumberOfFiles: len(files),
|
||||
Truncated: truncated,
|
||||
},
|
||||
|
||||
@@ -27,7 +27,7 @@ type grepMatch struct {
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
type GrepMetadata struct {
|
||||
type GrepResponseMetadata struct {
|
||||
NumberOfMatches int `json:"number_of_matches"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse(output),
|
||||
GrepMetadata{
|
||||
GrepResponseMetadata{
|
||||
NumberOfMatches: len(matches),
|
||||
Truncated: truncated,
|
||||
},
|
||||
|
||||
@@ -23,7 +23,7 @@ type TreeNode struct {
|
||||
Children []*TreeNode `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
type LSMetadata struct {
|
||||
type LSResponseMetadata struct {
|
||||
NumberOfFiles int `json:"number_of_files"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
@@ -121,7 +121,7 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse(output),
|
||||
LSMetadata{
|
||||
LSResponseMetadata{
|
||||
NumberOfFiles: len(files),
|
||||
Truncated: truncated,
|
||||
},
|
||||
|
||||
246
internal/llm/tools/mocks_test.go
Normal file
246
internal/llm/tools/mocks_test.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/kujtimiihoxha/termai/internal/history"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
||||
)
|
||||
|
||||
// Mock permission service for testing
|
||||
type mockPermissionService struct {
|
||||
*pubsub.Broker[permission.PermissionRequest]
|
||||
allow bool
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
|
||||
// Not needed for tests
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
|
||||
// Not needed for tests
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
|
||||
// Not needed for tests
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
|
||||
return m.allow
|
||||
}
|
||||
|
||||
func newMockPermissionService(allow bool) permission.Service {
|
||||
return &mockPermissionService{
|
||||
Broker: pubsub.NewBroker[permission.PermissionRequest](),
|
||||
allow: allow,
|
||||
}
|
||||
}
|
||||
|
||||
type mockFileHistoryService struct {
|
||||
*pubsub.Broker[history.File]
|
||||
files map[string]history.File // ID -> File
|
||||
timeNow func() int64
|
||||
}
|
||||
|
||||
// Create implements history.Service.
|
||||
func (m *mockFileHistoryService) Create(ctx context.Context, sessionID string, path string, content string) (history.File, error) {
|
||||
return m.createWithVersion(ctx, sessionID, path, content, history.InitialVersion)
|
||||
}
|
||||
|
||||
// CreateVersion implements history.Service.
|
||||
func (m *mockFileHistoryService) CreateVersion(ctx context.Context, sessionID string, path string, content string) (history.File, error) {
|
||||
var files []history.File
|
||||
for _, file := range m.files {
|
||||
if file.Path == path {
|
||||
files = append(files, file)
|
||||
}
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
// No previous versions, create initial
|
||||
return m.Create(ctx, sessionID, path, content)
|
||||
}
|
||||
|
||||
// Sort files by CreatedAt in descending order
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].CreatedAt > files[j].CreatedAt
|
||||
})
|
||||
|
||||
// Get the latest version
|
||||
latestFile := files[0]
|
||||
latestVersion := latestFile.Version
|
||||
|
||||
// Generate the next version
|
||||
var nextVersion string
|
||||
if latestVersion == history.InitialVersion {
|
||||
nextVersion = "v1"
|
||||
} else if strings.HasPrefix(latestVersion, "v") {
|
||||
versionNum, err := strconv.Atoi(latestVersion[1:])
|
||||
if err != nil {
|
||||
// If we can't parse the version, just use a timestamp-based version
|
||||
nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
|
||||
} else {
|
||||
nextVersion = fmt.Sprintf("v%d", versionNum+1)
|
||||
}
|
||||
} else {
|
||||
// If the version format is unexpected, use a timestamp-based version
|
||||
nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
|
||||
}
|
||||
|
||||
return m.createWithVersion(ctx, sessionID, path, content, nextVersion)
|
||||
}
|
||||
|
||||
func (m *mockFileHistoryService) createWithVersion(_ context.Context, sessionID, path, content, version string) (history.File, error) {
|
||||
now := m.timeNow()
|
||||
file := history.File{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
Path: path,
|
||||
Content: content,
|
||||
Version: version,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
m.files[file.ID] = file
|
||||
m.Publish(pubsub.CreatedEvent, file)
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// Delete implements history.Service.
|
||||
func (m *mockFileHistoryService) Delete(ctx context.Context, id string) error {
|
||||
file, ok := m.files[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("file not found: %s", id)
|
||||
}
|
||||
|
||||
delete(m.files, id)
|
||||
m.Publish(pubsub.DeletedEvent, file)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSessionFiles implements history.Service.
|
||||
func (m *mockFileHistoryService) DeleteSessionFiles(ctx context.Context, sessionID string) error {
|
||||
files, err := m.ListBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
err = m.Delete(ctx, file.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get implements history.Service.
|
||||
func (m *mockFileHistoryService) Get(ctx context.Context, id string) (history.File, error) {
|
||||
file, ok := m.files[id]
|
||||
if !ok {
|
||||
return history.File{}, fmt.Errorf("file not found: %s", id)
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// GetByPathAndSession implements history.Service.
|
||||
func (m *mockFileHistoryService) GetByPathAndSession(ctx context.Context, path string, sessionID string) (history.File, error) {
|
||||
var latestFile history.File
|
||||
var found bool
|
||||
var latestTime int64
|
||||
|
||||
for _, file := range m.files {
|
||||
if file.Path == path && file.SessionID == sessionID {
|
||||
if !found || file.CreatedAt > latestTime {
|
||||
latestFile = file
|
||||
latestTime = file.CreatedAt
|
||||
found = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return history.File{}, fmt.Errorf("file not found: %s for session %s", path, sessionID)
|
||||
}
|
||||
return latestFile, nil
|
||||
}
|
||||
|
||||
// ListBySession implements history.Service.
|
||||
func (m *mockFileHistoryService) ListBySession(ctx context.Context, sessionID string) ([]history.File, error) {
|
||||
var files []history.File
|
||||
for _, file := range m.files {
|
||||
if file.SessionID == sessionID {
|
||||
files = append(files, file)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by CreatedAt in descending order
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].CreatedAt > files[j].CreatedAt
|
||||
})
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// ListLatestSessionFiles implements history.Service.
|
||||
func (m *mockFileHistoryService) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]history.File, error) {
|
||||
// Map to track the latest file for each path
|
||||
latestFiles := make(map[string]history.File)
|
||||
|
||||
for _, file := range m.files {
|
||||
if file.SessionID == sessionID {
|
||||
existing, ok := latestFiles[file.Path]
|
||||
if !ok || file.CreatedAt > existing.CreatedAt {
|
||||
latestFiles[file.Path] = file
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert map to slice
|
||||
var result []history.File
|
||||
for _, file := range latestFiles {
|
||||
result = append(result, file)
|
||||
}
|
||||
|
||||
// Sort by CreatedAt in descending order
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].CreatedAt > result[j].CreatedAt
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Subscribe implements history.Service.
|
||||
func (m *mockFileHistoryService) Subscribe(ctx context.Context) <-chan pubsub.Event[history.File] {
|
||||
return m.Broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
// Update implements history.Service.
|
||||
func (m *mockFileHistoryService) Update(ctx context.Context, file history.File) (history.File, error) {
|
||||
_, ok := m.files[file.ID]
|
||||
if !ok {
|
||||
return history.File{}, fmt.Errorf("file not found: %s", file.ID)
|
||||
}
|
||||
|
||||
file.UpdatedAt = m.timeNow()
|
||||
m.files[file.ID] = file
|
||||
m.Publish(pubsub.UpdatedEvent, file)
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func newMockFileHistoryService() history.Service {
|
||||
return &mockFileHistoryService{
|
||||
Broker: pubsub.NewBroker[history.File](),
|
||||
files: make(map[string]history.File),
|
||||
timeNow: func() int64 { return time.Now().Unix() },
|
||||
}
|
||||
}
|
||||
@@ -83,11 +83,21 @@ func newPersistentShell(cwd string) *PersistentShell {
|
||||
commandQueue: make(chan *commandExecution, 10),
|
||||
}
|
||||
|
||||
go shell.processCommands()
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
|
||||
shell.isAlive = false
|
||||
close(shell.commandQueue)
|
||||
}
|
||||
}()
|
||||
shell.processCommands()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
if err != nil {
|
||||
// Log the error if needed
|
||||
}
|
||||
shell.isAlive = false
|
||||
close(shell.commandQueue)
|
||||
|
||||
@@ -18,7 +18,7 @@ type SourcegraphParams struct {
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
type SourcegraphMetadata struct {
|
||||
type SourcegraphResponseMetadata struct {
|
||||
NumberOfMatches int `json:"number_of_matches"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
|
||||
@@ -14,12 +14,17 @@ type ToolInfo struct {
|
||||
|
||||
type toolResponseType string
|
||||
|
||||
type (
|
||||
sessionIDContextKey string
|
||||
messageIDContextKey string
|
||||
)
|
||||
|
||||
const (
|
||||
ToolResponseTypeText toolResponseType = "text"
|
||||
ToolResponseTypeImage toolResponseType = "image"
|
||||
|
||||
SessionIDContextKey = "session_id"
|
||||
MessageIDContextKey = "message_id"
|
||||
SessionIDContextKey sessionIDContextKey = "session_id"
|
||||
MessageIDContextKey messageIDContextKey = "message_id"
|
||||
)
|
||||
|
||||
type ToolResponse struct {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/diff"
|
||||
"github.com/kujtimiihoxha/termai/internal/history"
|
||||
"github.com/kujtimiihoxha/termai/internal/lsp"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
)
|
||||
@@ -27,6 +28,7 @@ type WritePermissionsParams struct {
|
||||
type writeTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
permissions permission.Service
|
||||
files history.Service
|
||||
}
|
||||
|
||||
type WriteResponseMetadata struct {
|
||||
@@ -67,10 +69,11 @@ TIPS:
|
||||
- Always include descriptive comments when making changes to existing code`
|
||||
)
|
||||
|
||||
func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool {
|
||||
func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool {
|
||||
return &writeTool{
|
||||
lspClients: lspClients,
|
||||
permissions: permissions,
|
||||
files: files,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,6 +179,28 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
return ToolResponse{}, fmt.Errorf("error writing file: %w", err)
|
||||
}
|
||||
|
||||
// Check if file exists in history
|
||||
file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID)
|
||||
if err != nil {
|
||||
_, err = w.files.Create(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
}
|
||||
}
|
||||
if file.Content != oldContent {
|
||||
// User Manually changed the content store an intermediate version
|
||||
_, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating file history version: %v\n", err)
|
||||
}
|
||||
}
|
||||
// Store the new version
|
||||
_, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating file history version: %v\n", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
recordFileRead(filePath)
|
||||
waitForLspDiagnostics(ctx, filePath, w.lspClients)
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func TestWriteTool_Info(t *testing.T) {
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
info := tool.Info()
|
||||
|
||||
assert.Equal(t, WriteToolName, info.Name)
|
||||
@@ -32,7 +32,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
t.Run("creates a new file successfully", func(t *testing.T) {
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
filePath := filepath.Join(tempDir, "new_file.txt")
|
||||
content := "This is a test content"
|
||||
@@ -61,7 +61,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("creates file with nested directories", func(t *testing.T) {
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
|
||||
content := "Content in nested directory"
|
||||
@@ -90,7 +90,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("updates existing file", func(t *testing.T) {
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a file first
|
||||
filePath := filepath.Join(tempDir, "existing_file.txt")
|
||||
@@ -127,7 +127,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles invalid parameters", func(t *testing.T) {
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
@@ -140,7 +140,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles missing file_path", func(t *testing.T) {
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: "",
|
||||
@@ -161,7 +161,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles missing content", func(t *testing.T) {
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: filepath.Join(tempDir, "file.txt"),
|
||||
@@ -182,7 +182,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles writing to a directory path", func(t *testing.T) {
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a directory
|
||||
dirPath := filepath.Join(tempDir, "test_dir")
|
||||
@@ -208,7 +208,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles permission denied", func(t *testing.T) {
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false))
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService())
|
||||
|
||||
filePath := filepath.Join(tempDir, "permission_denied.txt")
|
||||
params := WriteParams{
|
||||
@@ -234,7 +234,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("detects file modified since last read", func(t *testing.T) {
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a file
|
||||
filePath := filepath.Join(tempDir, "modified_file.txt")
|
||||
@@ -275,7 +275,7 @@ func TestWriteTool_Run(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("skips writing when content is identical", func(t *testing.T) {
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
|
||||
tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
|
||||
|
||||
// Create a file
|
||||
filePath := filepath.Join(tempDir, "identical_content.txt")
|
||||
|
||||
Reference in New Issue
Block a user