structure tools the same

- add some tests
- fix some tests
- change how we handle permissions
This commit is contained in:
Kujtim Hoxha
2025-04-08 19:15:23 +02:00
parent 5acf0cba60
commit 94923948e1
20 changed files with 1210 additions and 910 deletions

View File

@@ -11,6 +11,17 @@ import (
"time"
)
type SourcegraphParams struct {
Query string `json:"query"`
Count int `json:"count,omitempty"`
ContextWindow int `json:"context_window,omitempty"`
Timeout int `json:"timeout,omitempty"`
}
type sourcegraphTool struct {
client *http.Client
}
const (
SourcegraphToolName = "sourcegraph"
sourcegraphToolDescription = `Search code across public repositories using Sourcegraph's GraphQL API.
@@ -110,17 +121,6 @@ TIPS:
- For more details on query syntax, visit: https://docs.sourcegraph.com/code_search/queries`
)
type SourcegraphParams struct {
Query string `json:"query"`
Count int `json:"count,omitempty"`
ContextWindow int `json:"context_window,omitempty"`
Timeout int `json:"timeout,omitempty"`
}
type sourcegraphTool struct {
client *http.Client
}
func NewSourcegraphTool() BaseTool {
return &sourcegraphTool{
client: &http.Client{
@@ -165,7 +165,6 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
return NewTextErrorResponse("Query parameter is required"), nil
}
// Set default count if not specified
if params.Count <= 0 {
params.Count = 10
} else if params.Count > 20 {
@@ -186,8 +185,6 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
}
}
// GraphQL query for Sourcegraph search
// Create a properly escaped JSON structure
type graphqlRequest struct {
Query string `json:"query"`
Variables struct {
@@ -200,14 +197,12 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
}
request.Variables.Query = params.Query
// Marshal to JSON to ensure proper escaping
graphqlQueryBytes, err := json.Marshal(request)
if err != nil {
return NewTextErrorResponse("Failed to create GraphQL request: " + err.Error()), nil
}
graphqlQuery := string(graphqlQueryBytes)
// Create request to Sourcegraph API
req, err := http.NewRequestWithContext(
ctx,
"POST",
@@ -228,7 +223,6 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
// log the error response
body, _ := io.ReadAll(resp.Body)
if len(body) > 0 {
return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
@@ -241,13 +235,11 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
}
// Parse the GraphQL response
var result map[string]any
if err = json.Unmarshal(body, &result); err != nil {
return NewTextErrorResponse("Failed to parse response: " + err.Error()), nil
}
// Format the results in a readable way
formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
if err != nil {
return NewTextErrorResponse("Failed to format results: " + err.Error()), nil
@@ -259,7 +251,6 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
var buffer strings.Builder
// Check for errors in the GraphQL response
if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
buffer.WriteString("## Sourcegraph API Error\n\n")
for _, err := range errors {
@@ -272,7 +263,6 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string,
return buffer.String(), nil
}
// Extract data from the response
data, ok := result["data"].(map[string]any)
if !ok {
return "", fmt.Errorf("invalid response format: missing data field")
@@ -288,7 +278,6 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string,
return "", fmt.Errorf("invalid response format: missing results field")
}
// Write search metadata
matchCount, _ := searchResults["matchCount"].(float64)
resultCount, _ := searchResults["resultCount"].(float64)
limitHit, _ := searchResults["limitHit"].(bool)
@@ -302,33 +291,28 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string,
buffer.WriteString("\n")
// Process results
results, ok := searchResults["results"].([]any)
if !ok || len(results) == 0 {
buffer.WriteString("No results found. Try a different query.\n")
return buffer.String(), nil
}
// Limit to 10 results
maxResults := 10
if len(results) > maxResults {
results = results[:maxResults]
}
// Process each result
for i, res := range results {
fileMatch, ok := res.(map[string]any)
if !ok {
continue
}
// Skip non-FileMatch results
typeName, _ := fileMatch["__typename"].(string)
if typeName != "FileMatch" {
continue
}
// Extract repository and file information
repo, _ := fileMatch["repository"].(map[string]any)
file, _ := fileMatch["file"].(map[string]any)
lineMatches, _ := fileMatch["lineMatches"].([]any)
@@ -348,7 +332,6 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string,
buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
}
// Show line matches with context
if len(lineMatches) > 0 {
for _, lm := range lineMatches {
lineMatch, ok := lm.(map[string]any)
@@ -359,13 +342,11 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string,
lineNumber, _ := lineMatch["lineNumber"].(float64)
preview, _ := lineMatch["preview"].(string)
// Extract context from file content if available
if fileContent != "" {
lines := strings.Split(fileContent, "\n")
buffer.WriteString("```\n")
// Display context before the match (up to 10 lines)
startLine := max(1, int(lineNumber)-contextWindow)
for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
@@ -374,10 +355,8 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string,
}
}
// Display the matching line (highlighted)
buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
// Display context after the match (up to 10 lines)
endLine := int(lineNumber) + contextWindow
for j := int(lineNumber); j < endLine && j < len(lines); j++ {
@@ -388,7 +367,6 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string,
buffer.WriteString("```\n\n")
} else {
// If file content is not available, just show the preview
buffer.WriteString("```\n")
buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
buffer.WriteString("```\n\n")