diff --git a/core/bolt11.py b/core/bolt11.py new file mode 100644 index 0000000..962581d --- /dev/null +++ b/core/bolt11.py @@ -0,0 +1,370 @@ +import hashlib +import re +import time +from binascii import unhexlify +from decimal import Decimal +from typing import List, NamedTuple, Optional + +import bitstring # type: ignore +import secp256k1 +from bech32 import CHARSET, bech32_decode, bech32_encode +from ecdsa import SECP256k1, VerifyingKey # type: ignore +from ecdsa.util import sigdecode_string # type: ignore + + +class Route(NamedTuple): + pubkey: str + short_channel_id: str + base_fee_msat: int + ppm_fee: int + cltv: int + + +class Invoice(object): + payment_hash: str + amount_msat: int = 0 + description: Optional[str] = None + description_hash: Optional[str] = None + payee: Optional[str] = None + date: int + expiry: int = 3600 + secret: Optional[str] = None + route_hints: List[Route] = [] + min_final_cltv_expiry: int = 18 + + +def decode(pr: str) -> Invoice: + """bolt11 decoder, + based on https://github.com/rustyrussell/lightning-payencode/blob/master/lnaddr.py + """ + + hrp, decoded_data = bech32_decode(pr) + if hrp is None or decoded_data is None: + raise ValueError("Bad bech32 checksum") + if not hrp.startswith("ln"): + raise ValueError("Does not start with ln") + + bitarray = _u5_to_bitarray(decoded_data) + + # final signature 65 bytes, split it off. + if len(bitarray) < 65 * 8: + raise ValueError("Too short to contain signature") + + # extract the signature + signature = bitarray[-65 * 8 :].tobytes() + + # the tagged fields as a bitstream + data = bitstring.ConstBitStream(bitarray[: -65 * 8]) + + # build the invoice object + invoice = Invoice() + + # decode the amount from the hrp + m = re.search(r"[^\d]+", hrp[2:]) + if m: + amountstr = hrp[2 + m.end() :] + if amountstr != "": + invoice.amount_msat = _unshorten_amount(amountstr) + + # pull out date + invoice.date = data.read(35).uint + + while data.pos != data.len: + tag, tagdata, data = _pull_tagged(data) + data_length = len(tagdata) / 5 + + if tag == "d": + invoice.description = _trim_to_bytes(tagdata).decode("utf-8") + elif tag == "h" and data_length == 52: + invoice.description_hash = _trim_to_bytes(tagdata).hex() + elif tag == "p" and data_length == 52: + invoice.payment_hash = _trim_to_bytes(tagdata).hex() + elif tag == "x": + invoice.expiry = tagdata.uint + elif tag == "n": + invoice.payee = _trim_to_bytes(tagdata).hex() + # this won't work in most cases, we must extract the payee + # from the signature + elif tag == "s": + invoice.secret = _trim_to_bytes(tagdata).hex() + elif tag == "r": + s = bitstring.ConstBitStream(tagdata) + while s.pos + 264 + 64 + 32 + 32 + 16 < s.len: + route = Route( + pubkey=s.read(264).tobytes().hex(), + short_channel_id=_readable_scid(s.read(64).intbe), + base_fee_msat=s.read(32).intbe, + ppm_fee=s.read(32).intbe, + cltv=s.read(16).intbe, + ) + invoice.route_hints.append(route) + + # BOLT #11: + # A reader MUST check that the `signature` is valid (see the `n` tagged + # field specified below). + # A reader MUST use the `n` field to validate the signature instead of + # performing signature recovery if a valid `n` field is provided. + message = bytearray([ord(c) for c in hrp]) + data.tobytes() + sig = signature[0:64] + if invoice.payee: + key = VerifyingKey.from_string(unhexlify(invoice.payee), curve=SECP256k1) + key.verify(sig, message, hashlib.sha256, sigdecode=sigdecode_string) + else: + keys = VerifyingKey.from_public_key_recovery( + sig, message, SECP256k1, hashlib.sha256 + ) + signaling_byte = signature[64] + key = keys[int(signaling_byte)] + invoice.payee = key.to_string("compressed").hex() + + return invoice + + +def encode(options): + """Convert options into LnAddr and pass it to the encoder""" + addr = LnAddr() + addr.currency = options["currency"] + addr.fallback = options["fallback"] if options["fallback"] else None + if options["amount"]: + addr.amount = options["amount"] + if options["timestamp"]: + addr.date = int(options["timestamp"]) + + addr.paymenthash = unhexlify(options["paymenthash"]) + + if options["description"]: + addr.tags.append(("d", options["description"])) + if options["description_hash"]: + addr.tags.append(("h", options["description_hash"])) + if options["expires"]: + addr.tags.append(("x", options["expires"])) + + if options["fallback"]: + addr.tags.append(("f", options["fallback"])) + if options["route"]: + for r in options["route"]: + splits = r.split("/") + route = [] + while len(splits) >= 5: + route.append( + ( + unhexlify(splits[0]), + unhexlify(splits[1]), + int(splits[2]), + int(splits[3]), + int(splits[4]), + ) + ) + splits = splits[5:] + assert len(splits) == 0 + addr.tags.append(("r", route)) + return lnencode(addr, options["privkey"]) + + +def lnencode(addr, privkey): + if addr.amount: + amount = Decimal(str(addr.amount)) + # We can only send down to millisatoshi. + if amount * 10**12 % 10: + raise ValueError( + "Cannot encode {}: too many decimal places".format(addr.amount) + ) + + amount = addr.currency + shorten_amount(amount) + else: + amount = addr.currency if addr.currency else "" + + hrp = "ln" + amount + "0n" + + # Start with the timestamp + data = bitstring.pack("uint:35", addr.date) + + # Payment hash + data += tagged_bytes("p", addr.paymenthash) + tags_set = set() + + for k, v in addr.tags: + + # BOLT #11: + # + # A writer MUST NOT include more than one `d`, `h`, `n` or `x` fields, + if k in ("d", "h", "n", "x"): + if k in tags_set: + raise ValueError("Duplicate '{}' tag".format(k)) + + if k == "r": + route = bitstring.BitArray() + for step in v: + pubkey, channel, feebase, feerate, cltv = step + route.append( + bitstring.BitArray(pubkey) + + bitstring.BitArray(channel) + + bitstring.pack("intbe:32", feebase) + + bitstring.pack("intbe:32", feerate) + + bitstring.pack("intbe:16", cltv) + ) + data += tagged("r", route) + elif k == "f": + data += encode_fallback(v, addr.currency) + elif k == "d": + data += tagged_bytes("d", v.encode()) + elif k == "x": + # Get minimal length by trimming leading 5 bits at a time. + expirybits = bitstring.pack("intbe:64", v)[4:64] + while expirybits.startswith("0b00000"): + expirybits = expirybits[5:] + data += tagged("x", expirybits) + elif k == "h": + data += tagged_bytes("h", v) + elif k == "n": + data += tagged_bytes("n", v) + else: + # FIXME: Support unknown tags? + raise ValueError("Unknown tag {}".format(k)) + + tags_set.add(k) + + # BOLT #11: + # + # A writer MUST include either a `d` or `h` field, and MUST NOT include + # both. + if "d" in tags_set and "h" in tags_set: + raise ValueError("Cannot include both 'd' and 'h'") + if not "d" in tags_set and not "h" in tags_set: + raise ValueError("Must include either 'd' or 'h'") + + # We actually sign the hrp, then data (padded to 8 bits with zeroes). + privkey = secp256k1.PrivateKey(bytes(unhexlify(privkey))) + sig = privkey.ecdsa_sign_recoverable( + bytearray([ord(c) for c in hrp]) + data.tobytes() + ) + # This doesn't actually serialize, but returns a pair of values :( + sig, recid = privkey.ecdsa_recoverable_serialize(sig) + data += bytes(sig) + bytes([recid]) + + return bech32_encode(hrp, bitarray_to_u5(data)) + + +class LnAddr(object): + def __init__( + self, paymenthash=None, amount=None, currency="bc", tags=None, date=None + ): + self.date = int(time.time()) if not date else int(date) + self.tags = [] if not tags else tags + self.unknown_tags = [] + self.paymenthash = paymenthash + self.signature = None + self.pubkey = None + self.currency = currency + self.amount = amount + + def __str__(self): + return "LnAddr[{}, amount={}{} tags=[{}]]".format( + hexlify(self.pubkey.serialize()).decode("utf-8"), + self.amount, + self.currency, + ", ".join([k + "=" + str(v) for k, v in self.tags]), + ) + + +def shorten_amount(amount): + """Given an amount in bitcoin, shorten it""" + # Convert to pico initially + amount = int(amount * 10**12) + units = ["p", "n", "u", "m", ""] + for unit in units: + if amount % 1000 == 0: + amount //= 1000 + else: + break + return str(amount) + unit + + +def _unshorten_amount(amount: str) -> int: + """Given a shortened amount, return millisatoshis""" + # BOLT #11: + # The following `multiplier` letters are defined: + # + # * `m` (milli): multiply by 0.001 + # * `u` (micro): multiply by 0.000001 + # * `n` (nano): multiply by 0.000000001 + # * `p` (pico): multiply by 0.000000000001 + units = {"p": 10**12, "n": 10**9, "u": 10**6, "m": 10**3} + unit = str(amount)[-1] + + # BOLT #11: + # A reader SHOULD fail if `amount` contains a non-digit, or is followed by + # anything except a `multiplier` in the table above. + if not re.fullmatch(r"\d+[pnum]?", str(amount)): + raise ValueError("Invalid amount '{}'".format(amount)) + + if unit in units: + return int(int(amount[:-1]) * 100_000_000_000 / units[unit]) + else: + return int(amount) * 100_000_000_000 + + +def _pull_tagged(stream): + tag = stream.read(5).uint + length = stream.read(5).uint * 32 + stream.read(5).uint + return (CHARSET[tag], stream.read(length * 5), stream) + + +def is_p2pkh(currency, prefix): + return prefix == base58_prefix_map[currency][0] + + +def is_p2sh(currency, prefix): + return prefix == base58_prefix_map[currency][1] + + +# Tagged field containing BitArray +def tagged(char, l): + # Tagged fields need to be zero-padded to 5 bits. + while l.len % 5 != 0: + l.append("0b0") + return ( + bitstring.pack( + "uint:5, uint:5, uint:5", + CHARSET.find(char), + (l.len / 5) / 32, + (l.len / 5) % 32, + ) + + l + ) + + +def tagged_bytes(char, l): + return tagged(char, bitstring.BitArray(l)) + + +def _trim_to_bytes(barr): + # Adds a byte if necessary. + b = barr.tobytes() + if barr.len % 8 != 0: + return b[:-1] + return b + + +def _readable_scid(short_channel_id: int) -> str: + return "{blockheight}x{transactionindex}x{outputindex}".format( + blockheight=((short_channel_id >> 40) & 0xFFFFFF), + transactionindex=((short_channel_id >> 16) & 0xFFFFFF), + outputindex=(short_channel_id & 0xFFFF), + ) + + +def _u5_to_bitarray(arr: List[int]) -> bitstring.BitArray: + ret = bitstring.BitArray() + for a in arr: + ret += bitstring.pack("uint:5", a) + return ret + + +def bitarray_to_u5(barr): + assert barr.len % 5 == 0 + ret = [] + s = bitstring.ConstBitStream(barr) + while s.pos != s.len: + ret.append(s.read(5).uint) + return ret