#!/usr/bin/env python3 from tqdm import tqdm from contextlib import contextmanager from sqlalchemy import create_engine from common import Base, ChannelAnnouncement, ChannelUpdate, NodeAnnouncement from sqlalchemy.orm import sessionmaker from sqlalchemy import func from datetime import datetime, timedelta from collections import namedtuple import click from pyln.proto.primitives import varint_encode, varint_decode import os from sqlalchemy.orm import load_only import re import io import logging import socket from pyln.proto import wire default_db = "sqlite:///$HOME/.lightning/bitcoin/historian.sqlite3" @contextmanager def db_session(dsn): """ Tiny contextmanager to facilitate sqlalchemy session management """ if dsn is None: dsn = default_db dsn = os.path.expandvars(dsn) engine = create_engine(dsn, echo=False) Base.metadata.create_all(engine) session_maker = sessionmaker(bind=engine) session = session_maker() try: yield session session.commit() except: session.rollback() raise finally: session.close() @click.group() def cli(): pass @cli.group() def snapshot(): pass dt_fmt = "%Y-%m-%d %H:%M:%S" default_since = datetime.utcnow() - timedelta(hours=1) @snapshot.command() @click.argument('destination', type=click.File('wb')) @click.argument( 'since', type=click.DateTime(formats=[dt_fmt]), default=default_since.strftime(dt_fmt) ) @click.option('--db', type=str, default=default_db) def incremental(since, destination, db): with db_session(db) as session: # Several nested queries here because join was a bit too # restrictive. The inner SELECT in the WHERE-clause selects all scids # that had any updates in the desired timerange. The outer SELECT then # gets all the announcements and kicks off inner SELECTs that look for # the latest update for each direction. rows = session.execute(""" SELECT a.scid, a.raw, ( SELECT u.raw FROM channel_updates u WHERE u.scid = a.scid AND direction = 0 ORDER BY timestamp DESC LIMIT 1 ) as u0, ( SELECT u.raw FROM channel_updates u WHERE u.scid = a.scid AND direction = 1 ORDER BY timestamp DESC LIMIT 1 ) as u1 FROM channel_announcements a WHERE a.scid IN ( SELECT u.scid FROM channel_updates u WHERE u.timestamp >= DATETIME('{}') GROUP BY u.scid ) ORDER BY a.scid """.format(since.strftime("%Y-%m-%d %H:%M:%S"))) # Write the header now that we know we'll be writing something. destination.write(b"GSP\x01") chan_count = 0 last_scid = None for scid, cann, u1, u2 in rows: if scid == last_scid: continue last_scid = scid chan_count += 1 varint_encode(len(cann), destination) destination.write(cann) if u1 is not None: varint_encode(len(u1), destination) destination.write(u1) if u2 is not None: varint_encode(len(u2), destination) destination.write(u2) # Now get and return the node_announcements in the timerange. These # come after the channels since no node without a # channel_announcements and channel_update is allowed. rows = session.execute(""" SELECT n.node_id, n.timestamp, n.raw FROM node_announcements n WHERE n.timestamp >= DATETIME('{}') GROUP BY n.node_id HAVING n.timestamp = MAX(n.timestamp) ORDER BY timestamp DESC """.format(since.strftime("%Y-%m-%d %H:%M:%S"))) last_nid = None node_count = 0 for nid, ts, nann in rows: if nid == last_nid: continue last_nid = nid node_count += 1 varint_encode(len(nann), destination) destination.write(nann) click.echo( f'Wrote {chan_count} channels and {node_count} nodes to {destination.name}', err=True ) @snapshot.command() @click.argument('destination', type=click.File('wb')) @click.pass_context @click.option('--db', type=str, default=default_db) def full(ctx, destination, db): since = datetime.utcnow() - timedelta(weeks=2) ctx.invoke(incremental, since=since, destination=destination, db=db) @snapshot.command() @click.argument('snapshot', type=click.File('rb')) def read(snapshot): header = snapshot.read(4) if len(header) < 4: raise ValueError("Could not read header") tag, version = header[0:3], header[3] if tag != b'GSP': raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}") if version != 1: raise ValueError(f"Unsupported version {version}, only support up to version 1") while True: l = varint_decode(snapshot) if l is None: break msg = snapshot.read(l) if len(msg) != l: raise ValueError("Incomplete read at end of file") print(msg.hex()) LightningAddress = namedtuple('LightningAddress', ['node_id', 'host', 'port']) class LightningAddressParam(click.ParamType): def convert(self, value, param, ctx): m = re.match(r"(0[23][a-fA-F0-9]+)@([a-zA-Z0-9\.:]+):([0-9]+)?", value) if m is None: self.fail( f"{value} isn't a valid lightning connection string, " "expected \"[node_id]@[host]:[port]\"" ) return if len(m.groups()) < 3: return LightningAddress(m[1], m[2], 9735) else: return LightningAddress(m[1], m[2], int(m[3])) class LightningPeer: def __init__(self, node_id: str, address: str, port: int = 9735): self.node_id = node_id self.address = address self.port = port self.connection = None self.local_privkey = wire.PrivateKey(os.urandom(32)) def connect(self): sock = socket.create_connection((self.address, self.port), timeout=30) self.connection = wire.LightningConnection( sock, remote_pubkey=wire.PublicKey(bytes.fromhex(self.node_id)), local_privkey=self.local_privkey, is_initiator=True, ) self.connection.shake() # Send an init message, with no global features, and 0b10101010 as # local features. self.connection.send_message(b"\x00\x10\x00\x00\x00\x01\xaa") def send(self, packet: bytes) -> None: if self.connection is None: raise ValueError("Not connected to peer") logging.debug("Sending {}".format(packet.hex())) self.connection.send_message(packet) def send_all(self, packets) -> None: assert self.connection is not None for p in packets: self.send(p) def disconnect(self): self.connection.connection.close() def split_gossip(reader: io.BytesIO): while True: length = varint_decode(reader) if length is None: break msg = reader.read(length) if len(msg) != length: raise ValueError("Incomplete read at end of file") yield msg @snapshot.command() @click.argument('snapshot', type=click.File('rb')) @click.argument('destination', type=LightningAddressParam()) def load(snapshot, destination): header = snapshot.read(4) if len(header) < 4: raise ValueError("Could not read header") tag, version = header[0:3], header[3] if tag != b'GSP': raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}") if version != 1: raise ValueError(f"Unsupported version {version}, only support up to version 1") logging.debug(f"Connecting to {destination}") peer = LightningPeer(destination.node_id, destination.host, destination.port) peer.connect() logging.debug("Connected, streaming messages from snapshot") peer.send_all(tqdm(split_gossip(snapshot))) peer.disconnect() logging.debug("Done streaming messages, disconnecting") if __name__ == "__main__": cli()