"""Tree-sitter generic AST extraction implementations. This module contains Tree-sitter extraction logic that works across multiple languages. """ from typing import Any, List, Dict, Optional from .base import ( find_containing_function_tree_sitter, extract_vars_from_tree_sitter_expr ) def extract_treesitter_functions(tree: Dict, parser_self, language: str) -> List[Dict]: """Extract function definitions from Tree-sitter AST.""" actual_tree = tree.get("tree") if not actual_tree: return [] if not parser_self.has_tree_sitter: return [] return _extract_tree_sitter_functions(actual_tree.root_node, language) def _extract_tree_sitter_functions(node: Any, language: str) -> List[Dict]: """Extract functions from Tree-sitter AST.""" functions = [] if node is None: return functions # Function node types per language function_types = { "python": ["function_definition"], "javascript": ["function_declaration", "arrow_function", "function_expression", "method_definition"], "typescript": ["function_declaration", "arrow_function", "function_expression", "method_definition"], } node_types = function_types.get(language, []) if node.type in node_types: # Extract function name name = "anonymous" for child in node.children: if child.type in ["identifier", "property_identifier"]: name = child.text.decode("utf-8", errors="ignore") break functions.append({ "name": name, "line": node.start_point[0] + 1, "type": node.type, }) # Recursively search children for child in node.children: functions.extend(_extract_tree_sitter_functions(child, language)) return functions def extract_treesitter_classes(tree: Dict, parser_self, language: str) -> List[Dict]: """Extract class definitions from Tree-sitter AST.""" actual_tree = tree.get("tree") if not actual_tree: return [] if not parser_self.has_tree_sitter: return [] return _extract_tree_sitter_classes(actual_tree.root_node, language) def _extract_tree_sitter_classes(node: Any, language: str) -> List[Dict]: """Extract classes from Tree-sitter AST.""" classes = [] if node is None: return classes # Class node types per language class_types = { "python": ["class_definition"], "javascript": ["class_declaration"], "typescript": ["class_declaration", "interface_declaration"], } node_types = class_types.get(language, []) if node.type in node_types: # Extract class name name = "anonymous" for child in node.children: if child.type in ["identifier", "type_identifier"]: name = child.text.decode("utf-8", errors="ignore") break classes.append({ "name": name, "line": node.start_point[0] + 1, "column": node.start_point[1], "type": node.type, }) # Recursively search children for child in node.children: classes.extend(_extract_tree_sitter_classes(child, language)) return classes def extract_treesitter_calls(tree: Dict, parser_self, language: str) -> List[Dict]: """Extract function calls from Tree-sitter AST.""" actual_tree = tree.get("tree") if not actual_tree: return [] if not parser_self.has_tree_sitter: return [] return _extract_tree_sitter_calls(actual_tree.root_node, language) def _extract_tree_sitter_calls(node: Any, language: str) -> List[Dict]: """Extract function calls from Tree-sitter AST.""" calls = [] if node is None: return calls # Call node types per language call_types = { "python": ["call"], "javascript": ["call_expression"], "typescript": ["call_expression"], } node_types = call_types.get(language, []) if node.type in node_types: # Extract function name being called name = "unknown" for child in node.children: if child.type in ["identifier", "member_expression", "attribute"]: name = child.text.decode("utf-8", errors="ignore") break # Also handle property access patterns for methods like res.send() elif child.type == "member_access_expression": name = child.text.decode("utf-8", errors="ignore") break calls.append({ "name": name, "line": node.start_point[0] + 1, "column": node.start_point[1], "type": "call", # Always use "call" type for database consistency }) # Recursively search children for child in node.children: calls.extend(_extract_tree_sitter_calls(child, language)) return calls def extract_treesitter_imports(tree: Dict, parser_self, language: str) -> List[Dict[str, Any]]: """Extract import statements from Tree-sitter AST.""" actual_tree = tree.get("tree") if not actual_tree: return [] if not parser_self.has_tree_sitter: return [] return _extract_tree_sitter_imports(actual_tree.root_node, language) def _extract_tree_sitter_imports(node: Any, language: str) -> List[Dict[str, Any]]: """Extract imports from Tree-sitter AST with language-specific handling.""" imports = [] if node is None: return imports # Import node types per language import_types = { "javascript": ["import_statement", "import_clause", "require_call"], "typescript": ["import_statement", "import_clause", "require_call", "import_type"], "python": ["import_statement", "import_from_statement"], } node_types = import_types.get(language, []) if node.type in node_types: # Parse based on node type if node.type == "import_statement": # Handle: import foo from 'bar' source_node = None specifiers = [] for child in node.children: if child.type == "string": source_node = child.text.decode("utf-8", errors="ignore").strip("\"'") elif child.type == "import_clause": # Extract imported names for spec_child in child.children: if spec_child.type == "identifier": specifiers.append(spec_child.text.decode("utf-8", errors="ignore")) if source_node: imports.append({ "source": "import", "target": source_node, "type": "import", "line": node.start_point[0] + 1, "specifiers": specifiers }) elif node.type == "require_call": # Handle: const foo = require('bar') for child in node.children: if child.type == "string": target = child.text.decode("utf-8", errors="ignore").strip("\"'") imports.append({ "source": "require", "target": target, "type": "require", "line": node.start_point[0] + 1, "specifiers": [] }) # Recursively search children for child in node.children: imports.extend(_extract_tree_sitter_imports(child, language)) return imports def extract_treesitter_exports(tree: Dict, parser_self, language: str) -> List[Dict[str, Any]]: """Extract export statements from Tree-sitter AST.""" actual_tree = tree.get("tree") if not actual_tree: return [] if not parser_self.has_tree_sitter: return [] return _extract_tree_sitter_exports(actual_tree.root_node, language) def _extract_tree_sitter_exports(node: Any, language: str) -> List[Dict[str, Any]]: """Extract exports from Tree-sitter AST.""" exports = [] if node is None: return exports # Export node types per language export_types = { "javascript": ["export_statement", "export_default_declaration"], "typescript": ["export_statement", "export_default_declaration", "export_type"], } node_types = export_types.get(language, []) if node.type in node_types: is_default = "default" in node.type # Extract exported name name = "unknown" export_type = "unknown" for child in node.children: if child.type in ["identifier", "type_identifier"]: name = child.text.decode("utf-8", errors="ignore") elif child.type == "function_declaration": export_type = "function" for subchild in child.children: if subchild.type == "identifier": name = subchild.text.decode("utf-8", errors="ignore") break elif child.type == "class_declaration": export_type = "class" for subchild in child.children: if subchild.type in ["identifier", "type_identifier"]: name = subchild.text.decode("utf-8", errors="ignore") break exports.append({ "name": name, "type": export_type, "line": node.start_point[0] + 1, "default": is_default }) # Recursively search children for child in node.children: exports.extend(_extract_tree_sitter_exports(child, language)) return exports def extract_treesitter_properties(tree: Dict, parser_self, language: str) -> List[Dict]: """Extract property accesses from Tree-sitter AST.""" actual_tree = tree.get("tree") if not actual_tree: return [] if not parser_self.has_tree_sitter: return [] return _extract_tree_sitter_properties(actual_tree.root_node, language) def _extract_tree_sitter_properties(node: Any, language: str) -> List[Dict]: """Extract property accesses from Tree-sitter AST.""" properties = [] if node is None: return properties # Property access node types per language property_types = { "javascript": ["member_expression", "property_access_expression"], "typescript": ["member_expression", "property_access_expression"], "python": ["attribute"], } node_types = property_types.get(language, []) if node.type in node_types: # Extract the full property access chain prop_text = node.text.decode("utf-8", errors="ignore") if node.text else "" # Filter for patterns that look like taint sources (req.*, request.*, ctx.*, etc.) if any(pattern in prop_text for pattern in ["req.", "request.", "ctx.", "body", "query", "params", "headers", "cookies"]): properties.append({ "name": prop_text, "line": node.start_point[0] + 1, "column": node.start_point[1], "type": "property" }) # Recursively search children for child in node.children: properties.extend(_extract_tree_sitter_properties(child, language)) return properties def extract_treesitter_assignments(tree: Dict, parser_self, language: str) -> List[Dict[str, Any]]: """Extract variable assignments from Tree-sitter AST.""" actual_tree = tree.get("tree") content = tree.get("content", "") if not actual_tree: return [] if not parser_self.has_tree_sitter: return [] return _extract_tree_sitter_assignments(actual_tree.root_node, language, content) def _extract_tree_sitter_assignments(node: Any, language: str, content: str) -> List[Dict[str, Any]]: """Extract assignments from Tree-sitter AST.""" import os import sys debug = os.environ.get("THEAUDITOR_DEBUG") assignments = [] if node is None: return assignments # Assignment node types per language assignment_types = { # Don't include variable_declarator - it's handled inside lexical_declaration/variable_declaration "javascript": ["assignment_expression", "lexical_declaration", "variable_declaration"], "typescript": ["assignment_expression", "lexical_declaration", "variable_declaration"], "python": ["assignment"], } node_types = assignment_types.get(language, []) if node.type in node_types: target_var = None source_expr = None source_vars = [] if node.type in ["lexical_declaration", "variable_declaration"]: # Handle lexical_declaration (const/let) and variable_declaration (var) # Both contain variable_declarator children # Process all variable_declarators within (const a = 1, b = 2) for child in node.children: if child.type == "variable_declarator": name_node = child.child_by_field_name('name') value_node = child.child_by_field_name('value') if name_node and value_node: in_function = find_containing_function_tree_sitter(child, content, language) or "global" if debug: print(f"[DEBUG] Found assignment: {name_node.text.decode('utf-8')} = {value_node.text.decode('utf-8')[:50]}", file=sys.stderr) assignments.append({ "target_var": name_node.text.decode("utf-8", errors="ignore"), "source_expr": value_node.text.decode("utf-8", errors="ignore"), "line": child.start_point[0] + 1, "in_function": in_function, "source_vars": extract_vars_from_tree_sitter_expr( value_node.text.decode("utf-8", errors="ignore") ) }) elif node.type == "assignment_expression": # x = value (JavaScript/TypeScript) - Use field-based API left_node = node.child_by_field_name('left') right_node = node.child_by_field_name('right') if left_node: target_var = left_node.text.decode("utf-8", errors="ignore") if right_node: source_expr = right_node.text.decode("utf-8", errors="ignore") source_vars = extract_vars_from_tree_sitter_expr(source_expr) elif node.type == "assignment": # x = value (Python) # Python assignment has structure: [target, "=", value] left_node = None right_node = None for child in node.children: if child.type != "=" and left_node is None: left_node = child elif child.type != "=" and left_node is not None: right_node = child if left_node: target_var = left_node.text.decode("utf-8", errors="ignore") if left_node.text else "" if right_node: source_expr = right_node.text.decode("utf-8", errors="ignore") if right_node.text else "" # Only create assignment record if we have both target and source # (Skip lexical_declaration/variable_declaration as they're handled above with their children) if target_var and source_expr and node.type not in ["lexical_declaration", "variable_declaration"]: # Find containing function in_function = find_containing_function_tree_sitter(node, content, language) assignments.append({ "target_var": target_var, "source_expr": source_expr, "line": node.start_point[0] + 1, "in_function": in_function or "global", "source_vars": source_vars if source_vars else extract_vars_from_tree_sitter_expr(source_expr) }) # Recursively search children for child in node.children: assignments.extend(_extract_tree_sitter_assignments(child, language, content)) return assignments def extract_treesitter_function_params(tree: Dict, parser_self, language: str) -> Dict[str, List[str]]: """Extract function parameters from Tree-sitter AST.""" actual_tree = tree.get("tree") if not actual_tree: return {} if not parser_self.has_tree_sitter: return {} return _extract_tree_sitter_function_params(actual_tree.root_node, language) def _extract_tree_sitter_function_params(node: Any, language: str) -> Dict[str, List[str]]: """Extract function parameters from Tree-sitter AST.""" func_params = {} if node is None: return func_params # Function definition node types if language in ["javascript", "typescript"]: if node.type in ["function_declaration", "function_expression", "arrow_function", "method_definition"]: func_name = "anonymous" params = [] # Use field-based API for function nodes name_node = node.child_by_field_name('name') params_node = node.child_by_field_name('parameters') if name_node: func_name = name_node.text.decode("utf-8", errors="ignore") # Fall back to child iteration if field access fails if not params_node: for child in node.children: if child.type in ["formal_parameters", "parameters"]: params_node = child break if params_node: # Extract parameter names for param_child in params_node.children: if param_child.type in ["identifier", "required_parameter", "optional_parameter"]: if param_child.type == "identifier": params.append(param_child.text.decode("utf-8", errors="ignore")) else: # For required/optional parameters, use field API pattern_node = param_child.child_by_field_name('pattern') if pattern_node and pattern_node.type == "identifier": params.append(pattern_node.text.decode("utf-8", errors="ignore")) if func_name and params: func_params[func_name] = params elif language == "python": if node.type == "function_definition": func_name = None params = [] for child in node.children: if child.type == "identifier": func_name = child.text.decode("utf-8", errors="ignore") elif child.type == "parameters": for param_child in child.children: if param_child.type == "identifier": params.append(param_child.text.decode("utf-8", errors="ignore")) if func_name: func_params[func_name] = params # Recursively search children for child in node.children: func_params.update(_extract_tree_sitter_function_params(child, language)) return func_params def extract_treesitter_calls_with_args( tree: Dict, function_params: Dict[str, List[str]], parser_self, language: str ) -> List[Dict[str, Any]]: """Extract function calls with arguments from Tree-sitter AST.""" actual_tree = tree.get("tree") content = tree.get("content", "") if not actual_tree: return [] if not parser_self.has_tree_sitter: return [] return _extract_tree_sitter_calls_with_args( actual_tree.root_node, language, content, function_params ) def _extract_tree_sitter_calls_with_args( node: Any, language: str, content: str, function_params: Dict[str, List[str]] ) -> List[Dict[str, Any]]: """Extract function calls with arguments from Tree-sitter AST.""" calls = [] if node is None: return calls # Call expression node types if language in ["javascript", "typescript"] and node.type == "call_expression": # Extract function name using field-based API func_node = node.child_by_field_name('function') func_name = "unknown" if func_node: func_name = func_node.text.decode("utf-8", errors="ignore") if func_node.text else "unknown" else: # Fallback to child iteration for child in node.children: if child.type in ["identifier", "member_expression"]: func_name = child.text.decode("utf-8", errors="ignore") if child.text else "unknown" break # Find caller function caller_function = find_containing_function_tree_sitter(node, content, language) or "global" # Get callee parameters callee_params = function_params.get(func_name.split(".")[-1], []) # Extract arguments using field-based API args_node = node.child_by_field_name('arguments') arg_index = 0 if args_node: for arg_child in args_node.children: if arg_child.type not in ["(", ")", ","]: arg_expr = arg_child.text.decode("utf-8", errors="ignore") if arg_child.text else "" param_name = callee_params[arg_index] if arg_index < len(callee_params) else f"arg{arg_index}" calls.append({ "line": node.start_point[0] + 1, "caller_function": caller_function, "callee_function": func_name, "argument_index": arg_index, "argument_expr": arg_expr, "param_name": param_name }) arg_index += 1 elif language == "python" and node.type == "call": # Similar logic for Python func_name = "unknown" for child in node.children: if child.type in ["identifier", "attribute"]: func_name = child.text.decode("utf-8", errors="ignore") if child.text else "unknown" break caller_function = find_containing_function_tree_sitter(node, content, language) or "global" callee_params = function_params.get(func_name.split(".")[-1], []) arg_index = 0 for child in node.children: if child.type == "argument_list": for arg_child in child.children: if arg_child.type not in ["(", ")", ","]: arg_expr = arg_child.text.decode("utf-8", errors="ignore") if arg_child.text else "" param_name = callee_params[arg_index] if arg_index < len(callee_params) else f"arg{arg_index}" calls.append({ "line": node.start_point[0] + 1, "caller_function": caller_function, "callee_function": func_name, "argument_index": arg_index, "argument_expr": arg_expr, "param_name": param_name }) arg_index += 1 # Recursively search children for child in node.children: calls.extend(_extract_tree_sitter_calls_with_args(child, language, content, function_params)) return calls def extract_treesitter_returns(tree: Dict, parser_self, language: str) -> List[Dict[str, Any]]: """Extract return statements from Tree-sitter AST.""" actual_tree = tree.get("tree") content = tree.get("content", "") if not actual_tree: return [] if not parser_self.has_tree_sitter: return [] return _extract_tree_sitter_returns(actual_tree.root_node, language, content) def _extract_tree_sitter_returns(node: Any, language: str, content: str) -> List[Dict[str, Any]]: """Extract return statements from Tree-sitter AST.""" returns = [] if node is None: return returns # Return statement node types if language in ["javascript", "typescript"] and node.type == "return_statement": # Find containing function function_name = find_containing_function_tree_sitter(node, content, language) or "global" # Extract return expression return_expr = "" for child in node.children: if child.type != "return": return_expr = child.text.decode("utf-8", errors="ignore") if child.text else "" break if not return_expr: return_expr = "undefined" returns.append({ "function_name": function_name, "line": node.start_point[0] + 1, "return_expr": return_expr, "return_vars": extract_vars_from_tree_sitter_expr(return_expr) }) elif language == "python" and node.type == "return_statement": # Find containing function function_name = find_containing_function_tree_sitter(node, content, language) or "global" # Extract return expression return_expr = "" for child in node.children: if child.type != "return": return_expr = child.text.decode("utf-8", errors="ignore") if child.text else "" break if not return_expr: return_expr = "None" returns.append({ "function_name": function_name, "line": node.start_point[0] + 1, "return_expr": return_expr, "return_vars": extract_vars_from_tree_sitter_expr(return_expr) }) # Recursively search children for child in node.children: returns.extend(_extract_tree_sitter_returns(child, language, content)) return returns