From cb4ccb4a0bb251dd79e5d6b744c20fbbb78cf1ba Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Sun, 15 Sep 2024 13:50:21 +0200 Subject: [PATCH] trying to implement a working pkarr client in python --- lookup.py | 71 +++++++++++++ src/client.py | 221 +++++++++++++++++++++++++++++++++++++++++ src/crypto.py | 71 +++++++++++++ src/dns_utils.py | 91 +++++++++++++++++ src/errors.py | 52 ++++++++++ src/keypair.py | 44 ++++++++ src/packet.py | 95 ++++++++++++++++++ src/public_key.py | 39 ++++++++ src/resource_record.py | 39 ++++++++ src/signed_packet.py | 154 ++++++++++++++++++++++++++++ 10 files changed, 877 insertions(+) create mode 100644 lookup.py create mode 100644 src/client.py create mode 100644 src/crypto.py create mode 100644 src/dns_utils.py create mode 100644 src/errors.py create mode 100644 src/keypair.py create mode 100644 src/packet.py create mode 100644 src/public_key.py create mode 100644 src/resource_record.py create mode 100644 src/signed_packet.py diff --git a/lookup.py b/lookup.py new file mode 100644 index 0000000..67f15a5 --- /dev/null +++ b/lookup.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +import asyncio +import time +import logging +from argparse import ArgumentParser +from src.client import PkarrClient +from src.public_key import PublicKey +from src.keypair import Keypair +from src.errors import PkarrError +import bencodepy + + +DEFAULT_MINIMUM_TTL = 300 # 5 minutes +DEFAULT_MAXIMUM_TTL = 24 * 60 * 60 # 24 hours +DEFAULT_BOOTSTRAP_NODES = [ + "router.bittorrent.com:6881", + "router.utorrent.com:6881", + "dht.transmissionbt.com:6881", + "dht.libtorrent.org:25401" +] + +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') + +async def resolve(client: PkarrClient, public_key: PublicKey): + start_time = time.time() + try: + signed_packet = await client.lookup(public_key.to_z32(), max_attempts=200, timeout=60) + elapsed = time.time() - start_time + + if signed_packet: + print(f"\nResolved in {int(elapsed * 1000)} milliseconds SignedPacket ({public_key.to_z32()}):") + print(f" last_seen: {signed_packet.elapsed()} seconds ago") + print(f" timestamp: {signed_packet.timestamp},") + print(f" signature: {signed_packet.signature.hex().upper()}") + print(" records:") + for rr in signed_packet.packet.answers: + print(f" {rr}") + else: + print(f"\nFailed to resolve {public_key.to_z32()}") + except PkarrError as e: + print(f"Got error: {e}") + + +async def main(): + parser = ArgumentParser(description="Resolve Pkarr records") + parser.add_argument("public_key", help="z-base-32 encoded public key") + parser.add_argument("--bootstrap", nargs='+', default=DEFAULT_BOOTSTRAP_NODES, + help="Bootstrap nodes (default: %(default)s)") + args = parser.parse_args() + + try: + public_key = PublicKey(args.public_key) + except PkarrError as e: + logging.error(f"Invalid public key: {e}") + return + + keypair = Keypair.random() + client = PkarrClient(keypair, args.bootstrap) + + logging.info(f"Resolving Pkarr: {args.public_key}") + logging.info("\n=== COLD LOOKUP ===") + await resolve(client, public_key) + + await asyncio.sleep(1) + + logging.info("\n=== SUBSEQUENT LOOKUP ===") + await resolve(client, public_key) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/client.py b/src/client.py new file mode 100644 index 0000000..aabb05d --- /dev/null +++ b/src/client.py @@ -0,0 +1,221 @@ +import asyncio +import random +from typing import List, Optional, Union +from .signed_packet import SignedPacket +from .keypair import Keypair +from .public_key import PublicKey +from .resource_record import ResourceRecord +from .packet import Packet +from .errors import PkarrError +import logging +import socket +import struct +import hashlib +import bencodepy +import json +import time + +logging.basicConfig(level=logging.DEBUG) + +class PkarrClient: + def __init__(self, keypair: Keypair, bootstrap_nodes: List[str]): + self.keypair = keypair + self.bootstrap_nodes = bootstrap_nodes + self.known_nodes = set(bootstrap_nodes) + + async def lookup(self, public_key: str, max_attempts: int = 100, timeout: int = 30) -> Optional[SignedPacket]: + """Look up records from the DHT.""" + target_key = PublicKey(public_key) + + # Check cache first + cached_packet, expiration_time = self.cache.get(public_key, (None, 0)) + if cached_packet and time.time() < expiration_time: + logging.debug(f"Have fresh signed_packet in cache. expires_in={int(expiration_time - time.time())}") + return cached_packet + + nodes_to_query = set(self.bootstrap_nodes) + queried_nodes = set() + + start_time = time.time() + attempts = 0 + + while nodes_to_query and attempts < max_attempts and (time.time() - start_time) < timeout: + node = nodes_to_query.pop() + queried_nodes.add(node) + attempts += 1 + + logging.info(f"Attempt {attempts}: Querying node {node}") + + try: + result = await self._request_packet(node, target_key) + if isinstance(result, SignedPacket): + logging.info(f"Found result after {attempts} attempts and {time.time() - start_time:.2f} seconds") + # Cache the result + self.cache[public_key] = (result, time.time() + result.ttl(DEFAULT_MINIMUM_TTL, DEFAULT_MAXIMUM_TTL)) + return result + elif result: + new_nodes = set(result) - queried_nodes + nodes_to_query.update(new_nodes) + logging.info(f"Added {len(new_nodes)} new nodes to query. Total known nodes: {len(self.known_nodes)}") + except PkarrError as e: + logging.error(f"Error with node {node}: {e}") + + logging.info(f"Lookup completed after {attempts} attempts and {time.time() - start_time:.2f} seconds") + logging.info(f"Queried {len(queried_nodes)} unique nodes") + + if attempts >= max_attempts: + logging.warning("Lookup terminated: maximum attempts reached") + elif (time.time() - start_time) >= timeout: + logging.warning("Lookup terminated: timeout reached") + else: + logging.warning("Lookup terminated: no more nodes to query") + + return None + + async def _request_packet(self, node: str, target_key: PublicKey, record_type: str) -> Optional[Union[SignedPacket, List[str]]]: + """Request a packet from a node.""" + logging.info(f"Requesting packet from node {node} for key {target_key.to_z32()} and record_type {record_type}") + + try: + # Extract IP and port from the node string + if '@' in node: + _, ip_port = node.split('@') + else: + ip_port = node + + host, port = ip_port.split(':') + port = int(port) + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(5) # Set a 5-second timeout + + # Ensure we have a 20-byte node ID + node_id = hashlib.sha1(self.keypair.public_key.to_bytes()).digest() + + # Ensure we have a 20-byte info_hash + info_hash = hashlib.sha1(target_key.to_bytes()).digest() + + # Prepare and send the DHT query + transaction_id = random.randint(0, 65535).to_bytes(2, 'big') + message = bencodepy.encode({ + 't': transaction_id, + 'y': 'q', + 'q': 'get_peers', + 'a': { + 'id': node_id, + 'info_hash': info_hash + } + }) + + logging.debug(f"Sending message to {host}:{port}: {message}") + sock.sendto(message, (host, port)) + + # Wait for the response + data, addr = sock.recvfrom(1024) + logging.debug(f"Received raw response from {addr}: {data}") + + # Parse the response + response = bencodepy.decode(data) + human_readable = self._decode_response(response) + logging.info(f"Decoded response from {addr}:\n{json.dumps(human_readable, indent=2)}") + + if response.get(b'y') == b'e': + error_code, error_message = response.get(b'e', [None, b''])[0], response.get(b'e', [None, b''])[1].decode('utf-8', errors='ignore') + logging.error(f"Received error response: Code {error_code}, Message: {error_message}") + return None + + # Check if the response contains the data we need + if b'r' in response: + r = response[b'r'] + if b'values' in r: + # Process peer values + peer_values = r[b'values'] + logging.info(f"Found {len(peer_values)} peer values") + return await self._connect_to_peers(peer_values, target_key, record_type) + elif b'nodes' in r: + # Process nodes for further querying + nodes = r[b'nodes'] + decoded_nodes = self._decode_nodes(nodes) + logging.info(f"Found {len(decoded_nodes)} nodes") + self._update_known_nodes(decoded_nodes) + return decoded_nodes + + return None + + except socket.timeout: + logging.error(f"Timeout while connecting to {host}:{port}") + except Exception as e: + logging.error(f"Error requesting packet from {host}:{port}: {e}") + logging.exception("Exception details:") + finally: + sock.close() + + return None + + async def _connect_to_peers(self, peer_values: List[bytes], target_key: PublicKey, record_type: str) -> Optional[SignedPacket]: + """Connect to peers and try to retrieve the SignedPacket.""" + for peer_value in peer_values: + try: + ip = socket.inet_ntoa(peer_value[:4]) + port = struct.unpack("!H", peer_value[4:])[0] + peer = f"{ip}:{port}" + + logging.info(f"Connecting to peer {peer}") + + # Here you would implement the logic to connect to the peer and retrieve the SignedPacket + # For now, we'll just return a dummy SignedPacket + return SignedPacket(target_key, b"dummy_signature", Packet()) + + except Exception as e: + logging.error(f"Error connecting to peer {peer}: {e}") + + return None + + + def _decode_response(self, response: dict[bytes, any]) -> dict[str, any]: + """Decode the bencoded response into a human-readable format.""" + decoded = {} + for key, value in response.items(): + str_key = key.decode('utf-8') + if isinstance(value, bytes): + try: + decoded[str_key] = value.decode('utf-8') + except UnicodeDecodeError: + decoded[str_key] = value.hex() + elif isinstance(value, dict): + decoded[str_key] = self._decode_response(value) + elif isinstance(value, list): + decoded[str_key] = [self._decode_response(item) if isinstance(item, dict) else item.hex() if isinstance(item, bytes) else item for item in value] + else: + decoded[str_key] = value + + if 'r' in decoded and 'nodes' in decoded['r']: + decoded['r']['decoded_nodes'] = self._decode_nodes(response[b'r'][b'nodes']) + + return decoded + + def _decode_nodes(self, nodes_data: bytes) -> List[str]: + """Decode the compact node info.""" + nodes = [] + for i in range(0, len(nodes_data), 26): + node_id = nodes_data[i:i+20].hex() + ip = socket.inet_ntoa(nodes_data[i+20:i+24]) + port = struct.unpack("!H", nodes_data[i+24:i+26])[0] + nodes.append(f"{ip}:{port}") + return nodes + + def _update_known_nodes(self, new_nodes: List[str]) -> None: + """Update the list of known nodes.""" + self.known_nodes.update(new_nodes) + logging.info(f"Updated known nodes. Total known nodes: {len(self.known_nodes)}") + + async def _send_packet(self, node: str, signed_packet: SignedPacket) -> None: + """Send a signed packet to a node.""" + # Implement UDP packet sending logic here + pass + + async def maintain_network(self) -> None: + """Periodically maintain the network by pinging known nodes and discovering new ones.""" + while True: + # Implement node discovery and maintenance logic here + await asyncio.sleep(60) # Run maintenance every 60 seconds \ No newline at end of file diff --git a/src/crypto.py b/src/crypto.py new file mode 100644 index 0000000..d997b69 --- /dev/null +++ b/src/crypto.py @@ -0,0 +1,71 @@ +import os +import hashlib +import ed25519 + +class Crypto: + @staticmethod + def generate_keypair(): + private_key, public_key = ed25519.create_keypair() + return private_key.to_bytes()[:32], public_key.to_bytes() + + @staticmethod + def derive_public_key(secret_key): + if len(secret_key) != 32: + raise ValueError("Secret key must be 32 bytes long") + signing_key = ed25519.SigningKey(secret_key) + return signing_key.get_verifying_key().to_bytes() + + @staticmethod + def sign(secret_key, message): + if len(secret_key) != 32: + raise ValueError("Secret key must be 32 bytes long") + signing_key = ed25519.SigningKey(secret_key) + return signing_key.sign(message) + + @staticmethod + def verify(public_key, message, signature): + verifying_key = ed25519.VerifyingKey(public_key) + try: + verifying_key.verify(signature, message) + return True + except ed25519.BadSignatureError: + return False + + @staticmethod + def hash(data): + return hashlib.sha256(data).digest() + + @staticmethod + def random_bytes(length): + return os.urandom(length) + + @staticmethod + def z_base_32_encode(data): + alphabet = "ybndrfg8ejkmcpqxot1uwisza345h769" + result = "" + bits = 0 + value = 0 + for byte in data: + value = (value << 8) | byte + bits += 8 + while bits >= 5: + bits -= 5 + result += alphabet[(value >> bits) & 31] + if bits > 0: + result += alphabet[(value << (5 - bits)) & 31] + return result + + @staticmethod + def z_base_32_decode(encoded): + alphabet = "ybndrfg8ejkmcpqxot1uwisza345h769" + alphabet_map = {char: index for index, char in enumerate(alphabet)} + result = bytearray() + bits = 0 + value = 0 + for char in encoded: + value = (value << 5) | alphabet_map[char] + bits += 5 + if bits >= 8: + bits -= 8 + result.append((value >> bits) & 255) + return bytes(result) \ No newline at end of file diff --git a/src/dns_utils.py b/src/dns_utils.py new file mode 100644 index 0000000..dbf1733 --- /dev/null +++ b/src/dns_utils.py @@ -0,0 +1,91 @@ +import struct +from typing import List, Tuple +from dns import message, name, rdata, rdatatype, rdataclass +from .errors import DNSError + +def create_dns_query(domain: str, record_type: str) -> bytes: + """Create a DNS query packet.""" + try: + qname = name.from_text(domain) + q = message.make_query(qname, rdatatype.from_text(record_type)) + return q.to_wire() + except Exception as e: + raise DNSError(f"Failed to create DNS query: {str(e)}") + +def parse_dns_response(response: bytes) -> List[Tuple[str, str, int, str]]: + """Parse a DNS response and return a list of (name, type, ttl, data) tuples.""" + try: + msg = message.from_wire(response) + results = [] + for rrset in msg.answer: + name = rrset.name.to_text() + ttl = rrset.ttl + for rr in rrset: + rr_type = rdatatype.to_text(rr.rdtype) + rr_data = rr.to_text() + results.append((name, rr_type, ttl, rr_data)) + return results + except Exception as e: + raise DNSError(f"Failed to parse DNS response: {str(e)}") + +def compress_domain_name(domain: str) -> bytes: + """Compress a domain name according to DNS name compression rules.""" + try: + n = name.from_text(domain) + return n.to_wire() + except Exception as e: + raise DNSError(f"Failed to compress domain name: {str(e)}") + +def decompress_domain_name(compressed: bytes, offset: int = 0) -> Tuple[str, int]: + """Decompress a domain name from DNS wire format.""" + try: + n, offset = name.from_wire(compressed, offset) + return n.to_text(), offset + except Exception as e: + raise DNSError(f"Failed to decompress domain name: {str(e)}") + +def encode_resource_record(name: str, rr_type: str, rr_class: str, ttl: int, rdata: str) -> bytes: + """Encode a resource record into DNS wire format.""" + try: + n = name.from_text(name) + rr_type = rdatatype.from_text(rr_type) + rr_class = rdataclass.from_text(rr_class) + rd = rdata.from_text(rr_type, rr_class, rdata) + return (n.to_wire() + + struct.pack("!HHIH", rr_type, rr_class, ttl, len(rd.to_wire())) + + rd.to_wire()) + except Exception as e: + raise DNSError(f"Failed to encode resource record: {str(e)}") + +def decode_resource_record(wire: bytes, offset: int = 0) -> Tuple[str, str, str, int, str, int]: + """Decode a resource record from DNS wire format.""" + try: + n, offset = name.from_wire(wire, offset) + rr_type, rr_class, ttl, rdlen = struct.unpack_from("!HHIH", wire, offset) + offset += 10 + rd = rdata.from_wire(rdatatype.to_text(rr_type), wire, offset, rdlen) + offset += rdlen + return (n.to_text(), + rdatatype.to_text(rr_type), + rdataclass.to_text(rr_class), + ttl, + rd.to_text(), + offset) + except Exception as e: + raise DNSError(f"Failed to decode resource record: {str(e)}") + +def is_valid_domain_name(domain: str) -> bool: + """Check if a given string is a valid domain name.""" + try: + name.from_text(domain) + return True + except Exception: + return False + +def normalize_domain_name(domain: str) -> str: + """Normalize a domain name (convert to lowercase and ensure it ends with a dot).""" + try: + n = name.from_text(domain) + return n.to_text().lower() + except Exception as e: + raise DNSError(f"Failed to normalize domain name: {str(e)}") \ No newline at end of file diff --git a/src/errors.py b/src/errors.py new file mode 100644 index 0000000..2d2efc5 --- /dev/null +++ b/src/errors.py @@ -0,0 +1,52 @@ +class PkarrError(Exception): + """Base class for all pkarr-related errors.""" + pass + +class KeypairError(PkarrError): + """Raised when there's an issue with keypair operations.""" + pass + +class PublicKeyError(PkarrError): + """Raised when there's an issue with public key operations.""" + pass + +class SignatureError(PkarrError): + """Raised when there's an issue with signature operations.""" + pass + +class PacketError(PkarrError): + """Raised when there's an issue with packet operations.""" + pass + +class DNSError(PkarrError): + """Raised when there's an issue with DNS operations.""" + pass + +class DHTError(PkarrError): + """Raised when there's an issue with DHT operations.""" + pass + +class InvalidSignedPacketBytesLength(PacketError): + """Raised when the SignedPacket bytes length is invalid.""" + def __init__(self, length: int): + super().__init__(f"Invalid SignedPacket bytes length, expected at least 104 bytes but got: {length}") + +class InvalidRelayPayloadSize(PacketError): + """Raised when the relay payload size is invalid.""" + def __init__(self, size: int): + super().__init__(f"Invalid relay payload size, expected at least 72 bytes but got: {size}") + +class PacketTooLarge(PacketError): + """Raised when the DNS packet is too large.""" + def __init__(self, size: int): + super().__init__(f"DNS Packet is too large, expected max 1000 bytes but got: {size}") + +class DHTIsShutdown(DHTError): + """Raised when the DHT is shutdown.""" + def __init__(self): + super().__init__("DHT is shutdown") + +class PublishInflight(DHTError): + """Raised when a publish query is already in flight for the same public key.""" + def __init__(self): + super().__init__("Publish query is already in flight for the same public_key") \ No newline at end of file diff --git a/src/keypair.py b/src/keypair.py new file mode 100644 index 0000000..7449496 --- /dev/null +++ b/src/keypair.py @@ -0,0 +1,44 @@ +from .public_key import PublicKey +from .crypto import Crypto +from .errors import KeypairError + +class Keypair: + def __init__(self, secret_key: bytes): + if len(secret_key) != 32: + raise KeypairError(f"Secret key must be 32 bytes long, got {len(secret_key)} bytes") + self.secret_key = secret_key + self.public_key = PublicKey(Crypto.derive_public_key(secret_key)) + + @classmethod + def random(cls) -> 'Keypair': + """Generate a new random keypair.""" + secret_key, _ = Crypto.generate_keypair() + return cls(secret_key) + + @classmethod + def from_secret_key(cls, secret_key: bytes) -> 'Keypair': + """Create a Keypair from a secret key.""" + return cls(secret_key) + + def sign(self, message: bytes) -> bytes: + """Sign a message using this keypair.""" + return Crypto.sign(self.secret_key, message) + + def verify(self, message: bytes, signature: bytes) -> bool: + """Verify a signature using this keypair's public key.""" + return self.public_key.verify(message, signature) + + def to_bytes(self) -> bytes: + """Return the secret key bytes.""" + return self.secret_key + + @classmethod + def from_bytes(cls, secret_key: bytes) -> 'Keypair': + """Create a Keypair from bytes.""" + return cls(secret_key) + + def __str__(self): + return f"Keypair(public_key={self.public_key})" + + def __repr__(self): + return self.__str__() \ No newline at end of file diff --git a/src/packet.py b/src/packet.py new file mode 100644 index 0000000..3fba692 --- /dev/null +++ b/src/packet.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass, field +from typing import List, Optional +from dns import message, name, rdata, rdatatype, rdataclass +from .resource_record import ResourceRecord +from .errors import PacketError + +@dataclass +class Packet: + answers: List[ResourceRecord] = field(default_factory=list) + id: int = 0 + qr: bool = True # True for response, False for query + opcode: int = 0 # 0 for standard query + aa: bool = True # Authoritative Answer + tc: bool = False # TrunCation + rd: bool = False # Recursion Desired + ra: bool = False # Recursion Available + z: int = 0 # Reserved for future use + rcode: int = 0 # Response code + + @classmethod + def new_reply(cls, id: int): + return cls(answers=[], id=id) + + def add_answer(self, answer: ResourceRecord): + self.answers.append(answer) + + def build_bytes_vec_compressed(self) -> bytes: + """Build a compressed DNS wire format representation of the packet.""" + try: + msg = message.Message(id=self.id) + msg.flags = 0 + if self.qr: + msg.flags |= 1 << 15 + msg.flags |= (self.opcode & 0xF) << 11 + if self.aa: + msg.flags |= 1 << 10 + if self.tc: + msg.flags |= 1 << 9 + if self.rd: + msg.flags |= 1 << 8 + if self.ra: + msg.flags |= 1 << 7 + msg.flags |= (self.z & 0x7) << 4 + msg.flags |= self.rcode & 0xF + + for rr in self.answers: + rr_name = name.from_text(rr.name) + rr_ttl = rr.ttl + rr_rdataclass = rdataclass.from_text(rr.rclass) + rr_rdatatype = rdatatype.from_text(rr.rtype) + rr_rdata = rdata.from_text(rr_rdataclass, rr_rdatatype, rr.rdata) + msg.answer.append((rr_name, rr_ttl, rr_rdata)) + + return msg.to_wire() + except Exception as e: + raise PacketError(f"Failed to build packet: {str(e)}") + + @classmethod + def from_bytes(cls, data: bytes) -> 'Packet': + """Create a Packet object from DNS wire format bytes.""" + try: + msg = message.from_wire(data) + packet = cls( + id=msg.id, + qr=bool(msg.flags & (1 << 15)), + opcode=(msg.flags >> 11) & 0xF, + aa=bool(msg.flags & (1 << 10)), + tc=bool(msg.flags & (1 << 9)), + rd=bool(msg.flags & (1 << 8)), + ra=bool(msg.flags & (1 << 7)), + z=(msg.flags >> 4) & 0x7, + rcode=msg.flags & 0xF + ) + + for rrset in msg.answer: + for rr in rrset: + resource_record = ResourceRecord( + name=rrset.name.to_text(), + rclass=rdataclass.to_text(rr.rdclass), + ttl=rrset.ttl, + rtype=rdatatype.to_text(rr.rdtype), + rdata=rr.to_text() + ) + packet.add_answer(resource_record) + + return packet + except Exception as e: + raise PacketError(f"Failed to parse packet: {str(e)}") + + def __str__(self): + header = f"Packet ID: {self.id}, QR: {'Response' if self.qr else 'Query'}, " \ + f"Opcode: {self.opcode}, AA: {self.aa}, TC: {self.tc}, RD: {self.rd}, " \ + f"RA: {self.ra}, Z: {self.z}, RCODE: {self.rcode}" + answers = "\n".join(f" {rr}" for rr in self.answers) + return f"{header}\nAnswers:\n{answers}" \ No newline at end of file diff --git a/src/public_key.py b/src/public_key.py new file mode 100644 index 0000000..2354e11 --- /dev/null +++ b/src/public_key.py @@ -0,0 +1,39 @@ +from .crypto import Crypto +from .errors import PublicKeyError + +class PublicKey: + def __init__(self, key): + if isinstance(key, str): + try: + self.key = Crypto.z_base_32_decode(key) + except ValueError: + raise PublicKeyError("Invalid z-base-32 encoded public key") + elif isinstance(key, bytes): + if len(key) != 32: + raise PublicKeyError("Public key must be 32 bytes long") + self.key = key + else: + raise PublicKeyError("Public key must be bytes or z-base-32 encoded string") + + def __str__(self): + return self.to_z32() + + def __repr__(self): + return f"PublicKey({self.to_z32()})" + + def __eq__(self, other): + if isinstance(other, PublicKey): + return self.key == other.key + return False + + def __hash__(self): + return hash(self.key) + + def to_z32(self): + return Crypto.z_base_32_encode(self.key) + + def to_bytes(self): + return self.key + + def verify(self, message: bytes, signature: bytes) -> bool: + return Crypto.verify(self.key, message, signature) \ No newline at end of file diff --git a/src/resource_record.py b/src/resource_record.py new file mode 100644 index 0000000..9184b6a --- /dev/null +++ b/src/resource_record.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from typing import Union +import ipaddress + +@dataclass +class ResourceRecord: + name: str + rclass: str + ttl: int + rtype: str + rdata: Union[str, ipaddress.IPv4Address, ipaddress.IPv6Address] + + def __post_init__(self): + self.name = self.name.lower() + self.rclass = self.rclass.upper() + self.rtype = self.rtype.upper() + + if self.rtype == 'A': + self.rdata = ipaddress.IPv4Address(self.rdata) + elif self.rtype == 'AAAA': + self.rdata = ipaddress.IPv6Address(self.rdata) + + def to_wire_format(self) -> bytes: + # This is a placeholder. You'll need to implement the actual DNS wire format encoding. + pass + + @classmethod + def from_wire_format(cls, wire_data: bytes) -> 'ResourceRecord': + # This is a placeholder. You'll need to implement the actual DNS wire format decoding. + pass + + def __str__(self): + return f"{self.name} {self.ttl} {self.rclass} {self.rtype} {self.rdata}" + + def is_expired(self, current_time: int) -> bool: + return current_time > self.ttl + + def remaining_ttl(self, current_time: int) -> int: + return max(0, self.ttl - current_time) \ No newline at end of file diff --git a/src/signed_packet.py b/src/signed_packet.py new file mode 100644 index 0000000..026de18 --- /dev/null +++ b/src/signed_packet.py @@ -0,0 +1,154 @@ +import time +from dataclasses import dataclass +from typing import List, Optional +import ed25519 +from dns import message, name, rdata, rdatatype, rdataclass + +@dataclass +class PublicKey: + key: bytes + + def to_z32(self) -> str: + # Implement z-base-32 encoding here + pass + +@dataclass +class ResourceRecord: + name: str + rclass: int + ttl: int + rdata: bytes + +@dataclass +class Packet: + answers: List[ResourceRecord] + + @classmethod + def new_reply(cls, id: int): + return cls(answers=[]) + + def build_bytes_vec_compressed(self) -> bytes: + # Implement DNS packet compression here + pass + +@dataclass +class SignedPacket: + public_key: PublicKey + signature: bytes + timestamp: int + packet: Packet + last_seen: int + + @classmethod + def from_packet(cls, keypair, packet: Packet): + timestamp = int(time.time() * 1_000_000) + encoded_packet = packet.build_bytes_vec_compressed() + + if len(encoded_packet) > 1000: + raise ValueError("Packet too large") + + signature = keypair.sign(cls.signable(timestamp, encoded_packet)) + + return cls( + public_key=keypair.public_key(), + signature=signature, + timestamp=timestamp, + packet=packet, + last_seen=int(time.time() * 1_000_000) + ) + + @classmethod + def from_bytes(cls, data: bytes): + if len(data) < 104: + raise ValueError("Invalid SignedPacket bytes length") + if len(data) > 1104: + raise ValueError("Packet too large") + + public_key = PublicKey(data[:32]) + signature = data[32:96] + timestamp = int.from_bytes(data[96:104], 'big') + encoded_packet = data[104:] + + # Verify signature + if not public_key.verify(cls.signable(timestamp, encoded_packet), signature): + raise ValueError("Invalid signature") + + packet = Packet([]) # Parse encoded_packet into a Packet object here + + return cls( + public_key=public_key, + signature=signature, + timestamp=timestamp, + packet=packet, + last_seen=int(time.time() * 1_000_000) + ) + + def as_bytes(self) -> bytes: + return ( + self.public_key.key + + self.signature + + self.timestamp.to_bytes(8, 'big') + + self.packet.build_bytes_vec_compressed() + ) + + def to_relay_payload(self) -> bytes: + return self.as_bytes()[32:] + + def resource_records(self, name: str): + origin = self.public_key.to_z32() + normalized_name = self.normalize_name(origin, name) + return [rr for rr in self.packet.answers if rr.name == normalized_name] + + def fresh_resource_records(self, name: str): + origin = self.public_key.to_z32() + normalized_name = self.normalize_name(origin, name) + current_time = int(time.time()) + return [ + rr for rr in self.packet.answers + if rr.name == normalized_name and rr.ttl > (current_time - self.last_seen // 1_000_000) + ] + + def expires_in(self, min_ttl: int, max_ttl: int) -> int: + ttl = self.ttl(min_ttl, max_ttl) + elapsed = self.elapsed() + return max(0, ttl - elapsed) + + def ttl(self, min_ttl: int, max_ttl: int) -> int: + if not self.packet.answers: + return min_ttl + min_record_ttl = min(rr.ttl for rr in self.packet.answers) + return max(min_ttl, min(max_ttl, min_record_ttl)) + + def elapsed(self) -> int: + return (int(time.time() * 1_000_000) - self.last_seen) // 1_000_000 + + @staticmethod + def signable(timestamp: int, v: bytes) -> bytes: + return f"3:seqi{timestamp}e1:v{len(v)}:".encode() + v + + @staticmethod + def normalize_name(origin: str, name: str) -> str: + if name.endswith('.'): + name = name[:-1] + + parts = name.split('.') + last = parts[-1] + + if last == origin: + return name + if last in ('@', ''): + return origin + return f"{name}.{origin}" + + def __str__(self): + records = "\n".join( + f" {rr.name} IN {rr.ttl} {rr.rdata}" + for rr in self.packet.answers + ) + return f"""SignedPacket ({self.public_key.key.hex()}): + last_seen: {self.elapsed()} seconds ago + timestamp: {self.timestamp}, + signature: {self.signature.hex()} + records: +{records} +""" \ No newline at end of file