trying to implement a working pkarr client in python

This commit is contained in:
2024-09-15 13:50:21 +02:00
parent d5369a288e
commit cb4ccb4a0b
10 changed files with 877 additions and 0 deletions

71
lookup.py Normal file
View File

@@ -0,0 +1,71 @@
#!/usr/bin/env python3
import asyncio
import time
import logging
from argparse import ArgumentParser
from src.client import PkarrClient
from src.public_key import PublicKey
from src.keypair import Keypair
from src.errors import PkarrError
import bencodepy
DEFAULT_MINIMUM_TTL = 300 # 5 minutes
DEFAULT_MAXIMUM_TTL = 24 * 60 * 60 # 24 hours
DEFAULT_BOOTSTRAP_NODES = [
"router.bittorrent.com:6881",
"router.utorrent.com:6881",
"dht.transmissionbt.com:6881",
"dht.libtorrent.org:25401"
]
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
async def resolve(client: PkarrClient, public_key: PublicKey):
start_time = time.time()
try:
signed_packet = await client.lookup(public_key.to_z32(), max_attempts=200, timeout=60)
elapsed = time.time() - start_time
if signed_packet:
print(f"\nResolved in {int(elapsed * 1000)} milliseconds SignedPacket ({public_key.to_z32()}):")
print(f" last_seen: {signed_packet.elapsed()} seconds ago")
print(f" timestamp: {signed_packet.timestamp},")
print(f" signature: {signed_packet.signature.hex().upper()}")
print(" records:")
for rr in signed_packet.packet.answers:
print(f" {rr}")
else:
print(f"\nFailed to resolve {public_key.to_z32()}")
except PkarrError as e:
print(f"Got error: {e}")
async def main():
parser = ArgumentParser(description="Resolve Pkarr records")
parser.add_argument("public_key", help="z-base-32 encoded public key")
parser.add_argument("--bootstrap", nargs='+', default=DEFAULT_BOOTSTRAP_NODES,
help="Bootstrap nodes (default: %(default)s)")
args = parser.parse_args()
try:
public_key = PublicKey(args.public_key)
except PkarrError as e:
logging.error(f"Invalid public key: {e}")
return
keypair = Keypair.random()
client = PkarrClient(keypair, args.bootstrap)
logging.info(f"Resolving Pkarr: {args.public_key}")
logging.info("\n=== COLD LOOKUP ===")
await resolve(client, public_key)
await asyncio.sleep(1)
logging.info("\n=== SUBSEQUENT LOOKUP ===")
await resolve(client, public_key)
if __name__ == "__main__":
asyncio.run(main())

221
src/client.py Normal file
View File

