historian: Add backup command for lnresearch

This commit is contained in:
Christian Decker
2020-10-20 21:34:18 +02:00
parent 474a2011d1
commit 3f5dc35ab1
4 changed files with 301 additions and 22 deletions

56
historian/cli/backup.py Normal file
View File

@@ -0,0 +1,56 @@
import click
from .common import db_session, split_gossip
import os
from pyln.proto.primitives import varint_decode, varint_encode
@click.group()
def backup():
pass
@backup.command()
@click.argument('destination', type=click.File('wb'))
@click.option('--db', type=str, default=None)
def create(destination, db):
with db_session(db) as session:
rows = session.execute("SELECT raw FROM channel_announcements")
# Write the header now that we know we'll be writing something.
destination.write(b"GSP\x01")
for r in rows:
varint_encode(len(r[0]), destination)
destination.write(r[0])
rows = session.execute("SELECT raw FROM channel_updates ORDER BY timestamp ASC")
for r in rows:
varint_encode(len(r[0]), destination)
destination.write(r[0])
rows = session.execute("SELECT raw FROM node_announcements ORDER BY timestamp ASC")
for r in rows:
varint_encode(len(r[0]), destination)
destination.write(r[0])
destination.close()
@backup.command()
@click.argument("source", type=click.File('rb'))
def read(source):
"""Load gossip messages from the specified source and print it to stdout
Prints the hex-encoded raw gossip message to stdout.
"""
header = source.read(4)
if len(header) < 4:
raise ValueError("Could not read header")
tag, version = header[0:3], header[3]
if tag != b'GSP':
raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}")
if version != 1:
raise ValueError(f"Unsupported version {version}, only support up to version 1")
for m in split_gossip(source):
print(m.hex())

90
historian/cli/common.py Normal file
View File

@@ -0,0 +1,90 @@
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
from contextlib import contextmanager
import os
from common import Base
import io
from pyln.proto.primitives import varint_decode
from parser import parse
import click
import bz2
default_db = "sqlite:///$HOME/.lightning/bitcoin/historian.sqlite3"
@contextmanager
def db_session(dsn):
"""Tiny contextmanager to facilitate sqlalchemy session management"""
if dsn is None:
dsn = default_db
dsn = os.path.expandvars(dsn)
engine = create_engine(dsn, echo=False)
Base.metadata.create_all(engine)
session_maker = sessionmaker(bind=engine)
session = session_maker()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
def split_gossip(reader: io.BytesIO):
while True:
length = varint_decode(reader)
if length is None:
break
msg = reader.read(length)
if len(msg) != length:
raise ValueError("Incomplete read at end of file")
yield msg
class GossipStream:
def __init__(self, file_stream, filename, decode=True):
self.stream = file_stream
self.decode = decode
self.filename = filename
# Read header
header = self.stream.read(4)
assert len(header) == 4
assert header[:3] == b"GSP"
assert header[3] == 1
def __iter__(self):
return self
def __next__(self):
pos = self.stream.tell()
length = varint_decode(self.stream)
if length is None:
raise StopIteration
msg = self.stream.read(length)
if len(msg) != length:
raise ValueError(
"Error reading snapshot at {pos}: incomplete read of {length} bytes, only got {lmsg} bytes".format(
pos=pos, length=length, lmsg=len(msg)
)
)
if not self.decode:
return msg
return parse(msg)
class GossipFile(click.File):
def __init__(self, decode=True):
click.File.__init__(self)
self.decode = decode
def convert(self, value, param, ctx):
f = bz2.open(value, "rb") if value.endswith(".bz2") else open(value, "rb")
return GossipStream(f, value, self.decode)

54
historian/cli/db.py Normal file
View File

