handle errors correctly in the agent tool

This commit is contained in:
Kujtim Hoxha
2025-04-14 10:49:32 +02:00
parent cdc5f209dc
commit 80cd75c4fb

View File

@@ -50,44 +50,45 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
sessionID, messageID := tools.GetContextValues(ctx) sessionID, messageID := tools.GetContextValues(ctx)
if sessionID == "" || messageID == "" { if sessionID == "" || messageID == "" {
return tools.NewTextErrorResponse("session ID and message ID are required"), nil return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required")
} }
agent, err := NewTaskAgent(b.lspClients) agent, err := NewTaskAgent(b.lspClients)
if err != nil { if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err)
} }
session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session") session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session")
if err != nil { if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
} }
err = agent.Generate(ctx, session.ID, params.Prompt) err = agent.Generate(ctx, session.ID, params.Prompt)
if err != nil { if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err)
} }
messages, err := b.messages.List(ctx, session.ID) messages, err := b.messages.List(ctx, session.ID)
if err != nil { if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil return tools.ToolResponse{}, fmt.Errorf("error listing messages: %s", err)
} }
if len(messages) == 0 { if len(messages) == 0 {
return tools.NewTextErrorResponse("no messages found"), nil return tools.NewTextErrorResponse("no response"), nil
} }
response := messages[len(messages)-1] response := messages[len(messages)-1]
if response.Role != message.Assistant { if response.Role != message.Assistant {
return tools.NewTextErrorResponse("no assistant message found"), nil return tools.NewTextErrorResponse("no response"), nil
} }
updatedSession, err := b.sessions.Get(ctx, session.ID) updatedSession, err := b.sessions.Get(ctx, session.ID)
if err != nil { if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil return tools.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
} }
parentSession, err := b.sessions.Get(ctx, sessionID) parentSession, err := b.sessions.Get(ctx, sessionID)
if err != nil { if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil return tools.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
} }
parentSession.Cost += updatedSession.Cost parentSession.Cost += updatedSession.Cost
@@ -96,7 +97,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
_, err = b.sessions.Save(ctx, parentSession) _, err = b.sessions.Save(ctx, parentSession)
if err != nil { if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil return tools.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
} }
return tools.NewTextResponse(response.Content().String()), nil return tools.NewTextResponse(response.Content().String()), nil
} }