Merge pull request #4 from mathd/github_ssh_repo

support downloading GitHub repo with ssh
This commit is contained in:
Zach
2025-04-07 21:38:09 -04:00
committed by GitHub
2 changed files with 101 additions and 7 deletions

View File

@@ -1,6 +1,7 @@
pocketflow>=0.0.1 pocketflow>=0.0.1
pyyaml>=6.0 pyyaml>=6.0
requests>=2.28.0 requests>=2.28.0
gitpython>=3.1.0
google-cloud-aiplatform>=1.25.0 google-cloud-aiplatform>=1.25.0
google-genai>=1.9.0 google-genai>=1.9.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0

View File

@@ -1,6 +1,8 @@
import requests import requests
import base64 import base64
import os import os
import tempfile
import git
import time import time
import fnmatch import fnmatch
from typing import Union, Set, List, Dict, Tuple, Any from typing import Union, Set, List, Dict, Tuple, Any
@@ -16,18 +18,21 @@ def crawl_github_files(
): ):
""" """
Crawl files from a specific path in a GitHub repository at a specific commit. Crawl files from a specific path in a GitHub repository at a specific commit.
Args: Args:
repo_url (str): URL of the GitHub repository with specific path and commit repo_url (str): URL of the GitHub repository with specific path and commit
(e.g., 'https://github.com/microsoft/autogen/tree/e45a15766746d95f8cfaaa705b0371267bec812e/python/packages/autogen-core/src/autogen_core') (e.g., 'https://github.com/microsoft/autogen/tree/e45a15766746d95f8cfaaa705b0371267bec812e/python/packages/autogen-core/src/autogen_core')
token (str, optional): GitHub personal access token. Required for private repositories and recommended for public repos to avoid rate limits. token (str, optional): **GitHub personal access token.**
- **Required for private repositories.**
- **Recommended for public repos to avoid rate limits.**
- Can be passed explicitly or set via the `GITHUB_TOKEN` environment variable.
max_file_size (int, optional): Maximum file size in bytes to download (default: 1 MB) max_file_size (int, optional): Maximum file size in bytes to download (default: 1 MB)
use_relative_paths (bool, optional): If True, file paths will be relative to the specified subdirectory use_relative_paths (bool, optional): If True, file paths will be relative to the specified subdirectory
include_patterns (str or set of str, optional): Pattern or set of patterns specifying which files to include (e.g., "*.py", {"*.md", "*.txt"}). include_patterns (str or set of str, optional): Pattern or set of patterns specifying which files to include (e.g., "*.py", {"*.md", "*.txt"}).
If None, all files are included. If None, all files are included.
exclude_patterns (str or set of str, optional): Pattern or set of patterns specifying which files to exclude. exclude_patterns (str or set of str, optional): Pattern or set of patterns specifying which files to exclude.
If None, no files are excluded. If None, no files are excluded.
Returns: Returns:
dict: Dictionary with files and statistics dict: Dictionary with files and statistics
""" """
@@ -36,7 +41,89 @@ def crawl_github_files(
include_patterns = {include_patterns} include_patterns = {include_patterns}
if exclude_patterns and isinstance(exclude_patterns, str): if exclude_patterns and isinstance(exclude_patterns, str):
exclude_patterns = {exclude_patterns} exclude_patterns = {exclude_patterns}
def should_include_file(file_path: str, file_name: str) -> bool:
"""Determine if a file should be included based on patterns"""
# If no include patterns are specified, include all files
if not include_patterns:
include_file = True
else:
# Check if file matches any include pattern
include_file = any(fnmatch.fnmatch(file_name, pattern) for pattern in include_patterns)
# If exclude patterns are specified, check if file should be excluded
if exclude_patterns and include_file:
# Exclude if file matches any exclude pattern
exclude_file = any(fnmatch.fnmatch(file_path, pattern) for pattern in exclude_patterns)
return not exclude_file
return include_file
# Detect SSH URL (git@ or .git suffix)
is_ssh_url = repo_url.startswith("git@") or repo_url.endswith(".git")
if is_ssh_url:
# Clone repo via SSH to temp dir
with tempfile.TemporaryDirectory() as tmpdirname:
print(f"Cloning SSH repo {repo_url} to temp dir {tmpdirname} ...")
try:
repo = git.Repo.clone_from(repo_url, tmpdirname)
except Exception as e:
print(f"Error cloning repo: {e}")
return {"files": {}, "stats": {"error": str(e)}}
# Attempt to checkout specific commit/branch if in URL
# Parse ref and subdir from SSH URL? SSH URLs don't have branch info embedded
# So rely on default branch, or user can checkout manually later
# Optionally, user can pass ref explicitly in future API
# Walk directory
files = {}
skipped_files = []
for root, dirs, filenames in os.walk(tmpdirname):
for filename in filenames:
abs_path = os.path.join(root, filename)
rel_path = os.path.relpath(abs_path, tmpdirname)
# Check file size
try:
file_size = os.path.getsize(abs_path)
except OSError:
continue
if file_size > max_file_size:
skipped_files.append((rel_path, file_size))
print(f"Skipping {rel_path}: size {file_size} exceeds limit {max_file_size}")
continue
# Check include/exclude patterns
if not should_include_file(rel_path, filename):
print(f"Skipping {rel_path}: does not match include/exclude patterns")
continue
# Read content
try:
with open(abs_path, "r", encoding="utf-8") as f:
content = f.read()
files[rel_path] = content
print(f"Added {rel_path} ({file_size} bytes)")
except Exception as e:
print(f"Failed to read {rel_path}: {e}")
return {
"files": files,
"stats": {
"downloaded_count": len(files),
"skipped_count": len(skipped_files),
"skipped_files": skipped_files,
"base_path": None,
"include_patterns": include_patterns,
"exclude_patterns": exclude_patterns,
"source": "ssh_clone"
}
}
# Parse GitHub URL to extract owner, repo, commit/branch, and path # Parse GitHub URL to extract owner, repo, commit/branch, and path
parsed_url = urlparse(repo_url) parsed_url = urlparse(repo_url)
path_parts = parsed_url.path.strip('/').split('/') path_parts = parsed_url.path.strip('/').split('/')
@@ -101,9 +188,11 @@ def crawl_github_files(
if response.status_code == 404: if response.status_code == 404:
if not token: if not token:
print(f"Error 404: Repository not found or is private. If this is a private repository, you need to provide a token.") print(f"Error 404: Repository not found or is private.\n"
f"If this is a private repository, please provide a valid GitHub token via the 'token' argument or set the GITHUB_TOKEN environment variable.")
else: else:
print(f"Error 404: Path '{path}' not found in repository or insufficient permissions.") print(f"Error 404: Path '{path}' not found in repository or insufficient permissions with the provided token.\n"
f"Please verify the token has access to this repository and the path exists.")
return return
if response.status_code != 200: if response.status_code != 200:
@@ -201,8 +290,12 @@ def crawl_github_files(
# Example usage # Example usage
if __name__ == "__main__": if __name__ == "__main__":
# Get token from environment variable (more secure than hardcoding) # Get token from environment variable (recommended for private repos)
github_token = os.environ.get("GITHUB_TOKEN") github_token = os.environ.get("GITHUB_TOKEN")
if not github_token:
print("Warning: No GitHub token found in environment variable 'GITHUB_TOKEN'.\n"
"Private repositories will not be accessible without a token.\n"
"To access private repos, set the environment variable or pass the token explicitly.")
repo_url = "https://github.com/pydantic/pydantic/tree/6c38dc93f40a47f4d1350adca9ec0d72502e223f/pydantic" repo_url = "https://github.com/pydantic/pydantic/tree/6c38dc93f40a47f4d1350adca9ec0d72502e223f/pydantic"