Made the source argument mutually exclusive - you must specify either --repo or --dir

Added a new crawl_local_files() function that mimics the interface of crawl_github_files()
Modified the FetchRepo node to handle both cases
Project name is now derived from either:
The repository name (from GitHub URL)
The directory name (from local path)
Or can be manually specified with -n/--name
The tool will use the same file pattern matching and size limits for both sources. All other functionality (generating abstractions, relationships, chapters, etc.) remains unchanged since they work with the same file list format.
This commit is contained in:
SpeedOfSpin
2025-04-08 18:27:32 +00:00
parent 60b5467f68
commit b6ab52aaa1
2 changed files with 114 additions and 25 deletions

107
nodes.py
View File

@@ -1,9 +1,73 @@
import os
import yaml
import fnmatch
from pocketflow import Node, BatchNode
from utils.crawl_github_files import crawl_github_files
from utils.call_llm import call_llm # Assuming you have this utility
def crawl_local_files(directory, include_patterns=None, exclude_patterns=None, max_file_size=None, use_relative_paths=True):
"""
Crawl files in a local directory with similar interface as crawl_github_files.
Args:
directory (str): Path to local directory
include_patterns (set): File patterns to include (e.g. {"*.py", "*.js"})
exclude_patterns (set): File patterns to exclude (e.g. {"tests/*"})
max_file_size (int): Maximum file size in bytes
use_relative_paths (bool): Whether to use paths relative to directory
Returns:
dict: {"files": {filepath: content}}
"""
if not os.path.isdir(directory):
raise ValueError(f"Directory does not exist: {directory}")
files_dict = {}
for root, _, files in os.walk(directory):
for filename in files:
filepath = os.path.join(root, filename)
# Get path relative to directory if requested
if use_relative_paths:
relpath = os.path.relpath(filepath, directory)
else:
relpath = filepath
# Check if file matches any include pattern
included = False
if include_patterns:
for pattern in include_patterns:
if fnmatch.fnmatch(relpath, pattern):
included = True
break
else:
included = True
# Check if file matches any exclude pattern
excluded = False
if exclude_patterns:
for pattern in exclude_patterns:
if fnmatch.fnmatch(relpath, pattern):
excluded = True
break
if not included or excluded:
continue
# Check file size
if max_file_size and os.path.getsize(filepath) > max_file_size:
continue
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
files_dict[relpath] = content
except Exception as e:
print(f"Warning: Could not read file {filepath}: {e}")
return {"files": files_dict}
# Helper to create context from files, respecting limits (basic example)
def create_llm_context(files_data):
context = ""
@@ -26,20 +90,26 @@ def get_content_for_indices(files_data, indices):
class FetchRepo(Node):
def prep(self, shared):
repo_url = shared["repo_url"]
repo_url = shared.get("repo_url")
local_dir = shared.get("local_dir")
project_name = shared.get("project_name")
if not project_name:
# Basic name derivation from URL
project_name = repo_url.split('/')[-1].replace('.git', '')
# Basic name derivation from URL or directory
if repo_url:
project_name = repo_url.split('/')[-1].replace('.git', '')
else:
project_name = os.path.basename(os.path.abspath(local_dir))
shared["project_name"] = project_name
# Get file patterns directly from shared (defaults are defined in main.py)
# Get file patterns directly from shared
include_patterns = shared["include_patterns"]
exclude_patterns = shared["exclude_patterns"]
max_file_size = shared["max_file_size"]
return {
"repo_url": repo_url,
"local_dir": local_dir,
"token": shared.get("github_token"),
"include_patterns": include_patterns,
"exclude_patterns": exclude_patterns,
@@ -48,15 +118,26 @@ class FetchRepo(Node):
}
def exec(self, prep_res):
print(f"Crawling repository: {prep_res['repo_url']}...")
result = crawl_github_files(
repo_url=prep_res["repo_url"],
token=prep_res["token"],
include_patterns=prep_res["include_patterns"],
exclude_patterns=prep_res["exclude_patterns"],
max_file_size=prep_res["max_file_size"],
use_relative_paths=prep_res["use_relative_paths"]
)
if prep_res["repo_url"]:
print(f"Crawling repository: {prep_res['repo_url']}...")
result = crawl_github_files(
repo_url=prep_res["repo_url"],
token=prep_res["token"],
include_patterns=prep_res["include_patterns"],
exclude_patterns=prep_res["exclude_patterns"],
max_file_size=prep_res["max_file_size"],
use_relative_paths=prep_res["use_relative_paths"]
)
else:
print(f"Crawling directory: {prep_res['local_dir']}...")
result = crawl_local_files(
directory=prep_res["local_dir"],
include_patterns=prep_res["include_patterns"],
exclude_patterns=prep_res["exclude_patterns"],
max_file_size=prep_res["max_file_size"],
use_relative_paths=prep_res["use_relative_paths"]
)
# Convert dict to list of tuples: [(path, content), ...]
files_list = list(result.get("files", {}).items())
print(f"Fetched {len(files_list)} files.")