reimplement agent,provider and add file history

This commit is contained in:
Kujtim Hoxha
2025-04-16 20:06:23 +02:00
parent 76b4065f17
commit bbfa60c787
73 changed files with 3742 additions and 4026 deletions

View File

@@ -2,14 +2,17 @@ package provider
import (
"context"
"fmt"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
// EventType represents the type of streaming event
type EventType string
const maxRetries = 8
const (
EventContentStart EventType = "content_start"
EventContentDelta EventType = "content_delta"
@@ -18,7 +21,6 @@ const (
EventComplete EventType = "complete"
EventError EventType = "error"
EventWarning EventType = "warning"
EventInfo EventType = "info"
)
type TokenUsage struct {
@@ -32,61 +34,152 @@ type ProviderResponse struct {
Content string
ToolCalls []message.ToolCall
Usage TokenUsage
FinishReason string
FinishReason message.FinishReason
}
type ProviderEvent struct {
Type EventType
Type EventType
Content string
Thinking string
ToolCall *message.ToolCall
Error error
Response *ProviderResponse
// Used for giving users info on e.x retry
Info string
Error error
}
type Provider interface {
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
Model() models.Model
}
func cleanupMessages(messages []message.Message) []message.Message {
// First pass: filter out canceled messages
var cleanedMessages []message.Message
type providerClientOptions struct {
apiKey string
model models.Model
maxTokens int64
systemMessage string
anthropicOptions []AnthropicOption
openaiOptions []OpenAIOption
geminiOptions []GeminiOption
bedrockOptions []BedrockOption
}
type ProviderClientOption func(*providerClientOptions)
type ProviderClient interface {
send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
}
type baseProvider[C ProviderClient] struct {
options providerClientOptions
client C
}
func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
clientOptions := providerClientOptions{}
for _, o := range opts {
o(&clientOptions)
}
switch providerName {
case models.ProviderAnthropic:
return &baseProvider[AnthropicClient]{
options: clientOptions,
client: newAnthropicClient(clientOptions),
}, nil
case models.ProviderOpenAI:
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderGemini:
return &baseProvider[GeminiClient]{
options: clientOptions,
client: newGeminiClient(clientOptions),
}, nil
case models.ProviderBedrock:
return &baseProvider[BedrockClient]{
options: clientOptions,
client: newBedrockClient(clientOptions),
}, nil
case models.ProviderMock:
// TODO: implement mock client for test
panic("not implemented")
}
return nil, fmt.Errorf("provider not supported: %s", providerName)
}
func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
for _, msg := range messages {
if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 {
// if there are toolCalls this means we want to return it to the LLM telling it that those tools have been
// cancelled
cleanedMessages = append(cleanedMessages, msg)
// The message has no content
if len(msg.Parts) == 0 {
continue
}
cleaned = append(cleaned, msg)
}
return
}
func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
messages = p.cleanMessages(messages)
return p.client.send(ctx, messages, tools)
}
func (p *baseProvider[C]) Model() models.Model {
return p.options.model
}
func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
messages = p.cleanMessages(messages)
return p.client.stream(ctx, messages, tools)
}
func WithAPIKey(apiKey string) ProviderClientOption {
return func(options *providerClientOptions) {
options.apiKey = apiKey
}
}
func WithModel(model models.Model) ProviderClientOption {
return func(options *providerClientOptions) {
options.model = model
}
}
func WithMaxTokens(maxTokens int64) ProviderClientOption {
return func(options *providerClientOptions) {
options.maxTokens = maxTokens
}
}
func WithSystemMessage(systemMessage string) ProviderClientOption {
return func(options *providerClientOptions) {
options.systemMessage = systemMessage
}
}
func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
return func(options *providerClientOptions) {
options.anthropicOptions = anthropicOptions
}
}
func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
return func(options *providerClientOptions) {
options.openaiOptions = openaiOptions
}
}
func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
return func(options *providerClientOptions) {
options.geminiOptions = geminiOptions
}
}
func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
return func(options *providerClientOptions) {
options.bedrockOptions = bedrockOptions
}
// Second pass: filter out tool messages without a corresponding tool call
var result []message.Message
toolMessageIDs := make(map[string]bool)
for _, msg := range cleanedMessages {
if msg.Role == message.Assistant {
for _, toolCall := range msg.ToolCalls() {
toolMessageIDs[toolCall.ID] = true // Mark as referenced
}
}
}
// Keep only messages that aren't unreferenced tool messages
for _, msg := range cleanedMessages {
if msg.Role == message.Tool {
for _, toolCall := range msg.ToolResults() {
if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced {
result = append(result, msg)
}
}
} else {
result = append(result, msg)
}
}
return result
}