mirror of
https://github.com/aljazceru/opencode.git
synced 2026-01-06 17:34:58 +01:00
rework llm
This commit is contained in:
@@ -2,16 +2,353 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/flow/agent/react"
|
||||
"github.com/kujtimiihoxha/termai/internal/app"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/prompt"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/provider"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
)
|
||||
|
||||
func GetAgent(ctx context.Context, name string) (*react.Agent, string, error) {
|
||||
switch name {
|
||||
case "coder":
|
||||
agent, err := NewCoderAgent(ctx)
|
||||
return agent, CoderSystemPrompt(), err
|
||||
}
|
||||
return nil, "", fmt.Errorf("agent %s not found", name)
|
||||
type Agent interface {
|
||||
Generate(sessionID string, content string) error
|
||||
}
|
||||
|
||||
type agent struct {
|
||||
*app.App
|
||||
model models.Model
|
||||
tools []tools.BaseTool
|
||||
agent provider.Provider
|
||||
titleGenerator provider.Provider
|
||||
}
|
||||
|
||||
func (c *agent) handleTitleGeneration(sessionID, content string) {
|
||||
response, err := c.titleGenerator.SendMessages(
|
||||
c.Context,
|
||||
[]message.Message{
|
||||
{
|
||||
Role: message.User,
|
||||
Content: content,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
session, err := c.Sessions.Get(sessionID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if response.Content != "" {
|
||||
session.Title = response.Content
|
||||
c.Sessions.Save(session)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
|
||||
session, err := c.Sessions.Get(sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
|
||||
model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
|
||||
model.CostPer1MIn/1e6*float64(usage.InputTokens) +
|
||||
model.CostPer1MOut/1e6*float64(usage.OutputTokens)
|
||||
|
||||
session.Cost += cost
|
||||
session.CompletionTokens += usage.OutputTokens
|
||||
session.PromptTokens += usage.InputTokens
|
||||
|
||||
_, err = c.Sessions.Save(session)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *agent) processEvent(
|
||||
sessionID string,
|
||||
assistantMsg *message.Message,
|
||||
event provider.ProviderEvent,
|
||||
) error {
|
||||
switch event.Type {
|
||||
case provider.EventThinkingDelta:
|
||||
assistantMsg.Thinking += event.Thinking
|
||||
return c.Messages.Update(*assistantMsg)
|
||||
case provider.EventContentDelta:
|
||||
assistantMsg.Content += event.Content
|
||||
return c.Messages.Update(*assistantMsg)
|
||||
case provider.EventError:
|
||||
log.Println("error", event.Error)
|
||||
return event.Error
|
||||
|
||||
case provider.EventComplete:
|
||||
assistantMsg.ToolCalls = event.Response.ToolCalls
|
||||
err := c.Messages.Update(*assistantMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.TrackUsage(sessionID, c.model, event.Response.Usage)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
|
||||
var wg sync.WaitGroup
|
||||
toolResults := make([]message.ToolResult, len(toolCalls))
|
||||
mutex := &sync.Mutex{}
|
||||
|
||||
for i, tc := range toolCalls {
|
||||
wg.Add(1)
|
||||
go func(index int, toolCall message.ToolCall) {
|
||||
defer wg.Done()
|
||||
|
||||
response := ""
|
||||
isError := false
|
||||
found := false
|
||||
|
||||
for _, tool := range tls {
|
||||
if tool.Info().Name == toolCall.Name {
|
||||
found = true
|
||||
toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
|
||||
ID: toolCall.ID,
|
||||
Name: toolCall.Name,
|
||||
Input: toolCall.Input,
|
||||
})
|
||||
if toolErr != nil {
|
||||
response = fmt.Sprintf("error running tool: %s", toolErr)
|
||||
isError = true
|
||||
} else {
|
||||
response = toolResult.Content
|
||||
isError = toolResult.IsError
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
response = fmt.Sprintf("tool not found: %s", toolCall.Name)
|
||||
isError = true
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
toolResults[index] = message.ToolResult{
|
||||
ToolCallID: toolCall.ID,
|
||||
Content: response,
|
||||
IsError: isError,
|
||||
}
|
||||
}(i, tc)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return toolResults, nil
|
||||
}
|
||||
|
||||
func (c *agent) handleToolExecution(
|
||||
ctx context.Context,
|
||||
assistantMsg message.Message,
|
||||
) (*message.Message, error) {
|
||||
if len(assistantMsg.ToolCalls) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls, c.tools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
|
||||
Role: message.Tool,
|
||||
ToolResults: toolResults,
|
||||
})
|
||||
|
||||
return &msg, err
|
||||
}
|
||||
|
||||
func (c *agent) generate(sessionID string, content string) error {
|
||||
messages, err := c.Messages.List(sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
go c.handleTitleGeneration(sessionID, content)
|
||||
}
|
||||
|
||||
userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
|
||||
Role: message.User,
|
||||
Content: content,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messages = append(messages, userMsg)
|
||||
for {
|
||||
|
||||
eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
|
||||
Role: message.Assistant,
|
||||
Content: "",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for event := range eventChan {
|
||||
err = c.processEvent(sessionID, &assistantMsg, event)
|
||||
if err != nil {
|
||||
assistantMsg.Finished = true
|
||||
c.Messages.Update(assistantMsg)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
msg, err := c.handleToolExecution(c.Context, assistantMsg)
|
||||
assistantMsg.Finished = true
|
||||
c.Messages.Update(assistantMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(assistantMsg.ToolCalls) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
messages = append(messages, assistantMsg)
|
||||
if msg != nil {
|
||||
messages = append(messages, *msg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
|
||||
maxTokens := config.Get().Model.CoderMaxTokens
|
||||
|
||||
providerConfig, ok := config.Get().Providers[model.Provider]
|
||||
if !ok || !providerConfig.Enabled {
|
||||
return nil, nil, errors.New("provider is not enabled")
|
||||
}
|
||||
var agentProvider provider.Provider
|
||||
var titleGenerator provider.Provider
|
||||
|
||||
switch model.Provider {
|
||||
case models.ProviderOpenAI:
|
||||
var err error
|
||||
agentProvider, err = provider.NewOpenAIProvider(
|
||||
provider.WithOpenAISystemMessage(
|
||||
prompt.CoderOpenAISystemPrompt(),
|
||||
),
|
||||
provider.WithOpenAIMaxTokens(maxTokens),
|
||||
provider.WithOpenAIModel(model),
|
||||
provider.WithOpenAIKey(providerConfig.APIKey),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
titleGenerator, err = provider.NewOpenAIProvider(
|
||||
provider.WithOpenAISystemMessage(
|
||||
prompt.TitlePrompt(),
|
||||
),
|
||||
provider.WithOpenAIMaxTokens(80),
|
||||
provider.WithOpenAIModel(model),
|
||||
provider.WithOpenAIKey(providerConfig.APIKey),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
case models.ProviderAnthropic:
|
||||
var err error
|
||||
agentProvider, err = provider.NewAnthropicProvider(
|
||||
provider.WithAnthropicSystemMessage(
|
||||
prompt.CoderAnthropicSystemPrompt(),
|
||||
),
|
||||
provider.WithAnthropicMaxTokens(maxTokens),
|
||||
provider.WithAnthropicKey(providerConfig.APIKey),
|
||||
provider.WithAnthropicModel(model),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
titleGenerator, err = provider.NewAnthropicProvider(
|
||||
provider.WithAnthropicSystemMessage(
|
||||
prompt.TitlePrompt(),
|
||||
),
|
||||
provider.WithAnthropicMaxTokens(80),
|
||||
provider.WithAnthropicKey(providerConfig.APIKey),
|
||||
provider.WithAnthropicModel(model),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
case models.ProviderGemini:
|
||||
var err error
|
||||
agentProvider, err = provider.NewGeminiProvider(
|
||||
ctx,
|
||||
provider.WithGeminiSystemMessage(
|
||||
prompt.CoderOpenAISystemPrompt(),
|
||||
),
|
||||
provider.WithGeminiMaxTokens(int32(maxTokens)),
|
||||
provider.WithGeminiKey(providerConfig.APIKey),
|
||||
provider.WithGeminiModel(model),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
titleGenerator, err = provider.NewGeminiProvider(
|
||||
ctx,
|
||||
provider.WithGeminiSystemMessage(
|
||||
prompt.TitlePrompt(),
|
||||
),
|
||||
provider.WithGeminiMaxTokens(80),
|
||||
provider.WithGeminiKey(providerConfig.APIKey),
|
||||
provider.WithGeminiModel(model),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
case models.ProviderGROQ:
|
||||
var err error
|
||||
agentProvider, err = provider.NewOpenAIProvider(
|
||||
provider.WithOpenAISystemMessage(
|
||||
prompt.CoderAnthropicSystemPrompt(),
|
||||
),
|
||||
provider.WithOpenAIMaxTokens(maxTokens),
|
||||
provider.WithOpenAIModel(model),
|
||||
provider.WithOpenAIKey(providerConfig.APIKey),
|
||||
provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
titleGenerator, err = provider.NewOpenAIProvider(
|
||||
provider.WithOpenAISystemMessage(
|
||||
prompt.TitlePrompt(),
|
||||
),
|
||||
provider.WithOpenAIMaxTokens(80),
|
||||
provider.WithOpenAIModel(model),
|
||||
provider.WithOpenAIKey(providerConfig.APIKey),
|
||||
provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return agentProvider, titleGenerator, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user