diff --git a/main.py b/main.py index f604d2b..d140797 100644 --- a/main.py +++ b/main.py @@ -8,22 +8,27 @@ dotenv.load_dotenv() # Default file patterns DEFAULT_INCLUDE_PATTERNS = { - "*.py", "*.js", "*.ts", "*.go", "*.java", "*.pyi", "*.pyx", + "*.py", "*.js", "*.jsx", "*.ts", "*.tsx", "*.go", "*.java", "*.pyi", "*.pyx", "*.c", "*.cc", "*.cpp", "*.h", "*.md", "*.rst", "Dockerfile", - "Makefile", "*.yaml", "*.yml" + "Makefile", "*.yaml", "*.yml", } DEFAULT_EXCLUDE_PATTERNS = { "*test*", "tests/*", "docs/*", "examples/*", "v1/*", "dist/*", "build/*", "experimental/*", "deprecated/*", - "legacy/*", ".git/*", ".github/*" + "legacy/*", ".git/*", ".github/*", ".next/*", ".vscode/*", "obj/*", "bin/*", "node_modules/*", "*.log" } # --- Main Function --- def main(): - parser = argparse.ArgumentParser(description="Generate a tutorial for a GitHub codebase.") - parser.add_argument("repo_url", help="URL of the public GitHub repository.") - parser.add_argument("-n", "--name", help="Project name (optional, derived from URL if omitted).") + parser = argparse.ArgumentParser(description="Generate a tutorial for a GitHub codebase or local directory.") + + # Create mutually exclusive group for source + source_group = parser.add_mutually_exclusive_group(required=True) + source_group.add_argument("--repo", help="URL of the public GitHub repository.") + source_group.add_argument("--dir", help="Path to local directory.") + + parser.add_argument("-n", "--name", help="Project name (optional, derived from repo/directory if omitted).") parser.add_argument("-t", "--token", help="GitHub personal access token (optional, reads from GITHUB_TOKEN env var if not provided).") parser.add_argument("-o", "--output", default="output", help="Base directory for output (default: ./output).") parser.add_argument("-i", "--include", nargs="+", help="Include file patterns (e.g. '*.py' '*.js'). Defaults to common code files if not specified.") @@ -32,14 +37,17 @@ def main(): args = parser.parse_args() - # Get GitHub token from argument or environment variable - github_token = args.token or os.environ.get('GITHUB_TOKEN') - if not github_token: - print("Warning: No GitHub token provided. You might hit rate limits for public repositories.") + # Get GitHub token from argument or environment variable if using repo + github_token = None + if args.repo: + github_token = args.token or os.environ.get('GITHUB_TOKEN') + if not github_token: + print("Warning: No GitHub token provided. You might hit rate limits for public repositories.") # Initialize the shared dictionary with inputs shared = { - "repo_url": args.repo_url, + "repo_url": args.repo, + "local_dir": args.dir, "project_name": args.name, # Can be None, FetchRepo will derive it "github_token": github_token, "output_dir": args.output, # Base directory for CombineTutorial output @@ -58,7 +66,7 @@ def main(): "final_output_dir": None } - print(f"Starting tutorial generation for: {args.repo_url}") + print(f"Starting tutorial generation for: {args.repo or args.dir}") # Create the flow instance tutorial_flow = create_tutorial_flow() diff --git a/nodes.py b/nodes.py index 2dc6df2..290dee0 100644 --- a/nodes.py +++ b/nodes.py @@ -1,9 +1,73 @@ import os import yaml +import fnmatch from pocketflow import Node, BatchNode from utils.crawl_github_files import crawl_github_files from utils.call_llm import call_llm # Assuming you have this utility +def crawl_local_files(directory, include_patterns=None, exclude_patterns=None, max_file_size=None, use_relative_paths=True): + """ + Crawl files in a local directory with similar interface as crawl_github_files. + + Args: + directory (str): Path to local directory + include_patterns (set): File patterns to include (e.g. {"*.py", "*.js"}) + exclude_patterns (set): File patterns to exclude (e.g. {"tests/*"}) + max_file_size (int): Maximum file size in bytes + use_relative_paths (bool): Whether to use paths relative to directory + + Returns: + dict: {"files": {filepath: content}} + """ + if not os.path.isdir(directory): + raise ValueError(f"Directory does not exist: {directory}") + + files_dict = {} + + for root, _, files in os.walk(directory): + for filename in files: + filepath = os.path.join(root, filename) + + # Get path relative to directory if requested + if use_relative_paths: + relpath = os.path.relpath(filepath, directory) + else: + relpath = filepath + + # Check if file matches any include pattern + included = False + if include_patterns: + for pattern in include_patterns: + if fnmatch.fnmatch(relpath, pattern): + included = True + break + else: + included = True + + # Check if file matches any exclude pattern + excluded = False + if exclude_patterns: + for pattern in exclude_patterns: + if fnmatch.fnmatch(relpath, pattern): + excluded = True + break + + if not included or excluded: + continue + + # Check file size + if max_file_size and os.path.getsize(filepath) > max_file_size: + continue + + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + files_dict[relpath] = content + except Exception as e: + print(f"Warning: Could not read file {filepath}: {e}") + + return {"files": files_dict} + # Helper to create context from files, respecting limits (basic example) def create_llm_context(files_data): context = "" @@ -26,20 +90,26 @@ def get_content_for_indices(files_data, indices): class FetchRepo(Node): def prep(self, shared): - repo_url = shared["repo_url"] + repo_url = shared.get("repo_url") + local_dir = shared.get("local_dir") project_name = shared.get("project_name") + if not project_name: - # Basic name derivation from URL - project_name = repo_url.split('/')[-1].replace('.git', '') + # Basic name derivation from URL or directory + if repo_url: + project_name = repo_url.split('/')[-1].replace('.git', '') + else: + project_name = os.path.basename(os.path.abspath(local_dir)) shared["project_name"] = project_name - # Get file patterns directly from shared (defaults are defined in main.py) + # Get file patterns directly from shared include_patterns = shared["include_patterns"] exclude_patterns = shared["exclude_patterns"] max_file_size = shared["max_file_size"] return { "repo_url": repo_url, + "local_dir": local_dir, "token": shared.get("github_token"), "include_patterns": include_patterns, "exclude_patterns": exclude_patterns, @@ -48,15 +118,26 @@ class FetchRepo(Node): } def exec(self, prep_res): - print(f"Crawling repository: {prep_res['repo_url']}...") - result = crawl_github_files( - repo_url=prep_res["repo_url"], - token=prep_res["token"], - include_patterns=prep_res["include_patterns"], - exclude_patterns=prep_res["exclude_patterns"], - max_file_size=prep_res["max_file_size"], - use_relative_paths=prep_res["use_relative_paths"] - ) + if prep_res["repo_url"]: + print(f"Crawling repository: {prep_res['repo_url']}...") + result = crawl_github_files( + repo_url=prep_res["repo_url"], + token=prep_res["token"], + include_patterns=prep_res["include_patterns"], + exclude_patterns=prep_res["exclude_patterns"], + max_file_size=prep_res["max_file_size"], + use_relative_paths=prep_res["use_relative_paths"] + ) + else: + print(f"Crawling directory: {prep_res['local_dir']}...") + result = crawl_local_files( + directory=prep_res["local_dir"], + include_patterns=prep_res["include_patterns"], + exclude_patterns=prep_res["exclude_patterns"], + max_file_size=prep_res["max_file_size"], + use_relative_paths=prep_res["use_relative_paths"] + ) + # Convert dict to list of tuples: [(path, content), ...] files_list = list(result.get("files", {}).items()) print(f"Fetched {len(files_list)} files.")