@@ -0,0 +1,54 @@
import click
from common import NodeAnnouncement, ChannelAnnouncement, ChannelUpdate
from tqdm import tqdm
from parser import parse
from cli.common import db_session, default_db
@click.group()
def db():
pass
@db.command()
@click.argument('source', type=str)
@click.argument('destination', type=str, default=default_db)
def merge(source, destination):
"""Merge two historian databases by copying from source to destination.
"""
meta = {
'channel_announcements': None,
'channel_updates': None,
'node_announcements': None,
}
with db_session(source) as source, db_session(destination) as target:
# Not strictly necessary, but I like progress indicators and ETAs.
for table in meta.keys():
rows = source.execute(f"SELECT count(*) FROM {table}")
count, = rows.next()
meta[table] = count
for r, in tqdm(
source.execute("SELECT raw FROM channel_announcements"),
total=meta['channel_announcements'],
):
msg = parse(r)
target.merge(ChannelAnnouncement.from_gossip(msg, r))
for r, in tqdm(
source.execute("SELECT raw FROM channel_updates ORDER BY timestamp ASC"),
total=meta['channel_updates'],
):
msg = parse(r)
target.merge(ChannelUpdate.from_gossip(msg, r))
for r, in tqdm(
source.execute("SELECT raw FROM node_announcements ORDER BY timestamp ASC"),
total=meta['node_announcements'],
):
msg = parse(r)
target.merge(NodeAnnouncement.from_gossip(msg, r))
target.commit()

View File

