From 3f5dc35ab1c1cb2b901c0fc4adc493f1dea1ba81 Mon Sep 17 00:00:00 2001 From: Christian Decker Date: Tue, 20 Oct 2020 21:34:18 +0200 Subject: [PATCH] historian: Add backup command for lnresearch --- historian/cli/backup.py | 56 ++++++++++++++++++ historian/cli/common.py | 90 +++++++++++++++++++++++++++++ historian/cli/db.py | 54 ++++++++++++++++++ historian/historian-cli | 123 +++++++++++++++++++++++++++++++++------- 4 files changed, 301 insertions(+), 22 deletions(-) create mode 100644 historian/cli/backup.py create mode 100644 historian/cli/common.py create mode 100644 historian/cli/db.py diff --git a/historian/cli/backup.py b/historian/cli/backup.py new file mode 100644 index 0000000..af19c1f --- /dev/null +++ b/historian/cli/backup.py @@ -0,0 +1,56 @@ +import click +from .common import db_session, split_gossip +import os +from pyln.proto.primitives import varint_decode, varint_encode + +@click.group() +def backup(): + pass + + +@backup.command() +@click.argument('destination', type=click.File('wb')) +@click.option('--db', type=str, default=None) +def create(destination, db): + with db_session(db) as session: + rows = session.execute("SELECT raw FROM channel_announcements") + + # Write the header now that we know we'll be writing something. + destination.write(b"GSP\x01") + + for r in rows: + varint_encode(len(r[0]), destination) + destination.write(r[0]) + + rows = session.execute("SELECT raw FROM channel_updates ORDER BY timestamp ASC") + for r in rows: + varint_encode(len(r[0]), destination) + destination.write(r[0]) + + rows = session.execute("SELECT raw FROM node_announcements ORDER BY timestamp ASC") + for r in rows: + varint_encode(len(r[0]), destination) + destination.write(r[0]) + + destination.close() + +@backup.command() +@click.argument("source", type=click.File('rb')) +def read(source): + """Load gossip messages from the specified source and print it to stdout + + Prints the hex-encoded raw gossip message to stdout. + """ + header = source.read(4) + if len(header) < 4: + raise ValueError("Could not read header") + + tag, version = header[0:3], header[3] + if tag != b'GSP': + raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}") + + if version != 1: + raise ValueError(f"Unsupported version {version}, only support up to version 1") + + for m in split_gossip(source): + print(m.hex()) diff --git a/historian/cli/common.py b/historian/cli/common.py new file mode 100644 index 0000000..4e695ad --- /dev/null +++ b/historian/cli/common.py @@ -0,0 +1,90 @@ +from sqlalchemy.orm import sessionmaker +from sqlalchemy import create_engine +from contextlib import contextmanager +import os +from common import Base +import io +from pyln.proto.primitives import varint_decode +from parser import parse +import click +import bz2 + +default_db = "sqlite:///$HOME/.lightning/bitcoin/historian.sqlite3" + + +@contextmanager +def db_session(dsn): + """Tiny contextmanager to facilitate sqlalchemy session management""" + if dsn is None: + dsn = default_db + dsn = os.path.expandvars(dsn) + engine = create_engine(dsn, echo=False) + Base.metadata.create_all(engine) + session_maker = sessionmaker(bind=engine) + session = session_maker() + try: + yield session + session.commit() + except: + session.rollback() + raise + finally: + session.close() + + +def split_gossip(reader: io.BytesIO): + while True: + length = varint_decode(reader) + if length is None: + break + + msg = reader.read(length) + if len(msg) != length: + raise ValueError("Incomplete read at end of file") + + yield msg + + +class GossipStream: + def __init__(self, file_stream, filename, decode=True): + self.stream = file_stream + self.decode = decode + self.filename = filename + + # Read header + header = self.stream.read(4) + assert len(header) == 4 + assert header[:3] == b"GSP" + assert header[3] == 1 + + def __iter__(self): + return self + + def __next__(self): + pos = self.stream.tell() + length = varint_decode(self.stream) + + if length is None: + raise StopIteration + + msg = self.stream.read(length) + if len(msg) != length: + raise ValueError( + "Error reading snapshot at {pos}: incomplete read of {length} bytes, only got {lmsg} bytes".format( + pos=pos, length=length, lmsg=len(msg) + ) + ) + if not self.decode: + return msg + + return parse(msg) + + +class GossipFile(click.File): + def __init__(self, decode=True): + click.File.__init__(self) + self.decode = decode + + def convert(self, value, param, ctx): + f = bz2.open(value, "rb") if value.endswith(".bz2") else open(value, "rb") + return GossipStream(f, value, self.decode) diff --git a/historian/cli/db.py b/historian/cli/db.py new file mode 100644 index 0000000..aeb3883 --- /dev/null +++ b/historian/cli/db.py @@ -0,0 +1,54 @@ +import click +from common import NodeAnnouncement, ChannelAnnouncement, ChannelUpdate +from tqdm import tqdm +from parser import parse +from cli.common import db_session, default_db + + +@click.group() +def db(): + pass + + +@db.command() +@click.argument('source', type=str) +@click.argument('destination', type=str, default=default_db) +def merge(source, destination): + """Merge two historian databases by copying from source to destination. + """ + + meta = { + 'channel_announcements': None, + 'channel_updates': None, + 'node_announcements': None, + } + + with db_session(source) as source, db_session(destination) as target: + # Not strictly necessary, but I like progress indicators and ETAs. + for table in meta.keys(): + rows = source.execute(f"SELECT count(*) FROM {table}") + count, = rows.next() + meta[table] = count + + for r, in tqdm( + source.execute("SELECT raw FROM channel_announcements"), + total=meta['channel_announcements'], + ): + msg = parse(r) + target.merge(ChannelAnnouncement.from_gossip(msg, r)) + + for r, in tqdm( + source.execute("SELECT raw FROM channel_updates ORDER BY timestamp ASC"), + total=meta['channel_updates'], + ): + msg = parse(r) + target.merge(ChannelUpdate.from_gossip(msg, r)) + + for r, in tqdm( + source.execute("SELECT raw FROM node_announcements ORDER BY timestamp ASC"), + total=meta['node_announcements'], + ): + msg = parse(r) + target.merge(NodeAnnouncement.from_gossip(msg, r)) + + target.commit() diff --git a/historian/historian-cli b/historian/historian-cli index d77d204..554c02e 100755 --- a/historian/historian-cli +++ b/historian/historian-cli @@ -1,7 +1,11 @@ #!/usr/bin/env python3 +import struct from tqdm import tqdm +import shlex +import subprocess from contextlib import contextmanager from sqlalchemy import create_engine +from cli import common from common import Base, ChannelAnnouncement, ChannelUpdate, NodeAnnouncement from sqlalchemy.orm import sessionmaker from sqlalchemy import func @@ -16,15 +20,15 @@ import io import logging import socket from pyln.proto import wire - +from cli.backup import backup +from cli.db import db default_db = "sqlite:///$HOME/.lightning/bitcoin/historian.sqlite3" @contextmanager def db_session(dsn): - """ Tiny contextmanager to facilitate sqlalchemy session management - """ + """Tiny contextmanager to facilitate sqlalchemy session management""" if dsn is None: dsn = default_db dsn = os.path.expandvars(dsn) @@ -47,6 +51,10 @@ def cli(): pass +cli.add_command(backup) +cli.add_command(db) + + @cli.group() def snapshot(): pass @@ -57,13 +65,13 @@ default_since = datetime.utcnow() - timedelta(hours=1) @snapshot.command() -@click.argument('destination', type=click.File('wb')) +@click.argument("destination", type=click.File("wb")) @click.argument( - 'since', + "since", type=click.DateTime(formats=[dt_fmt]), - default=default_since.strftime(dt_fmt) + default=default_since.strftime(dt_fmt), ) -@click.option('--db', type=str, default=default_db) +@click.option("--db", type=str, default=default_db) def incremental(since, destination, db): with db_session(db) as session: # Several nested queries here because join was a bit too @@ -71,7 +79,8 @@ def incremental(since, destination, db): # that had any updates in the desired timerange. The outer SELECT then # gets all the announcements and kicks off inner SELECTs that look for # the latest update for each direction. - rows = session.execute(""" + rows = session.execute( + """ SELECT a.scid, a.raw, @@ -114,7 +123,10 @@ WHERE ) ORDER BY a.scid - """.format(since.strftime("%Y-%m-%d %H:%M:%S"))) + """.format( + since.strftime("%Y-%m-%d %H:%M:%S") + ) + ) # Write the header now that we know we'll be writing something. destination.write(b"GSP\x01") @@ -138,7 +150,8 @@ ORDER BY # Now get and return the node_announcements in the timerange. These # come after the channels since no node without a # channel_announcements and channel_update is allowed. - rows = session.execute(""" + rows = session.execute( + """ SELECT n.node_id, n.timestamp, @@ -152,7 +165,10 @@ GROUP BY HAVING n.timestamp = MAX(n.timestamp) ORDER BY timestamp DESC - """.format(since.strftime("%Y-%m-%d %H:%M:%S"))) + """.format( + since.strftime("%Y-%m-%d %H:%M:%S") + ) + ) last_nid = None node_count = 0 for nid, ts, nann in rows: @@ -164,29 +180,29 @@ ORDER BY timestamp DESC destination.write(nann) click.echo( - f'Wrote {chan_count} channels and {node_count} nodes to {destination.name}', - err=True + f"Wrote {chan_count} channels and {node_count} nodes to {destination.name}", + err=True, ) @snapshot.command() -@click.argument('destination', type=click.File('wb')) +@click.argument("destination", type=click.File("wb")) @click.pass_context -@click.option('--db', type=str, default=default_db) +@click.option("--db", type=str, default=default_db) def full(ctx, destination, db): since = datetime.utcnow() - timedelta(weeks=2) ctx.invoke(incremental, since=since, destination=destination, db=db) @snapshot.command() -@click.argument('snapshot', type=click.File('rb')) +@click.argument("snapshot", type=click.File("rb")) def read(snapshot): header = snapshot.read(4) if len(header) < 4: raise ValueError("Could not read header") tag, version = header[0:3], header[3] - if tag != b'GSP': + if tag != b"GSP": raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}") if version != 1: @@ -204,7 +220,70 @@ def read(snapshot): print(msg.hex()) -LightningAddress = namedtuple('LightningAddress', ['node_id', 'host', 'port']) +@snapshot.command() +@click.argument("snapshot", type=common.GossipFile(decode=False)) +@click.argument("max_bytes", type=int) +@click.option("-x", "--exec", type=str) +def split(snapshot, max_bytes, exec): + def bundle(f: common.GossipFile): + bundle = None + for m in f: + (typ,) = struct.unpack_from("!H", m) + + if typ == 257: + # NodeAnnouncements are always self-contained, so yield them + # individually + yield m, + elif typ == 256: + # ChannelAnnouncements indicate the start of a new bundle + if bundle is not None: + yield tuple(bundle) + bundle = [] + bundle.append(m) + else: + # ChannelUpdates belong to the bundle + bundle.append(m) + # If we have an unyielded bundle we need to flush it at the end. + yield tuple(bundle) + + def serialize_bundle(b): + buff = io.BytesIO() + for m in b: + varint_encode(len(m), buff) + buff.write(m) + return buff.getvalue() + + filenum = 0 + prefix, extension = os.path.splitext(snapshot.filename) + filename = "{prefix}_{{filenum:04d}}{extension}".format( + prefix=prefix, extension=extension + ) + + def on_complete(filenum): + fname = filename.format(filenum=filenum) + if exec is not None: + cmd = shlex.split(exec.replace("{}", shlex.quote(fname))) + logging.debug("Exec:\n> {}".format(" ".join(cmd))) + subprocess.run(cmd) + + f = open(filename.format(filenum=filenum), "wb") + f.write(b"GSP\x01") + for b in bundle(snapshot): + assert len(b) <= 3 + m = serialize_bundle(b) + + if f.tell() + len(m) > max_bytes: + f.close() + on_complete(filenum) + filenum += 1 + f = open(filename.format(filenum=filenum), "wb") + f.write(b"GSP\x01") + f.write(m) + f.close() + on_complete(filenum) + + +LightningAddress = namedtuple("LightningAddress", ["node_id", "host", "port"]) class LightningAddressParam(click.ParamType): @@ -213,7 +292,7 @@ class LightningAddressParam(click.ParamType): if m is None: self.fail( f"{value} isn't a valid lightning connection string, " - "expected \"[node_id]@[host]:[port]\"" + 'expected "[node_id]@[host]:[port]"' ) return @@ -275,15 +354,15 @@ def split_gossip(reader: io.BytesIO): @snapshot.command() -@click.argument('snapshot', type=click.File('rb')) -@click.argument('destination', type=LightningAddressParam()) +@click.argument("snapshot", type=click.File("rb")) +@click.argument("destination", type=LightningAddressParam()) def load(snapshot, destination): header = snapshot.read(4) if len(header) < 4: raise ValueError("Could not read header") tag, version = header[0:3], header[3] - if tag != b'GSP': + if tag != b"GSP": raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}") if version != 1: