#!/usr/bin/env python3 import struct from tqdm import tqdm import shlex import subprocess from contextlib import contextmanager from sqlalchemy import create_engine from cli import common from common import Base, ChannelAnnouncement, ChannelUpdate, NodeAnnouncement from sqlalchemy.orm import sessionmaker from sqlalchemy import func from datetime import datetime, timedelta from collections import namedtuple import click from pyln.proto.primitives import varint_encode, varint_decode import os from sqlalchemy.orm import load_only import re import io import logging import socket from pyln.proto import wire from cli.backup import backup from cli.db import db from common import db_session, default_db, stream_snapshot_since import json @click.group() def cli(): pass cli.add_command(backup) cli.add_command(db) @cli.group() def snapshot(): pass dt_fmt = "%Y-%m-%d %H:%M:%S" default_since = datetime.utcnow() - timedelta(hours=1) @snapshot.command() @click.argument("destination", type=click.File("wb")) @click.argument( "since", type=click.DateTime(formats=[dt_fmt]), default=default_since.strftime(dt_fmt), ) @click.option("--db", type=str, default=default_db) def incremental(since, destination, db): # Write the header destination.write(b"GSP\x01") node_count, chan_count = 0, 0 for msg in stream_snapshot_since(since, db): varint_encode(len(msg), destination) destination.write(msg) typ = msg[:2] if typ == b"\x01\x01": node_count += 1 elif typ == b"\x01\x00": chan_count += 1 click.echo( f"Wrote {chan_count} channels and {node_count} nodes to {destination.name}", err=True, ) @snapshot.command() @click.argument("destination", type=click.File("wb")) @click.pass_context @click.option("--db", type=str, default=default_db) def full(ctx, destination, db): since = datetime.utcnow() - timedelta(weeks=2) ctx.invoke(incremental, since=since, destination=destination, db=db) @snapshot.command() @click.argument("snapshot", type=click.File("rb")) def read(snapshot): header = snapshot.read(4) if len(header) < 4: raise ValueError("Could not read header") tag, version = header[0:3], header[3] if tag != b"GSP": raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}") if version != 1: raise ValueError(f"Unsupported version {version}, only support up to version 1") while True: l = varint_decode(snapshot) if l is None: break msg = snapshot.read(l) if len(msg) != l: raise ValueError("Incomplete read at end of file") print(msg.hex()) @snapshot.command() @click.argument("snapshot", type=common.GossipFile(decode=False)) @click.argument("max_bytes", type=int) @click.option("-x", "--exec", type=str) def split(snapshot, max_bytes, exec): def bundle(f: common.GossipFile): bundle = None for m in f: (typ,) = struct.unpack_from("!H", m) if typ == 257: # NodeAnnouncements are always self-contained, so yield them # individually yield m, elif typ == 256: # ChannelAnnouncements indicate the start of a new bundle if bundle is not None: yield tuple(bundle) bundle = [] bundle.append(m) else: # ChannelUpdates belong to the bundle bundle.append(m) # If we have an unyielded bundle we need to flush it at the end. yield tuple(bundle) def serialize_bundle(b): buff = io.BytesIO() for m in b: varint_encode(len(m), buff) buff.write(m) return buff.getvalue() filenum = 0 prefix, extension = os.path.splitext(snapshot.filename) filename = "{prefix}_{{filenum:04d}}{extension}".format( prefix=prefix, extension=extension ) def on_complete(filenum): fname = filename.format(filenum=filenum) if exec is not None: cmd = shlex.split(exec.replace("{}", shlex.quote(fname))) logging.debug("Exec:\n> {}".format(" ".join(cmd))) subprocess.run(cmd) f = open(filename.format(filenum=filenum), "wb") f.write(b"GSP\x01") for b in bundle(snapshot): assert len(b) <= 3 m = serialize_bundle(b) if f.tell() + len(m) > max_bytes: f.close() on_complete(filenum) filenum += 1 f = open(filename.format(filenum=filenum), "wb") f.write(b"GSP\x01") f.write(m) f.close() on_complete(filenum) LightningAddress = namedtuple("LightningAddress", ["node_id", "host", "port"]) class LightningAddressParam(click.ParamType): def convert(self, value, param, ctx): m = re.match(r"(0[23][a-fA-F0-9]+)@([a-zA-Z0-9\.:]+):([0-9]+)?", value) if m is None: self.fail( f"{value} isn't a valid lightning connection string, " 'expected "[node_id]@[host]:[port]"' ) return if len(m.groups()) < 3: return LightningAddress(m[1], m[2], 9735) else: return LightningAddress(m[1], m[2], int(m[3])) class LightningPeer: def __init__(self, node_id: str, address: str, port: int = 9735): self.node_id = node_id self.address = address self.port = port self.connection = None self.local_privkey = wire.PrivateKey(os.urandom(32)) def connect(self): sock = socket.create_connection((self.address, self.port), timeout=30) self.connection = wire.LightningConnection( sock, remote_pubkey=wire.PublicKey(bytes.fromhex(self.node_id)), local_privkey=self.local_privkey, is_initiator=True, ) self.connection.shake() # Send an init message, with no global features, and 0b10101010 as # local features. self.connection.send_message(b"\x00\x10\x00\x00\x00\x01\xaa") def send(self, packet: bytes) -> None: if self.connection is None: raise ValueError("Not connected to peer") logging.debug("Sending {}".format(packet.hex())) self.connection.send_message(packet) def send_all(self, packets) -> None: assert self.connection is not None for p in packets: self.send(p) def disconnect(self): self.connection.connection.close() def split_gossip(reader: io.BytesIO): while True: length = varint_decode(reader) if length is None: break msg = reader.read(length) if len(msg) != length: raise ValueError("Incomplete read at end of file") yield msg @snapshot.command() @click.argument("snapshot", type=click.File("rb")) @click.argument("destination", type=LightningAddressParam(), required=False) def load(snapshot, destination=None): if destination is None: logging.debug("No destination specified, attempting auto-discovery") info = json.loads(subprocess.check_output(["lightning-cli", "getinfo"])) id = info["id"] bindings = [ (id, b["address"], b["port"]) for b in info["binding"] if b["type"] == "ipv4" ] if len(bindings) < 1: raise ValueError( "Could not automatically determine the c-lightning" " address to connect to. Please provide a --destination" ) binding = bindings[0] logging.debug("Discovered local node {}@{}:{}".format(*binding)) destination = LightningAddress(*binding) header = snapshot.read(4) if len(header) < 4: raise ValueError("Could not read header") tag, version = header[0:3], header[3] if tag != b"GSP": raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}") if version != 1: raise ValueError(f"Unsupported version {version}, only support up to version 1") logging.debug(f"Connecting to {destination}") peer = LightningPeer(destination.node_id, destination.host, destination.port) peer.connect() logging.debug("Connected, streaming messages from snapshot") peer.send_all(tqdm(split_gossip(snapshot))) peer.disconnect() logging.debug("Done streaming messages, disconnecting") if __name__ == "__main__": cli()