@@ -1,7 +1,11 @@
#!/usr/bin/env python3
import struct
from tqdm import tqdm
import shlex
import subprocess
from contextlib import contextmanager
from sqlalchemy import create_engine
from cli import common
from common import Base, ChannelAnnouncement, ChannelUpdate, NodeAnnouncement
from sqlalchemy.orm import sessionmaker
from sqlalchemy import func
@@ -16,15 +20,15 @@ import io
import logging
import socket
from pyln.proto import wire
from cli.backup import backup
from cli.db import db
default_db = "sqlite:///$HOME/.lightning/bitcoin/historian.sqlite3"
@contextmanager
def db_session(dsn):
""" Tiny contextmanager to facilitate sqlalchemy session management
"""
"""Tiny contextmanager to facilitate sqlalchemy session management"""
if dsn is None:
dsn = default_db
dsn = os.path.expandvars(dsn)
@@ -47,6 +51,10 @@ def cli():
pass
cli.add_command(backup)
cli.add_command(db)
@cli.group()
def snapshot():
pass
@@ -57,13 +65,13 @@ default_since = datetime.utcnow() - timedelta(hours=1)
@snapshot.command()
@click.argument('destination', type=click.File('wb'))
@click.argument("destination", type=click.File("wb"))
@click.argument(
'since',
"since",
type=click.DateTime(formats=[dt_fmt]),
default=default_since.strftime(dt_fmt)
default=default_since.strftime(dt_fmt),
)
@click.option('--db', type=str, default=default_db)
@click.option("--db", type=str, default=default_db)
def incremental(since, destination, db):
with db_session(db) as session:
# Several nested queries here because join was a bit too
@@ -71,7 +79,8 @@ def incremental(since, destination, db):
# that had any updates in the desired timerange. The outer SELECT then
# gets all the announcements and kicks off inner SELECTs that look for
# the latest update for each direction.
rows = session.execute("""
rows = session.execute(
"""
SELECT
a.scid,
a.raw,
@@ -114,7 +123,10 @@ WHERE
)
ORDER BY
a.scid
""".format(since.strftime("%Y-%m-%d %H:%M:%S")))
""".format(
since.strftime("%Y-%m-%d %H:%M:%S")
)
)
# Write the header now that we know we'll be writing something.
destination.write(b"GSP\x01")
@@ -138,7 +150,8 @@ ORDER BY
# Now get and return the node_announcements in the timerange. These
# come after the channels since no node without a
# channel_announcements and channel_update is allowed.
rows = session.execute("""
rows = session.execute(
"""
SELECT
n.node_id,
n.timestamp,
@@ -152,7 +165,10 @@ GROUP BY
HAVING
n.timestamp = MAX(n.timestamp)
ORDER BY timestamp DESC
""".format(since.strftime("%Y-%m-%d %H:%M:%S")))
""".format(
since.strftime("%Y-%m-%d %H:%M:%S")
)
)
last_nid = None
node_count = 0
for nid, ts, nann in rows:
@@ -164,29 +180,29 @@ ORDER BY timestamp DESC
destination.write(nann)
click.echo(
f'Wrote {chan_count} channels and {node_count} nodes to {destination.name}',
err=True
f"Wrote {chan_count} channels and {node_count} nodes to {destination.name}",
err=True,
)
@snapshot.command()
@click.argument('destination', type=click.File('wb'))
@click.argument("destination", type=click.File("wb"))
@click.pass_context
@click.option('--db', type=str, default=default_db)
@click.option("--db", type=str, default=default_db)
def full(ctx, destination, db):
since = datetime.utcnow() - timedelta(weeks=2)
ctx.invoke(incremental, since=since, destination=destination, db=db)
@snapshot.command()
@click.argument('snapshot', type=click.File('rb'))
@click.argument("snapshot", type=click.File("rb"))
def read(snapshot):
header = snapshot.read(4)
if len(header) < 4:
raise ValueError("Could not read header")
tag, version = header[0:3], header[3]
if tag != b'GSP':
if tag != b"GSP":
raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}")
if version != 1:
@@ -204,7 +220,70 @@ def read(snapshot):
print(msg.hex())
LightningAddress = namedtuple('LightningAddress', ['node_id', 'host', 'port'])
@snapshot.command()
@click.argument("snapshot", type=common.GossipFile(decode=False))
@click.argument("max_bytes", type=int)
@click.option("-x", "--exec", type=str)
def split(snapshot, max_bytes, exec):
def bundle(f: common.GossipFile):
bundle = None
for m in f:
(typ,) = struct.unpack_from("!H", m)
if typ == 257:
# NodeAnnouncements are always self-contained, so yield them
# individually
yield m,
elif typ == 256:
# ChannelAnnouncements indicate the start of a new bundle
if bundle is not None:
yield tuple(bundle)
bundle = []
bundle.append(m)
else:
# ChannelUpdates belong to the bundle
bundle.append(m)
# If we have an unyielded bundle we need to flush it at the end.
yield tuple(bundle)
def serialize_bundle(b):
buff = io.BytesIO()
for m in b:
varint_encode(len(m), buff)
buff.write(m)
return buff.getvalue()
filenum = 0
prefix, extension = os.path.splitext(snapshot.filename)
filename = "{prefix}_{{filenum:04d}}{extension}".format(
prefix=prefix, extension=extension
)
def on_complete(filenum):
fname = filename.format(filenum=filenum)
if exec is not None:
cmd = shlex.split(exec.replace("{}", shlex.quote(fname)))
logging.debug("Exec:\n> {}".format(" ".join(cmd)))
subprocess.run(cmd)
f = open(filename.format(filenum=filenum), "wb")
f.write(b"GSP\x01")
for b in bundle(snapshot):
assert len(b) <= 3
m = serialize_bundle(b)
if f.tell() + len(m) > max_bytes:
f.close()
on_complete(filenum)
filenum += 1
f = open(filename.format(filenum=filenum), "wb")
f.write(b"GSP\x01")
f.write(m)
f.close()
on_complete(filenum)
LightningAddress = namedtuple("LightningAddress", ["node_id", "host", "port"])
class LightningAddressParam(click.ParamType):
@@ -213,7 +292,7 @@ class LightningAddressParam(click.ParamType):
if m is None:
self.fail(
f"{value} isn't a valid lightning connection string, "
"expected \"[node_id]@[host]:[port]\""
'expected "[node_id]@[host]:[port]"'
)
return
@@ -275,15 +354,15 @@ def split_gossip(reader: io.BytesIO):
@snapshot.command()
@click.argument('snapshot', type=click.File('rb'))
@click.argument('destination', type=LightningAddressParam())
@click.argument("snapshot", type=click.File("rb"))
@click.argument("destination", type=LightningAddressParam())
def load(snapshot, destination):
header = snapshot.read(4)
if len(header) < 4:
raise ValueError("Could not read header")
tag, version = header[0:3], header[3]
if tag != b'GSP':
if tag != b"GSP":
raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}")
if version != 1: