mirror of
https://github.com/aljazceru/plugins.git
synced 2025-12-23 08:04:20 +01:00
noise: Use the onion classes and fix the length prefix issue
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from pyln.client import Plugin, RpcError
|
from pyln.client import Plugin, RpcError
|
||||||
from pyln.proto.primitives import varint_decode, varint_encode
|
from pyln.proto.primitives import varint_decode, varint_encode
|
||||||
|
from onion import TlvPayload
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
import struct
|
import struct
|
||||||
import string
|
import string
|
||||||
@@ -35,7 +36,7 @@ class Message(object):
|
|||||||
def serialize_payload(n, blockheight):
|
def serialize_payload(n, blockheight):
|
||||||
block, tx, out = n['channel'].split('x')
|
block, tx, out = n['channel'].split('x')
|
||||||
payload = hexlify(struct.pack(
|
payload = hexlify(struct.pack(
|
||||||
"!QQL",
|
"!cQQL", b'\x00',
|
||||||
int(block) << 40 | int(tx) << 16 | int(out),
|
int(block) << 40 | int(tx) << 16 | int(out),
|
||||||
int(n['amount_msat']),
|
int(n['amount_msat']),
|
||||||
blockheight + n['delay'])).decode('ASCII')
|
blockheight + n['delay'])).decode('ASCII')
|
||||||
@@ -102,19 +103,15 @@ def deliver(node_id, payload, amt, max_attempts=5, payment_hash=None):
|
|||||||
|
|
||||||
@plugin.async_method('sendmsg')
|
@plugin.async_method('sendmsg')
|
||||||
def sendmsg(node_id, msg, plugin, request, amt=1000, **kwargs):
|
def sendmsg(node_id, msg, plugin, request, amt=1000, **kwargs):
|
||||||
payload = BytesIO()
|
payload = TlvPayload()
|
||||||
varint_encode(34349334, payload)
|
payload.add_field(34349334, msg.encode('UTF-8'))
|
||||||
varint_encode(len(msg), payload)
|
|
||||||
payload.write(msg.encode('UTF-8'))
|
|
||||||
|
|
||||||
# Sign the message:
|
# Sign the message:
|
||||||
sig = plugin.rpc.signmessage(msg)['signature']
|
sig = plugin.rpc.signmessage(msg)['signature']
|
||||||
sig = unhexlify(sig)
|
sig = unhexlify(sig)
|
||||||
varint_encode(34349336, payload)
|
payload.add_field(34349336, sig)
|
||||||
varint_encode(len(sig), payload)
|
|
||||||
payload.write(sig)
|
|
||||||
|
|
||||||
res = deliver(node_id, payload.getbuffer(), amt=amt)
|
res = deliver(node_id, payload.to_bytes(), amt=amt)
|
||||||
request.set_result(res)
|
request.set_result(res)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
236
noise/onion.py
Normal file
236
noise/onion.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
from primitives import varint_decode, varint_encode
|
||||||
|
from io import BytesIO, SEEK_CUR
|
||||||
|
from binascii import hexlify, unhexlify
|
||||||
|
import struct
|
||||||
|
|
||||||
|
|
||||||
|
class OnionPayload(object):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, b):
|
||||||
|
if isinstance(b, bytes):
|
||||||
|
b = BytesIO(b)
|
||||||
|
|
||||||
|
realm = b.read(1)
|
||||||
|
b.seek(-1, SEEK_CUR)
|
||||||
|
|
||||||
|
if realm == b'\x00':
|
||||||
|
return LegacyOnionPayload.from_bytes(b)
|
||||||
|
elif realm != b'\x01':
|
||||||
|
return TlvPayload.from_bytes(b, skip_length=False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Onion payloads with realm 0x01 are unsupported")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_hex(cls, s):
|
||||||
|
if isinstance(s, str):
|
||||||
|
s = s.encode('ASCII')
|
||||||
|
return cls.from_bytes(bytes(unhexlify(s)))
|
||||||
|
|
||||||
|
def to_bytes(self):
|
||||||
|
raise ValueError("OnionPayload is an abstract class, use "
|
||||||
|
"LegacyOnionPayload or TlvPayload instead")
|
||||||
|
|
||||||
|
def to_hex(self):
|
||||||
|
return hexlify(self.to_bytes()).decode('ASCII')
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyOnionPayload(OnionPayload):
|
||||||
|
|
||||||
|
def __init__(self, amt_to_forward, outgoing_cltv_value,
|
||||||
|
short_channel_id=None, padding=None):
|
||||||
|
assert(padding is None or len(padding) == 12)
|
||||||
|
self.padding = b'\x00'*12 if padding is None else padding
|
||||||
|
|
||||||
|
if isinstance(amt_to_forward, str):
|
||||||
|
self.amt_to_forward = int(amt_to_forward)
|
||||||
|
else:
|
||||||
|
self.amt_to_forward = amt_to_forward
|
||||||
|
|
||||||
|
self.outgoing_cltv_value = outgoing_cltv_value
|
||||||
|
|
||||||
|
if isinstance(short_channel_id, str) and 'x' in short_channel_id:
|
||||||
|
# Convert the short_channel_id from its string representation to its numeric representation
|
||||||
|
block, tx, out = short_channel_id.split('x')
|
||||||
|
num_scid = int(block) << 40 | int(tx) << 16 | int(out)
|
||||||
|
self.short_channel_id = num_scid
|
||||||
|
elif isinstance(short_channel_id, int):
|
||||||
|
self.short_channel_id = short_channel_id
|
||||||
|
else:
|
||||||
|
raise ValueError("short_channel_id format cannot be recognized: {}".format(short_channel_id))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, b):
|
||||||
|
if isinstance(b, bytes):
|
||||||
|
b = BytesIO(b)
|
||||||
|
|
||||||
|
assert(b.read(1) == b'\x00')
|
||||||
|
|
||||||
|
s, a, o = struct.unpack("!QQL", b.read(20))
|
||||||
|
padding = b.read(12)
|
||||||
|
return LegacyOnionPayload(a, o, s, padding)
|
||||||
|
|
||||||
|
def to_bytes(self, include_realm=True):
|
||||||
|
b = b''
|
||||||
|
if include_realm:
|
||||||
|
b += b'\x00'
|
||||||
|
|
||||||
|
b += struct.pack("!Q", self.short_channel_id)
|
||||||
|
b += struct.pack("!Q", self.amt_to_forward)
|
||||||
|
b += struct.pack("!L", self.outgoing_cltv_value)
|
||||||
|
b += self.padding
|
||||||
|
assert(len(b) == 32 + include_realm)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def to_hex(self, include_realm=True):
|
||||||
|
return hexlify(self.to_bytes(include_realm)).decode('ASCII')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return ("LegacyOnionPayload[scid={self.short_channel_id}, "
|
||||||
|
"amt_to_forward={self.amt_to_forward}, "
|
||||||
|
"outgoing_cltv={self.outgoing_cltv_value}]").format(self=self)
|
||||||
|
|
||||||
|
|
||||||
|
class TlvPayload(OnionPayload):
|
||||||
|
|
||||||
|
def __init__(self, fields=None):
|
||||||
|
self.fields = [] if fields is None else fields
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, b, skip_length=False):
|
||||||
|
if isinstance(b, str):
|
||||||
|
b = b.encode('ASCII')
|
||||||
|
if isinstance(b, bytes):
|
||||||
|
b = BytesIO(b)
|
||||||
|
|
||||||
|
if skip_length:
|
||||||
|
# Consume the entire remainder of the buffer.
|
||||||
|
payload_length = len(b.getvalue()) - b.tell()
|
||||||
|
else:
|
||||||
|
payload_length = varint_decode(b)
|
||||||
|
|
||||||
|
instance = TlvPayload()
|
||||||
|
|
||||||
|
start = b.tell()
|
||||||
|
while b.tell() < start + payload_length:
|
||||||
|
typenum = varint_decode(b)
|
||||||
|
if typenum is None:
|
||||||
|
break
|
||||||
|
length = varint_decode(b)
|
||||||
|
if length is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Unable to read length at position {}".format(b.tell())
|
||||||
|
)
|
||||||
|
val = b.read(length)
|
||||||
|
|
||||||
|
# Get the subclass that is the correct interpretation of this
|
||||||
|
# field. Default to the binary field type.
|
||||||
|
c = tlv_types.get(typenum, (TlvField, "unknown"))
|
||||||
|
cls = c[0]
|
||||||
|
field = cls.from_bytes(typenum=typenum, b=val, description=c[1])
|
||||||
|
instance.fields.append(field)
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_hex(cls, h):
|
||||||
|
return cls.from_bytes(unhexlify(h))
|
||||||
|
|
||||||
|
def add_field(self, typenum, value):
|
||||||
|
self.fields.append(TlvField(typenum=typenum, value=value))
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
for f in self.fields:
|
||||||
|
if f.typenum == key:
|
||||||
|
return f
|
||||||
|
return default
|
||||||
|
|
||||||
|
def to_bytes(self):
|
||||||
|
ser = [f.to_bytes() for f in self.fields]
|
||||||
|
b = BytesIO()
|
||||||
|
varint_encode(sum([len(b) for b in ser]), b)
|
||||||
|
for f in ser:
|
||||||
|
b.write(f)
|
||||||
|
return b.getvalue()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "TlvPayload[" + ', '.join([str(f) for f in self.fields]) + "]"
|
||||||
|
|
||||||
|
|
||||||
|
class TlvField(object):
|
||||||
|
|
||||||
|
def __init__(self, typenum, value=None, description=None):
|
||||||
|
self.typenum = typenum
|
||||||
|
self.value = value
|
||||||
|
self.description = description
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, typenum, b, description=None):
|
||||||
|
return TlvField(typenum=typenum, value=b, description=description)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "TlvField[{description},{num}={hex}]".format(
|
||||||
|
description=self.description,
|
||||||
|
num=self.typenum,
|
||||||
|
hex=hexlify(self.value).decode('ASCII')
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_bytes(self):
|
||||||
|
b = BytesIO()
|
||||||
|
varint_encode(self.typenum, b)
|
||||||
|
varint_encode(len(self.value), b)
|
||||||
|
b.write(self.value)
|
||||||
|
return b.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
class Tu32Field(TlvField):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Tu64Field(TlvField):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ShortChannelIdField(TlvField):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TextField(TlvField):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, typenum, b, description=None):
|
||||||
|
val = b.decode('UTF-8')
|
||||||
|
return TextField(typenum, value=val, description=description)
|
||||||
|
|
||||||
|
def to_bytes(self):
|
||||||
|
b = BytesIO()
|
||||||
|
val = self.value.encode('UTF-8')
|
||||||
|
varint_encode(self.typenum, b)
|
||||||
|
varint_encode(len(val), b)
|
||||||
|
b.write(val)
|
||||||
|
return b.getvalue()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "TextField[{description},{num}=\"{val}\"]".format(
|
||||||
|
description=self.description,
|
||||||
|
num=self.typenum,
|
||||||
|
val=self.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HashField(TlvField):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SignatureField(TlvField):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# A mapping of known TLV types
|
||||||
|
tlv_types = {
|
||||||
|
2: (Tu64Field, 'amt_to_forward'),
|
||||||
|
4: (Tu32Field, 'outgoing_cltv_value'),
|
||||||
|
6: (ShortChannelIdField, 'short_channel_id'),
|
||||||
|
34349334: (TextField, 'noise_message_body'),
|
||||||
|
34349336: (SignatureField, 'noise_message_signature'),
|
||||||
|
}
|
||||||
74
noise/primitives.py
Normal file
74
noise/primitives.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import struct
|
||||||
|
|
||||||
|
|
||||||
|
def varint_encode(i, w):
|
||||||
|
"""Encode an integer `i` into the writer `w`
|
||||||
|
"""
|
||||||
|
if i < 0xFD:
|
||||||
|
w.write(struct.pack("!B", i))
|
||||||
|
elif i <= 0xFFFF:
|
||||||
|
w.write(struct.pack("!BH", 0xFD, i))
|
||||||
|
elif i <= 0xFFFFFFFF:
|
||||||
|
w.write(struct.pack("!BL", 0xFE, i))
|
||||||
|
else:
|
||||||
|
raise ValueError("Integers beyond 0xFFFFFFFF are not allowed in TLVs")
|
||||||
|
|
||||||
|
|
||||||
|
def varint_decode(r):
|
||||||
|
"""Decode an integer from reader `r`
|
||||||
|
"""
|
||||||
|
raw = r.read(1)
|
||||||
|
if len(raw) != 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
i, = struct.unpack("!B", raw)
|
||||||
|
if i < 0xFD:
|
||||||
|
return i
|
||||||
|
elif i == 0xFD:
|
||||||
|
return struct.unpack("!H", r.read(2))[0]
|
||||||
|
elif i == 0xFE:
|
||||||
|
return struct.unpack("!L", r.read(4))[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("Attempted to unpack")
|
||||||
|
|
||||||
|
|
||||||
|
class ShortChannelId(object):
|
||||||
|
def __init__(self, block, txnum, outnum):
|
||||||
|
self.block = block
|
||||||
|
self.txnum = txnum
|
||||||
|
self.outnum = outnum
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, b):
|
||||||
|
assert(len(b) == 8)
|
||||||
|
i, = struct.unpack("!Q", b)
|
||||||
|
return cls.from_int(i)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_int(cls, i):
|
||||||
|
block = (i >> 40) & 0xFFFFFF
|
||||||
|
txnum = (i >> 16) & 0xFFFFFF
|
||||||
|
outnum = (i >> 0) & 0xFFFF
|
||||||
|
return cls(block=block, txnum=txnum, outnum=outnum)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_str(self, s):
|
||||||
|
block, txnum, outnum = s.split('x')
|
||||||
|
return ShortChannelId(block=int(block), txnum=int(txnum),
|
||||||
|
outnum=int(outnum))
|
||||||
|
|
||||||
|
def to_int(self):
|
||||||
|
return self.block << 40 | self.txnum << 16 | self.outnum
|
||||||
|
|
||||||
|
def to_bytes(self):
|
||||||
|
return struct.pack("!Q", self.to_int())
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "{self.block}x{self.txnum}x{self.outnum}".format(self=self)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (
|
||||||
|
self.block == other.block and
|
||||||
|
self.txnum == other.txnum and
|
||||||
|
self.outnum == other.outnum
|
||||||
|
)
|
||||||
@@ -5,7 +5,7 @@ from pprint import pprint
|
|||||||
plugin = os.path.join(os.path.dirname(__file__), 'noise.py')
|
plugin = os.path.join(os.path.dirname(__file__), 'noise.py')
|
||||||
|
|
||||||
|
|
||||||
def test_sendmsg(node_factory, executor):
|
def test_sendmsg_success(node_factory, executor):
|
||||||
opts = [{'plugin': plugin}, {}, {'plugin': plugin}]
|
opts = [{'plugin': plugin}, {}, {'plugin': plugin}]
|
||||||
l1, l2, l3 = node_factory.line_graph(3, wait_for_announce=True, opts=opts)
|
l1, l2, l3 = node_factory.line_graph(3, wait_for_announce=True, opts=opts)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user