@@ -0,0 +1,221 @@
import asyncio
import random
from typing import List, Optional, Union
from .signed_packet import SignedPacket
from .keypair import Keypair
from .public_key import PublicKey
from .resource_record import ResourceRecord
from .packet import Packet
from .errors import PkarrError
import logging
import socket
import struct
import hashlib
import bencodepy
import json
import time
logging.basicConfig(level=logging.DEBUG)
class PkarrClient:
def __init__(self, keypair: Keypair, bootstrap_nodes: List[str]):
self.keypair = keypair
self.bootstrap_nodes = bootstrap_nodes
self.known_nodes = set(bootstrap_nodes)
async def lookup(self, public_key: str, max_attempts: int = 100, timeout: int = 30) -> Optional[SignedPacket]:
"""Look up records from the DHT."""
target_key = PublicKey(public_key)
# Check cache first
cached_packet, expiration_time = self.cache.get(public_key, (None, 0))
if cached_packet and time.time() < expiration_time:
logging.debug(f"Have fresh signed_packet in cache. expires_in={int(expiration_time - time.time())}")
return cached_packet
nodes_to_query = set(self.bootstrap_nodes)
queried_nodes = set()
start_time = time.time()
attempts = 0
while nodes_to_query and attempts < max_attempts and (time.time() - start_time) < timeout:
node = nodes_to_query.pop()
queried_nodes.add(node)
attempts += 1
logging.info(f"Attempt {attempts}: Querying node {node}")
try:
result = await self._request_packet(node, target_key)
if isinstance(result, SignedPacket):
logging.info(f"Found result after {attempts} attempts and {time.time() - start_time:.2f} seconds")
# Cache the result
self.cache[public_key] = (result, time.time() + result.ttl(DEFAULT_MINIMUM_TTL, DEFAULT_MAXIMUM_TTL))
return result
elif result:
new_nodes = set(result) - queried_nodes
nodes_to_query.update(new_nodes)
logging.info(f"Added {len(new_nodes)} new nodes to query. Total known nodes: {len(self.known_nodes)}")
except PkarrError as e:
logging.error(f"Error with node {node}: {e}")
logging.info(f"Lookup completed after {attempts} attempts and {time.time() - start_time:.2f} seconds")
logging.info(f"Queried {len(queried_nodes)} unique nodes")
if attempts >= max_attempts:
logging.warning("Lookup terminated: maximum attempts reached")
elif (time.time() - start_time) >= timeout:
logging.warning("Lookup terminated: timeout reached")
else:
logging.warning("Lookup terminated: no more nodes to query")
return None
async def _request_packet(self, node: str, target_key: PublicKey, record_type: str) -> Optional[Union[SignedPacket, List[str]]]:
"""Request a packet from a node."""
logging.info(f"Requesting packet from node {node} for key {target_key.to_z32()} and record_type {record_type}")
try:
# Extract IP and port from the node string
if '@' in node:
_, ip_port = node.split('@')
else:
ip_port = node
host, port = ip_port.split(':')
port = int(port)
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(5) # Set a 5-second timeout
# Ensure we have a 20-byte node ID
node_id = hashlib.sha1(self.keypair.public_key.to_bytes()).digest()
# Ensure we have a 20-byte info_hash
info_hash = hashlib.sha1(target_key.to_bytes()).digest()
# Prepare and send the DHT query
transaction_id = random.randint(0, 65535).to_bytes(2, 'big')
message = bencodepy.encode({
't': transaction_id,
'y': 'q',
'q': 'get_peers',
'a': {
'id': node_id,
'info_hash': info_hash
}
})
logging.debug(f"Sending message to {host}:{port}: {message}")
sock.sendto(message, (host, port))
# Wait for the response
data, addr = sock.recvfrom(1024)
logging.debug(f"Received raw response from {addr}: {data}")
# Parse the response
response = bencodepy.decode(data)
human_readable = self._decode_response(response)
logging.info(f"Decoded response from {addr}:\n{json.dumps(human_readable, indent=2)}")
if response.get(b'y') == b'e':
error_code, error_message = response.get(b'e', [None, b''])[0], response.get(b'e', [None, b''])[1].decode('utf-8', errors='ignore')
logging.error(f"Received error response: Code {error_code}, Message: {error_message}")
return None
# Check if the response contains the data we need
if b'r' in response:
r = response[b'r']
if b'values' in r:
# Process peer values
peer_values = r[b'values']
logging.info(f"Found {len(peer_values)} peer values")
return await self._connect_to_peers(peer_values, target_key, record_type)
elif b'nodes' in r:
# Process nodes for further querying
nodes = r[b'nodes']
decoded_nodes = self._decode_nodes(nodes)
logging.info(f"Found {len(decoded_nodes)} nodes")
self._update_known_nodes(decoded_nodes)
return decoded_nodes
return None
except socket.timeout:
logging.error(f"Timeout while connecting to {host}:{port}")
except Exception as e:
logging.error(f"Error requesting packet from {host}:{port}: {e}")
logging.exception("Exception details:")
finally:
sock.close()
return None
async def _connect_to_peers(self, peer_values: List[bytes], target_key: PublicKey, record_type: str) -> Optional[SignedPacket]:
"""Connect to peers and try to retrieve the SignedPacket."""
for peer_value in peer_values:
try:
ip = socket.inet_ntoa(peer_value[:4])
port = struct.unpack("!H", peer_value[4:])[0]
peer = f"{ip}:{port}"
logging.info(f"Connecting to peer {peer}")
# Here you would implement the logic to connect to the peer and retrieve the SignedPacket
# For now, we'll just return a dummy SignedPacket
return SignedPacket(target_key, b"dummy_signature", Packet())
except Exception as e:
logging.error(f"Error connecting to peer {peer}: {e}")
return None
def _decode_response(self, response: dict[bytes, any]) -> dict[str, any]:
"""Decode the bencoded response into a human-readable format."""
decoded = {}
for key, value in response.items():
str_key = key.decode('utf-8')
if isinstance(value, bytes):
try:
decoded[str_key] = value.decode('utf-8')
except UnicodeDecodeError:
decoded[str_key] = value.hex()
elif isinstance(value, dict):
decoded[str_key] = self._decode_response(value)
elif isinstance(value, list):
decoded[str_key] = [self._decode_response(item) if isinstance(item, dict) else item.hex() if isinstance(item, bytes) else item for item in value]
else:
decoded[str_key] = value
if 'r' in decoded and 'nodes' in decoded['r']:
decoded['r']['decoded_nodes'] = self._decode_nodes(response[b'r'][b'nodes'])
return decoded
def _decode_nodes(self, nodes_data: bytes) -> List[str]:
"""Decode the compact node info."""
nodes = []
for i in range(0, len(nodes_data), 26):
node_id = nodes_data[i:i+20].hex()
ip = socket.inet_ntoa(nodes_data[i+20:i+24])
port = struct.unpack("!H", nodes_data[i+24:i+26])[0]
nodes.append(f"{ip}:{port}")
return nodes
def _update_known_nodes(self, new_nodes: List[str]) -> None:
"""Update the list of known nodes."""
self.known_nodes.update(new_nodes)
logging.info(f"Updated known nodes. Total known nodes: {len(self.known_nodes)}")
async def _send_packet(self, node: str, signed_packet: SignedPacket) -> None:
"""Send a signed packet to a node."""
# Implement UDP packet sending logic here
pass
async def maintain_network(self) -> None:
"""Periodically maintain the network by pinging known nodes and discovering new ones."""
while True:
# Implement node discovery and maintenance logic here
await asyncio.sleep(60) # Run maintenance every 60 seconds

