structure tools the same

- add some tests
- fix some tests
- change how we handle permissions
This commit is contained in:
Kujtim Hoxha
2025-04-08 19:15:23 +02:00
parent 5acf0cba60
commit 94923948e1
20 changed files with 1210 additions and 910 deletions

View File

@@ -11,16 +11,6 @@ import (
"github.com/kujtimiihoxha/termai/internal/permission"
)
type bashTool struct{}
const (
BashToolName = "bash"
DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
MaxOutputLength = 30000
)
type BashParams struct {
Command string `json:"command"`
Timeout int `json:"timeout"`
@@ -31,180 +21,36 @@ type BashPermissionsParams struct {
Timeout int `json:"timeout"`
}
var BannedCommands = []string{
type bashTool struct {
permissions permission.Service
}
const (
BashToolName = "bash"
DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
MaxOutputLength = 30000
)
var bannedCommands = []string{
"alias", "curl", "curlie", "wget", "axel", "aria2c",
"nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
"http-prompt", "chrome", "firefox", "safari",
}
var SafeReadOnlyCommands = []string{
// Basic shell commands
var safeReadOnlyCommands = []string{
"ls", "echo", "pwd", "date", "cal", "uptime", "whoami", "id", "groups", "env", "printenv", "set", "unset", "which", "type", "whereis",
"whatis", "uname", "hostname", "df", "du", "free", "top", "ps", "kill", "killall", "nice", "nohup", "time", "timeout",
// Git read-only commands
"git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote",
"git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog",
// Go commands
"go version", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean",
// Node.js commands
"node", "npm", "npx", "yarn", "pnpm",
// Python commands
"python", "python3", "pip", "pip3", "pytest", "pylint", "mypy", "black", "isort", "flake8", "ruff",
// Docker commands
"docker ps", "docker images", "docker volume", "docker network", "docker info", "docker version",
"docker-compose ps", "docker-compose config",
// Kubernetes commands
"kubectl get", "kubectl describe", "kubectl logs", "kubectl version", "kubectl config",
// Rust commands
"cargo", "rustc", "rustup",
// Java commands
"java", "javac", "mvn", "gradle",
// Misc development tools
"make", "cmake", "bazel", "terraform plan", "terraform validate", "ansible",
}
func (b *bashTool) Info() ToolInfo {
return ToolInfo{
Name: BashToolName,
Description: bashDescription(),
Parameters: map[string]any{
"command": map[string]any{
"type": "string",
"description": "The command to execute",
},
"timeout": map[string]any{
"type": "number",
"desription": "Optional timeout in milliseconds (max 600000)",
},
},
Required: []string{"command"},
}
}
// Handle implements Tool.
func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params BashParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse("invalid parameters"), nil
}
if params.Timeout > MaxTimeout {
params.Timeout = MaxTimeout
} else if params.Timeout <= 0 {
params.Timeout = DefaultTimeout
}
if params.Command == "" {
return NewTextErrorResponse("missing command"), nil
}
// Check for banned commands (first word only)
baseCmd := strings.Fields(params.Command)[0]
for _, banned := range BannedCommands {
if strings.EqualFold(baseCmd, banned) {
return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
}
}
// Check for safe commands (can be multi-word)
isSafeReadOnly := false
cmdLower := strings.ToLower(params.Command)
for _, safe := range SafeReadOnlyCommands {
// Check if command starts with the safe command pattern
if strings.HasPrefix(cmdLower, strings.ToLower(safe)) {
// Make sure it's either an exact match or followed by a space or flag
if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
isSafeReadOnly = true
break
}
}
}
if !isSafeReadOnly {
p := permission.Default.Request(
permission.CreatePermissionRequest{
Path: config.WorkingDirectory(),
ToolName: BashToolName,
Action: "execute",
Description: fmt.Sprintf("Execute command: %s", params.Command),
Params: BashPermissionsParams{
Command: params.Command,
},
},
)
if !p {
return NewTextErrorResponse("permission denied"), nil
}
}
shell := shell.GetPersistentShell(config.WorkingDirectory())
stdout, stderr, exitCode, interrupted, err := shell.Exec(ctx, params.Command, params.Timeout)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("error executing command: %s", err)), nil
}
stdout = truncateOutput(stdout)
stderr = truncateOutput(stderr)
errorMessage := stderr
if interrupted {
if errorMessage != "" {
errorMessage += "\n"
}
errorMessage += "Command was aborted before completion"
} else if exitCode != 0 {
if errorMessage != "" {
errorMessage += "\n"
}
errorMessage += fmt.Sprintf("Exit code %d", exitCode)
}
hasBothOutputs := stdout != "" && stderr != ""
if hasBothOutputs {
stdout += "\n"
}
if errorMessage != "" {
stdout += "\n" + errorMessage
}
if stdout == "" {
return NewTextResponse("no output"), nil
}
return NewTextResponse(stdout), nil
}
func truncateOutput(content string) string {
if len(content) <= MaxOutputLength {
return content
}
halfLength := MaxOutputLength / 2
start := content[:halfLength]
end := content[len(content)-halfLength:]
truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
}
func countLines(s string) int {
if s == "" {
return 0
}
return len(strings.Split(s, "\n"))
}
func bashDescription() string {
bannedCommandsStr := strings.Join(BannedCommands, ", ")
bannedCommandsStr := strings.Join(bannedCommands, ", ")
return fmt.Sprintf(`Executes a given bash command in a persistent shell session with optional timeout, ensuring proper handling and security measures.
Before executing the command, please follow these steps:
@@ -352,6 +198,134 @@ Important:
- Never update git config`, bannedCommandsStr, MaxOutputLength)
}
func NewBashTool() BaseTool {
return &bashTool{}
func NewBashTool(permission permission.Service) BaseTool {
return &bashTool{
permissions: permission,
}
}
func (b *bashTool) Info() ToolInfo {
return ToolInfo{
Name: BashToolName,
Description: bashDescription(),
Parameters: map[string]any{
"command": map[string]any{
"type": "string",
"description": "The command to execute",
},
"timeout": map[string]any{
"type": "number",
"desription": "Optional timeout in milliseconds (max 600000)",
},
},
Required: []string{"command"},
}
}
func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params BashParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse("invalid parameters"), nil
}
if params.Timeout > MaxTimeout {
params.Timeout = MaxTimeout
} else if params.Timeout <= 0 {
params.Timeout = DefaultTimeout
}
if params.Command == "" {
return NewTextErrorResponse("missing command"), nil
}
baseCmd := strings.Fields(params.Command)[0]
for _, banned := range bannedCommands {
if strings.EqualFold(baseCmd, banned) {
return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
}
}
isSafeReadOnly := false
cmdLower := strings.ToLower(params.Command)
for _, safe := range safeReadOnlyCommands {
if strings.HasPrefix(cmdLower, strings.ToLower(safe)) {
if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
isSafeReadOnly = true
break
}
}
}
if !isSafeReadOnly {
p := b.permissions.Request(
permission.CreatePermissionRequest{
Path: config.WorkingDirectory(),
ToolName: BashToolName,
Action: "execute",
Description: fmt.Sprintf("Execute command: %s", params.Command),
Params: BashPermissionsParams{
Command: params.Command,
},
},
)
if !p {
return NewTextErrorResponse("permission denied"), nil
}
}
shell := shell.GetPersistentShell(config.WorkingDirectory())
stdout, stderr, exitCode, interrupted, err := shell.Exec(ctx, params.Command, params.Timeout)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("error executing command: %s", err)), nil
}
stdout = truncateOutput(stdout)
stderr = truncateOutput(stderr)
errorMessage := stderr
if interrupted {
if errorMessage != "" {
errorMessage += "\n"
}
errorMessage += "Command was aborted before completion"
} else if exitCode != 0 {
if errorMessage != "" {
errorMessage += "\n"
}
errorMessage += fmt.Sprintf("Exit code %d", exitCode)
}
hasBothOutputs := stdout != "" && stderr != ""
if hasBothOutputs {
stdout += "\n"
}
if errorMessage != "" {
stdout += "\n" + errorMessage
}
if stdout == "" {
return NewTextResponse("no output"), nil
}
return NewTextResponse(stdout), nil
}
func truncateOutput(content string) string {
if len(content) <= MaxOutputLength {
return content
}
halfLength := MaxOutputLength / 2
start := content[:halfLength]
end := content[len(content)-halfLength:]
truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
}
func countLines(s string) int {
if s == "" {
return 0
}
return len(strings.Split(s, "\n"))
}