add bedrock support

This commit is contained in:
Kujtim Hoxha
2025-04-09 17:45:41 +02:00
parent fde04bbf85
commit 939ae03f42
8 changed files with 217 additions and 7 deletions

View File

@@ -36,6 +36,11 @@ type Model struct {
// TODO: Maybe support multiple models for different purposes
}
type AnthropicConfig struct {
DisableCache bool `json:"disableCache"`
UseBedrock bool `json:"useBedrock"`
}
type Provider struct {
APIKey string `json:"apiKey"`
Enabled bool `json:"enabled"`
@@ -130,6 +135,8 @@ func Load(debug bool) error {
defaultModelSet = true
}
}
viper.SetDefault("providers.bedrock.enabled", true)
// TODO: add more providers
cfg = &Config{}

View File

@@ -380,6 +380,29 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
return nil, nil, err
}
case models.ProviderBedrock:
var err error
agentProvider, err = provider.NewBedrockProvider(
provider.WithBedrockSystemMessage(
prompt.CoderAnthropicSystemPrompt(),
),
provider.WithBedrockMaxTokens(maxTokens),
provider.WithBedrockModel(model),
)
if err != nil {
return nil, nil, err
}
titleGenerator, err = provider.NewBedrockProvider(
provider.WithBedrockSystemMessage(
prompt.TitlePrompt(),
),
provider.WithBedrockMaxTokens(maxTokens),
provider.WithBedrockModel(model),
)
if err != nil {
return nil, nil, err
}
}
return agentProvider, titleGenerator, nil

View File

@@ -31,11 +31,15 @@ const (
// GROQ
QWENQwq ModelID = "qwen-qwq"
// Bedrock
BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet"
)
const (
ProviderOpenAI ModelProvider = "openai"
ProviderAnthropic ModelProvider = "anthropic"
ProviderBedrock ModelProvider = "bedrock"
ProviderGemini ModelProvider = "gemini"
ProviderGROQ ModelProvider = "groq"
)
@@ -119,4 +123,16 @@ var SupportedModels = map[ModelID]Model{
CostPer1MOutCached: 0,
CostPer1MOut: 0,
},
// Bedrock
BedrockClaude37Sonnet: {
ID: BedrockClaude37Sonnet,
Name: "Bedrock: Claude 3.7 Sonnet",
Provider: ProviderBedrock,
APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
},
}

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
@@ -21,6 +22,8 @@ type anthropicProvider struct {
maxTokens int64
apiKey string
systemMessage string
useBedrock bool
disableCache bool
}
type AnthropicOption func(*anthropicProvider)
@@ -49,6 +52,18 @@ func WithAnthropicKey(apiKey string) AnthropicOption {
}
}
func WithAnthropicBedrock() AnthropicOption {
return func(a *anthropicProvider) {
a.useBedrock = true
}
}
func WithAnthropicDisableCache() AnthropicOption {
return func(a *anthropicProvider) {
a.disableCache = true
}
}
func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
provider := &anthropicProvider{
maxTokens: 1024,
@@ -62,7 +77,16 @@ func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
return nil, errors.New("system message is required")
}
provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
anthropicOptions := []option.RequestOption{}
if provider.apiKey != "" {
anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
}
if provider.useBedrock {
anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
}
provider.client = anthropic.NewClient(anthropicOptions...)
return provider, nil
}
@@ -338,7 +362,7 @@ func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []an
},
}
if i == len(tools)-1 {
if i == len(tools)-1 && !a.disableCache {
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -358,7 +382,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
switch msg.Role {
case message.User:
content := anthropic.NewTextBlock(msg.Content().String())
if cachedBlocks < 2 {
if cachedBlocks < 2 && !a.disableCache {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -370,7 +394,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
blocks := []anthropic.ContentBlockParamUnion{}
if msg.Content().String() != "" {
content := anthropic.NewTextBlock(msg.Content().String())
if cachedBlocks < 2 {
if cachedBlocks < 2 && !a.disableCache {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -404,4 +428,3 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
return anthropicMessages
}

View File

@@ -0,0 +1,87 @@
package provider
import (
"context"
"errors"
"fmt"
"os"
"strings"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
type bedrockProvider struct {
childProvider Provider
model models.Model
maxTokens int64
systemMessage string
}
func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
return b.childProvider.SendMessages(ctx, messages, tools)
}
func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
return b.childProvider.StreamResponse(ctx, messages, tools)
}
func NewBedrockProvider(opts ...BedrockOption) (Provider, error) {
provider := &bedrockProvider{}
for _, opt := range opts {
opt(provider)
}
// based on the AWS region prefix the model name with, us, eu, ap, sa, etc.
region := os.Getenv("AWS_REGION")
if region == "" {
region = os.Getenv("AWS_DEFAULT_REGION")
}
if region == "" {
return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is required")
}
if len(region) < 2 {
return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid")
}
regionPrefix := region[:2]
provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel)
if strings.Contains(string(provider.model.APIModel), "anthropic") {
anthropic, err := NewAnthropicProvider(
WithAnthropicModel(provider.model),
WithAnthropicMaxTokens(provider.maxTokens),
WithAnthropicSystemMessage(provider.systemMessage),
WithAnthropicBedrock(),
WithAnthropicDisableCache(),
)
provider.childProvider = anthropic
if err != nil {
return nil, err
}
} else {
return nil, errors.New("unsupported model for bedrock provider")
}
return provider, nil
}
type BedrockOption func(*bedrockProvider)
func WithBedrockSystemMessage(message string) BedrockOption {
return func(a *bedrockProvider) {
a.systemMessage = message
}
}
func WithBedrockMaxTokens(maxTokens int64) BedrockOption {
return func(a *bedrockProvider) {
a.maxTokens = maxTokens
}
}
func WithBedrockModel(model models.Model) BedrockOption {
return func(a *bedrockProvider) {
a.model = model
}
}

View File

@@ -1,6 +1,7 @@
package repl
import (
"log"
"strings"
"github.com/charmbracelet/bubbles/key"
@@ -138,11 +139,22 @@ func (m *editorCmp) SetSize(width int, height int) {
func (m *editorCmp) Send() tea.Cmd {
return func() tea.Msg {
messages, _ := m.app.Messages.List(m.sessionID)
messages, err := m.app.Messages.List(m.sessionID)
log.Printf("error: %v", err)
log.Printf("messages: %v", messages)
if err != nil {
return util.ReportError(err)
}
if hasUnfinishedMessages(messages) {
return util.ReportWarn("Assistant is still working on the previous message")
}
a, _ := agent.NewCoderAgent(m.app)
a, err := agent.NewCoderAgent(m.app)
log.Printf("error: %v", err)
log.Printf("agent: %v", a)
if err != nil {
return util.ReportError(err)
}
content := strings.Join(m.editor.GetBuffer().Lines(), "\n")
go a.Generate(m.sessionID, content)