71
src/crypto.py Normal file
View File

@@ -0,0 +1,71 @@
import os
import hashlib
import ed25519
class Crypto:
@staticmethod
def generate_keypair():
private_key, public_key = ed25519.create_keypair()
return private_key.to_bytes()[:32], public_key.to_bytes()
@staticmethod
def derive_public_key(secret_key):
if len(secret_key) != 32:
raise ValueError("Secret key must be 32 bytes long")
signing_key = ed25519.SigningKey(secret_key)
return signing_key.get_verifying_key().to_bytes()
@staticmethod
def sign(secret_key, message):
if len(secret_key) != 32:
raise ValueError("Secret key must be 32 bytes long")
signing_key = ed25519.SigningKey(secret_key)
return signing_key.sign(message)
@staticmethod
def verify(public_key, message, signature):
verifying_key = ed25519.VerifyingKey(public_key)
try:
verifying_key.verify(signature, message)
return True
except ed25519.BadSignatureError:
return False
@staticmethod
def hash(data):
return hashlib.sha256(data).digest()
@staticmethod
def random_bytes(length):
return os.urandom(length)
@staticmethod
def z_base_32_encode(data):
alphabet = "ybndrfg8ejkmcpqxot1uwisza345h769"
result = ""
bits = 0
value = 0
for byte in data:
value = (value << 8) | byte
bits += 8
while bits >= 5:
bits -= 5
result += alphabet[(value >> bits) & 31]
if bits > 0:
result += alphabet[(value << (5 - bits)) & 31]
return result
@staticmethod
def z_base_32_decode(encoded):
alphabet = "ybndrfg8ejkmcpqxot1uwisza345h769"
alphabet_map = {char: index for index, char in enumerate(alphabet)}
result = bytearray()
bits = 0
value = 0
for char in encoded:
value = (value << 5) | alphabet_map[char]
bits += 5
if bits >= 8:
bits -= 8
result.append((value >> bits) & 255)
return bytes(result)

91
src/dns_utils.py Normal file
View File

