mirror of
https://github.com/aljazceru/plugins.git
synced 2025-12-24 16:34:20 +01:00
303 lines
8.0 KiB
Python
Executable File
303 lines
8.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
from tqdm import tqdm
|
|
from contextlib import contextmanager
|
|
from sqlalchemy import create_engine
|
|
from common import Base, ChannelAnnouncement, ChannelUpdate, NodeAnnouncement
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy import func
|
|
from datetime import datetime, timedelta
|
|
from collections import namedtuple
|
|
import click
|
|
from pyln.proto.primitives import varint_encode, varint_decode
|
|
import os
|
|
from sqlalchemy.orm import load_only
|
|
import re
|
|
import io
|
|
import logging
|
|
import socket
|
|
from pyln.proto import wire
|
|
|
|
|
|
default_db = "sqlite:///$HOME/.lightning/bitcoin/historian.sqlite3"
|
|
|
|
|
|
@contextmanager
|
|
def db_session(dsn):
|
|
""" Tiny contextmanager to facilitate sqlalchemy session management
|
|
"""
|
|
if dsn is None:
|
|
dsn = default_db
|
|
dsn = os.path.expandvars(dsn)
|
|
engine = create_engine(dsn, echo=False)
|
|
Base.metadata.create_all(engine)
|
|
session_maker = sessionmaker(bind=engine)
|
|
session = session_maker()
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
@click.group()
|
|
def cli():
|
|
pass
|
|
|
|
|
|
@cli.group()
|
|
def snapshot():
|
|
pass
|
|
|
|
|
|
dt_fmt = "%Y-%m-%d %H:%M:%S"
|
|
default_since = datetime.utcnow() - timedelta(hours=1)
|
|
|
|
|
|
@snapshot.command()
|
|
@click.argument('destination', type=click.File('wb'))
|
|
@click.argument(
|
|
'since',
|
|
type=click.DateTime(formats=[dt_fmt]),
|
|
default=default_since.strftime(dt_fmt)
|
|
)
|
|
@click.option('--db', type=str, default=default_db)
|
|
def incremental(since, destination, db):
|
|
with db_session(db) as session:
|
|
# Several nested queries here because join was a bit too
|
|
# restrictive. The inner SELECT in the WHERE-clause selects all scids
|
|
# that had any updates in the desired timerange. The outer SELECT then
|
|
# gets all the announcements and kicks off inner SELECTs that look for
|
|
# the latest update for each direction.
|
|
rows = session.execute("""
|
|
SELECT
|
|
a.scid,
|
|
a.raw,
|
|
(
|
|
SELECT
|
|
u.raw
|
|
FROM
|
|
channel_updates u
|
|
WHERE
|
|
u.scid = a.scid AND
|
|
direction = 0
|
|
ORDER BY
|
|
timestamp
|
|
DESC LIMIT 1
|
|
) as u0,
|
|
(
|
|
SELECT
|
|
u.raw
|
|
FROM
|
|
channel_updates u
|
|
WHERE
|
|
u.scid = a.scid AND
|
|
direction = 1
|
|
ORDER BY
|
|
timestamp
|
|
DESC LIMIT 1
|
|
) as u1
|
|
FROM
|
|
channel_announcements a
|
|
WHERE
|
|
a.scid IN (
|
|
SELECT
|
|
u.scid
|
|
FROM
|
|
channel_updates u
|
|
WHERE
|
|
u.timestamp >= DATETIME('{}')
|
|
GROUP BY
|
|
u.scid
|
|
)
|
|
ORDER BY
|
|
a.scid
|
|
""".format(since.strftime("%Y-%m-%d %H:%M:%S")))
|
|
|
|
# Write the header now that we know we'll be writing something.
|
|
destination.write(b"GSP\x01")
|
|
|
|
chan_count = 0
|
|
last_scid = None
|
|
for scid, cann, u1, u2 in rows:
|
|
if scid == last_scid:
|
|
continue
|
|
last_scid = scid
|
|
chan_count += 1
|
|
varint_encode(len(cann), destination)
|
|
destination.write(cann)
|
|
if u1 is not None:
|
|
varint_encode(len(u1), destination)
|
|
destination.write(u1)
|
|
if u2 is not None:
|
|
varint_encode(len(u2), destination)
|
|
destination.write(u2)
|
|
|
|
# Now get and return the node_announcements in the timerange. These
|
|
# come after the channels since no node without a
|
|
# channel_announcements and channel_update is allowed.
|
|
rows = session.execute("""
|
|
SELECT
|
|
n.node_id,
|
|
n.timestamp,
|
|
n.raw
|
|
FROM
|
|
node_announcements n
|
|
WHERE
|
|
n.timestamp >= DATETIME('{}')
|
|
GROUP BY
|
|
n.node_id
|
|
HAVING
|
|
n.timestamp = MAX(n.timestamp)
|
|
ORDER BY timestamp DESC
|
|
""".format(since.strftime("%Y-%m-%d %H:%M:%S")))
|
|
last_nid = None
|
|
node_count = 0
|
|
for nid, ts, nann in rows:
|
|
if nid == last_nid:
|
|
continue
|
|
last_nid = nid
|
|
node_count += 1
|
|
varint_encode(len(nann), destination)
|
|
destination.write(nann)
|
|
|
|
click.echo(
|
|
f'Wrote {chan_count} channels and {node_count} nodes to {destination.name}',
|
|
err=True
|
|
)
|
|
|
|
|
|
@snapshot.command()
|
|
@click.argument('destination', type=click.File('wb'))
|
|
@click.pass_context
|
|
@click.option('--db', type=str, default=default_db)
|
|
def full(ctx, destination, db):
|
|
since = datetime.utcnow() - timedelta(weeks=2)
|
|
ctx.invoke(incremental, since=since, destination=destination, db=db)
|
|
|
|
|
|
@snapshot.command()
|
|
@click.argument('snapshot', type=click.File('rb'))
|
|
def read(snapshot):
|
|
header = snapshot.read(4)
|
|
if len(header) < 4:
|
|
raise ValueError("Could not read header")
|
|
|
|
tag, version = header[0:3], header[3]
|
|
if tag != b'GSP':
|
|
raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}")
|
|
|
|
if version != 1:
|
|
raise ValueError(f"Unsupported version {version}, only support up to version 1")
|
|
|
|
while True:
|
|
l = varint_decode(snapshot)
|
|
if l is None:
|
|
break
|
|
|
|
msg = snapshot.read(l)
|
|
if len(msg) != l:
|
|
raise ValueError("Incomplete read at end of file")
|
|
|
|
print(msg.hex())
|
|
|
|
|
|
LightningAddress = namedtuple('LightningAddress', ['node_id', 'host', 'port'])
|
|
|
|
|
|
class LightningAddressParam(click.ParamType):
|
|
def convert(self, value, param, ctx):
|
|
m = re.match(r"(0[23][a-fA-F0-9]+)@([a-zA-Z0-9\.:]+):([0-9]+)?", value)
|
|
if m is None:
|
|
self.fail(
|
|
f"{value} isn't a valid lightning connection string, "
|
|
"expected \"[node_id]@[host]:[port]\""
|
|
)
|
|
return
|
|
|
|
if len(m.groups()) < 3:
|
|
return LightningAddress(m[1], m[2], 9735)
|
|
else:
|
|
return LightningAddress(m[1], m[2], int(m[3]))
|
|
|
|
|
|
class LightningPeer:
|
|
def __init__(self, node_id: str, address: str, port: int = 9735):
|
|
self.node_id = node_id
|
|
self.address = address
|
|
self.port = port
|
|
self.connection = None
|
|
self.local_privkey = wire.PrivateKey(os.urandom(32))
|
|
|
|
def connect(self):
|
|
sock = socket.create_connection((self.address, self.port), timeout=30)
|
|
self.connection = wire.LightningConnection(
|
|
sock,
|
|
remote_pubkey=wire.PublicKey(bytes.fromhex(self.node_id)),
|
|
local_privkey=self.local_privkey,
|
|
is_initiator=True,
|
|
)
|
|
self.connection.shake()
|
|
|
|
# Send an init message, with no global features, and 0b10101010 as
|
|
# local features.
|
|
self.connection.send_message(b"\x00\x10\x00\x00\x00\x01\xaa")
|
|
|
|
def send(self, packet: bytes) -> None:
|
|
if self.connection is None:
|
|
raise ValueError("Not connected to peer")
|
|
|
|
logging.debug("Sending {}".format(packet.hex()))
|
|
self.connection.send_message(packet)
|
|
|
|
def send_all(self, packets) -> None:
|
|
assert self.connection is not None
|
|
for p in packets:
|
|
self.send(p)
|
|
|
|
def disconnect(self):
|
|
self.connection.connection.close()
|
|
|
|
|
|
def split_gossip(reader: io.BytesIO):
|
|
while True:
|
|
length = varint_decode(reader)
|
|
if length is None:
|
|
break
|
|
|
|
msg = reader.read(length)
|
|
if len(msg) != length:
|
|
raise ValueError("Incomplete read at end of file")
|
|
|
|
yield msg
|
|
|
|
|
|
@snapshot.command()
|
|
@click.argument('snapshot', type=click.File('rb'))
|
|
@click.argument('destination', type=LightningAddressParam())
|
|
def load(snapshot, destination):
|
|
header = snapshot.read(4)
|
|
if len(header) < 4:
|
|
raise ValueError("Could not read header")
|
|
|
|
tag, version = header[0:3], header[3]
|
|
if tag != b'GSP':
|
|
raise ValueError(f"Header mismatch, expected GSP, got {repr(tag)}")
|
|
|
|
if version != 1:
|
|
raise ValueError(f"Unsupported version {version}, only support up to version 1")
|
|
|
|
logging.debug(f"Connecting to {destination}")
|
|
peer = LightningPeer(destination.node_id, destination.host, destination.port)
|
|
peer.connect()
|
|
logging.debug("Connected, streaming messages from snapshot")
|
|
peer.send_all(tqdm(split_gossip(snapshot)))
|
|
peer.disconnect()
|
|
logging.debug("Done streaming messages, disconnecting")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli()
|