mirror of
https://github.com/aljazceru/opencode.git
synced 2025-12-24 11:14:23 +01:00
add bedrock support
This commit is contained in:
@@ -36,6 +36,11 @@ type Model struct {
|
||||
// TODO: Maybe support multiple models for different purposes
|
||||
}
|
||||
|
||||
type AnthropicConfig struct {
|
||||
DisableCache bool `json:"disableCache"`
|
||||
UseBedrock bool `json:"useBedrock"`
|
||||
}
|
||||
|
||||
type Provider struct {
|
||||
APIKey string `json:"apiKey"`
|
||||
Enabled bool `json:"enabled"`
|
||||
@@ -130,6 +135,8 @@ func Load(debug bool) error {
|
||||
defaultModelSet = true
|
||||
}
|
||||
}
|
||||
|
||||
viper.SetDefault("providers.bedrock.enabled", true)
|
||||
// TODO: add more providers
|
||||
cfg = &Config{}
|
||||
|
||||
|
||||
@@ -380,6 +380,29 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
case models.ProviderBedrock:
|
||||
var err error
|
||||
agentProvider, err = provider.NewBedrockProvider(
|
||||
provider.WithBedrockSystemMessage(
|
||||
prompt.CoderAnthropicSystemPrompt(),
|
||||
),
|
||||
provider.WithBedrockMaxTokens(maxTokens),
|
||||
provider.WithBedrockModel(model),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
titleGenerator, err = provider.NewBedrockProvider(
|
||||
provider.WithBedrockSystemMessage(
|
||||
prompt.TitlePrompt(),
|
||||
),
|
||||
provider.WithBedrockMaxTokens(maxTokens),
|
||||
provider.WithBedrockModel(model),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return agentProvider, titleGenerator, nil
|
||||
|
||||
@@ -31,11 +31,15 @@ const (
|
||||
|
||||
// GROQ
|
||||
QWENQwq ModelID = "qwen-qwq"
|
||||
|
||||
// Bedrock
|
||||
BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderOpenAI ModelProvider = "openai"
|
||||
ProviderAnthropic ModelProvider = "anthropic"
|
||||
ProviderBedrock ModelProvider = "bedrock"
|
||||
ProviderGemini ModelProvider = "gemini"
|
||||
ProviderGROQ ModelProvider = "groq"
|
||||
)
|
||||
@@ -119,4 +123,16 @@ var SupportedModels = map[ModelID]Model{
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0,
|
||||
},
|
||||
|
||||
// Bedrock
|
||||
BedrockClaude37Sonnet: {
|
||||
ID: BedrockClaude37Sonnet,
|
||||
Name: "Bedrock: Claude 3.7 Sonnet",
|
||||
Provider: ProviderBedrock,
|
||||
APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/bedrock"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
@@ -21,6 +22,8 @@ type anthropicProvider struct {
|
||||
maxTokens int64
|
||||
apiKey string
|
||||
systemMessage string
|
||||
useBedrock bool
|
||||
disableCache bool
|
||||
}
|
||||
|
||||
type AnthropicOption func(*anthropicProvider)
|
||||
@@ -49,6 +52,18 @@ func WithAnthropicKey(apiKey string) AnthropicOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicBedrock() AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.useBedrock = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicDisableCache() AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.disableCache = true
|
||||
}
|
||||
}
|
||||
|
||||
func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
|
||||
provider := &anthropicProvider{
|
||||
maxTokens: 1024,
|
||||
@@ -62,7 +77,16 @@ func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
|
||||
return nil, errors.New("system message is required")
|
||||
}
|
||||
|
||||
provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
|
||||
anthropicOptions := []option.RequestOption{}
|
||||
|
||||
if provider.apiKey != "" {
|
||||
anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
|
||||
}
|
||||
if provider.useBedrock {
|
||||
anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
|
||||
}
|
||||
|
||||
provider.client = anthropic.NewClient(anthropicOptions...)
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
@@ -338,7 +362,7 @@ func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []an
|
||||
},
|
||||
}
|
||||
|
||||
if i == len(tools)-1 {
|
||||
if i == len(tools)-1 && !a.disableCache {
|
||||
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
@@ -358,7 +382,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
content := anthropic.NewTextBlock(msg.Content().String())
|
||||
if cachedBlocks < 2 {
|
||||
if cachedBlocks < 2 && !a.disableCache {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
@@ -370,7 +394,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
|
||||
blocks := []anthropic.ContentBlockParamUnion{}
|
||||
if msg.Content().String() != "" {
|
||||
content := anthropic.NewTextBlock(msg.Content().String())
|
||||
if cachedBlocks < 2 {
|
||||
if cachedBlocks < 2 && !a.disableCache {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
@@ -404,4 +428,3 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
|
||||
|
||||
return anthropicMessages
|
||||
}
|
||||
|
||||
|
||||
87
internal/llm/provider/bedrock.go
Normal file
87
internal/llm/provider/bedrock.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
)
|
||||
|
||||
type bedrockProvider struct {
|
||||
childProvider Provider
|
||||
model models.Model
|
||||
maxTokens int64
|
||||
systemMessage string
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
return b.childProvider.SendMessages(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
return b.childProvider.StreamResponse(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func NewBedrockProvider(opts ...BedrockOption) (Provider, error) {
|
||||
provider := &bedrockProvider{}
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
|
||||
// based on the AWS region prefix the model name with, us, eu, ap, sa, etc.
|
||||
region := os.Getenv("AWS_REGION")
|
||||
if region == "" {
|
||||
region = os.Getenv("AWS_DEFAULT_REGION")
|
||||
}
|
||||
|
||||
if region == "" {
|
||||
return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is required")
|
||||
}
|
||||
if len(region) < 2 {
|
||||
return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid")
|
||||
}
|
||||
regionPrefix := region[:2]
|
||||
provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel)
|
||||
|
||||
if strings.Contains(string(provider.model.APIModel), "anthropic") {
|
||||
anthropic, err := NewAnthropicProvider(
|
||||
WithAnthropicModel(provider.model),
|
||||
WithAnthropicMaxTokens(provider.maxTokens),
|
||||
WithAnthropicSystemMessage(provider.systemMessage),
|
||||
WithAnthropicBedrock(),
|
||||
WithAnthropicDisableCache(),
|
||||
)
|
||||
provider.childProvider = anthropic
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("unsupported model for bedrock provider")
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
type BedrockOption func(*bedrockProvider)
|
||||
|
||||
func WithBedrockSystemMessage(message string) BedrockOption {
|
||||
return func(a *bedrockProvider) {
|
||||
a.systemMessage = message
|
||||
}
|
||||
}
|
||||
|
||||
func WithBedrockMaxTokens(maxTokens int64) BedrockOption {
|
||||
return func(a *bedrockProvider) {
|
||||
a.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithBedrockModel(model models.Model) BedrockOption {
|
||||
return func(a *bedrockProvider) {
|
||||
a.model = model
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package repl
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
@@ -138,11 +139,22 @@ func (m *editorCmp) SetSize(width int, height int) {
|
||||
|
||||
func (m *editorCmp) Send() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
messages, _ := m.app.Messages.List(m.sessionID)
|
||||
messages, err := m.app.Messages.List(m.sessionID)
|
||||
log.Printf("error: %v", err)
|
||||
log.Printf("messages: %v", messages)
|
||||
|
||||
if err != nil {
|
||||
return util.ReportError(err)
|
||||
}
|
||||
if hasUnfinishedMessages(messages) {
|
||||
return util.ReportWarn("Assistant is still working on the previous message")
|
||||
}
|
||||
a, _ := agent.NewCoderAgent(m.app)
|
||||
a, err := agent.NewCoderAgent(m.app)
|
||||
log.Printf("error: %v", err)
|
||||
log.Printf("agent: %v", a)
|
||||
if err != nil {
|
||||
return util.ReportError(err)
|
||||
}
|
||||
|
||||
content := strings.Join(m.editor.GetBuffer().Lines(), "\n")
|
||||
go a.Generate(m.sessionID, content)
|
||||
|
||||
Reference in New Issue
Block a user