Files
plugins/historian/historian-cli
2021-01-22 18:00:16 +01:00

303 lines
8.0 KiB
Python
Executable File

#!/usr/bin/env python3
from tqdm import tqdm
from contextlib import contextmanager
from sqlalchemy import create_engine
from common import Base, ChannelAnnouncement, ChannelUpdate, NodeAnnouncement
from sqlalchemy.orm import sessionmaker
from sqlalchemy import func
from datetime import datetime, timedelta
from collections import namedtuple
import click
from pyln.proto.primitives import varint_encode, varint_decode
import os
from sqlalchemy.orm import load_only
import re
import io
import logging
import socket
from pyln.proto import wire
default_db = "sqlite:///$HOME/.lightning/bitcoin/historian.sqlite3"
@contextmanager
def db_session(dsn):
""" Tiny contextmanager to facilitate sqlalchemy session management
"""
if dsn is None:
dsn = default_db
dsn = os.path.expandvars(dsn)
engine = create_engine(dsn, echo=False)
Base.metadata.create_all(engine)
session_maker = sessionmaker(bind=engine)
session = session_maker()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
@click.group()
def cli():
pass
@cli.group()
def snapshot():
pass
dt_fmt = "%Y-%m-%d %H:%M:%S"
default_since = datetime.utcnow() - timedelta(hours=1)
@snapshot.command()
@click.argument('destination', type=click.File('wb'))
@click.argument(
'since',
type=click.DateTime(formats=[dt_fmt]),
default=default_since.strftime(dt_fmt)
)
@click.option('--db', type=str, default=default_db)
def incremental(since, destination, db):
with db_session(db) as session:
# Several nested queries here because join was a bit too
# restrictive. The inner SELECT in the WHERE-clause selects all scids
# that had any updates in the desired timerange. The outer SELECT then
# gets all the announcements and kicks off inner SELECTs that look for
# the latest update for each direction.
rows = session.execute("""
SELECT
a.scid,
a.raw,
(
SELECT
u.raw
FROM
channel_updates u
WHERE
u.scid = a.scid AND
direction = 0
ORDER BY
timestamp
DESC LIMIT 1
) as u0,
(
SELECT
u.raw
FROM
channel_updates u
WHERE
u.scid = a.scid AND
direction = 1
ORDER BY
timestamp
DESC LIMIT 1
) as u1
FROM
channel_announcements a
WHERE
a.scid IN (
SELECT
u.scid
FROM
channel_updates u
WHERE
u.timestamp >= DATETIME('{}')
GROUP BY
u.scid
)
ORDER BY
a.scid
""".format(since.strftime("%Y-%m-%d %H:%M:%S")))
# Write the header now that we know we'll be writing something.
destination.write(b"GSP\x01")
chan_count = 0
last_scid = None
for scid, cann, u1, u2 in rows:
if scid == last_scid:
continue
last_scid = scid
chan_count += 1
varint_encode(len(cann), destination)
destination.write(cann)
if u1 is not None:
varint_encode(len(u1), destination)
destination.write(u1)
if u2 is not None:
varint_encode(len(u2), destination)
destination.write(u2)
# Now get and return the node_announcements in the timerange. These
# come after the channels since no node without a
# channel_announcements and channel_update is allowed.
rows = session.execute("""
SELECT
n.node_id,
n.timestamp,
n.raw
FROM
node_announcements n
WHERE
n.timestamp >= DATETIME('{}')
GROUP BY
n.node_id
HAVING
n.timestamp = MAX(n.timestamp)
ORDER BY timestamp DESC
""".format(since.strftime("%Y-%m-%d %H:%M:%S")))
last_nid = None
node_count = 0
for nid, ts, nann in rows:
if nid == last_nid:
continue
last_nid = nid
node_count += 1
varint_encode(len(nann), destination)
destination.write(nann)
click.echo(
f'Wrote {chan_count} channels and {node_count} nodes to {destination.name}',
err=True
)
@snapshot.command()
@click.argument('destination', type=click.File('wb'))
@click.pass_context
@click.option('--db', type=str, default=default_db)
def full(ctx, destination, db):
since = datetime.utcnow() - timedelta(weeks=2)
ctx.invoke(incremental, since=since, destination=destination, db=db)
@snapshot.command()
@click.argument('snapshot', type=click.File('rb'))
def read(snapshot):
header = snapshot.read(4)
if len(header) < 4:
raise ValueError("Could not read header")
tag, version = header[0:3], header[3]
if tag != b'GSP':
raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}")
if version != 1:
raise ValueError(f"Unsupported version {version}, only support up to version 1")
while True:
l = varint_decode(snapshot)
if l is None:
break
msg = snapshot.read(l)
if len(msg) != l:
raise ValueError("Incomplete read at end of file")
print(msg.hex())
LightningAddress = namedtuple('LightningAddress', ['node_id', 'host', 'port'])
class LightningAddressParam(click.ParamType):
def convert(self, value, param, ctx):
m = re.match(r"(0[23][a-fA-F0-9]+)@([a-zA-Z0-9\.:]+):([0-9]+)?", value)
if m is None:
self.fail(
f"{value} isn't a valid lightning connection string, "
"expected \"[node_id]@[host]:[port]\""
)
return
if len(m.groups()) < 3:
return LightningAddress(m[1], m[2], 9735)
else:
return LightningAddress(m[1], m[2], int(m[3]))
class LightningPeer:
def __init__(self, node_id: str, address: str, port: int = 9735):
self.node_id = node_id
self.address = address
self.port = port
self.connection = None
self.local_privkey = wire.PrivateKey(os.urandom(32))
def connect(self):
sock = socket.create_connection((self.address, self.port), timeout=30)
self.connection = wire.LightningConnection(
sock,
remote_pubkey=wire.PublicKey(bytes.fromhex(self.node_id)),
local_privkey=self.local_privkey,
is_initiator=True,
)
self.connection.shake()
# Send an init message, with no global features, and 0b10101010 as
# local features.
self.connection.send_message(b"\x00\x10\x00\x00\x00\x01\xaa")
def send(self, packet: bytes) -> None:
if self.connection is None:
raise ValueError("Not connected to peer")
logging.debug("Sending {}".format(packet.hex()))
self.connection.send_message(packet)
def send_all(self, packets) -> None:
assert self.connection is not None
for p in packets:
self.send(p)
def disconnect(self):
self.connection.connection.close()
def split_gossip(reader: io.BytesIO):
while True:
length = varint_decode(reader)
if length is None:
break
msg = reader.read(length)
if len(msg) != length:
raise ValueError("Incomplete read at end of file")
yield msg
@snapshot.command()
@click.argument('snapshot', type=click.File('rb'))
@click.argument('destination', type=LightningAddressParam())
def load(snapshot, destination):
header = snapshot.read(4)
if len(header) < 4:
raise ValueError("Could not read header")
tag, version = header[0:3], header[3]
if tag != b'GSP':
raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}")
if version != 1:
raise ValueError(f"Unsupported version {version}, only support up to version 1")
logging.debug(f"Connecting to {destination}")
peer = LightningPeer(destination.node_id, destination.host, destination.port)
peer.connect()
logging.debug("Connected, streaming messages from snapshot")
peer.send_all(tqdm(split_gossip(snapshot)))
peer.disconnect()
logging.debug("Done streaming messages, disconnecting")
if __name__ == "__main__":
cli()