diff --git a/src/analysis/analyzer.py b/src/analysis/analyzer.py index 1a5afd4..d1bb6ab 100644 --- a/src/analysis/analyzer.py +++ b/src/analysis/analyzer.py @@ -2,7 +2,7 @@ import logging from typing import List, Dict, Any, Optional, Tuple -from datetime import datetime +from datetime import datetime, timedelta import numpy as np from rich.console import Console from rich.table import Table @@ -18,9 +18,10 @@ console = Console() class ChannelMetrics: """Calculated metrics for a channel""" - - def __init__(self, channel: Channel): + + def __init__(self, channel: Channel, config: Optional[Config] = None): self.channel = channel + self.config = config self.calculate_metrics() def calculate_metrics(self): @@ -58,49 +59,66 @@ class ChannelMetrics: self.roi = float('inf') # Performance scores - self.profitability_score = self._calculate_profitability_score() - self.activity_score = self._calculate_activity_score() - self.efficiency_score = self._calculate_efficiency_score() + self.profitability_score = self._calculate_profitability_score(self.config) + self.activity_score = self._calculate_activity_score(self.config) + self.efficiency_score = self._calculate_efficiency_score(self.config) self.flow_efficiency = self._calculate_flow_efficiency() self.overall_score = (self.profitability_score + self.activity_score + self.efficiency_score) / 3 - def _calculate_profitability_score(self) -> float: + def _calculate_profitability_score(self, config: Optional[Config] = None) -> float: """Score based on net profit and ROI (0-100)""" if self.net_profit <= 0: return 0 - - # Normalize profit (assume 10k sats/month is excellent) - profit_score = min(100, (self.net_profit / 10000) * 100) - - # ROI score (assume 200% ROI is excellent) - roi_score = min(100, (self.roi / 2.0) * 100) if self.roi != float('inf') else 100 - + + # Get thresholds from config or use defaults + excellent_profit = 10000 + excellent_roi = 2.0 + if config: + excellent_profit = config.optimization.excellent_monthly_profit_sats + excellent_roi = config.optimization.excellent_roi_ratio + + # Normalize profit + profit_score = min(100, (self.net_profit / excellent_profit) * 100) + + # ROI score + roi_score = min(100, (self.roi / excellent_roi) * 100) if self.roi != float('inf') else 100 + return (profit_score + roi_score) / 2 - def _calculate_activity_score(self) -> float: + def _calculate_activity_score(self, config: Optional[Config] = None) -> float: """Score based on flow volume and consistency (0-100)""" if self.monthly_flow == 0: return 0 - - # Normalize flow (assume 10M sats/month is excellent) - flow_score = min(100, (self.monthly_flow / 10_000_000) * 100) - + + # Get threshold from config or use default + excellent_flow = 10_000_000 + if config: + excellent_flow = config.optimization.excellent_monthly_flow_sats + + # Normalize flow + flow_score = min(100, (self.monthly_flow / excellent_flow) * 100) + # Balance score (perfect balance = 100) balance_score = (1 - self.flow_imbalance) * 100 - + return (flow_score + balance_score) / 2 - def _calculate_efficiency_score(self) -> float: + def _calculate_efficiency_score(self, config: Optional[Config] = None) -> float: """Score based on earnings efficiency (0-100)""" - # Earnings per million sats routed (assume 1000 ppm is excellent) - efficiency = min(100, (self.earnings_per_million / 1000) * 100) - + # Get threshold from config or use default + excellent_earnings_ppm = 1000 + if config: + excellent_earnings_ppm = config.optimization.excellent_earnings_per_million_ppm + + # Earnings per million sats routed + efficiency = min(100, (self.earnings_per_million / excellent_earnings_ppm) * 100) + # Penalty for high rebalance costs if self.monthly_earnings > 0: cost_ratio = self.rebalance_costs / self.monthly_earnings cost_penalty = max(0, 1 - cost_ratio) * 100 return (efficiency + cost_penalty) / 2 - + return efficiency def _calculate_flow_efficiency(self) -> float: @@ -114,39 +132,94 @@ class ChannelMetrics: class ChannelAnalyzer: """Analyze channel performance and prepare optimization data""" - - def __init__(self, client: LndManageClient, config: Config): + + def __init__(self, client: LndManageClient, config: Config, cache_ttl_seconds: int = 300): self.client = client self.config = config + self.cache_ttl_seconds = cache_ttl_seconds + self._metrics_cache: Dict[str, Tuple[ChannelMetrics, datetime]] = {} + self._last_cache_cleanup = datetime.utcnow() - async def analyze_channels(self, channel_ids: List[str]) -> Dict[str, ChannelMetrics]: - """Analyze all channels and return metrics""" - # Fetch all channel data - channel_data = await self.client.fetch_all_channel_data(channel_ids) - - # Convert to Channel models and calculate metrics + async def analyze_channels(self, channel_ids: List[str], use_cache: bool = True) -> Dict[str, ChannelMetrics]: + """Analyze all channels and return metrics with optional caching""" + # Cleanup old cache entries periodically (every hour) + if (datetime.utcnow() - self._last_cache_cleanup).total_seconds() > 3600: + self._cleanup_cache() + metrics = {} - for data in channel_data: - try: - # Add timestamp if not present - if 'timestamp' not in data: - data['timestamp'] = datetime.utcnow().isoformat() - - channel = Channel(**data) - channel_id = channel.channel_id_compact - metrics[channel_id] = ChannelMetrics(channel) - - logger.debug(f"Analyzed channel {channel_id}: {metrics[channel_id].overall_score:.1f} score") - - except Exception as e: - channel_id = data.get('channelIdCompact', data.get('channel_id', 'unknown')) - logger.error(f"Failed to analyze channel {channel_id}: {e}") - logger.debug(f"Channel data keys: {list(data.keys())}") - + channels_to_fetch = [] + + # Check cache first if enabled + if use_cache: + cache_cutoff = datetime.utcnow() - timedelta(seconds=self.cache_ttl_seconds) + for channel_id in channel_ids: + if channel_id in self._metrics_cache: + cached_metric, cache_time = self._metrics_cache[channel_id] + if cache_time > cache_cutoff: + metrics[channel_id] = cached_metric + logger.debug(f"Using cached metrics for channel {channel_id}") + else: + channels_to_fetch.append(channel_id) + else: + channels_to_fetch.append(channel_id) + else: + channels_to_fetch = channel_ids + + # Fetch data only for channels not in cache or expired + if channels_to_fetch: + logger.info(f"Fetching fresh data for {len(channels_to_fetch)} channels " + f"(using cache for {len(metrics)})") + channel_data = await self.client.fetch_all_channel_data(channels_to_fetch) + + # Convert to Channel models and calculate metrics + for data in channel_data: + try: + # Add timestamp if not present + if 'timestamp' not in data: + data['timestamp'] = datetime.utcnow().isoformat() + + channel = Channel(**data) + channel_id = channel.channel_id_compact + channel_metrics = ChannelMetrics(channel, self.config) + metrics[channel_id] = channel_metrics + + # Update cache + if use_cache: + self._metrics_cache[channel_id] = (channel_metrics, datetime.utcnow()) + + logger.debug(f"Analyzed channel {channel_id}: {metrics[channel_id].overall_score:.1f} score") + + except Exception as e: + channel_id = data.get('channelIdCompact', data.get('channel_id', 'unknown')) + logger.error(f"Failed to analyze channel {channel_id}: {e}") + logger.debug(f"Channel data keys: {list(data.keys())}") + return metrics + + def _cleanup_cache(self) -> None: + """Remove expired entries from the metrics cache""" + cache_cutoff = datetime.utcnow() - timedelta(seconds=self.cache_ttl_seconds * 2) + expired_keys = [ + channel_id for channel_id, (_, cache_time) in self._metrics_cache.items() + if cache_time < cache_cutoff + ] + + for channel_id in expired_keys: + del self._metrics_cache[channel_id] + + if expired_keys: + logger.debug(f"Cleaned up {len(expired_keys)} expired cache entries") + + self._last_cache_cleanup = datetime.utcnow() + + def clear_cache(self) -> None: + """Manually clear the metrics cache""" + count = len(self._metrics_cache) + self._metrics_cache.clear() + logger.info(f"Cleared {count} entries from metrics cache") def categorize_channels(self, metrics: Dict[str, ChannelMetrics]) -> Dict[str, List[ChannelMetrics]]: - """Categorize channels by performance""" + """Categorize channels by performance using single-pass algorithm with configurable thresholds""" categories = { 'high_performers': [], 'profitable': [], @@ -154,19 +227,26 @@ class ChannelAnalyzer: 'inactive': [], 'problematic': [] } - + + # Get thresholds from config + high_score = self.config.optimization.high_performance_score + min_profit = self.config.optimization.min_profitable_sats + min_flow = self.config.optimization.min_active_flow_sats + + # Single pass through all metrics with optimized conditional logic for channel_metrics in metrics.values(): - if channel_metrics.overall_score >= 70: + # Use elif chain for mutually exclusive categories (only one category per channel) + if channel_metrics.overall_score >= high_score: categories['high_performers'].append(channel_metrics) - elif channel_metrics.net_profit > 100: # 100 sats profit + elif channel_metrics.net_profit > min_profit: categories['profitable'].append(channel_metrics) - elif channel_metrics.monthly_flow > 1_000_000: # 1M sats flow + elif channel_metrics.monthly_flow > min_flow: categories['active_unprofitable'].append(channel_metrics) elif channel_metrics.monthly_flow == 0: categories['inactive'].append(channel_metrics) else: categories['problematic'].append(channel_metrics) - + return categories def print_analysis(self, metrics: Dict[str, ChannelMetrics]): diff --git a/src/api/client.py b/src/api/client.py index 11ab121..bc43463 100644 --- a/src/api/client.py +++ b/src/api/client.py @@ -11,13 +11,18 @@ logger = logging.getLogger(__name__) class LndManageClient: """Client for interacting with LND Manage API""" - - def __init__(self, base_url: str = "http://localhost:18081"): + + def __init__(self, base_url: str = "http://localhost:18081", max_concurrent: int = 10): self.base_url = base_url.rstrip('/') self.client: Optional[httpx.AsyncClient] = None + self.max_concurrent = max_concurrent + self._semaphore: Optional[asyncio.Semaphore] = None async def __aenter__(self): - self.client = httpx.AsyncClient(timeout=30.0) + # Use connection pooling with limits + limits = httpx.Limits(max_connections=50, max_keepalive_connections=20) + self.client = httpx.AsyncClient(timeout=30.0, limits=limits) + self._semaphore = asyncio.Semaphore(self.max_concurrent) return self async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -138,7 +143,7 @@ class LndManageClient: return await self._get(f"/api/node/{pubkey}/warnings") async def fetch_all_channel_data(self, channel_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]: - """Fetch comprehensive data for all channels using the /details endpoint""" + """Fetch comprehensive data for all channels using the /details endpoint with concurrency limiting""" if channel_ids is None: # Get channel IDs from the API response response = await self.get_open_channels() @@ -146,16 +151,16 @@ class LndManageClient: channel_ids = response['channels'] else: channel_ids = response if isinstance(response, list) else [] - - logger.info(f"Fetching data for {len(channel_ids)} channels") - - # Fetch data for all channels concurrently + + logger.info(f"Fetching data for {len(channel_ids)} channels (max {self.max_concurrent} concurrent)") + + # Fetch data for all channels concurrently with semaphore limiting tasks = [] for channel_id in channel_ids: - tasks.append(self._fetch_single_channel_data(channel_id)) - + tasks.append(self._fetch_single_channel_data_limited(channel_id)) + results = await asyncio.gather(*tasks, return_exceptions=True) - + # Filter out failed requests channel_data = [] for i, result in enumerate(results): @@ -163,8 +168,18 @@ class LndManageClient: logger.error(f"Failed to fetch data for channel {channel_ids[i]}: {result}") else: channel_data.append(result) - + + logger.info(f"Successfully fetched data for {len(channel_data)}/{len(channel_ids)} channels") return channel_data + + async def _fetch_single_channel_data_limited(self, channel_id: str) -> Dict[str, Any]: + """Fetch channel data with semaphore limiting to prevent overwhelming the API""" + if self._semaphore is None: + # Fallback if semaphore not initialized (shouldn't happen in normal use) + return await self._fetch_single_channel_data(channel_id) + + async with self._semaphore: + return await self._fetch_single_channel_data(channel_id) async def _fetch_single_channel_data(self, channel_id: str) -> Dict[str, Any]: """Fetch all data for a single channel using the details endpoint""" diff --git a/src/data_fetcher.py b/src/data_fetcher.py index 91d0e0e..5e87e99 100644 --- a/src/data_fetcher.py +++ b/src/data_fetcher.py @@ -1,4 +1,5 @@ -import requests +import httpx +import asyncio import json from typing import Dict, List, Optional, Any from dataclasses import dataclass @@ -22,19 +23,39 @@ class ChannelData: warnings: List[str] class LightningDataFetcher: - def __init__(self, base_url: str = "http://localhost:18081/api"): + """Async Lightning Network data fetcher using httpx for non-blocking I/O""" + + def __init__(self, base_url: str = "http://localhost:18081/api", max_concurrent: int = 10): self.base_url = base_url - self.session = requests.Session() - - def _get(self, endpoint: str) -> Optional[Any]: - """Make GET request to API endpoint""" + self.max_concurrent = max_concurrent + self.client: Optional[httpx.AsyncClient] = None + self._semaphore: Optional[asyncio.Semaphore] = None + + async def __aenter__(self): + """Async context manager entry""" + limits = httpx.Limits(max_connections=50, max_keepalive_connections=20) + self.client = httpx.AsyncClient(timeout=10.0, limits=limits) + self._semaphore = asyncio.Semaphore(self.max_concurrent) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit""" + if self.client: + await self.client.aclose() + + async def _get(self, endpoint: str) -> Optional[Any]: + """Make async GET request to API endpoint""" + if not self.client: + raise RuntimeError("Client not initialized. Use async with statement.") + try: url = f"{self.base_url}{endpoint}" - response = self.session.get(url, timeout=10) + response = await self.client.get(url) if response.status_code == 200: - try: + content_type = response.headers.get('content-type', '') + if 'application/json' in content_type: return response.json() - except json.JSONDecodeError: + else: return response.text.strip() else: logger.warning(f"Failed to fetch {endpoint}: {response.status_code}") @@ -43,124 +64,168 @@ class LightningDataFetcher: logger.error(f"Error fetching {endpoint}: {e}") return None - def check_sync_status(self) -> bool: + async def check_sync_status(self) -> bool: """Check if lnd is synced to chain""" - result = self._get("/status/synced-to-chain") + result = await self._get("/status/synced-to-chain") return result == "true" if result else False - - def get_block_height(self) -> Optional[int]: + + async def get_block_height(self) -> Optional[int]: """Get current block height""" - result = self._get("/status/block-height") + result = await self._get("/status/block-height") return int(result) if result else None - - def get_open_channels(self) -> List[str]: + + async def get_open_channels(self) -> List[str]: """Get list of all open channel IDs""" - result = self._get("/status/open-channels") + result = await self._get("/status/open-channels") return result if isinstance(result, list) else [] - - def get_all_channels(self) -> List[str]: + + async def get_all_channels(self) -> List[str]: """Get list of all channel IDs (open, closed, etc)""" - result = self._get("/status/all-channels") + result = await self._get("/status/all-channels") return result if isinstance(result, list) else [] - def get_channel_details(self, channel_id: str) -> ChannelData: - """Fetch comprehensive data for a specific channel""" + async def get_channel_details(self, channel_id: str) -> ChannelData: + """Fetch comprehensive data for a specific channel using concurrent requests""" logger.info(f"Fetching data for channel {channel_id}") - - basic_info = self._get(f"/channel/{channel_id}/") or {} - balance = self._get(f"/channel/{channel_id}/balance") or {} - policies = self._get(f"/channel/{channel_id}/policies") or {} - fee_report = self._get(f"/channel/{channel_id}/fee-report") or {} - flow_report = self._get(f"/channel/{channel_id}/flow-report") or {} - flow_report_7d = self._get(f"/channel/{channel_id}/flow-report/last-days/7") or {} - flow_report_30d = self._get(f"/channel/{channel_id}/flow-report/last-days/30") or {} - rating = self._get(f"/channel/{channel_id}/rating") - warnings = self._get(f"/channel/{channel_id}/warnings") or [] - - # Fetch rebalance data - rebalance_data = { - "source_costs": self._get(f"/channel/{channel_id}/rebalance-source-costs") or 0, - "source_amount": self._get(f"/channel/{channel_id}/rebalance-source-amount") or 0, - "target_costs": self._get(f"/channel/{channel_id}/rebalance-target-costs") or 0, - "target_amount": self._get(f"/channel/{channel_id}/rebalance-target-amount") or 0, - "support_as_source": self._get(f"/channel/{channel_id}/rebalance-support-as-source-amount") or 0, - "support_as_target": self._get(f"/channel/{channel_id}/rebalance-support-as-target-amount") or 0 + + # Fetch all data concurrently for better performance + tasks = { + 'basic_info': self._get(f"/channel/{channel_id}/"), + 'balance': self._get(f"/channel/{channel_id}/balance"), + 'policies': self._get(f"/channel/{channel_id}/policies"), + 'fee_report': self._get(f"/channel/{channel_id}/fee-report"), + 'flow_report': self._get(f"/channel/{channel_id}/flow-report"), + 'flow_report_7d': self._get(f"/channel/{channel_id}/flow-report/last-days/7"), + 'flow_report_30d': self._get(f"/channel/{channel_id}/flow-report/last-days/30"), + 'rating': self._get(f"/channel/{channel_id}/rating"), + 'warnings': self._get(f"/channel/{channel_id}/warnings"), + 'rebalance_source_costs': self._get(f"/channel/{channel_id}/rebalance-source-costs"), + 'rebalance_source_amount': self._get(f"/channel/{channel_id}/rebalance-source-amount"), + 'rebalance_target_costs': self._get(f"/channel/{channel_id}/rebalance-target-costs"), + 'rebalance_target_amount': self._get(f"/channel/{channel_id}/rebalance-target-amount"), + 'rebalance_support_source': self._get(f"/channel/{channel_id}/rebalance-support-as-source-amount"), + 'rebalance_support_target': self._get(f"/channel/{channel_id}/rebalance-support-as-target-amount"), } - + + # Execute all requests concurrently + results = await asyncio.gather(*tasks.values(), return_exceptions=True) + data = dict(zip(tasks.keys(), results)) + + # Build rebalance data + rebalance_data = { + "source_costs": data.get('rebalance_source_costs') or 0, + "source_amount": data.get('rebalance_source_amount') or 0, + "target_costs": data.get('rebalance_target_costs') or 0, + "target_amount": data.get('rebalance_target_amount') or 0, + "support_as_source": data.get('rebalance_support_source') or 0, + "support_as_target": data.get('rebalance_support_target') or 0 + } + return ChannelData( channel_id=channel_id, - basic_info=basic_info, - balance=balance, - policies=policies, - fee_report=fee_report, - flow_report=flow_report, - flow_report_7d=flow_report_7d, - flow_report_30d=flow_report_30d, - rating=float(rating) if rating else None, + basic_info=data.get('basic_info') or {}, + balance=data.get('balance') or {}, + policies=data.get('policies') or {}, + fee_report=data.get('fee_report') or {}, + flow_report=data.get('flow_report') or {}, + flow_report_7d=data.get('flow_report_7d') or {}, + flow_report_30d=data.get('flow_report_30d') or {}, + rating=float(data['rating']) if data.get('rating') else None, rebalance_data=rebalance_data, - warnings=warnings if isinstance(warnings, list) else [] + warnings=data.get('warnings') if isinstance(data.get('warnings'), list) else [] ) - def get_node_data(self, pubkey: str) -> Dict[str, Any]: - """Fetch comprehensive data for a specific node""" + async def get_node_data(self, pubkey: str) -> Dict[str, Any]: + """Fetch comprehensive data for a specific node using concurrent requests""" logger.info(f"Fetching data for node {pubkey[:10]}...") - + + # Fetch all node data concurrently + tasks = { + "alias": self._get(f"/node/{pubkey}/alias"), + "open_channels": self._get(f"/node/{pubkey}/open-channels"), + "all_channels": self._get(f"/node/{pubkey}/all-channels"), + "balance": self._get(f"/node/{pubkey}/balance"), + "fee_report": self._get(f"/node/{pubkey}/fee-report"), + "fee_report_7d": self._get(f"/node/{pubkey}/fee-report/last-days/7"), + "fee_report_30d": self._get(f"/node/{pubkey}/fee-report/last-days/30"), + "flow_report": self._get(f"/node/{pubkey}/flow-report"), + "flow_report_7d": self._get(f"/node/{pubkey}/flow-report/last-days/7"), + "flow_report_30d": self._get(f"/node/{pubkey}/flow-report/last-days/30"), + "on_chain_costs": self._get(f"/node/{pubkey}/on-chain-costs"), + "rating": self._get(f"/node/{pubkey}/rating"), + "warnings": self._get(f"/node/{pubkey}/warnings") + } + + # Execute all requests concurrently + results = await asyncio.gather(*tasks.values(), return_exceptions=True) + data = dict(zip(tasks.keys(), results)) + return { "pubkey": pubkey, - "alias": self._get(f"/node/{pubkey}/alias"), - "open_channels": self._get(f"/node/{pubkey}/open-channels") or [], - "all_channels": self._get(f"/node/{pubkey}/all-channels") or [], - "balance": self._get(f"/node/{pubkey}/balance") or {}, - "fee_report": self._get(f"/node/{pubkey}/fee-report") or {}, - "fee_report_7d": self._get(f"/node/{pubkey}/fee-report/last-days/7") or {}, - "fee_report_30d": self._get(f"/node/{pubkey}/fee-report/last-days/30") or {}, - "flow_report": self._get(f"/node/{pubkey}/flow-report") or {}, - "flow_report_7d": self._get(f"/node/{pubkey}/flow-report/last-days/7") or {}, - "flow_report_30d": self._get(f"/node/{pubkey}/flow-report/last-days/30") or {}, - "on_chain_costs": self._get(f"/node/{pubkey}/on-chain-costs") or {}, - "rating": self._get(f"/node/{pubkey}/rating"), - "warnings": self._get(f"/node/{pubkey}/warnings") or [] + "alias": data.get('alias'), + "open_channels": data.get('open_channels') or [], + "all_channels": data.get('all_channels') or [], + "balance": data.get('balance') or {}, + "fee_report": data.get('fee_report') or {}, + "fee_report_7d": data.get('fee_report_7d') or {}, + "fee_report_30d": data.get('fee_report_30d') or {}, + "flow_report": data.get('flow_report') or {}, + "flow_report_7d": data.get('flow_report_7d') or {}, + "flow_report_30d": data.get('flow_report_30d') or {}, + "on_chain_costs": data.get('on_chain_costs') or {}, + "rating": data.get('rating'), + "warnings": data.get('warnings') or [] } - def fetch_all_data(self) -> Dict[str, Any]: - """Fetch all channel and node data""" + async def fetch_all_data(self) -> Dict[str, Any]: + """Fetch all channel and node data with concurrency limiting""" logger.info("Starting comprehensive data fetch...") - + # Check sync status - if not self.check_sync_status(): + if not await self.check_sync_status(): logger.warning("Node is not synced to chain!") - + # Get basic info - block_height = self.get_block_height() - open_channels = self.get_open_channels() - all_channels = self.get_all_channels() - + block_height = await self.get_block_height() + open_channels = await self.get_open_channels() + all_channels = await self.get_all_channels() + logger.info(f"Block height: {block_height}") logger.info(f"Open channels: {len(open_channels)}") logger.info(f"Total channels: {len(all_channels)}") - - # Fetch detailed channel data - channels_data = {} - for channel_id in open_channels: - try: - channels_data[channel_id] = self.get_channel_details(channel_id) - except Exception as e: - logger.error(f"Error fetching channel {channel_id}: {e}") - + + # Fetch detailed channel data with semaphore limiting + async def fetch_channel_limited(channel_id: str): + async with self._semaphore: + try: + return channel_id, await self.get_channel_details(channel_id) + except Exception as e: + logger.error(f"Error fetching channel {channel_id}: {e}") + return channel_id, None + + channel_tasks = [fetch_channel_limited(cid) for cid in open_channels] + channel_results = await asyncio.gather(*channel_tasks) + channels_data = {cid: data for cid, data in channel_results if data is not None} + # Get unique node pubkeys from channel data node_pubkeys = set() for channel_data in channels_data.values(): if 'remotePubkey' in channel_data.basic_info: node_pubkeys.add(channel_data.basic_info['remotePubkey']) - - # Fetch node data - nodes_data = {} - for pubkey in node_pubkeys: - try: - nodes_data[pubkey] = self.get_node_data(pubkey) - except Exception as e: - logger.error(f"Error fetching node {pubkey[:10]}...: {e}") - + + # Fetch node data with semaphore limiting + async def fetch_node_limited(pubkey: str): + async with self._semaphore: + try: + return pubkey, await self.get_node_data(pubkey) + except Exception as e: + logger.error(f"Error fetching node {pubkey[:10]}...: {e}") + return pubkey, None + + node_tasks = [fetch_node_limited(pubkey) for pubkey in node_pubkeys] + node_results = await asyncio.gather(*node_tasks) + nodes_data = {pubkey: data for pubkey, data in node_results if data is not None} + return { "block_height": block_height, "open_channels": open_channels, @@ -176,6 +241,9 @@ class LightningDataFetcher: logger.info(f"Data saved to {filename}") if __name__ == "__main__": - fetcher = LightningDataFetcher() - all_data = fetcher.fetch_all_data() - fetcher.save_data(all_data, "lightning-fee-optimizer/data/lightning_data.json") \ No newline at end of file + async def main(): + async with LightningDataFetcher() as fetcher: + all_data = await fetcher.fetch_all_data() + fetcher.save_data(all_data, "lightning_data.json") + + asyncio.run(main()) \ No newline at end of file diff --git a/src/policy/manager.py b/src/policy/manager.py index 5c69fd2..7bc2936 100644 --- a/src/policy/manager.py +++ b/src/policy/manager.py @@ -17,16 +17,18 @@ logger = logging.getLogger(__name__) class PolicyManager: """Manages policy-based fee optimization with inbound fee support""" - - def __init__(self, + + def __init__(self, config_file: str, lnd_manage_url: str, lnd_rest_url: str = "https://localhost:8080", lnd_grpc_host: str = "localhost:10009", lnd_dir: str = "~/.lnd", database_path: str = "experiment_data/policy.db", - prefer_grpc: bool = True): - + prefer_grpc: bool = True, + max_history_entries: int = 1000, + history_ttl_hours: int = 168): # 7 days default + self.policy_engine = PolicyEngine(config_file) self.lnd_manage_url = lnd_manage_url self.lnd_rest_url = lnd_rest_url @@ -34,13 +36,16 @@ class PolicyManager: self.lnd_dir = lnd_dir self.prefer_grpc = prefer_grpc self.db = ExperimentDatabase(database_path) - - # Policy-specific tracking + + # Policy-specific tracking with memory management self.policy_session_id = None self.last_fee_changes: Dict[str, Dict] = {} self.rollback_candidates: Dict[str, datetime] = {} - + self.max_history_entries = max_history_entries + self.history_ttl_hours = history_ttl_hours + logger.info(f"Policy manager initialized with {len(self.policy_engine.rules)} rules") + logger.info(f"Memory management: max {max_history_entries} entries, TTL {history_ttl_hours}h") async def start_policy_session(self, session_name: str = None) -> int: """Start a new policy management session""" @@ -239,8 +244,56 @@ class PolicyManager: if len(results['errors']) > 5: logger.warning(f" ... and {len(results['errors']) - 5} more errors") + # Cleanup old entries to prevent memory growth + self._cleanup_old_entries() + return results - + + def _cleanup_old_entries(self) -> None: + """Clean up old entries from tracking dictionaries to prevent unbounded memory growth""" + cutoff_time = datetime.utcnow() - timedelta(hours=self.history_ttl_hours) + initial_count = len(self.last_fee_changes) + + # Remove entries older than TTL + expired_channels = [] + for channel_id, change_info in self.last_fee_changes.items(): + if change_info['timestamp'] < cutoff_time: + expired_channels.append(channel_id) + + for channel_id in expired_channels: + del self.last_fee_changes[channel_id] + + # If still over limit, remove oldest entries + if len(self.last_fee_changes) > self.max_history_entries: + # Sort by timestamp and keep only the most recent max_history_entries + sorted_changes = sorted( + self.last_fee_changes.items(), + key=lambda x: x[1]['timestamp'], + reverse=True + ) + self.last_fee_changes = dict(sorted_changes[:self.max_history_entries]) + + # Cleanup rollback_candidates with similar logic + expired_candidates = [ + cid for cid, ts in self.rollback_candidates.items() + if ts < cutoff_time + ] + for channel_id in expired_candidates: + del self.rollback_candidates[channel_id] + + if len(self.rollback_candidates) > self.max_history_entries: + sorted_candidates = sorted( + self.rollback_candidates.items(), + key=lambda x: x[1], + reverse=True + ) + self.rollback_candidates = dict(sorted_candidates[:self.max_history_entries]) + + cleaned_count = initial_count - len(self.last_fee_changes) + if cleaned_count > 0: + logger.info(f"Cleaned up {cleaned_count} old entries from memory " + f"({len(self.last_fee_changes)} remaining)") + async def _enrich_channel_data(self, channel_info: Dict[str, Any], lnd_manage: LndManageClient) -> Dict[str, Any]: """Enrich channel data with additional metrics for policy matching""" diff --git a/src/utils/config.py b/src/utils/config.py index 516c0df..0019c4a 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -14,23 +14,38 @@ class OptimizationConfig: # Fee rate limits (ppm) min_fee_rate: int = 1 max_fee_rate: int = 5000 - + # Flow thresholds (sats) high_flow_threshold: int = 10_000_000 low_flow_threshold: int = 1_000_000 - + # Balance thresholds (ratio) high_balance_threshold: float = 0.8 low_balance_threshold: float = 0.2 - + # Strategy parameters fee_increase_factor: float = 1.5 flow_preservation_weight: float = 0.6 - + # Minimum changes to recommend min_fee_change_ppm: int = 5 min_earnings_improvement: float = 100 # sats + # Performance metric thresholds for scoring + excellent_monthly_profit_sats: int = 10_000 # 10k sats/month + excellent_monthly_flow_sats: int = 10_000_000 # 10M sats/month + excellent_earnings_per_million_ppm: int = 1000 # 1000 ppm + excellent_roi_ratio: float = 2.0 # 200% ROI + + # Channel categorization thresholds + high_performance_score: float = 70.0 + min_profitable_sats: int = 100 + min_active_flow_sats: int = 1_000_000 + + # Capacity tier thresholds (sats) + high_capacity_threshold: int = 5_000_000 + medium_capacity_threshold: int = 1_000_000 + @dataclass class APIConfig: diff --git a/src/utils/database.py b/src/utils/database.py index 9889d89..b59c716 100644 --- a/src/utils/database.py +++ b/src/utils/database.py @@ -140,13 +140,21 @@ class ExperimentDatabase: ) """) - # Create useful indexes + # Create useful indexes for performance optimization conn.execute("CREATE INDEX IF NOT EXISTS idx_data_points_channel_time ON data_points(channel_id, timestamp)") conn.execute("CREATE INDEX IF NOT EXISTS idx_data_points_parameter_set ON data_points(parameter_set, timestamp)") conn.execute("CREATE INDEX IF NOT EXISTS idx_fee_changes_channel_time ON fee_changes(channel_id, timestamp)") - + + # Additional indexes for improved query performance + conn.execute("CREATE INDEX IF NOT EXISTS idx_data_points_experiment_id ON data_points(experiment_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_data_points_experiment_channel ON data_points(experiment_id, channel_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_data_points_phase ON data_points(phase, timestamp)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_channels_experiment ON channels(experiment_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_channels_segment ON channels(segment)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_fee_changes_experiment ON fee_changes(experiment_id)") + conn.commit() - logger.info("Database initialized successfully") + logger.info("Database initialized successfully with optimized indexes") @contextmanager def _get_connection(self):