Files
Auditor/theauditor/workset.py

377 lines
13 KiB
Python

"""Workset resolver - computes target file set from git diff and dependencies."""
import json
import os
import platform
import sqlite3
import subprocess
import tempfile
from datetime import UTC, datetime
from fnmatch import fnmatch
from pathlib import Path
from typing import Any
# Windows compatibility
IS_WINDOWS = platform.system() == "Windows"
def normalize_path(path: str) -> str:
"""Normalize path to POSIX style."""
# Replace backslashes with forward slashes
path = path.replace("\\", "/")
# Use Path to properly resolve .. and .
path = str(Path(path).as_posix())
# Remove leading ./
if path.startswith("./"):
path = path[2:]
return path
def load_manifest(manifest_path: str) -> dict[str, str]:
"""Load manifest and create path -> sha256 mapping."""
with open(manifest_path) as f:
manifest = json.load(f)
return {item["path"]: item["sha256"] for item in manifest}
def get_git_diff_files(diff_spec: str, root_path: str = ".") -> list[str]:
"""Get list of changed files from git diff."""
import tempfile
try:
# Use temp files to avoid buffer overflow
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='_stdout.txt', encoding='utf-8') as stdout_fp, \
tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='_stderr.txt', encoding='utf-8') as stderr_fp:
stdout_path = stdout_fp.name
stderr_path = stderr_fp.name
result = subprocess.run(
["git", "diff", "--name-only"] + diff_spec.split(".."),
cwd=root_path,
stdout=stdout_fp,
stderr=stderr_fp,
text=True,
encoding='utf-8',
errors='replace',
check=True,
shell=IS_WINDOWS # Windows compatibility fix
)
# Read the outputs back
with open(stdout_path, 'r', encoding='utf-8') as f:
stdout_content = f.read()
with open(stderr_path, 'r', encoding='utf-8') as f:
stderr_content = f.read()
# Clean up temp files
os.unlink(stdout_path)
os.unlink(stderr_path)
files = stdout_content.strip().split("\n") if stdout_content.strip() else []
return [normalize_path(f) for f in files]
except subprocess.CalledProcessError as e:
# Read stderr for error message
try:
with open(stderr_path, 'r', encoding='utf-8') as f:
error_msg = f.read()
except:
error_msg = 'git not available'
finally:
# Clean up temp files
if 'stdout_path' in locals() and os.path.exists(stdout_path):
os.unlink(stdout_path)
if 'stderr_path' in locals() and os.path.exists(stderr_path):
os.unlink(stderr_path)
raise RuntimeError(f"Git diff failed: {error_msg}") from e
except FileNotFoundError:
# Clean up temp files if they exist
if 'stdout_path' in locals() and os.path.exists(stdout_path):
os.unlink(stdout_path)
if 'stderr_path' in locals() and os.path.exists(stderr_path):
os.unlink(stderr_path)
raise RuntimeError("Git is not available. Use --files instead.") from None
def get_forward_deps(
conn: sqlite3.Connection, file_path: str, manifest_paths: set[str]
) -> set[str]:
"""Get files that this file imports/uses."""
cursor = conn.cursor()
cursor.execute("SELECT value FROM refs WHERE src = ? AND kind = 'from'", (file_path,))
deps = set()
for (value,) in cursor.fetchall():
# Skip certain values that are not paths
if value in ["{", "}", "(", ")", "*"] or value.startswith("'") and value.endswith("'"):
continue
# Skip external packages (starting with @ or no slashes)
if value.startswith("@"):
continue
# Clean up the value - remove quotes if present
value = value.strip("'\"")
# Try to resolve the import path
candidates = []
# If it's a relative path
if value.startswith("./") or value.startswith("../"):
# Resolve relative to file's directory
file_dir = Path(file_path).parent
# Use normpath instead of resolve to stay relative
resolved = os.path.normpath(str(file_dir / value))
resolved = normalize_path(resolved)
# Remove any leading path that's outside the repo
if resolved.startswith(".."):
continue
candidates.append(resolved)
# Try with common extensions
for ext in [".ts", ".js", ".tsx", ".jsx", ".py"]:
candidates.append(resolved + ext)
candidates.append(resolved + "/index" + ext)
elif "/" in value and not value.startswith("/"):
# Could be a project path
candidates.append(normalize_path(value))
for ext in [".ts", ".js", ".tsx", ".jsx", ".py"]:
candidates.append(normalize_path(value) + ext)
# Check if any candidate exists in manifest
for candidate in candidates:
if candidate in manifest_paths:
deps.add(candidate)
break
return deps
def get_reverse_deps(
conn: sqlite3.Connection, file_path: str, manifest_paths: set[str]
) -> set[str]:
"""Get files that import/use this file."""
cursor = conn.cursor()
# Find all refs that might point to this file
deps = set()
logged_paths = set() # Track which paths we've already logged errors for
# Get all 'from' refs
cursor.execute("SELECT src, value FROM refs WHERE kind = 'from'")
# Remove extension from target file for matching
file_path_no_ext = str(Path(file_path).with_suffix(""))
for src, value in cursor.fetchall():
if src == file_path:
continue
# Clean up the value
value = value.strip("'\"")
# Skip non-path values
if value in ["{", "}", "(", ")", "*"] or value.startswith("@"):
continue
# Try to resolve this import from the source file
if value.startswith("./") or value.startswith("../"):
# Resolve relative to source file's directory
src_dir = Path(src).parent
try:
resolved = os.path.normpath(str(src_dir / value))
resolved = normalize_path(resolved)
# Check if this resolves to our target file
if resolved in (file_path_no_ext, file_path):
deps.add(src)
continue
# Also check with common extensions
for ext in [".ts", ".js", ".tsx", ".jsx", ".py"]:
if resolved + ext == file_path:
deps.add(src)
break
except (FileNotFoundError, OSError, ValueError) as e:
# Log path resolution issue once per file
if src not in logged_paths:
logged_paths.add(src)
print(f"Debug: Could not resolve import from {src}: {type(e).__name__}")
continue
return deps
def expand_dependencies(
conn: sqlite3.Connection,
seed_files: set[str],
manifest_paths: set[str],
max_depth: int,
) -> set[str]:
"""Expand file set by following dependencies up to max_depth."""
if max_depth == 0:
return seed_files
expanded = seed_files.copy()
current_level = seed_files
for _depth in range(max_depth):
next_level = set()
for file_path in current_level:
# Forward dependencies
forward = get_forward_deps(conn, file_path, manifest_paths)
next_level.update(forward - expanded)
# Reverse dependencies
reverse = get_reverse_deps(conn, file_path, manifest_paths)
# Filter to only files in manifest
reverse = {f for f in reverse if f in manifest_paths}
next_level.update(reverse - expanded)
if not next_level:
break
expanded.update(next_level)
current_level = next_level
return expanded
def apply_glob_filters(
files: set[str],
include_patterns: list[str],
exclude_patterns: list[str],
) -> set[str]:
"""Apply include/exclude glob patterns to file set."""
if not include_patterns:
include_patterns = ["**"]
filtered = set()
for file_path in files:
# Check if file matches any include pattern
included = any(fnmatch(file_path, pattern) for pattern in include_patterns)
# Check if file matches any exclude pattern
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
if included and not excluded:
filtered.add(file_path)
return filtered
def compute_workset(
root_path: str = ".",
db_path: str = "repo_index.db",
manifest_path: str = "manifest.json",
all_files: bool = False,
diff_spec: str = None,
file_list: list[str] = None,
include_patterns: list[str] = None,
exclude_patterns: list[str] = None,
max_depth: int = 2,
output_path: str = "./.pf/workset.json",
print_stats: bool = False,
) -> dict[str, Any]:
"""Compute workset from git diff, file list, or all files."""
# Validate inputs
if sum([bool(all_files), bool(diff_spec), bool(file_list)]) > 1:
raise ValueError("Cannot specify multiple input modes (--all, --diff, --files)")
if not all_files and not diff_spec and not file_list:
raise ValueError("Must specify either --all, --diff, or --files")
# Load manifest
try:
manifest_mapping = load_manifest(manifest_path)
manifest_paths = set(manifest_mapping.keys())
except FileNotFoundError:
# Check if user is in wrong directory
cwd = Path.cwd()
helpful_msg = f"Manifest not found at {manifest_path}. Run 'aud index' first."
if cwd.name in ["Desktop", "Documents", "Downloads"]:
helpful_msg += f"\n\nAre you in the right directory? You're in: {cwd}"
helpful_msg += "\nTry: cd <your-project-folder> then run this command again"
raise RuntimeError(helpful_msg) from None
# Connect to database
if not Path(db_path).exists():
raise RuntimeError(f"Database not found at {db_path}. Run 'aud index' first.")
conn = sqlite3.connect(db_path)
# Get seed files
seed_files = set()
seed_mode = None
seed_value = None
if all_files:
seed_mode = "all"
seed_value = "all_indexed_files"
# Use all files from manifest
seed_files = manifest_paths.copy()
# No dependency expansion needed for all files
max_depth = 0
elif diff_spec:
seed_mode = "diff"
seed_value = diff_spec
diff_files = get_git_diff_files(diff_spec, root_path)
# Filter to files in manifest
seed_files = {f for f in diff_files if f in manifest_paths}
else:
seed_mode = "files"
seed_value = ",".join(file_list)
# Normalize and filter to manifest
seed_files = {normalize_path(f) for f in file_list if normalize_path(f) in manifest_paths}
# Expand dependencies
expanded_files = expand_dependencies(conn, seed_files, manifest_paths, max_depth)
# Apply filters
filtered_files = apply_glob_filters(
expanded_files,
include_patterns or [],
exclude_patterns or [],
)
# Sort for deterministic output
sorted_files = sorted(filtered_files)
# Build output
workset_data = {
"generated_at": datetime.now(UTC).isoformat(),
"root": root_path,
"seed": {"mode": seed_mode, "value": seed_value},
"max_depth": max_depth,
"counts": {
"seed_files": len(seed_files),
"expanded_files": len(sorted_files),
},
"paths": [{"path": path, "sha256": manifest_mapping[path]} for path in sorted_files],
}
# Create output directory if needed
output_dir = Path(output_path).parent
output_dir.mkdir(parents=True, exist_ok=True)
# Write output
with open(output_path, "w") as f:
json.dump(workset_data, f, indent=2)
if print_stats:
include_count = len(include_patterns) if include_patterns else 0
exclude_count = len(exclude_patterns) if exclude_patterns else 0
print(
f"seed={len(seed_files)} expanded={len(sorted_files)} depth={max_depth} include={include_count} exclude={exclude_count}"
)
conn.close()
return {
"success": True,
"seed_count": len(seed_files),
"expanded_count": len(sorted_files),
"output_path": output_path,
}