@@ -0,0 +1,91 @@
import struct
from typing import List, Tuple
from dns import message, name, rdata, rdatatype, rdataclass
from .errors import DNSError
def create_dns_query(domain: str, record_type: str) -> bytes:
"""Create a DNS query packet."""
try:
qname = name.from_text(domain)
q = message.make_query(qname, rdatatype.from_text(record_type))
return q.to_wire()
except Exception as e:
raise DNSError(f"Failed to create DNS query: {str(e)}")
def parse_dns_response(response: bytes) -> List[Tuple[str, str, int, str]]:
"""Parse a DNS response and return a list of (name, type, ttl, data) tuples."""
try:
msg = message.from_wire(response)
results = []
for rrset in msg.answer:
name = rrset.name.to_text()
ttl = rrset.ttl
for rr in rrset:
rr_type = rdatatype.to_text(rr.rdtype)
rr_data = rr.to_text()
results.append((name, rr_type, ttl, rr_data))
return results
except Exception as e:
raise DNSError(f"Failed to parse DNS response: {str(e)}")
def compress_domain_name(domain: str) -> bytes:
"""Compress a domain name according to DNS name compression rules."""
try:
n = name.from_text(domain)
return n.to_wire()
except Exception as e:
raise DNSError(f"Failed to compress domain name: {str(e)}")
def decompress_domain_name(compressed: bytes, offset: int = 0) -> Tuple[str, int]:
"""Decompress a domain name from DNS wire format."""
try:
n, offset = name.from_wire(compressed, offset)
return n.to_text(), offset
except Exception as e:
raise DNSError(f"Failed to decompress domain name: {str(e)}")
def encode_resource_record(name: str, rr_type: str, rr_class: str, ttl: int, rdata: str) -> bytes:
"""Encode a resource record into DNS wire format."""
try:
n = name.from_text(name)
rr_type = rdatatype.from_text(rr_type)
rr_class = rdataclass.from_text(rr_class)
rd = rdata.from_text(rr_type, rr_class, rdata)
return (n.to_wire() +
struct.pack("!HHIH", rr_type, rr_class, ttl, len(rd.to_wire())) +
rd.to_wire())
except Exception as e:
raise DNSError(f"Failed to encode resource record: {str(e)}")
def decode_resource_record(wire: bytes, offset: int = 0) -> Tuple[str, str, str, int, str, int]:
"""Decode a resource record from DNS wire format."""
try:
n, offset = name.from_wire(wire, offset)
rr_type, rr_class, ttl, rdlen = struct.unpack_from("!HHIH", wire, offset)
offset += 10
rd = rdata.from_wire(rdatatype.to_text(rr_type), wire, offset, rdlen)
offset += rdlen
return (n.to_text(),
rdatatype.to_text(rr_type),
rdataclass.to_text(rr_class),
ttl,
rd.to_text(),
offset)
except Exception as e:
raise DNSError(f"Failed to decode resource record: {str(e)}")
def is_valid_domain_name(domain: str) -> bool:
"""Check if a given string is a valid domain name."""
try:
name.from_text(domain)
return True
except Exception:
return False
def normalize_domain_name(domain: str) -> str:
"""Normalize a domain name (convert to lowercase and ensure it ends with a dot)."""
try:
n = name.from_text(domain)
return n.to_text().lower()
except Exception as e:
raise DNSError(f"Failed to normalize domain name: {str(e)}")

52
src/errors.py Normal file
View File

