From bcaebc55d215c1583633c6d3d36bb5b81edffff3 Mon Sep 17 00:00:00 2001 From: nazeh Date: Fri, 20 Sep 2024 10:17:05 +0300 Subject: [PATCH] feat(pubky): Resolve through intermediate pkarr packet --- pubky/src/native/internals/endpoints.rs | 153 +++++++++++++++--------- 1 file changed, 95 insertions(+), 58 deletions(-) diff --git a/pubky/src/native/internals/endpoints.rs b/pubky/src/native/internals/endpoints.rs index 64352d4..b650e75 100644 --- a/pubky/src/native/internals/endpoints.rs +++ b/pubky/src/native/internals/endpoints.rs @@ -6,7 +6,7 @@ use reqwest::dns::{Addrs, Resolve}; use crate::error::{Error, Result}; -const MAX_ENDPOINT_RESOLUTION_RECURSION: u8 = 3; +const MAX_CHAIN_LENGTH: u8 = 3; #[derive(Debug, Clone)] pub struct PkarrResolver { @@ -27,6 +27,8 @@ impl PkarrResolver { let target = qname; // TODO: cache the result of this function? + let is_svcb = target.starts_with('_'); + let mut step = 0; let mut svcb: Option = None; @@ -34,19 +36,15 @@ impl PkarrResolver { let current = svcb.clone().map_or(target.to_string(), |s| s.target); if let Ok(tld) = PublicKey::try_from(current.clone()) { if let Ok(Some(signed_packet)) = self.pkarr.resolve(&tld).await { - if step >= MAX_ENDPOINT_RESOLUTION_RECURSION { + if step >= MAX_CHAIN_LENGTH { break; }; step += 1; // Choose most prior SVCB record - svcb = get_endpoint(&signed_packet, ¤t); + svcb = get_endpoint(&signed_packet, ¤t, is_svcb); // TODO: support wildcard? - - if step >= MAX_ENDPOINT_RESOLUTION_RECURSION { - break; - }; } else { break; } @@ -94,9 +92,7 @@ struct Endpoint { port: u16, } -fn get_endpoint(signed_packet: &SignedPacket, target: &str) -> Option { - let is_svcb = target.starts_with('_'); - +fn get_endpoint(signed_packet: &SignedPacket, target: &str, is_svcb: bool) -> Option { signed_packet .resource_records(target) .fold(None, |prev: Option, answer| { @@ -152,29 +148,6 @@ mod tests { use pkarr::PkarrClient; use pkarr::{mainline::Testnet, Keypair}; - async fn publish_packets( - client: &PkarrClientAsync, - tree: Vec)>>, - ) -> Vec { - let mut keypairs: Vec = Vec::with_capacity(tree.len()); - for node in tree { - let mut packet = dns::Packet::new_reply(0); - for record in node { - packet.answers.push(dns::ResourceRecord::new( - dns::Name::new(record.0).unwrap(), - dns::CLASS::IN, - 3600, - record.1, - )); - } - let keypair = Keypair::random(); - let signed_packet = SignedPacket::from_packet(&keypair, &packet).unwrap(); - keypairs.push(keypair); - client.publish(&signed_packet).await.unwrap(); - } - keypairs - } - #[tokio::test] async fn resolve_direct_endpoint() { let testnet = Testnet::new(3); @@ -184,34 +157,40 @@ mod tests { .unwrap() .as_async(); - let keypairs = publish_packets( - &pkarr, - vec![vec![ - ( - "foo", - RData::HTTPS(SVCB::new(0, "https.example.com".try_into().unwrap()).into()), - ), - // Make sure HTTPS only follows HTTPs - ( - "foo", - RData::SVCB(SVCB::new(0, "protocol.example.com".try_into().unwrap())), - ), - // Make sure SVCB only follows SVCB - ( - "foo", - RData::HTTPS(SVCB::new(0, "https.example.com".try_into().unwrap()).into()), - ), - ( - "_foo", - RData::SVCB(SVCB::new(0, "protocol.example.com".try_into().unwrap())), - ), - ]], - ) - .await; + let mut packet = dns::Packet::new_reply(0); + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("foo").unwrap(), + dns::CLASS::IN, + 3600, + RData::HTTPS(SVCB::new(0, "https.example.com".try_into().unwrap()).into()), + )); + // Make sure HTTPS only follows HTTPs + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("foo").unwrap(), + dns::CLASS::IN, + 3600, + RData::SVCB(SVCB::new(0, "protocol.example.com".try_into().unwrap())), + )); + // Make sure SVCB only follows SVCB + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("foo").unwrap(), + dns::CLASS::IN, + 3600, + RData::HTTPS(SVCB::new(0, "https.example.com".try_into().unwrap()).into()), + )); + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("_foo").unwrap(), + dns::CLASS::IN, + 3600, + RData::SVCB(SVCB::new(0, "protocol.example.com".try_into().unwrap())), + )); + let keypair = Keypair::random(); + let inter_signed_packet = SignedPacket::from_packet(&keypair, &packet).unwrap(); + pkarr.publish(&inter_signed_packet).await.unwrap(); let resolver = PkarrResolver { pkarr }; - let tld = keypairs.first().unwrap().public_key(); + let tld = keypair.public_key(); // Follow foo.tld HTTPS records let endpoint = resolver @@ -227,4 +206,62 @@ mod tests { .unwrap(); assert_eq!(endpoint.target, "protocol.example.com"); } + + #[tokio::test] + async fn resolve_endpoint_with_intermediate_pubky() { + let testnet = Testnet::new(3); + let pkarr = PkarrClient::builder() + .testnet(&testnet) + .build() + .unwrap() + .as_async(); + + // USER => Server Owner => Server + // pubky. => pubky-homeserver. => @. + + let mut packet = dns::Packet::new_reply(0); + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("@").unwrap(), + dns::CLASS::IN, + 3600, + RData::HTTPS(SVCB::new(0, "example.com".try_into().unwrap()).into()), + )); + let keypair = Keypair::random(); + let inter_signed_packet = SignedPacket::from_packet(&keypair, &packet).unwrap(); + pkarr.publish(&inter_signed_packet).await.unwrap(); + + let end_target = format!("{}", keypair.public_key()); + let mut packet = dns::Packet::new_reply(0); + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("pubky-homeserver.").unwrap(), + dns::CLASS::IN, + 3600, + RData::HTTPS(SVCB::new(0, end_target.as_str().try_into().unwrap()).into()), + )); + let keypair = Keypair::random(); + let inter_signed_packet = SignedPacket::from_packet(&keypair, &packet).unwrap(); + pkarr.publish(&inter_signed_packet).await.unwrap(); + + let inter_target = format!("pubky-homeserver.{}", keypair.public_key()); + let mut packet = dns::Packet::new_reply(0); + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("pubky.").unwrap(), + dns::CLASS::IN, + 3600, + RData::HTTPS(SVCB::new(0, inter_target.as_str().try_into().unwrap()).into()), + )); + let keypair = Keypair::random(); + let inter_signed_packet = SignedPacket::from_packet(&keypair, &packet).unwrap(); + pkarr.publish(&inter_signed_packet).await.unwrap(); + + let resolver = PkarrResolver { pkarr }; + + let tld = keypair.public_key(); + + let endpoint = resolver + .resolve_endpoint(&format!("pubky.{tld}")) + .await + .unwrap(); + assert_eq!(endpoint.target, "example.com"); + } }