replace github.com/google/generative-ai-go with github.com/googleapis/go-genai (#138)

* replace to github.com/googleapis/go-genai

* fix history logic

* small fixes

---------

Co-authored-by: Kujtim Hoxha <kujtimii.h@gmail.com>
This commit is contained in:
mineo
2025-05-09 21:15:38 +09:00
committed by adamdottv
parent 1d1a1ddcbf
commit f92b2b76dc
3 changed files with 82 additions and 91 deletions

View File

@@ -9,14 +9,12 @@ import (
"strings"
"time"
"github.com/google/generative-ai-go/genai"
"github.com/google/uuid"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/message"
"github.com/opencode-ai/opencode/internal/status"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/genai"
"log/slog"
)
@@ -40,7 +38,7 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
o(&geminiOpts)
}
client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
if err != nil {
slog.Error("Failed to create Gemini client", "error", err)
return nil
@@ -58,11 +56,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
for _, msg := range messages {
switch msg.Role {
case message.User:
var parts []genai.Part
parts = append(parts, genai.Text(msg.Content().String()))
var parts []*genai.Part
parts = append(parts, &genai.Part{Text: msg.Content().String()})
for _, binaryContent := range msg.BinaryContent() {
imageFormat := strings.Split(binaryContent.MIMEType, "/")
parts = append(parts, genai.ImageData(imageFormat[1], binaryContent.Data))
parts = append(parts, &genai.Part{InlineData: &genai.Blob{
MIMEType: imageFormat[1],
Data: binaryContent.Data,
}})
}
history = append(history, &genai.Content{
Parts: parts,
@@ -71,19 +72,21 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
case message.Assistant:
content := &genai.Content{
Role: "model",
Parts: []genai.Part{},
Parts: []*genai.Part{},
}
if msg.Content().String() != "" {
content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
}
if len(msg.ToolCalls()) > 0 {
for _, call := range msg.ToolCalls() {
args, _ := parseJsonToMap(call.Input)
content.Parts = append(content.Parts, genai.FunctionCall{
Name: call.Name,
Args: args,
content.Parts = append(content.Parts, &genai.Part{
FunctionCall: &genai.FunctionCall{
Name: call.Name,
Args: args,
},
})
}
}
@@ -111,10 +114,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
}
history = append(history, &genai.Content{
Parts: []genai.Part{genai.FunctionResponse{
Name: toolCall.Name,
Response: response,
}},
Parts: []*genai.Part{
{
FunctionResponse: &genai.FunctionResponse{
Name: toolCall.Name,
Response: response,
},
},
},
Role: "function",
})
}
@@ -158,18 +165,6 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
}
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{
genai.Text(g.providerOptions.systemMessage),
},
}
// Convert tools
if len(tools) > 0 {
model.Tools = g.convertTools(tools)
}
// Convert messages
geminiMessages := g.convertMessages(messages)
@@ -179,16 +174,26 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
slog.Debug("Prepared messages", "messages", string(jsonData))
}
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
MaxOutputTokens: int32(g.providerOptions.maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
Tools: g.convertTools(tools),
}, history)
attempts := 0
for {
attempts++
var toolCalls []message.ToolCall
chat := model.StartChat()
chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
var lastMsgParts []genai.Part
for _, part := range lastMsg.Parts {
lastMsgParts = append(lastMsgParts, *part)
}
resp, err := chat.SendMessage(ctx, lastMsgParts...)
// If there is an error we are going to see if we can retry the call
if err != nil {
retry, after, retryErr := g.shouldRetry(attempts, err)
@@ -211,15 +216,15 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch p := part.(type) {
case genai.Text:
content = string(p)
case genai.FunctionCall:
switch {
case part.Text != "":
content = string(part.Text)
case part.FunctionCall != nil:
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
args, _ := json.Marshal(part.FunctionCall.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
Name: p.Name,
Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
Finished: true,
@@ -245,18 +250,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
}
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{
genai.Text(g.providerOptions.systemMessage),
},
}
// Convert tools
if len(tools) > 0 {
model.Tools = g.convertTools(tools)
}
// Convert messages
geminiMessages := g.convertMessages(messages)
@@ -266,6 +259,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
slog.Debug("Prepared messages", "messages", string(jsonData))
}
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
MaxOutputTokens: int32(g.providerOptions.maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
Tools: g.convertTools(tools),
}, history)
attempts := 0
eventChan := make(chan ProviderEvent)
@@ -274,11 +277,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
for {
attempts++
chat := model.StartChat()
chat.History = geminiMessages[:len(geminiMessages)-1]
lastMsg := geminiMessages[len(geminiMessages)-1]
iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
currentContent := ""
toolCalls := []message.ToolCall{}
@@ -286,11 +284,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
eventChan <- ProviderEvent{Type: EventContentStart}
for {
resp, err := iter.Next()
if err == iterator.Done {
break
}
var lastMsgParts []genai.Part
for _, part := range lastMsg.Parts {
lastMsgParts = append(lastMsgParts, *part)
}
for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
if err != nil {
retry, after, retryErr := g.shouldRetry(attempts, err)
if retryErr != nil {
@@ -319,9 +318,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch p := part.(type) {
case genai.Text:
delta := string(p)
switch {
case part.Text != "":
delta := string(part.Text)
if delta != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
@@ -329,12 +328,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
}
currentContent += delta
}
case genai.FunctionCall:
case part.FunctionCall != nil:
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
args, _ := json.Marshal(part.FunctionCall.Args)
newCall := message.ToolCall{
ID: id,
Name: p.Name,
Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
Finished: true,
@@ -422,12 +421,12 @@ func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
if funcCall, ok := part.(genai.FunctionCall); ok {
if part.FunctionCall != nil {
id := "call_" + uuid.New().String()
args, _ := json.Marshal(funcCall.Args)
args, _ := json.Marshal(part.FunctionCall.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
Name: funcCall.Name,
Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
})