@@ -0,0 +1,52 @@
class PkarrError(Exception):
"""Base class for all pkarr-related errors."""
pass
class KeypairError(PkarrError):
"""Raised when there's an issue with keypair operations."""
pass
class PublicKeyError(PkarrError):
"""Raised when there's an issue with public key operations."""
pass
class SignatureError(PkarrError):
"""Raised when there's an issue with signature operations."""
pass
class PacketError(PkarrError):
"""Raised when there's an issue with packet operations."""
pass
class DNSError(PkarrError):
"""Raised when there's an issue with DNS operations."""
pass
class DHTError(PkarrError):
"""Raised when there's an issue with DHT operations."""
pass
class InvalidSignedPacketBytesLength(PacketError):
"""Raised when the SignedPacket bytes length is invalid."""
def __init__(self, length: int):
super().__init__(f"Invalid SignedPacket bytes length, expected at least 104 bytes but got: {length}")
class InvalidRelayPayloadSize(PacketError):
"""Raised when the relay payload size is invalid."""
def __init__(self, size: int):
super().__init__(f"Invalid relay payload size, expected at least 72 bytes but got: {size}")
class PacketTooLarge(PacketError):
"""Raised when the DNS packet is too large."""
def __init__(self, size: int):
super().__init__(f"DNS Packet is too large, expected max 1000 bytes but got: {size}")
class DHTIsShutdown(DHTError):
"""Raised when the DHT is shutdown."""
def __init__(self):
super().__init__("DHT is shutdown")
class PublishInflight(DHTError):
"""Raised when a publish query is already in flight for the same public key."""
def __init__(self):
super().__init__("Publish query is already in flight for the same public_key")

44
src/keypair.py Normal file
View File

@@ -0,0 +1,44 @@
from .public_key import PublicKey
from .crypto import Crypto
from .errors import KeypairError
class Keypair:
def __init__(self, secret_key: bytes):
if len(secret_key) != 32:
raise KeypairError(f"Secret key must be 32 bytes long, got {len(secret_key)} bytes")
self.secret_key = secret_key
self.public_key = PublicKey(Crypto.derive_public_key(secret_key))
@classmethod
def random(cls) -> 'Keypair':
"""Generate a new random keypair."""
secret_key, _ = Crypto.generate_keypair()
return cls(secret_key)
@classmethod
def from_secret_key(cls, secret_key: bytes) -> 'Keypair':
"""Create a Keypair from a secret key."""
return cls(secret_key)
def sign(self, message: bytes) -> bytes:
"""Sign a message using this keypair."""
return Crypto.sign(self.secret_key, message)
def verify(self, message: bytes, signature: bytes) -> bool:
"""Verify a signature using this keypair's public key."""
return self.public_key.verify(message, signature)
def to_bytes(self) -> bytes:
"""Return the secret key bytes."""
return self.secret_key
@classmethod
def from_bytes(cls, secret_key: bytes) -> 'Keypair':
"""Create a Keypair from bytes."""
return cls(secret_key)
def __str__(self):
return f"Keypair(public_key={self.public_key})"
def __repr__(self):
return self.__str__()

95
src/packet.py Normal file
View File

