mirror of
https://github.com/aljazceru/opencode.git
synced 2025-12-23 18:54:21 +01:00
structure tools the same
- add some tests - fix some tests - change how we handle permissions
This commit is contained in:
@@ -11,16 +11,6 @@ import (
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
)
|
||||
|
||||
type bashTool struct{}
|
||||
|
||||
const (
|
||||
BashToolName = "bash"
|
||||
|
||||
DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
|
||||
MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
|
||||
MaxOutputLength = 30000
|
||||
)
|
||||
|
||||
type BashParams struct {
|
||||
Command string `json:"command"`
|
||||
Timeout int `json:"timeout"`
|
||||
@@ -31,180 +21,36 @@ type BashPermissionsParams struct {
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
var BannedCommands = []string{
|
||||
type bashTool struct {
|
||||
permissions permission.Service
|
||||
}
|
||||
|
||||
const (
|
||||
BashToolName = "bash"
|
||||
|
||||
DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
|
||||
MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
|
||||
MaxOutputLength = 30000
|
||||
)
|
||||
|
||||
var bannedCommands = []string{
|
||||
"alias", "curl", "curlie", "wget", "axel", "aria2c",
|
||||
"nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
|
||||
"http-prompt", "chrome", "firefox", "safari",
|
||||
}
|
||||
|
||||
var SafeReadOnlyCommands = []string{
|
||||
// Basic shell commands
|
||||
var safeReadOnlyCommands = []string{
|
||||
"ls", "echo", "pwd", "date", "cal", "uptime", "whoami", "id", "groups", "env", "printenv", "set", "unset", "which", "type", "whereis",
|
||||
"whatis", "uname", "hostname", "df", "du", "free", "top", "ps", "kill", "killall", "nice", "nohup", "time", "timeout",
|
||||
|
||||
// Git read-only commands
|
||||
|
||||
"git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote",
|
||||
"git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog",
|
||||
|
||||
// Go commands
|
||||
|
||||
"go version", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean",
|
||||
|
||||
// Node.js commands
|
||||
"node", "npm", "npx", "yarn", "pnpm",
|
||||
|
||||
// Python commands
|
||||
"python", "python3", "pip", "pip3", "pytest", "pylint", "mypy", "black", "isort", "flake8", "ruff",
|
||||
|
||||
// Docker commands
|
||||
"docker ps", "docker images", "docker volume", "docker network", "docker info", "docker version",
|
||||
"docker-compose ps", "docker-compose config",
|
||||
|
||||
// Kubernetes commands
|
||||
"kubectl get", "kubectl describe", "kubectl logs", "kubectl version", "kubectl config",
|
||||
|
||||
// Rust commands
|
||||
"cargo", "rustc", "rustup",
|
||||
|
||||
// Java commands
|
||||
"java", "javac", "mvn", "gradle",
|
||||
|
||||
// Misc development tools
|
||||
"make", "cmake", "bazel", "terraform plan", "terraform validate", "ansible",
|
||||
}
|
||||
|
||||
func (b *bashTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: BashToolName,
|
||||
Description: bashDescription(),
|
||||
Parameters: map[string]any{
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The command to execute",
|
||||
},
|
||||
"timeout": map[string]any{
|
||||
"type": "number",
|
||||
"desription": "Optional timeout in milliseconds (max 600000)",
|
||||
},
|
||||
},
|
||||
Required: []string{"command"},
|
||||
}
|
||||
}
|
||||
|
||||
// Handle implements Tool.
|
||||
func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params BashParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse("invalid parameters"), nil
|
||||
}
|
||||
|
||||
if params.Timeout > MaxTimeout {
|
||||
params.Timeout = MaxTimeout
|
||||
} else if params.Timeout <= 0 {
|
||||
params.Timeout = DefaultTimeout
|
||||
}
|
||||
|
||||
if params.Command == "" {
|
||||
return NewTextErrorResponse("missing command"), nil
|
||||
}
|
||||
|
||||
// Check for banned commands (first word only)
|
||||
baseCmd := strings.Fields(params.Command)[0]
|
||||
for _, banned := range BannedCommands {
|
||||
if strings.EqualFold(baseCmd, banned) {
|
||||
return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check for safe commands (can be multi-word)
|
||||
isSafeReadOnly := false
|
||||
cmdLower := strings.ToLower(params.Command)
|
||||
|
||||
for _, safe := range SafeReadOnlyCommands {
|
||||
// Check if command starts with the safe command pattern
|
||||
if strings.HasPrefix(cmdLower, strings.ToLower(safe)) {
|
||||
// Make sure it's either an exact match or followed by a space or flag
|
||||
if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
|
||||
isSafeReadOnly = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !isSafeReadOnly {
|
||||
p := permission.Default.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
Path: config.WorkingDirectory(),
|
||||
ToolName: BashToolName,
|
||||
Action: "execute",
|
||||
Description: fmt.Sprintf("Execute command: %s", params.Command),
|
||||
Params: BashPermissionsParams{
|
||||
Command: params.Command,
|
||||
},
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return NewTextErrorResponse("permission denied"), nil
|
||||
}
|
||||
}
|
||||
shell := shell.GetPersistentShell(config.WorkingDirectory())
|
||||
stdout, stderr, exitCode, interrupted, err := shell.Exec(ctx, params.Command, params.Timeout)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error executing command: %s", err)), nil
|
||||
}
|
||||
|
||||
stdout = truncateOutput(stdout)
|
||||
stderr = truncateOutput(stderr)
|
||||
|
||||
errorMessage := stderr
|
||||
if interrupted {
|
||||
if errorMessage != "" {
|
||||
errorMessage += "\n"
|
||||
}
|
||||
errorMessage += "Command was aborted before completion"
|
||||
} else if exitCode != 0 {
|
||||
if errorMessage != "" {
|
||||
errorMessage += "\n"
|
||||
}
|
||||
errorMessage += fmt.Sprintf("Exit code %d", exitCode)
|
||||
}
|
||||
|
||||
hasBothOutputs := stdout != "" && stderr != ""
|
||||
|
||||
if hasBothOutputs {
|
||||
stdout += "\n"
|
||||
}
|
||||
|
||||
if errorMessage != "" {
|
||||
stdout += "\n" + errorMessage
|
||||
}
|
||||
|
||||
if stdout == "" {
|
||||
return NewTextResponse("no output"), nil
|
||||
}
|
||||
return NewTextResponse(stdout), nil
|
||||
}
|
||||
|
||||
func truncateOutput(content string) string {
|
||||
if len(content) <= MaxOutputLength {
|
||||
return content
|
||||
}
|
||||
|
||||
halfLength := MaxOutputLength / 2
|
||||
start := content[:halfLength]
|
||||
end := content[len(content)-halfLength:]
|
||||
|
||||
truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
|
||||
return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
|
||||
}
|
||||
|
||||
func countLines(s string) int {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
return len(strings.Split(s, "\n"))
|
||||
}
|
||||
|
||||
func bashDescription() string {
|
||||
bannedCommandsStr := strings.Join(BannedCommands, ", ")
|
||||
bannedCommandsStr := strings.Join(bannedCommands, ", ")
|
||||
return fmt.Sprintf(`Executes a given bash command in a persistent shell session with optional timeout, ensuring proper handling and security measures.
|
||||
|
||||
Before executing the command, please follow these steps:
|
||||
@@ -352,6 +198,134 @@ Important:
|
||||
- Never update git config`, bannedCommandsStr, MaxOutputLength)
|
||||
}
|
||||
|
||||
func NewBashTool() BaseTool {
|
||||
return &bashTool{}
|
||||
func NewBashTool(permission permission.Service) BaseTool {
|
||||
return &bashTool{
|
||||
permissions: permission,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bashTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: BashToolName,
|
||||
Description: bashDescription(),
|
||||
Parameters: map[string]any{
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The command to execute",
|
||||
},
|
||||
"timeout": map[string]any{
|
||||
"type": "number",
|
||||
"desription": "Optional timeout in milliseconds (max 600000)",
|
||||
},
|
||||
},
|
||||
Required: []string{"command"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params BashParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse("invalid parameters"), nil
|
||||
}
|
||||
|
||||
if params.Timeout > MaxTimeout {
|
||||
params.Timeout = MaxTimeout
|
||||
} else if params.Timeout <= 0 {
|
||||
params.Timeout = DefaultTimeout
|
||||
}
|
||||
|
||||
if params.Command == "" {
|
||||
return NewTextErrorResponse("missing command"), nil
|
||||
}
|
||||
|
||||
baseCmd := strings.Fields(params.Command)[0]
|
||||
for _, banned := range bannedCommands {
|
||||
if strings.EqualFold(baseCmd, banned) {
|
||||
return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
|
||||
}
|
||||
}
|
||||
|
||||
isSafeReadOnly := false
|
||||
cmdLower := strings.ToLower(params.Command)
|
||||
|
||||
for _, safe := range safeReadOnlyCommands {
|
||||
if strings.HasPrefix(cmdLower, strings.ToLower(safe)) {
|
||||
if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
|
||||
isSafeReadOnly = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !isSafeReadOnly {
|
||||
p := b.permissions.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
Path: config.WorkingDirectory(),
|
||||
ToolName: BashToolName,
|
||||
Action: "execute",
|
||||
Description: fmt.Sprintf("Execute command: %s", params.Command),
|
||||
Params: BashPermissionsParams{
|
||||
Command: params.Command,
|
||||
},
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return NewTextErrorResponse("permission denied"), nil
|
||||
}
|
||||
}
|
||||
shell := shell.GetPersistentShell(config.WorkingDirectory())
|
||||
stdout, stderr, exitCode, interrupted, err := shell.Exec(ctx, params.Command, params.Timeout)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error executing command: %s", err)), nil
|
||||
}
|
||||
|
||||
stdout = truncateOutput(stdout)
|
||||
stderr = truncateOutput(stderr)
|
||||
|
||||
errorMessage := stderr
|
||||
if interrupted {
|
||||
if errorMessage != "" {
|
||||
errorMessage += "\n"
|
||||
}
|
||||
errorMessage += "Command was aborted before completion"
|
||||
} else if exitCode != 0 {
|
||||
if errorMessage != "" {
|
||||
errorMessage += "\n"
|
||||
}
|
||||
errorMessage += fmt.Sprintf("Exit code %d", exitCode)
|
||||
}
|
||||
|
||||
hasBothOutputs := stdout != "" && stderr != ""
|
||||
|
||||
if hasBothOutputs {
|
||||
stdout += "\n"
|
||||
}
|
||||
|
||||
if errorMessage != "" {
|
||||
stdout += "\n" + errorMessage
|
||||
}
|
||||
|
||||
if stdout == "" {
|
||||
return NewTextResponse("no output"), nil
|
||||
}
|
||||
return NewTextResponse(stdout), nil
|
||||
}
|
||||
|
||||
func truncateOutput(content string) string {
|
||||
if len(content) <= MaxOutputLength {
|
||||
return content
|
||||
}
|
||||
|
||||
halfLength := MaxOutputLength / 2
|
||||
start := content[:halfLength]
|
||||
end := content[len(content)-halfLength:]
|
||||
|
||||
truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
|
||||
return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
|
||||
}
|
||||
|
||||
func countLines(s string) int {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
return len(strings.Split(s, "\n"))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user