mirror of
https://github.com/aljazceru/pypkarr.git
synced 2025-12-19 15:14:26 +01:00
trying to implement a working pkarr client in python
This commit is contained in:
71
lookup.py
Normal file
71
lookup.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from src.client import PkarrClient
|
||||||
|
from src.public_key import PublicKey
|
||||||
|
from src.keypair import Keypair
|
||||||
|
from src.errors import PkarrError
|
||||||
|
import bencodepy
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MINIMUM_TTL = 300 # 5 minutes
|
||||||
|
DEFAULT_MAXIMUM_TTL = 24 * 60 * 60 # 24 hours
|
||||||
|
DEFAULT_BOOTSTRAP_NODES = [
|
||||||
|
"router.bittorrent.com:6881",
|
||||||
|
"router.utorrent.com:6881",
|
||||||
|
"dht.transmissionbt.com:6881",
|
||||||
|
"dht.libtorrent.org:25401"
|
||||||
|
]
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
async def resolve(client: PkarrClient, public_key: PublicKey):
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
signed_packet = await client.lookup(public_key.to_z32(), max_attempts=200, timeout=60)
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
if signed_packet:
|
||||||
|
print(f"\nResolved in {int(elapsed * 1000)} milliseconds SignedPacket ({public_key.to_z32()}):")
|
||||||
|
print(f" last_seen: {signed_packet.elapsed()} seconds ago")
|
||||||
|
print(f" timestamp: {signed_packet.timestamp},")
|
||||||
|
print(f" signature: {signed_packet.signature.hex().upper()}")
|
||||||
|
print(" records:")
|
||||||
|
for rr in signed_packet.packet.answers:
|
||||||
|
print(f" {rr}")
|
||||||
|
else:
|
||||||
|
print(f"\nFailed to resolve {public_key.to_z32()}")
|
||||||
|
except PkarrError as e:
|
||||||
|
print(f"Got error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
parser = ArgumentParser(description="Resolve Pkarr records")
|
||||||
|
parser.add_argument("public_key", help="z-base-32 encoded public key")
|
||||||
|
parser.add_argument("--bootstrap", nargs='+', default=DEFAULT_BOOTSTRAP_NODES,
|
||||||
|
help="Bootstrap nodes (default: %(default)s)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
public_key = PublicKey(args.public_key)
|
||||||
|
except PkarrError as e:
|
||||||
|
logging.error(f"Invalid public key: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
keypair = Keypair.random()
|
||||||
|
client = PkarrClient(keypair, args.bootstrap)
|
||||||
|
|
||||||
|
logging.info(f"Resolving Pkarr: {args.public_key}")
|
||||||
|
logging.info("\n=== COLD LOOKUP ===")
|
||||||
|
await resolve(client, public_key)
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
logging.info("\n=== SUBSEQUENT LOOKUP ===")
|
||||||
|
await resolve(client, public_key)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
221
src/client.py
Normal file
221
src/client.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
from .signed_packet import SignedPacket
|
||||||
|
from .keypair import Keypair
|
||||||
|
from .public_key import PublicKey
|
||||||
|
from .resource_record import ResourceRecord
|
||||||
|
from .packet import Packet
|
||||||
|
from .errors import PkarrError
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
import hashlib
|
||||||
|
import bencodepy
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
class PkarrClient:
|
||||||
|
def __init__(self, keypair: Keypair, bootstrap_nodes: List[str]):
|
||||||
|
self.keypair = keypair
|
||||||
|
self.bootstrap_nodes = bootstrap_nodes
|
||||||
|
self.known_nodes = set(bootstrap_nodes)
|
||||||
|
|
||||||
|
async def lookup(self, public_key: str, max_attempts: int = 100, timeout: int = 30) -> Optional[SignedPacket]:
|
||||||
|
"""Look up records from the DHT."""
|
||||||
|
target_key = PublicKey(public_key)
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
cached_packet, expiration_time = self.cache.get(public_key, (None, 0))
|
||||||
|
if cached_packet and time.time() < expiration_time:
|
||||||
|
logging.debug(f"Have fresh signed_packet in cache. expires_in={int(expiration_time - time.time())}")
|
||||||
|
return cached_packet
|
||||||
|
|
||||||
|
nodes_to_query = set(self.bootstrap_nodes)
|
||||||
|
queried_nodes = set()
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
attempts = 0
|
||||||
|
|
||||||
|
while nodes_to_query and attempts < max_attempts and (time.time() - start_time) < timeout:
|
||||||
|
node = nodes_to_query.pop()
|
||||||
|
queried_nodes.add(node)
|
||||||
|
attempts += 1
|
||||||
|
|
||||||
|
logging.info(f"Attempt {attempts}: Querying node {node}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._request_packet(node, target_key)
|
||||||
|
if isinstance(result, SignedPacket):
|
||||||
|
logging.info(f"Found result after {attempts} attempts and {time.time() - start_time:.2f} seconds")
|
||||||
|
# Cache the result
|
||||||
|
self.cache[public_key] = (result, time.time() + result.ttl(DEFAULT_MINIMUM_TTL, DEFAULT_MAXIMUM_TTL))
|
||||||
|
return result
|
||||||
|
elif result:
|
||||||
|
new_nodes = set(result) - queried_nodes
|
||||||
|
nodes_to_query.update(new_nodes)
|
||||||
|
logging.info(f"Added {len(new_nodes)} new nodes to query. Total known nodes: {len(self.known_nodes)}")
|
||||||
|
except PkarrError as e:
|
||||||
|
logging.error(f"Error with node {node}: {e}")
|
||||||
|
|
||||||
|
logging.info(f"Lookup completed after {attempts} attempts and {time.time() - start_time:.2f} seconds")
|
||||||
|
logging.info(f"Queried {len(queried_nodes)} unique nodes")
|
||||||
|
|
||||||
|
if attempts >= max_attempts:
|
||||||
|
logging.warning("Lookup terminated: maximum attempts reached")
|
||||||
|
elif (time.time() - start_time) >= timeout:
|
||||||
|
logging.warning("Lookup terminated: timeout reached")
|
||||||
|
else:
|
||||||
|
logging.warning("Lookup terminated: no more nodes to query")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _request_packet(self, node: str, target_key: PublicKey, record_type: str) -> Optional[Union[SignedPacket, List[str]]]:
|
||||||
|
"""Request a packet from a node."""
|
||||||
|
logging.info(f"Requesting packet from node {node} for key {target_key.to_z32()} and record_type {record_type}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract IP and port from the node string
|
||||||
|
if '@' in node:
|
||||||
|
_, ip_port = node.split('@')
|
||||||
|
else:
|
||||||
|
ip_port = node
|
||||||
|
|
||||||
|
host, port = ip_port.split(':')
|
||||||
|
port = int(port)
|
||||||
|
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
sock.settimeout(5) # Set a 5-second timeout
|
||||||
|
|
||||||
|
# Ensure we have a 20-byte node ID
|
||||||
|
node_id = hashlib.sha1(self.keypair.public_key.to_bytes()).digest()
|
||||||
|
|
||||||
|
# Ensure we have a 20-byte info_hash
|
||||||
|
info_hash = hashlib.sha1(target_key.to_bytes()).digest()
|
||||||
|
|
||||||
|
# Prepare and send the DHT query
|
||||||
|
transaction_id = random.randint(0, 65535).to_bytes(2, 'big')
|
||||||
|
message = bencodepy.encode({
|
||||||
|
't': transaction_id,
|
||||||
|
'y': 'q',
|
||||||
|
'q': 'get_peers',
|
||||||
|
'a': {
|
||||||
|
'id': node_id,
|
||||||
|
'info_hash': info_hash
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
logging.debug(f"Sending message to {host}:{port}: {message}")
|
||||||
|
sock.sendto(message, (host, port))
|
||||||
|
|
||||||
|
# Wait for the response
|
||||||
|
data, addr = sock.recvfrom(1024)
|
||||||
|
logging.debug(f"Received raw response from {addr}: {data}")
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
response = bencodepy.decode(data)
|
||||||
|
human_readable = self._decode_response(response)
|
||||||
|
logging.info(f"Decoded response from {addr}:\n{json.dumps(human_readable, indent=2)}")
|
||||||
|
|
||||||
|
if response.get(b'y') == b'e':
|
||||||
|
error_code, error_message = response.get(b'e', [None, b''])[0], response.get(b'e', [None, b''])[1].decode('utf-8', errors='ignore')
|
||||||
|
logging.error(f"Received error response: Code {error_code}, Message: {error_message}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if the response contains the data we need
|
||||||
|
if b'r' in response:
|
||||||
|
r = response[b'r']
|
||||||
|
if b'values' in r:
|
||||||
|
# Process peer values
|
||||||
|
peer_values = r[b'values']
|
||||||
|
logging.info(f"Found {len(peer_values)} peer values")
|
||||||
|
return await self._connect_to_peers(peer_values, target_key, record_type)
|
||||||
|
elif b'nodes' in r:
|
||||||
|
# Process nodes for further querying
|
||||||
|
nodes = r[b'nodes']
|
||||||
|
decoded_nodes = self._decode_nodes(nodes)
|
||||||
|
logging.info(f"Found {len(decoded_nodes)} nodes")
|
||||||
|
self._update_known_nodes(decoded_nodes)
|
||||||
|
return decoded_nodes
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except socket.timeout:
|
||||||
|
logging.error(f"Timeout while connecting to {host}:{port}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error requesting packet from {host}:{port}: {e}")
|
||||||
|
logging.exception("Exception details:")
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _connect_to_peers(self, peer_values: List[bytes], target_key: PublicKey, record_type: str) -> Optional[SignedPacket]:
|
||||||
|
"""Connect to peers and try to retrieve the SignedPacket."""
|
||||||
|
for peer_value in peer_values:
|
||||||
|
try:
|
||||||
|
ip = socket.inet_ntoa(peer_value[:4])
|
||||||
|
port = struct.unpack("!H", peer_value[4:])[0]
|
||||||
|
peer = f"{ip}:{port}"
|
||||||
|
|
||||||
|
logging.info(f"Connecting to peer {peer}")
|
||||||
|
|
||||||
|
# Here you would implement the logic to connect to the peer and retrieve the SignedPacket
|
||||||
|
# For now, we'll just return a dummy SignedPacket
|
||||||
|
return SignedPacket(target_key, b"dummy_signature", Packet())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error connecting to peer {peer}: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_response(self, response: dict[bytes, any]) -> dict[str, any]:
|
||||||
|
"""Decode the bencoded response into a human-readable format."""
|
||||||
|
decoded = {}
|
||||||
|
for key, value in response.items():
|
||||||
|
str_key = key.decode('utf-8')
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
try:
|
||||||
|
decoded[str_key] = value.decode('utf-8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
decoded[str_key] = value.hex()
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
decoded[str_key] = self._decode_response(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
decoded[str_key] = [self._decode_response(item) if isinstance(item, dict) else item.hex() if isinstance(item, bytes) else item for item in value]
|
||||||
|
else:
|
||||||
|
decoded[str_key] = value
|
||||||
|
|
||||||
|
if 'r' in decoded and 'nodes' in decoded['r']:
|
||||||
|
decoded['r']['decoded_nodes'] = self._decode_nodes(response[b'r'][b'nodes'])
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
def _decode_nodes(self, nodes_data: bytes) -> List[str]:
|
||||||
|
"""Decode the compact node info."""
|
||||||
|
nodes = []
|
||||||
|
for i in range(0, len(nodes_data), 26):
|
||||||
|
node_id = nodes_data[i:i+20].hex()
|
||||||
|
ip = socket.inet_ntoa(nodes_data[i+20:i+24])
|
||||||
|
port = struct.unpack("!H", nodes_data[i+24:i+26])[0]
|
||||||
|
nodes.append(f"{ip}:{port}")
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
def _update_known_nodes(self, new_nodes: List[str]) -> None:
|
||||||
|
"""Update the list of known nodes."""
|
||||||
|
self.known_nodes.update(new_nodes)
|
||||||
|
logging.info(f"Updated known nodes. Total known nodes: {len(self.known_nodes)}")
|
||||||
|
|
||||||
|
async def _send_packet(self, node: str, signed_packet: SignedPacket) -> None:
|
||||||
|
"""Send a signed packet to a node."""
|
||||||
|
# Implement UDP packet sending logic here
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def maintain_network(self) -> None:
|
||||||
|
"""Periodically maintain the network by pinging known nodes and discovering new ones."""
|
||||||
|
while True:
|
||||||
|
# Implement node discovery and maintenance logic here
|
||||||
|
await asyncio.sleep(60) # Run maintenance every 60 seconds
|
||||||
71
src/crypto.py
Normal file
71
src/crypto.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
import ed25519
|
||||||
|
|
||||||
|
class Crypto:
|
||||||
|
@staticmethod
|
||||||
|
def generate_keypair():
|
||||||
|
private_key, public_key = ed25519.create_keypair()
|
||||||
|
return private_key.to_bytes()[:32], public_key.to_bytes()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def derive_public_key(secret_key):
|
||||||
|
if len(secret_key) != 32:
|
||||||
|
raise ValueError("Secret key must be 32 bytes long")
|
||||||
|
signing_key = ed25519.SigningKey(secret_key)
|
||||||
|
return signing_key.get_verifying_key().to_bytes()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sign(secret_key, message):
|
||||||
|
if len(secret_key) != 32:
|
||||||
|
raise ValueError("Secret key must be 32 bytes long")
|
||||||
|
signing_key = ed25519.SigningKey(secret_key)
|
||||||
|
return signing_key.sign(message)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify(public_key, message, signature):
|
||||||
|
verifying_key = ed25519.VerifyingKey(public_key)
|
||||||
|
try:
|
||||||
|
verifying_key.verify(signature, message)
|
||||||
|
return True
|
||||||
|
except ed25519.BadSignatureError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hash(data):
|
||||||
|
return hashlib.sha256(data).digest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def random_bytes(length):
|
||||||
|
return os.urandom(length)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def z_base_32_encode(data):
|
||||||
|
alphabet = "ybndrfg8ejkmcpqxot1uwisza345h769"
|
||||||
|
result = ""
|
||||||
|
bits = 0
|
||||||
|
value = 0
|
||||||
|
for byte in data:
|
||||||
|
value = (value << 8) | byte
|
||||||
|
bits += 8
|
||||||
|
while bits >= 5:
|
||||||
|
bits -= 5
|
||||||
|
result += alphabet[(value >> bits) & 31]
|
||||||
|
if bits > 0:
|
||||||
|
result += alphabet[(value << (5 - bits)) & 31]
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def z_base_32_decode(encoded):
|
||||||
|
alphabet = "ybndrfg8ejkmcpqxot1uwisza345h769"
|
||||||
|
alphabet_map = {char: index for index, char in enumerate(alphabet)}
|
||||||
|
result = bytearray()
|
||||||
|
bits = 0
|
||||||
|
value = 0
|
||||||
|
for char in encoded:
|
||||||
|
value = (value << 5) | alphabet_map[char]
|
||||||
|
bits += 5
|
||||||
|
if bits >= 8:
|
||||||
|
bits -= 8
|
||||||
|
result.append((value >> bits) & 255)
|
||||||
|
return bytes(result)
|
||||||
91
src/dns_utils.py
Normal file
91
src/dns_utils.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import struct
|
||||||
|
from typing import List, Tuple
|
||||||
|
from dns import message, name, rdata, rdatatype, rdataclass
|
||||||
|
from .errors import DNSError
|
||||||
|
|
||||||
|
def create_dns_query(domain: str, record_type: str) -> bytes:
|
||||||
|
"""Create a DNS query packet."""
|
||||||
|
try:
|
||||||
|
qname = name.from_text(domain)
|
||||||
|
q = message.make_query(qname, rdatatype.from_text(record_type))
|
||||||
|
return q.to_wire()
|
||||||
|
except Exception as e:
|
||||||
|
raise DNSError(f"Failed to create DNS query: {str(e)}")
|
||||||
|
|
||||||
|
def parse_dns_response(response: bytes) -> List[Tuple[str, str, int, str]]:
|
||||||
|
"""Parse a DNS response and return a list of (name, type, ttl, data) tuples."""
|
||||||
|
try:
|
||||||
|
msg = message.from_wire(response)
|
||||||
|
results = []
|
||||||
|
for rrset in msg.answer:
|
||||||
|
name = rrset.name.to_text()
|
||||||
|
ttl = rrset.ttl
|
||||||
|
for rr in rrset:
|
||||||
|
rr_type = rdatatype.to_text(rr.rdtype)
|
||||||
|
rr_data = rr.to_text()
|
||||||
|
results.append((name, rr_type, ttl, rr_data))
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
raise DNSError(f"Failed to parse DNS response: {str(e)}")
|
||||||
|
|
||||||
|
def compress_domain_name(domain: str) -> bytes:
|
||||||
|
"""Compress a domain name according to DNS name compression rules."""
|
||||||
|
try:
|
||||||
|
n = name.from_text(domain)
|
||||||
|
return n.to_wire()
|
||||||
|
except Exception as e:
|
||||||
|
raise DNSError(f"Failed to compress domain name: {str(e)}")
|
||||||
|
|
||||||
|
def decompress_domain_name(compressed: bytes, offset: int = 0) -> Tuple[str, int]:
|
||||||
|
"""Decompress a domain name from DNS wire format."""
|
||||||
|
try:
|
||||||
|
n, offset = name.from_wire(compressed, offset)
|
||||||
|
return n.to_text(), offset
|
||||||
|
except Exception as e:
|
||||||
|
raise DNSError(f"Failed to decompress domain name: {str(e)}")
|
||||||
|
|
||||||
|
def encode_resource_record(name: str, rr_type: str, rr_class: str, ttl: int, rdata: str) -> bytes:
|
||||||
|
"""Encode a resource record into DNS wire format."""
|
||||||
|
try:
|
||||||
|
n = name.from_text(name)
|
||||||
|
rr_type = rdatatype.from_text(rr_type)
|
||||||
|
rr_class = rdataclass.from_text(rr_class)
|
||||||
|
rd = rdata.from_text(rr_type, rr_class, rdata)
|
||||||
|
return (n.to_wire() +
|
||||||
|
struct.pack("!HHIH", rr_type, rr_class, ttl, len(rd.to_wire())) +
|
||||||
|
rd.to_wire())
|
||||||
|
except Exception as e:
|
||||||
|
raise DNSError(f"Failed to encode resource record: {str(e)}")
|
||||||
|
|
||||||
|
def decode_resource_record(wire: bytes, offset: int = 0) -> Tuple[str, str, str, int, str, int]:
|
||||||
|
"""Decode a resource record from DNS wire format."""
|
||||||
|
try:
|
||||||
|
n, offset = name.from_wire(wire, offset)
|
||||||
|
rr_type, rr_class, ttl, rdlen = struct.unpack_from("!HHIH", wire, offset)
|
||||||
|
offset += 10
|
||||||
|
rd = rdata.from_wire(rdatatype.to_text(rr_type), wire, offset, rdlen)
|
||||||
|
offset += rdlen
|
||||||
|
return (n.to_text(),
|
||||||
|
rdatatype.to_text(rr_type),
|
||||||
|
rdataclass.to_text(rr_class),
|
||||||
|
ttl,
|
||||||
|
rd.to_text(),
|
||||||
|
offset)
|
||||||
|
except Exception as e:
|
||||||
|
raise DNSError(f"Failed to decode resource record: {str(e)}")
|
||||||
|
|
||||||
|
def is_valid_domain_name(domain: str) -> bool:
|
||||||
|
"""Check if a given string is a valid domain name."""
|
||||||
|
try:
|
||||||
|
name.from_text(domain)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def normalize_domain_name(domain: str) -> str:
|
||||||
|
"""Normalize a domain name (convert to lowercase and ensure it ends with a dot)."""
|
||||||
|
try:
|
||||||
|
n = name.from_text(domain)
|
||||||
|
return n.to_text().lower()
|
||||||
|
except Exception as e:
|
||||||
|
raise DNSError(f"Failed to normalize domain name: {str(e)}")
|
||||||
52
src/errors.py
Normal file
52
src/errors.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
class PkarrError(Exception):
|
||||||
|
"""Base class for all pkarr-related errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class KeypairError(PkarrError):
|
||||||
|
"""Raised when there's an issue with keypair operations."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class PublicKeyError(PkarrError):
|
||||||
|
"""Raised when there's an issue with public key operations."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class SignatureError(PkarrError):
|
||||||
|
"""Raised when there's an issue with signature operations."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class PacketError(PkarrError):
|
||||||
|
"""Raised when there's an issue with packet operations."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DNSError(PkarrError):
|
||||||
|
"""Raised when there's an issue with DNS operations."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DHTError(PkarrError):
|
||||||
|
"""Raised when there's an issue with DHT operations."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class InvalidSignedPacketBytesLength(PacketError):
|
||||||
|
"""Raised when the SignedPacket bytes length is invalid."""
|
||||||
|
def __init__(self, length: int):
|
||||||
|
super().__init__(f"Invalid SignedPacket bytes length, expected at least 104 bytes but got: {length}")
|
||||||
|
|
||||||
|
class InvalidRelayPayloadSize(PacketError):
|
||||||
|
"""Raised when the relay payload size is invalid."""
|
||||||
|
def __init__(self, size: int):
|
||||||
|
super().__init__(f"Invalid relay payload size, expected at least 72 bytes but got: {size}")
|
||||||
|
|
||||||
|
class PacketTooLarge(PacketError):
|
||||||
|
"""Raised when the DNS packet is too large."""
|
||||||
|
def __init__(self, size: int):
|
||||||
|
super().__init__(f"DNS Packet is too large, expected max 1000 bytes but got: {size}")
|
||||||
|
|
||||||
|
class DHTIsShutdown(DHTError):
|
||||||
|
"""Raised when the DHT is shutdown."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("DHT is shutdown")
|
||||||
|
|
||||||
|
class PublishInflight(DHTError):
|
||||||
|
"""Raised when a publish query is already in flight for the same public key."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("Publish query is already in flight for the same public_key")
|
||||||
44
src/keypair.py
Normal file
44
src/keypair.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from .public_key import PublicKey
|
||||||
|
from .crypto import Crypto
|
||||||
|
from .errors import KeypairError
|
||||||
|
|
||||||
|
class Keypair:
|
||||||
|
def __init__(self, secret_key: bytes):
|
||||||
|
if len(secret_key) != 32:
|
||||||
|
raise KeypairError(f"Secret key must be 32 bytes long, got {len(secret_key)} bytes")
|
||||||
|
self.secret_key = secret_key
|
||||||
|
self.public_key = PublicKey(Crypto.derive_public_key(secret_key))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def random(cls) -> 'Keypair':
|
||||||
|
"""Generate a new random keypair."""
|
||||||
|
secret_key, _ = Crypto.generate_keypair()
|
||||||
|
return cls(secret_key)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_secret_key(cls, secret_key: bytes) -> 'Keypair':
|
||||||
|
"""Create a Keypair from a secret key."""
|
||||||
|
return cls(secret_key)
|
||||||
|
|
||||||
|
def sign(self, message: bytes) -> bytes:
|
||||||
|
"""Sign a message using this keypair."""
|
||||||
|
return Crypto.sign(self.secret_key, message)
|
||||||
|
|
||||||
|
def verify(self, message: bytes, signature: bytes) -> bool:
|
||||||
|
"""Verify a signature using this keypair's public key."""
|
||||||
|
return self.public_key.verify(message, signature)
|
||||||
|
|
||||||
|
def to_bytes(self) -> bytes:
|
||||||
|
"""Return the secret key bytes."""
|
||||||
|
return self.secret_key
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, secret_key: bytes) -> 'Keypair':
|
||||||
|
"""Create a Keypair from bytes."""
|
||||||
|
return cls(secret_key)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Keypair(public_key={self.public_key})"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
95
src/packet.py
Normal file
95
src/packet.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional
|
||||||
|
from dns import message, name, rdata, rdatatype, rdataclass
|
||||||
|
from .resource_record import ResourceRecord
|
||||||
|
from .errors import PacketError
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Packet:
|
||||||
|
answers: List[ResourceRecord] = field(default_factory=list)
|
||||||
|
id: int = 0
|
||||||
|
qr: bool = True # True for response, False for query
|
||||||
|
opcode: int = 0 # 0 for standard query
|
||||||
|
aa: bool = True # Authoritative Answer
|
||||||
|
tc: bool = False # TrunCation
|
||||||
|
rd: bool = False # Recursion Desired
|
||||||
|
ra: bool = False # Recursion Available
|
||||||
|
z: int = 0 # Reserved for future use
|
||||||
|
rcode: int = 0 # Response code
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_reply(cls, id: int):
|
||||||
|
return cls(answers=[], id=id)
|
||||||
|
|
||||||
|
def add_answer(self, answer: ResourceRecord):
|
||||||
|
self.answers.append(answer)
|
||||||
|
|
||||||
|
def build_bytes_vec_compressed(self) -> bytes:
|
||||||
|
"""Build a compressed DNS wire format representation of the packet."""
|
||||||
|
try:
|
||||||
|
msg = message.Message(id=self.id)
|
||||||
|
msg.flags = 0
|
||||||
|
if self.qr:
|
||||||
|
msg.flags |= 1 << 15
|
||||||
|
msg.flags |= (self.opcode & 0xF) << 11
|
||||||
|
if self.aa:
|
||||||
|
msg.flags |= 1 << 10
|
||||||
|
if self.tc:
|
||||||
|
msg.flags |= 1 << 9
|
||||||
|
if self.rd:
|
||||||
|
msg.flags |= 1 << 8
|
||||||
|
if self.ra:
|
||||||
|
msg.flags |= 1 << 7
|
||||||
|
msg.flags |= (self.z & 0x7) << 4
|
||||||
|
msg.flags |= self.rcode & 0xF
|
||||||
|
|
||||||
|
for rr in self.answers:
|
||||||
|
rr_name = name.from_text(rr.name)
|
||||||
|
rr_ttl = rr.ttl
|
||||||
|
rr_rdataclass = rdataclass.from_text(rr.rclass)
|
||||||
|
rr_rdatatype = rdatatype.from_text(rr.rtype)
|
||||||
|
rr_rdata = rdata.from_text(rr_rdataclass, rr_rdatatype, rr.rdata)
|
||||||
|
msg.answer.append((rr_name, rr_ttl, rr_rdata))
|
||||||
|
|
||||||
|
return msg.to_wire()
|
||||||
|
except Exception as e:
|
||||||
|
raise PacketError(f"Failed to build packet: {str(e)}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data: bytes) -> 'Packet':
|
||||||
|
"""Create a Packet object from DNS wire format bytes."""
|
||||||
|
try:
|
||||||
|
msg = message.from_wire(data)
|
||||||
|
packet = cls(
|
||||||
|
id=msg.id,
|
||||||
|
qr=bool(msg.flags & (1 << 15)),
|
||||||
|
opcode=(msg.flags >> 11) & 0xF,
|
||||||
|
aa=bool(msg.flags & (1 << 10)),
|
||||||
|
tc=bool(msg.flags & (1 << 9)),
|
||||||
|
rd=bool(msg.flags & (1 << 8)),
|
||||||
|
ra=bool(msg.flags & (1 << 7)),
|
||||||
|
z=(msg.flags >> 4) & 0x7,
|
||||||
|
rcode=msg.flags & 0xF
|
||||||
|
)
|
||||||
|
|
||||||
|
for rrset in msg.answer:
|
||||||
|
for rr in rrset:
|
||||||
|
resource_record = ResourceRecord(
|
||||||
|
name=rrset.name.to_text(),
|
||||||
|
rclass=rdataclass.to_text(rr.rdclass),
|
||||||
|
ttl=rrset.ttl,
|
||||||
|
rtype=rdatatype.to_text(rr.rdtype),
|
||||||
|
rdata=rr.to_text()
|
||||||
|
)
|
||||||
|
packet.add_answer(resource_record)
|
||||||
|
|
||||||
|
return packet
|
||||||
|
except Exception as e:
|
||||||
|
raise PacketError(f"Failed to parse packet: {str(e)}")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
header = f"Packet ID: {self.id}, QR: {'Response' if self.qr else 'Query'}, " \
|
||||||
|
f"Opcode: {self.opcode}, AA: {self.aa}, TC: {self.tc}, RD: {self.rd}, " \
|
||||||
|
f"RA: {self.ra}, Z: {self.z}, RCODE: {self.rcode}"
|
||||||
|
answers = "\n".join(f" {rr}" for rr in self.answers)
|
||||||
|
return f"{header}\nAnswers:\n{answers}"
|
||||||
39
src/public_key.py
Normal file
39
src/public_key.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from .crypto import Crypto
|
||||||
|
from .errors import PublicKeyError
|
||||||
|
|
||||||
|
class PublicKey:
|
||||||
|
def __init__(self, key):
|
||||||
|
if isinstance(key, str):
|
||||||
|
try:
|
||||||
|
self.key = Crypto.z_base_32_decode(key)
|
||||||
|
except ValueError:
|
||||||
|
raise PublicKeyError("Invalid z-base-32 encoded public key")
|
||||||
|
elif isinstance(key, bytes):
|
||||||
|
if len(key) != 32:
|
||||||
|
raise PublicKeyError("Public key must be 32 bytes long")
|
||||||
|
self.key = key
|
||||||
|
else:
|
||||||
|
raise PublicKeyError("Public key must be bytes or z-base-32 encoded string")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.to_z32()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"PublicKey({self.to_z32()})"
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, PublicKey):
|
||||||
|
return self.key == other.key
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.key)
|
||||||
|
|
||||||
|
def to_z32(self):
|
||||||
|
return Crypto.z_base_32_encode(self.key)
|
||||||
|
|
||||||
|
def to_bytes(self):
|
||||||
|
return self.key
|
||||||
|
|
||||||
|
def verify(self, message: bytes, signature: bytes) -> bool:
|
||||||
|
return Crypto.verify(self.key, message, signature)
|
||||||
39
src/resource_record.py
Normal file
39
src/resource_record.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Union
|
||||||
|
import ipaddress
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResourceRecord:
|
||||||
|
name: str
|
||||||
|
rclass: str
|
||||||
|
ttl: int
|
||||||
|
rtype: str
|
||||||
|
rdata: Union[str, ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.name = self.name.lower()
|
||||||
|
self.rclass = self.rclass.upper()
|
||||||
|
self.rtype = self.rtype.upper()
|
||||||
|
|
||||||
|
if self.rtype == 'A':
|
||||||
|
self.rdata = ipaddress.IPv4Address(self.rdata)
|
||||||
|
elif self.rtype == 'AAAA':
|
||||||
|
self.rdata = ipaddress.IPv6Address(self.rdata)
|
||||||
|
|
||||||
|
def to_wire_format(self) -> bytes:
|
||||||
|
# This is a placeholder. You'll need to implement the actual DNS wire format encoding.
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_wire_format(cls, wire_data: bytes) -> 'ResourceRecord':
|
||||||
|
# This is a placeholder. You'll need to implement the actual DNS wire format decoding.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.name} {self.ttl} {self.rclass} {self.rtype} {self.rdata}"
|
||||||
|
|
||||||
|
def is_expired(self, current_time: int) -> bool:
|
||||||
|
return current_time > self.ttl
|
||||||
|
|
||||||
|
def remaining_ttl(self, current_time: int) -> int:
|
||||||
|
return max(0, self.ttl - current_time)
|
||||||
154
src/signed_packet.py
Normal file
154
src/signed_packet.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
import ed25519
|
||||||
|
from dns import message, name, rdata, rdatatype, rdataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PublicKey:
|
||||||
|
key: bytes
|
||||||
|
|
||||||
|
def to_z32(self) -> str:
|
||||||
|
# Implement z-base-32 encoding here
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResourceRecord:
|
||||||
|
name: str
|
||||||
|
rclass: int
|
||||||
|
ttl: int
|
||||||
|
rdata: bytes
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Packet:
|
||||||
|
answers: List[ResourceRecord]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_reply(cls, id: int):
|
||||||
|
return cls(answers=[])
|
||||||
|
|
||||||
|
def build_bytes_vec_compressed(self) -> bytes:
|
||||||
|
# Implement DNS packet compression here
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SignedPacket:
|
||||||
|
public_key: PublicKey
|
||||||
|
signature: bytes
|
||||||
|
timestamp: int
|
||||||
|
packet: Packet
|
||||||
|
last_seen: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_packet(cls, keypair, packet: Packet):
|
||||||
|
timestamp = int(time.time() * 1_000_000)
|
||||||
|
encoded_packet = packet.build_bytes_vec_compressed()
|
||||||
|
|
||||||
|
if len(encoded_packet) > 1000:
|
||||||
|
raise ValueError("Packet too large")
|
||||||
|
|
||||||
|
signature = keypair.sign(cls.signable(timestamp, encoded_packet))
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
public_key=keypair.public_key(),
|
||||||
|
signature=signature,
|
||||||
|
timestamp=timestamp,
|
||||||
|
packet=packet,
|
||||||
|
last_seen=int(time.time() * 1_000_000)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data: bytes):
|
||||||
|
if len(data) < 104:
|
||||||
|
raise ValueError("Invalid SignedPacket bytes length")
|
||||||
|
if len(data) > 1104:
|
||||||
|
raise ValueError("Packet too large")
|
||||||
|
|
||||||
|
public_key = PublicKey(data[:32])
|
||||||
|
signature = data[32:96]
|
||||||
|
timestamp = int.from_bytes(data[96:104], 'big')
|
||||||
|
encoded_packet = data[104:]
|
||||||
|
|
||||||
|
# Verify signature
|
||||||
|
if not public_key.verify(cls.signable(timestamp, encoded_packet), signature):
|
||||||
|
raise ValueError("Invalid signature")
|
||||||
|
|
||||||
|
packet = Packet([]) # Parse encoded_packet into a Packet object here
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
public_key=public_key,
|
||||||
|
signature=signature,
|
||||||
|
timestamp=timestamp,
|
||||||
|
packet=packet,
|
||||||
|
last_seen=int(time.time() * 1_000_000)
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_bytes(self) -> bytes:
|
||||||
|
return (
|
||||||
|
self.public_key.key +
|
||||||
|
self.signature +
|
||||||
|
self.timestamp.to_bytes(8, 'big') +
|
||||||
|
self.packet.build_bytes_vec_compressed()
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_relay_payload(self) -> bytes:
|
||||||
|
return self.as_bytes()[32:]
|
||||||
|
|
||||||
|
def resource_records(self, name: str):
|
||||||
|
origin = self.public_key.to_z32()
|
||||||
|
normalized_name = self.normalize_name(origin, name)
|
||||||
|
return [rr for rr in self.packet.answers if rr.name == normalized_name]
|
||||||
|
|
||||||
|
def fresh_resource_records(self, name: str):
|
||||||
|
origin = self.public_key.to_z32()
|
||||||
|
normalized_name = self.normalize_name(origin, name)
|
||||||
|
current_time = int(time.time())
|
||||||
|
return [
|
||||||
|
rr for rr in self.packet.answers
|
||||||
|
if rr.name == normalized_name and rr.ttl > (current_time - self.last_seen // 1_000_000)
|
||||||
|
]
|
||||||
|
|
||||||
|
def expires_in(self, min_ttl: int, max_ttl: int) -> int:
|
||||||
|
ttl = self.ttl(min_ttl, max_ttl)
|
||||||
|
elapsed = self.elapsed()
|
||||||
|
return max(0, ttl - elapsed)
|
||||||
|
|
||||||
|
def ttl(self, min_ttl: int, max_ttl: int) -> int:
|
||||||
|
if not self.packet.answers:
|
||||||
|
return min_ttl
|
||||||
|
min_record_ttl = min(rr.ttl for rr in self.packet.answers)
|
||||||
|
return max(min_ttl, min(max_ttl, min_record_ttl))
|
||||||
|
|
||||||
|
def elapsed(self) -> int:
|
||||||
|
return (int(time.time() * 1_000_000) - self.last_seen) // 1_000_000
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def signable(timestamp: int, v: bytes) -> bytes:
|
||||||
|
return f"3:seqi{timestamp}e1:v{len(v)}:".encode() + v
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_name(origin: str, name: str) -> str:
|
||||||
|
if name.endswith('.'):
|
||||||
|
name = name[:-1]
|
||||||
|
|
||||||
|
parts = name.split('.')
|
||||||
|
last = parts[-1]
|
||||||
|
|
||||||
|
if last == origin:
|
||||||
|
return name
|
||||||
|
if last in ('@', ''):
|
||||||
|
return origin
|
||||||
|
return f"{name}.{origin}"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
records = "\n".join(
|
||||||
|
f" {rr.name} IN {rr.ttl} {rr.rdata}"
|
||||||
|
for rr in self.packet.answers
|
||||||
|
)
|
||||||
|
return f"""SignedPacket ({self.public_key.key.hex()}):
|
||||||
|
last_seen: {self.elapsed()} seconds ago
|
||||||
|
timestamp: {self.timestamp},
|
||||||
|
signature: {self.signature.hex()}
|
||||||
|
records:
|
||||||
|
{records}
|
||||||
|
"""
|
||||||
Reference in New Issue
Block a user