@@ -0,0 +1,95 @@
from dataclasses import dataclass, field
from typing import List, Optional
from dns import message, name, rdata, rdatatype, rdataclass
from .resource_record import ResourceRecord
from .errors import PacketError
@dataclass
class Packet:
answers: List[ResourceRecord] = field(default_factory=list)
id: int = 0
qr: bool = True # True for response, False for query
opcode: int = 0 # 0 for standard query
aa: bool = True # Authoritative Answer
tc: bool = False # TrunCation
rd: bool = False # Recursion Desired
ra: bool = False # Recursion Available
z: int = 0 # Reserved for future use
rcode: int = 0 # Response code
@classmethod
def new_reply(cls, id: int):
return cls(answers=[], id=id)
def add_answer(self, answer: ResourceRecord):
self.answers.append(answer)
def build_bytes_vec_compressed(self) -> bytes:
"""Build a compressed DNS wire format representation of the packet."""
try:
msg = message.Message(id=self.id)
msg.flags = 0
if self.qr:
msg.flags |= 1 << 15
msg.flags |= (self.opcode & 0xF) << 11
if self.aa:
msg.flags |= 1 << 10
if self.tc:
msg.flags |= 1 << 9
if self.rd:
msg.flags |= 1 << 8
if self.ra:
msg.flags |= 1 << 7
msg.flags |= (self.z & 0x7) << 4
msg.flags |= self.rcode & 0xF
for rr in self.answers:
rr_name = name.from_text(rr.name)
rr_ttl = rr.ttl
rr_rdataclass = rdataclass.from_text(rr.rclass)
rr_rdatatype = rdatatype.from_text(rr.rtype)
rr_rdata = rdata.from_text(rr_rdataclass, rr_rdatatype, rr.rdata)
msg.answer.append((rr_name, rr_ttl, rr_rdata))
return msg.to_wire()
except Exception as e:
raise PacketError(f"Failed to build packet: {str(e)}")
@classmethod
def from_bytes(cls, data: bytes) -> 'Packet':
"""Create a Packet object from DNS wire format bytes."""
try:
msg = message.from_wire(data)
packet = cls(
id=msg.id,
qr=bool(msg.flags & (1 << 15)),
opcode=(msg.flags >> 11) & 0xF,
aa=bool(msg.flags & (1 << 10)),
tc=bool(msg.flags & (1 << 9)),
rd=bool(msg.flags & (1 << 8)),
ra=bool(msg.flags & (1 << 7)),
z=(msg.flags >> 4) & 0x7,
rcode=msg.flags & 0xF
)
for rrset in msg.answer:
for rr in rrset:
resource_record = ResourceRecord(
name=rrset.name.to_text(),
rclass=rdataclass.to_text(rr.rdclass),
ttl=rrset.ttl,
rtype=rdatatype.to_text(rr.rdtype),
rdata=rr.to_text()
)
packet.add_answer(resource_record)
return packet
except Exception as e:
raise PacketError(f"Failed to parse packet: {str(e)}")
def __str__(self):
header = f"Packet ID: {self.id}, QR: {'Response' if self.qr else 'Query'}, " \
f"Opcode: {self.opcode}, AA: {self.aa}, TC: {self.tc}, RD: {self.rd}, " \
f"RA: {self.ra}, Z: {self.z}, RCODE: {self.rcode}"
answers = "\n".join(f" {rr}" for rr in self.answers)
return f"{header}\nAnswers:\n{answers}"

39
src/public_key.py Normal file
View File

@@ -0,0 +1,39 @@
from .crypto import Crypto
from .errors import PublicKeyError
class PublicKey:
def __init__(self, key):
if isinstance(key, str):
try:
self.key = Crypto.z_base_32_decode(key)
except ValueError:
raise PublicKeyError("Invalid z-base-32 encoded public key")
elif isinstance(key, bytes):
if len(key) != 32:
raise PublicKeyError("Public key must be 32 bytes long")
self.key = key
else:
raise PublicKeyError("Public key must be bytes or z-base-32 encoded string")
def __str__(self):
return self.to_z32()
def __repr__(self):
return f"PublicKey({self.to_z32()})"
def __eq__(self, other):
if isinstance(other, PublicKey):
return self.key == other.key
return False
def __hash__(self):
return hash(self.key)
def to_z32(self):
return Crypto.z_base_32_encode(self.key)
def to_bytes(self):
return self.key
def verify(self, message: bytes, signature: bytes) -> bool:
return Crypto.verify(self.key, message, signature)

39
src/resource_record.py Normal file
View File

@@ -0,0 +1,39 @@
from dataclasses import dataclass
from typing import Union
import ipaddress
@dataclass
class ResourceRecord:
name: str
rclass: str
ttl: int
rtype: str
rdata: Union[str, ipaddress.IPv4Address, ipaddress.IPv6Address]
def __post_init__(self):
self.name = self.name.lower()
self.rclass = self.rclass.upper()
self.rtype = self.rtype.upper()
if self.rtype == 'A':
self.rdata = ipaddress.IPv4Address(self.rdata)
elif self.rtype == 'AAAA':
self.rdata = ipaddress.IPv6Address(self.rdata)
def to_wire_format(self) -> bytes:
# This is a placeholder. You'll need to implement the actual DNS wire format encoding.
pass
@classmethod
def from_wire_format(cls, wire_data: bytes) -> 'ResourceRecord':
# This is a placeholder. You'll need to implement the actual DNS wire format decoding.
pass
def __str__(self):
return f"{self.name} {self.ttl} {self.rclass} {self.rtype} {self.rdata}"
def is_expired(self, current_time: int) -> bool:
return current_time > self.ttl
def remaining_ttl(self, current_time: int) -> int:
return max(0, self.ttl - current_time)

154
src/signed_packet.py Normal file
View File

