mirror of
https://github.com/aljazceru/pypkarr.git
synced 2025-12-18 14:44:21 +01:00
trying to implement a working pkarr client in python
This commit is contained in:
71
lookup.py
Normal file
71
lookup.py
Normal file
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from argparse import ArgumentParser
|
||||
from src.client import PkarrClient
|
||||
from src.public_key import PublicKey
|
||||
from src.keypair import Keypair
|
||||
from src.errors import PkarrError
|
||||
import bencodepy
|
||||
|
||||
|
||||
DEFAULT_MINIMUM_TTL = 300 # 5 minutes
|
||||
DEFAULT_MAXIMUM_TTL = 24 * 60 * 60 # 24 hours
|
||||
DEFAULT_BOOTSTRAP_NODES = [
|
||||
"router.bittorrent.com:6881",
|
||||
"router.utorrent.com:6881",
|
||||
"dht.transmissionbt.com:6881",
|
||||
"dht.libtorrent.org:25401"
|
||||
]
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
async def resolve(client: PkarrClient, public_key: PublicKey):
|
||||
start_time = time.time()
|
||||
try:
|
||||
signed_packet = await client.lookup(public_key.to_z32(), max_attempts=200, timeout=60)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
if signed_packet:
|
||||
print(f"\nResolved in {int(elapsed * 1000)} milliseconds SignedPacket ({public_key.to_z32()}):")
|
||||
print(f" last_seen: {signed_packet.elapsed()} seconds ago")
|
||||
print(f" timestamp: {signed_packet.timestamp},")
|
||||
print(f" signature: {signed_packet.signature.hex().upper()}")
|
||||
print(" records:")
|
||||
for rr in signed_packet.packet.answers:
|
||||
print(f" {rr}")
|
||||
else:
|
||||
print(f"\nFailed to resolve {public_key.to_z32()}")
|
||||
except PkarrError as e:
|
||||
print(f"Got error: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
parser = ArgumentParser(description="Resolve Pkarr records")
|
||||
parser.add_argument("public_key", help="z-base-32 encoded public key")
|
||||
parser.add_argument("--bootstrap", nargs='+', default=DEFAULT_BOOTSTRAP_NODES,
|
||||
help="Bootstrap nodes (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
public_key = PublicKey(args.public_key)
|
||||
except PkarrError as e:
|
||||
logging.error(f"Invalid public key: {e}")
|
||||
return
|
||||
|
||||
keypair = Keypair.random()
|
||||
client = PkarrClient(keypair, args.bootstrap)
|
||||
|
||||
logging.info(f"Resolving Pkarr: {args.public_key}")
|
||||
logging.info("\n=== COLD LOOKUP ===")
|
||||
await resolve(client, public_key)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logging.info("\n=== SUBSEQUENT LOOKUP ===")
|
||||
await resolve(client, public_key)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
221
src/client.py
Normal file
221
src/client.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import asyncio
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
from .signed_packet import SignedPacket
|
||||
from .keypair import Keypair
|
||||
from .public_key import PublicKey
|
||||
from .resource_record import ResourceRecord
|
||||
from .packet import Packet
|
||||
from .errors import PkarrError
|
||||
import logging
|
||||
import socket
|
||||
import struct
|
||||
import hashlib
|
||||
import bencodepy
|
||||
import json
|
||||
import time
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
class PkarrClient:
|
||||
def __init__(self, keypair: Keypair, bootstrap_nodes: List[str]):
|
||||
self.keypair = keypair
|
||||
self.bootstrap_nodes = bootstrap_nodes
|
||||
self.known_nodes = set(bootstrap_nodes)
|
||||
|
||||
async def lookup(self, public_key: str, max_attempts: int = 100, timeout: int = 30) -> Optional[SignedPacket]:
|
||||
"""Look up records from the DHT."""
|
||||
target_key = PublicKey(public_key)
|
||||
|
||||
# Check cache first
|
||||
cached_packet, expiration_time = self.cache.get(public_key, (None, 0))
|
||||
if cached_packet and time.time() < expiration_time:
|
||||
logging.debug(f"Have fresh signed_packet in cache. expires_in={int(expiration_time - time.time())}")
|
||||
return cached_packet
|
||||
|
||||
nodes_to_query = set(self.bootstrap_nodes)
|
||||
queried_nodes = set()
|
||||
|
||||
start_time = time.time()
|
||||
attempts = 0
|
||||
|
||||
while nodes_to_query and attempts < max_attempts and (time.time() - start_time) < timeout:
|
||||
node = nodes_to_query.pop()
|
||||
queried_nodes.add(node)
|
||||
attempts += 1
|
||||
|
||||
logging.info(f"Attempt {attempts}: Querying node {node}")
|
||||
|
||||
try:
|
||||
result = await self._request_packet(node, target_key)
|
||||
if isinstance(result, SignedPacket):
|
||||
logging.info(f"Found result after {attempts} attempts and {time.time() - start_time:.2f} seconds")
|
||||
# Cache the result
|
||||
self.cache[public_key] = (result, time.time() + result.ttl(DEFAULT_MINIMUM_TTL, DEFAULT_MAXIMUM_TTL))
|
||||
return result
|
||||
elif result:
|
||||
new_nodes = set(result) - queried_nodes
|
||||
nodes_to_query.update(new_nodes)
|
||||
logging.info(f"Added {len(new_nodes)} new nodes to query. Total known nodes: {len(self.known_nodes)}")
|
||||
except PkarrError as e:
|
||||
logging.error(f"Error with node {node}: {e}")
|
||||
|
||||
logging.info(f"Lookup completed after {attempts} attempts and {time.time() - start_time:.2f} seconds")
|
||||
logging.info(f"Queried {len(queried_nodes)} unique nodes")
|
||||
|
||||
if attempts >= max_attempts:
|
||||
logging.warning("Lookup terminated: maximum attempts reached")
|
||||
elif (time.time() - start_time) >= timeout:
|
||||
logging.warning("Lookup terminated: timeout reached")
|
||||
else:
|
||||
logging.warning("Lookup terminated: no more nodes to query")
|
||||
|
||||
return None
|
||||
|
||||
async def _request_packet(self, node: str, target_key: PublicKey, record_type: str) -> Optional[Union[SignedPacket, List[str]]]:
|
||||
"""Request a packet from a node."""
|
||||
logging.info(f"Requesting packet from node {node} for key {target_key.to_z32()} and record_type {record_type}")
|
||||
|
||||
try:
|
||||
# Extract IP and port from the node string
|
||||
if '@' in node:
|
||||
_, ip_port = node.split('@')
|
||||
else:
|
||||
ip_port = node
|
||||
|
||||
host, port = ip_port.split(':')
|
||||
port = int(port)
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.settimeout(5) # Set a 5-second timeout
|
||||
|
||||
# Ensure we have a 20-byte node ID
|
||||
node_id = hashlib.sha1(self.keypair.public_key.to_bytes()).digest()
|
||||
|
||||
# Ensure we have a 20-byte info_hash
|
||||
info_hash = hashlib.sha1(target_key.to_bytes()).digest()
|
||||
|
||||
# Prepare and send the DHT query
|
||||
transaction_id = random.randint(0, 65535).to_bytes(2, 'big')
|
||||
message = bencodepy.encode({
|
||||
't': transaction_id,
|
||||
'y': 'q',
|
||||
'q': 'get_peers',
|
||||
'a': {
|
||||
'id': node_id,
|
||||
'info_hash': info_hash
|
||||
}
|
||||
})
|
||||
|
||||
logging.debug(f"Sending message to {host}:{port}: {message}")
|
||||
sock.sendto(message, (host, port))
|
||||
|
||||
# Wait for the response
|
||||
data, addr = sock.recvfrom(1024)
|
||||
logging.debug(f"Received raw response from {addr}: {data}")
|
||||
|
||||
# Parse the response
|
||||
response = bencodepy.decode(data)
|
||||
human_readable = self._decode_response(response)
|
||||
logging.info(f"Decoded response from {addr}:\n{json.dumps(human_readable, indent=2)}")
|
||||
|
||||
if response.get(b'y') == b'e':
|
||||
error_code, error_message = response.get(b'e', [None, b''])[0], response.get(b'e', [None, b''])[1].decode('utf-8', errors='ignore')
|
||||
logging.error(f"Received error response: Code {error_code}, Message: {error_message}")
|
||||
return None
|
||||
|
||||
# Check if the response contains the data we need
|
||||
if b'r' in response:
|
||||
r = response[b'r']
|
||||
if b'values' in r:
|
||||
# Process peer values
|
||||
peer_values = r[b'values']
|
||||
logging.info(f"Found {len(peer_values)} peer values")
|
||||
return await self._connect_to_peers(peer_values, target_key, record_type)
|
||||
elif b'nodes' in r:
|
||||
# Process nodes for further querying
|
||||
nodes = r[b'nodes']
|
||||
decoded_nodes = self._decode_nodes(nodes)
|
||||
logging.info(f"Found {len(decoded_nodes)} nodes")
|
||||
self._update_known_nodes(decoded_nodes)
|
||||
return decoded_nodes
|
||||
|
||||
return None
|
||||
|
||||
except socket.timeout:
|
||||
logging.error(f"Timeout while connecting to {host}:{port}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error requesting packet from {host}:{port}: {e}")
|
||||
logging.exception("Exception details:")
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
return None
|
||||
|
||||
async def _connect_to_peers(self, peer_values: List[bytes], target_key: PublicKey, record_type: str) -> Optional[SignedPacket]:
|
||||
"""Connect to peers and try to retrieve the SignedPacket."""
|
||||
for peer_value in peer_values:
|
||||
try:
|
||||
ip = socket.inet_ntoa(peer_value[:4])
|
||||
port = struct.unpack("!H", peer_value[4:])[0]
|
||||
peer = f"{ip}:{port}"
|
||||
|
||||
logging.info(f"Connecting to peer {peer}")
|
||||
|
||||
# Here you would implement the logic to connect to the peer and retrieve the SignedPacket
|
||||
# For now, we'll just return a dummy SignedPacket
|
||||
return SignedPacket(target_key, b"dummy_signature", Packet())
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error connecting to peer {peer}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _decode_response(self, response: dict[bytes, any]) -> dict[str, any]:
|
||||
"""Decode the bencoded response into a human-readable format."""
|
||||
decoded = {}
|
||||
for key, value in response.items():
|
||||
str_key = key.decode('utf-8')
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
decoded[str_key] = value.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
decoded[str_key] = value.hex()
|
||||
elif isinstance(value, dict):
|
||||
decoded[str_key] = self._decode_response(value)
|
||||
elif isinstance(value, list):
|
||||
decoded[str_key] = [self._decode_response(item) if isinstance(item, dict) else item.hex() if isinstance(item, bytes) else item for item in value]
|
||||
else:
|
||||
decoded[str_key] = value
|
||||
|
||||
if 'r' in decoded and 'nodes' in decoded['r']:
|
||||
decoded['r']['decoded_nodes'] = self._decode_nodes(response[b'r'][b'nodes'])
|
||||
|
||||
return decoded
|
||||
|
||||
def _decode_nodes(self, nodes_data: bytes) -> List[str]:
|
||||
"""Decode the compact node info."""
|
||||
nodes = []
|
||||
for i in range(0, len(nodes_data), 26):
|
||||
node_id = nodes_data[i:i+20].hex()
|
||||
ip = socket.inet_ntoa(nodes_data[i+20:i+24])
|
||||
port = struct.unpack("!H", nodes_data[i+24:i+26])[0]
|
||||
nodes.append(f"{ip}:{port}")
|
||||
return nodes
|
||||
|
||||
def _update_known_nodes(self, new_nodes: List[str]) -> None:
|
||||
"""Update the list of known nodes."""
|
||||
self.known_nodes.update(new_nodes)
|
||||
logging.info(f"Updated known nodes. Total known nodes: {len(self.known_nodes)}")
|
||||
|
||||
async def _send_packet(self, node: str, signed_packet: SignedPacket) -> None:
|
||||
"""Send a signed packet to a node."""
|
||||
# Implement UDP packet sending logic here
|
||||
pass
|
||||
|
||||
async def maintain_network(self) -> None:
|
||||
"""Periodically maintain the network by pinging known nodes and discovering new ones."""
|
||||
while True:
|
||||
# Implement node discovery and maintenance logic here
|
||||
await asyncio.sleep(60) # Run maintenance every 60 seconds
|
||||
71
src/crypto.py
Normal file
71
src/crypto.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import hashlib
|
||||
import ed25519
|
||||
|
||||
class Crypto:
|
||||
@staticmethod
|
||||
def generate_keypair():
|
||||
private_key, public_key = ed25519.create_keypair()
|
||||
return private_key.to_bytes()[:32], public_key.to_bytes()
|
||||
|
||||
@staticmethod
|
||||
def derive_public_key(secret_key):
|
||||
if len(secret_key) != 32:
|
||||
raise ValueError("Secret key must be 32 bytes long")
|
||||
signing_key = ed25519.SigningKey(secret_key)
|
||||
return signing_key.get_verifying_key().to_bytes()
|
||||
|
||||
@staticmethod
|
||||
def sign(secret_key, message):
|
||||
if len(secret_key) != 32:
|
||||
raise ValueError("Secret key must be 32 bytes long")
|
||||
signing_key = ed25519.SigningKey(secret_key)
|
||||
return signing_key.sign(message)
|
||||
|
||||
@staticmethod
|
||||
def verify(public_key, message, signature):
|
||||
verifying_key = ed25519.VerifyingKey(public_key)
|
||||
try:
|
||||
verifying_key.verify(signature, message)
|
||||
return True
|
||||
except ed25519.BadSignatureError:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def hash(data):
|
||||
return hashlib.sha256(data).digest()
|
||||
|
||||
@staticmethod
|
||||
def random_bytes(length):
|
||||
return os.urandom(length)
|
||||
|
||||
@staticmethod
|
||||
def z_base_32_encode(data):
|
||||
alphabet = "ybndrfg8ejkmcpqxot1uwisza345h769"
|
||||
result = ""
|
||||
bits = 0
|
||||
value = 0
|
||||
for byte in data:
|
||||
value = (value << 8) | byte
|
||||
bits += 8
|
||||
while bits >= 5:
|
||||
bits -= 5
|
||||
result += alphabet[(value >> bits) & 31]
|
||||
if bits > 0:
|
||||
result += alphabet[(value << (5 - bits)) & 31]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def z_base_32_decode(encoded):
|
||||
alphabet = "ybndrfg8ejkmcpqxot1uwisza345h769"
|
||||
alphabet_map = {char: index for index, char in enumerate(alphabet)}
|
||||
result = bytearray()
|
||||
bits = 0
|
||||
value = 0
|
||||
for char in encoded:
|
||||
value = (value << 5) | alphabet_map[char]
|
||||
bits += 5
|
||||
if bits >= 8:
|
||||
bits -= 8
|
||||
result.append((value >> bits) & 255)
|
||||
return bytes(result)
|
||||
91
src/dns_utils.py
Normal file
91
src/dns_utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import struct
|
||||
from typing import List, Tuple
|
||||
from dns import message, name, rdata, rdatatype, rdataclass
|
||||
from .errors import DNSError
|
||||
|
||||
def create_dns_query(domain: str, record_type: str) -> bytes:
|
||||
"""Create a DNS query packet."""
|
||||
try:
|
||||
qname = name.from_text(domain)
|
||||
q = message.make_query(qname, rdatatype.from_text(record_type))
|
||||
return q.to_wire()
|
||||
except Exception as e:
|
||||
raise DNSError(f"Failed to create DNS query: {str(e)}")
|
||||
|
||||
def parse_dns_response(response: bytes) -> List[Tuple[str, str, int, str]]:
|
||||
"""Parse a DNS response and return a list of (name, type, ttl, data) tuples."""
|
||||
try:
|
||||
msg = message.from_wire(response)
|
||||
results = []
|
||||
for rrset in msg.answer:
|
||||
name = rrset.name.to_text()
|
||||
ttl = rrset.ttl
|
||||
for rr in rrset:
|
||||
rr_type = rdatatype.to_text(rr.rdtype)
|
||||
rr_data = rr.to_text()
|
||||
results.append((name, rr_type, ttl, rr_data))
|
||||
return results
|
||||
except Exception as e:
|
||||
raise DNSError(f"Failed to parse DNS response: {str(e)}")
|
||||
|
||||
def compress_domain_name(domain: str) -> bytes:
|
||||
"""Compress a domain name according to DNS name compression rules."""
|
||||
try:
|
||||
n = name.from_text(domain)
|
||||
return n.to_wire()
|
||||
except Exception as e:
|
||||
raise DNSError(f"Failed to compress domain name: {str(e)}")
|
||||
|
||||
def decompress_domain_name(compressed: bytes, offset: int = 0) -> Tuple[str, int]:
|
||||
"""Decompress a domain name from DNS wire format."""
|
||||
try:
|
||||
n, offset = name.from_wire(compressed, offset)
|
||||
return n.to_text(), offset
|
||||
except Exception as e:
|
||||
raise DNSError(f"Failed to decompress domain name: {str(e)}")
|
||||
|
||||
def encode_resource_record(name: str, rr_type: str, rr_class: str, ttl: int, rdata: str) -> bytes:
|
||||
"""Encode a resource record into DNS wire format."""
|
||||
try:
|
||||
n = name.from_text(name)
|
||||
rr_type = rdatatype.from_text(rr_type)
|
||||
rr_class = rdataclass.from_text(rr_class)
|
||||
rd = rdata.from_text(rr_type, rr_class, rdata)
|
||||
return (n.to_wire() +
|
||||
struct.pack("!HHIH", rr_type, rr_class, ttl, len(rd.to_wire())) +
|
||||
rd.to_wire())
|
||||
except Exception as e:
|
||||
raise DNSError(f"Failed to encode resource record: {str(e)}")
|
||||
|
||||
def decode_resource_record(wire: bytes, offset: int = 0) -> Tuple[str, str, str, int, str, int]:
|
||||
"""Decode a resource record from DNS wire format."""
|
||||
try:
|
||||
n, offset = name.from_wire(wire, offset)
|
||||
rr_type, rr_class, ttl, rdlen = struct.unpack_from("!HHIH", wire, offset)
|
||||
offset += 10
|
||||
rd = rdata.from_wire(rdatatype.to_text(rr_type), wire, offset, rdlen)
|
||||
offset += rdlen
|
||||
return (n.to_text(),
|
||||
rdatatype.to_text(rr_type),
|
||||
rdataclass.to_text(rr_class),
|
||||
ttl,
|
||||
rd.to_text(),
|
||||
offset)
|
||||
except Exception as e:
|
||||
raise DNSError(f"Failed to decode resource record: {str(e)}")
|
||||
|
||||
def is_valid_domain_name(domain: str) -> bool:
|
||||
"""Check if a given string is a valid domain name."""
|
||||
try:
|
||||
name.from_text(domain)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def normalize_domain_name(domain: str) -> str:
|
||||
"""Normalize a domain name (convert to lowercase and ensure it ends with a dot)."""
|
||||
try:
|
||||
n = name.from_text(domain)
|
||||
return n.to_text().lower()
|
||||
except Exception as e:
|
||||
raise DNSError(f"Failed to normalize domain name: {str(e)}")
|
||||
52
src/errors.py
Normal file
52
src/errors.py
Normal file
@@ -0,0 +1,52 @@
|
||||
class PkarrError(Exception):
|
||||
"""Base class for all pkarr-related errors."""
|
||||
pass
|
||||
|
||||
class KeypairError(PkarrError):
|
||||
"""Raised when there's an issue with keypair operations."""
|
||||
pass
|
||||
|
||||
class PublicKeyError(PkarrError):
|
||||
"""Raised when there's an issue with public key operations."""
|
||||
pass
|
||||
|
||||
class SignatureError(PkarrError):
|
||||
"""Raised when there's an issue with signature operations."""
|
||||
pass
|
||||
|
||||
class PacketError(PkarrError):
|
||||
"""Raised when there's an issue with packet operations."""
|
||||
pass
|
||||
|
||||
class DNSError(PkarrError):
|
||||
"""Raised when there's an issue with DNS operations."""
|
||||
pass
|
||||
|
||||
class DHTError(PkarrError):
|
||||
"""Raised when there's an issue with DHT operations."""
|
||||
pass
|
||||
|
||||
class InvalidSignedPacketBytesLength(PacketError):
|
||||
"""Raised when the SignedPacket bytes length is invalid."""
|
||||
def __init__(self, length: int):
|
||||
super().__init__(f"Invalid SignedPacket bytes length, expected at least 104 bytes but got: {length}")
|
||||
|
||||
class InvalidRelayPayloadSize(PacketError):
|
||||
"""Raised when the relay payload size is invalid."""
|
||||
def __init__(self, size: int):
|
||||
super().__init__(f"Invalid relay payload size, expected at least 72 bytes but got: {size}")
|
||||
|
||||
class PacketTooLarge(PacketError):
|
||||
"""Raised when the DNS packet is too large."""
|
||||
def __init__(self, size: int):
|
||||
super().__init__(f"DNS Packet is too large, expected max 1000 bytes but got: {size}")
|
||||
|
||||
class DHTIsShutdown(DHTError):
|
||||
"""Raised when the DHT is shutdown."""
|
||||
def __init__(self):
|
||||
super().__init__("DHT is shutdown")
|
||||
|
||||
class PublishInflight(DHTError):
|
||||
"""Raised when a publish query is already in flight for the same public key."""
|
||||
def __init__(self):
|
||||
super().__init__("Publish query is already in flight for the same public_key")
|
||||
44
src/keypair.py
Normal file
44
src/keypair.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from .public_key import PublicKey
|
||||
from .crypto import Crypto
|
||||
from .errors import KeypairError
|
||||
|
||||
class Keypair:
|
||||
def __init__(self, secret_key: bytes):
|
||||
if len(secret_key) != 32:
|
||||
raise KeypairError(f"Secret key must be 32 bytes long, got {len(secret_key)} bytes")
|
||||
self.secret_key = secret_key
|
||||
self.public_key = PublicKey(Crypto.derive_public_key(secret_key))
|
||||
|
||||
@classmethod
|
||||
def random(cls) -> 'Keypair':
|
||||
"""Generate a new random keypair."""
|
||||
secret_key, _ = Crypto.generate_keypair()
|
||||
return cls(secret_key)
|
||||
|
||||
@classmethod
|
||||
def from_secret_key(cls, secret_key: bytes) -> 'Keypair':
|
||||
"""Create a Keypair from a secret key."""
|
||||
return cls(secret_key)
|
||||
|
||||
def sign(self, message: bytes) -> bytes:
|
||||
"""Sign a message using this keypair."""
|
||||
return Crypto.sign(self.secret_key, message)
|
||||
|
||||
def verify(self, message: bytes, signature: bytes) -> bool:
|
||||
"""Verify a signature using this keypair's public key."""
|
||||
return self.public_key.verify(message, signature)
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
"""Return the secret key bytes."""
|
||||
return self.secret_key
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, secret_key: bytes) -> 'Keypair':
|
||||
"""Create a Keypair from bytes."""
|
||||
return cls(secret_key)
|
||||
|
||||
def __str__(self):
|
||||
return f"Keypair(public_key={self.public_key})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
95
src/packet.py
Normal file
95
src/packet.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from dns import message, name, rdata, rdatatype, rdataclass
|
||||
from .resource_record import ResourceRecord
|
||||
from .errors import PacketError
|
||||
|
||||
@dataclass
|
||||
class Packet:
|
||||
answers: List[ResourceRecord] = field(default_factory=list)
|
||||
id: int = 0
|
||||
qr: bool = True # True for response, False for query
|
||||
opcode: int = 0 # 0 for standard query
|
||||
aa: bool = True # Authoritative Answer
|
||||
tc: bool = False # TrunCation
|
||||
rd: bool = False # Recursion Desired
|
||||
ra: bool = False # Recursion Available
|
||||
z: int = 0 # Reserved for future use
|
||||
rcode: int = 0 # Response code
|
||||
|
||||
@classmethod
|
||||
def new_reply(cls, id: int):
|
||||
return cls(answers=[], id=id)
|
||||
|
||||
def add_answer(self, answer: ResourceRecord):
|
||||
self.answers.append(answer)
|
||||
|
||||
def build_bytes_vec_compressed(self) -> bytes:
|
||||
"""Build a compressed DNS wire format representation of the packet."""
|
||||
try:
|
||||
msg = message.Message(id=self.id)
|
||||
msg.flags = 0
|
||||
if self.qr:
|
||||
msg.flags |= 1 << 15
|
||||
msg.flags |= (self.opcode & 0xF) << 11
|
||||
if self.aa:
|
||||
msg.flags |= 1 << 10
|
||||
if self.tc:
|
||||
msg.flags |= 1 << 9
|
||||
if self.rd:
|
||||
msg.flags |= 1 << 8
|
||||
if self.ra:
|
||||
msg.flags |= 1 << 7
|
||||
msg.flags |= (self.z & 0x7) << 4
|
||||
msg.flags |= self.rcode & 0xF
|
||||
|
||||
for rr in self.answers:
|
||||
rr_name = name.from_text(rr.name)
|
||||
rr_ttl = rr.ttl
|
||||
rr_rdataclass = rdataclass.from_text(rr.rclass)
|
||||
rr_rdatatype = rdatatype.from_text(rr.rtype)
|
||||
rr_rdata = rdata.from_text(rr_rdataclass, rr_rdatatype, rr.rdata)
|
||||
msg.answer.append((rr_name, rr_ttl, rr_rdata))
|
||||
|
||||
return msg.to_wire()
|
||||
except Exception as e:
|
||||
raise PacketError(f"Failed to build packet: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> 'Packet':
|
||||
"""Create a Packet object from DNS wire format bytes."""
|
||||
try:
|
||||
msg = message.from_wire(data)
|
||||
packet = cls(
|
||||
id=msg.id,
|
||||
qr=bool(msg.flags & (1 << 15)),
|
||||
opcode=(msg.flags >> 11) & 0xF,
|
||||
aa=bool(msg.flags & (1 << 10)),
|
||||
tc=bool(msg.flags & (1 << 9)),
|
||||
rd=bool(msg.flags & (1 << 8)),
|
||||
ra=bool(msg.flags & (1 << 7)),
|
||||
z=(msg.flags >> 4) & 0x7,
|
||||
rcode=msg.flags & 0xF
|
||||
)
|
||||
|
||||
for rrset in msg.answer:
|
||||
for rr in rrset:
|
||||
resource_record = ResourceRecord(
|
||||
name=rrset.name.to_text(),
|
||||
rclass=rdataclass.to_text(rr.rdclass),
|
||||
ttl=rrset.ttl,
|
||||
rtype=rdatatype.to_text(rr.rdtype),
|
||||
rdata=rr.to_text()
|
||||
)
|
||||
packet.add_answer(resource_record)
|
||||
|
||||
return packet
|
||||
except Exception as e:
|
||||
raise PacketError(f"Failed to parse packet: {str(e)}")
|
||||
|
||||
def __str__(self):
|
||||
header = f"Packet ID: {self.id}, QR: {'Response' if self.qr else 'Query'}, " \
|
||||
f"Opcode: {self.opcode}, AA: {self.aa}, TC: {self.tc}, RD: {self.rd}, " \
|
||||
f"RA: {self.ra}, Z: {self.z}, RCODE: {self.rcode}"
|
||||
answers = "\n".join(f" {rr}" for rr in self.answers)
|
||||
return f"{header}\nAnswers:\n{answers}"
|
||||
39
src/public_key.py
Normal file
39
src/public_key.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from .crypto import Crypto
|
||||
from .errors import PublicKeyError
|
||||
|
||||
class PublicKey:
|
||||
def __init__(self, key):
|
||||
if isinstance(key, str):
|
||||
try:
|
||||
self.key = Crypto.z_base_32_decode(key)
|
||||
except ValueError:
|
||||
raise PublicKeyError("Invalid z-base-32 encoded public key")
|
||||
elif isinstance(key, bytes):
|
||||
if len(key) != 32:
|
||||
raise PublicKeyError("Public key must be 32 bytes long")
|
||||
self.key = key
|
||||
else:
|
||||
raise PublicKeyError("Public key must be bytes or z-base-32 encoded string")
|
||||
|
||||
def __str__(self):
|
||||
return self.to_z32()
|
||||
|
||||
def __repr__(self):
|
||||
return f"PublicKey({self.to_z32()})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, PublicKey):
|
||||
return self.key == other.key
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.key)
|
||||
|
||||
def to_z32(self):
|
||||
return Crypto.z_base_32_encode(self.key)
|
||||
|
||||
def to_bytes(self):
|
||||
return self.key
|
||||
|
||||
def verify(self, message: bytes, signature: bytes) -> bool:
|
||||
return Crypto.verify(self.key, message, signature)
|
||||
39
src/resource_record.py
Normal file
39
src/resource_record.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
import ipaddress
|
||||
|
||||
@dataclass
|
||||
class ResourceRecord:
|
||||
name: str
|
||||
rclass: str
|
||||
ttl: int
|
||||
rtype: str
|
||||
rdata: Union[str, ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||
|
||||
def __post_init__(self):
|
||||
self.name = self.name.lower()
|
||||
self.rclass = self.rclass.upper()
|
||||
self.rtype = self.rtype.upper()
|
||||
|
||||
if self.rtype == 'A':
|
||||
self.rdata = ipaddress.IPv4Address(self.rdata)
|
||||
elif self.rtype == 'AAAA':
|
||||
self.rdata = ipaddress.IPv6Address(self.rdata)
|
||||
|
||||
def to_wire_format(self) -> bytes:
|
||||
# This is a placeholder. You'll need to implement the actual DNS wire format encoding.
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_wire_format(cls, wire_data: bytes) -> 'ResourceRecord':
|
||||
# This is a placeholder. You'll need to implement the actual DNS wire format decoding.
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name} {self.ttl} {self.rclass} {self.rtype} {self.rdata}"
|
||||
|
||||
def is_expired(self, current_time: int) -> bool:
|
||||
return current_time > self.ttl
|
||||
|
||||
def remaining_ttl(self, current_time: int) -> int:
|
||||
return max(0, self.ttl - current_time)
|
||||
154
src/signed_packet.py
Normal file
154
src/signed_packet.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
import ed25519
|
||||
from dns import message, name, rdata, rdatatype, rdataclass
|
||||
|
||||
@dataclass
|
||||
class PublicKey:
|
||||
key: bytes
|
||||
|
||||
def to_z32(self) -> str:
|
||||
# Implement z-base-32 encoding here
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class ResourceRecord:
|
||||
name: str
|
||||
rclass: int
|
||||
ttl: int
|
||||
rdata: bytes
|
||||
|
||||
@dataclass
|
||||
class Packet:
|
||||
answers: List[ResourceRecord]
|
||||
|
||||
@classmethod
|
||||
def new_reply(cls, id: int):
|
||||
return cls(answers=[])
|
||||
|
||||
def build_bytes_vec_compressed(self) -> bytes:
|
||||
# Implement DNS packet compression here
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class SignedPacket:
|
||||
public_key: PublicKey
|
||||
signature: bytes
|
||||
timestamp: int
|
||||
packet: Packet
|
||||
last_seen: int
|
||||
|
||||
@classmethod
|
||||
def from_packet(cls, keypair, packet: Packet):
|
||||
timestamp = int(time.time() * 1_000_000)
|
||||
encoded_packet = packet.build_bytes_vec_compressed()
|
||||
|
||||
if len(encoded_packet) > 1000:
|
||||
raise ValueError("Packet too large")
|
||||
|
||||
signature = keypair.sign(cls.signable(timestamp, encoded_packet))
|
||||
|
||||
return cls(
|
||||
public_key=keypair.public_key(),
|
||||
signature=signature,
|
||||
timestamp=timestamp,
|
||||
packet=packet,
|
||||
last_seen=int(time.time() * 1_000_000)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
if len(data) < 104:
|
||||
raise ValueError("Invalid SignedPacket bytes length")
|
||||
if len(data) > 1104:
|
||||
raise ValueError("Packet too large")
|
||||
|
||||
public_key = PublicKey(data[:32])
|
||||
signature = data[32:96]
|
||||
timestamp = int.from_bytes(data[96:104], 'big')
|
||||
encoded_packet = data[104:]
|
||||
|
||||
# Verify signature
|
||||
if not public_key.verify(cls.signable(timestamp, encoded_packet), signature):
|
||||
raise ValueError("Invalid signature")
|
||||
|
||||
packet = Packet([]) # Parse encoded_packet into a Packet object here
|
||||
|
||||
return cls(
|
||||
public_key=public_key,
|
||||
signature=signature,
|
||||
timestamp=timestamp,
|
||||
packet=packet,
|
||||
last_seen=int(time.time() * 1_000_000)
|
||||
)
|
||||
|
||||
def as_bytes(self) -> bytes:
|
||||
return (
|
||||
self.public_key.key +
|
||||
self.signature +
|
||||
self.timestamp.to_bytes(8, 'big') +
|
||||
self.packet.build_bytes_vec_compressed()
|
||||
)
|
||||
|
||||
def to_relay_payload(self) -> bytes:
|
||||
return self.as_bytes()[32:]
|
||||
|
||||
def resource_records(self, name: str):
|
||||
origin = self.public_key.to_z32()
|
||||
normalized_name = self.normalize_name(origin, name)
|
||||
return [rr for rr in self.packet.answers if rr.name == normalized_name]
|
||||
|
||||
def fresh_resource_records(self, name: str):
|
||||
origin = self.public_key.to_z32()
|
||||
normalized_name = self.normalize_name(origin, name)
|
||||
current_time = int(time.time())
|
||||
return [
|
||||
rr for rr in self.packet.answers
|
||||
if rr.name == normalized_name and rr.ttl > (current_time - self.last_seen // 1_000_000)
|
||||
]
|
||||
|
||||
def expires_in(self, min_ttl: int, max_ttl: int) -> int:
|
||||
ttl = self.ttl(min_ttl, max_ttl)
|
||||
elapsed = self.elapsed()
|
||||
return max(0, ttl - elapsed)
|
||||
|
||||
def ttl(self, min_ttl: int, max_ttl: int) -> int:
|
||||
if not self.packet.answers:
|
||||
return min_ttl
|
||||
min_record_ttl = min(rr.ttl for rr in self.packet.answers)
|
||||
return max(min_ttl, min(max_ttl, min_record_ttl))
|
||||
|
||||
def elapsed(self) -> int:
|
||||
return (int(time.time() * 1_000_000) - self.last_seen) // 1_000_000
|
||||
|
||||
@staticmethod
|
||||
def signable(timestamp: int, v: bytes) -> bytes:
|
||||
return f"3:seqi{timestamp}e1:v{len(v)}:".encode() + v
|
||||
|
||||
@staticmethod
|
||||
def normalize_name(origin: str, name: str) -> str:
|
||||
if name.endswith('.'):
|
||||
name = name[:-1]
|
||||
|
||||
parts = name.split('.')
|
||||
last = parts[-1]
|
||||
|
||||
if last == origin:
|
||||
return name
|
||||
if last in ('@', ''):
|
||||
return origin
|
||||
return f"{name}.{origin}"
|
||||
|
||||
def __str__(self):
|
||||
records = "\n".join(
|
||||
f" {rr.name} IN {rr.ttl} {rr.rdata}"
|
||||
for rr in self.packet.answers
|
||||
)
|
||||
return f"""SignedPacket ({self.public_key.key.hex()}):
|
||||
last_seen: {self.elapsed()} seconds ago
|
||||
timestamp: {self.timestamp},
|
||||
signature: {self.signature.hex()}
|
||||
records:
|
||||
{records}
|
||||
"""
|
||||
Reference in New Issue
Block a user