@@ -0,0 +1,154 @@
import time
from dataclasses import dataclass
from typing import List, Optional
import ed25519
from dns import message, name, rdata, rdatatype, rdataclass
@dataclass
class PublicKey:
key: bytes
def to_z32(self) -> str:
# Implement z-base-32 encoding here
pass
@dataclass
class ResourceRecord:
name: str
rclass: int
ttl: int
rdata: bytes
@dataclass
class Packet:
answers: List[ResourceRecord]
@classmethod
def new_reply(cls, id: int):
return cls(answers=[])
def build_bytes_vec_compressed(self) -> bytes:
# Implement DNS packet compression here
pass
@dataclass
class SignedPacket:
public_key: PublicKey
signature: bytes
timestamp: int
packet: Packet
last_seen: int
@classmethod
def from_packet(cls, keypair, packet: Packet):
timestamp = int(time.time() * 1_000_000)
encoded_packet = packet.build_bytes_vec_compressed()
if len(encoded_packet) > 1000:
raise ValueError("Packet too large")
signature = keypair.sign(cls.signable(timestamp, encoded_packet))
return cls(
public_key=keypair.public_key(),
signature=signature,
timestamp=timestamp,
packet=packet,
last_seen=int(time.time() * 1_000_000)
)
@classmethod
def from_bytes(cls, data: bytes):
if len(data) < 104:
raise ValueError("Invalid SignedPacket bytes length")
if len(data) > 1104:
raise ValueError("Packet too large")
public_key = PublicKey(data[:32])
signature = data[32:96]
timestamp = int.from_bytes(data[96:104], 'big')
encoded_packet = data[104:]
# Verify signature
if not public_key.verify(cls.signable(timestamp, encoded_packet), signature):
raise ValueError("Invalid signature")
packet = Packet([]) # Parse encoded_packet into a Packet object here
return cls(
public_key=public_key,
signature=signature,
timestamp=timestamp,
packet=packet,
last_seen=int(time.time() * 1_000_000)
)
def as_bytes(self) -> bytes:
return (
self.public_key.key +
self.signature +
self.timestamp.to_bytes(8, 'big') +
self.packet.build_bytes_vec_compressed()
)
def to_relay_payload(self) -> bytes:
return self.as_bytes()[32:]
def resource_records(self, name: str):
origin = self.public_key.to_z32()
normalized_name = self.normalize_name(origin, name)
return [rr for rr in self.packet.answers if rr.name == normalized_name]
def fresh_resource_records(self, name: str):
origin = self.public_key.to_z32()
normalized_name = self.normalize_name(origin, name)
current_time = int(time.time())
return [
rr for rr in self.packet.answers
if rr.name == normalized_name and rr.ttl > (current_time - self.last_seen // 1_000_000)
]
def expires_in(self, min_ttl: int, max_ttl: int) -> int:
ttl = self.ttl(min_ttl, max_ttl)
elapsed = self.elapsed()
return max(0, ttl - elapsed)
def ttl(self, min_ttl: int, max_ttl: int) -> int:
if not self.packet.answers:
return min_ttl
min_record_ttl = min(rr.ttl for rr in self.packet.answers)
return max(min_ttl, min(max_ttl, min_record_ttl))
def elapsed(self) -> int:
return (int(time.time() * 1_000_000) - self.last_seen) // 1_000_000
@staticmethod
def signable(timestamp: int, v: bytes) -> bytes:
return f"3:seqi{timestamp}e1:v{len(v)}:".encode() + v
@staticmethod
def normalize_name(origin: str, name: str) -> str:
if name.endswith('.'):
name = name[:-1]
parts = name.split('.')
last = parts[-1]
if last == origin:
return name
if last in ('@', ''):
return origin
return f"{name}.{origin}"
def __str__(self):
records = "\n".join(
f" {rr.name} IN {rr.ttl} {rr.rdata}"
for rr in self.packet.answers
)
return f"""SignedPacket ({self.public_key.key.hex()}):
last_seen: {self.elapsed()} seconds ago
timestamp: {self.timestamp},
signature: {self.signature.hex()}
records:
{records}
"""