use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::RwLock; use std::time::{Duration, Instant}; use log::{debug, info}; use crate::cache::DnsCache; use crate::forward::forward_udp; use crate::header::ResultCode; use crate::packet::DnsPacket; use crate::question::QueryType; use crate::record::DnsRecord; use crate::srtt::SrttCache; const MAX_REFERRAL_DEPTH: u8 = 10; const MAX_CNAME_DEPTH: u8 = 8; const NS_QUERY_TIMEOUT: Duration = Duration::from_millis(800); const TCP_TIMEOUT: Duration = Duration::from_millis(1500); const UDP_FAIL_THRESHOLD: u8 = 3; static QUERY_ID: AtomicU16 = AtomicU16::new(1); static UDP_FAILURES: std::sync::atomic::AtomicU8 = std::sync::atomic::AtomicU8::new(0); pub(crate) static UDP_DISABLED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); fn next_id() -> u16 { QUERY_ID.fetch_add(1, Ordering::Relaxed) } fn dns_addr(ip: impl Into) -> SocketAddr { SocketAddr::new(ip.into(), 53) } fn record_to_addr(rec: &DnsRecord) -> Option { match rec { DnsRecord::A { addr, .. } => Some(dns_addr(*addr)), DnsRecord::AAAA { addr, .. } => Some(dns_addr(*addr)), _ => None, } } pub fn reset_udp_state() { UDP_DISABLED.store(false, Ordering::Release); UDP_FAILURES.store(0, Ordering::Release); } /// Probe whether UDP works again. Called periodically from the network watch loop. pub async fn probe_udp(root_hints: &[SocketAddr]) { if !UDP_DISABLED.load(Ordering::Acquire) { return; } let hint = match root_hints.first() { Some(h) => *h, None => return, }; let mut probe = DnsPacket::query(next_id(), ".", QueryType::NS); probe.header.recursion_desired = false; if forward_udp(&probe, hint, Duration::from_millis(1500)) .await .is_ok() { info!("UDP probe succeeded — re-enabling UDP"); reset_udp_state(); } } /// Probe whether recursive resolution works by querying root servers. /// Tries up to 3 hints before declaring failure. pub async fn probe_recursive(root_hints: &[SocketAddr]) -> bool { let mut probe = DnsPacket::query(next_id(), ".", QueryType::NS); probe.header.recursion_desired = false; for hint in root_hints.iter().take(3) { if let Ok(resp) = forward_udp(&probe, *hint, Duration::from_secs(3)).await { if !resp.answers.is_empty() || !resp.authorities.is_empty() { return true; } } } false } pub async fn prime_tld_cache( cache: &RwLock, root_hints: &[SocketAddr], tlds: &[String], srtt: &RwLock, ) { if root_hints.is_empty() || tlds.is_empty() { return; } let mut root_addr = root_hints[0]; for hint in root_hints { info!("prime: probing root {}", hint); match send_query(".", QueryType::NS, *hint, srtt).await { Ok(_) => { info!("prime: root {} reachable", hint); root_addr = *hint; break; } Err(e) => { info!("prime: root {} failed: {}, trying next", hint, e); } } } // Fetch root DNSKEY (needed for DNSSEC chain-of-trust terminus) if let Ok(root_dnskey) = send_query(".", QueryType::DNSKEY, root_addr, srtt).await { cache .write() .unwrap() .insert(".", QueryType::DNSKEY, &root_dnskey); debug!("prime: cached root DNSKEY"); } let mut primed = 0u16; for tld in tlds { // Fetch NS referral (includes DS in authority section from root) let response = match send_query(tld, QueryType::NS, root_addr, srtt).await { Ok(r) => r, Err(e) => { debug!("prime: failed to query NS for .{}: {}", tld, e); continue; } }; let ns_names = extract_ns_names(&response); if ns_names.is_empty() { continue; } { let mut cache_w = cache.write().unwrap(); cache_w.insert(tld, QueryType::NS, &response); cache_glue(&mut cache_w, &response, &ns_names); cache_ds_from_authority(&mut cache_w, &response); } // Fetch DNSKEY for this TLD (needed for DNSSEC chain validation) let first_ns_name = ns_names.first().map(|s| s.as_str()).unwrap_or(""); let first_ns = glue_addrs_for(&response, first_ns_name); if let Some(ns_addr) = first_ns.first() { if let Ok(dnskey_resp) = send_query(tld, QueryType::DNSKEY, *ns_addr, srtt).await { cache .write() .unwrap() .insert(tld, QueryType::DNSKEY, &dnskey_resp); } } primed += 1; } info!( "primed {}/{} TLD caches (NS + glue + DS + DNSKEY)", primed, tlds.len() ); } pub async fn resolve_recursive( qname: &str, qtype: QueryType, cache: &RwLock, original_query: &DnsPacket, root_hints: &[SocketAddr], srtt: &RwLock, ) -> crate::Result { // No overall timeout — each hop is bounded by NS_QUERY_TIMEOUT (UDP + TCP fallback), // and MAX_REFERRAL_DEPTH caps the chain length. let mut resp = resolve_iterative(qname, qtype, cache, root_hints, srtt, 0, 0).await?; resp.header.id = original_query.header.id; resp.header.recursion_available = true; resp.header.recursion_desired = original_query.header.recursion_desired; resp.questions = original_query.questions.clone(); Ok(resp) } pub(crate) fn resolve_iterative<'a>( qname: &'a str, qtype: QueryType, cache: &'a RwLock, root_hints: &'a [SocketAddr], srtt: &'a RwLock, referral_depth: u8, cname_depth: u8, ) -> std::pin::Pin> + Send + 'a>> { Box::pin(async move { if referral_depth > MAX_REFERRAL_DEPTH { return Err("max referral depth exceeded".into()); } if let Some(cached) = cache.read().unwrap().lookup(qname, qtype) { return Ok(cached); } let (mut current_zone, mut ns_addrs) = find_closest_ns(qname, cache, root_hints); srtt.read().unwrap().sort_by_rtt(&mut ns_addrs); let mut ns_idx = 0; for _ in 0..MAX_REFERRAL_DEPTH { let ns_addr = match ns_addrs.get(ns_idx) { Some(addr) => *addr, None => return Err("no nameserver available".into()), }; let (q_name, q_type) = minimize_query(qname, qtype, ¤t_zone); debug!( "recursive: querying {} for {:?} {} (zone: {}, depth {})", ns_addr, q_type, q_name, current_zone, referral_depth ); let response = match send_query(q_name, q_type, ns_addr, srtt).await { Ok(r) => r, Err(e) => { debug!("recursive: NS {} failed: {}", ns_addr, e); ns_idx += 1; continue; } }; if (q_type != qtype || !q_name.eq_ignore_ascii_case(qname)) && (!response.authorities.is_empty() || !response.answers.is_empty()) { if let Some(zone) = referral_zone(&response) { current_zone = zone; } let mut all_ns = extract_ns_from_records(&response.answers); if all_ns.is_empty() { all_ns = extract_ns_names(&response); } let mut new_addrs = resolve_ns_addrs_from_glue(&response, &all_ns, cache); if !new_addrs.is_empty() { srtt.read().unwrap().sort_by_rtt(&mut new_addrs); ns_addrs = new_addrs; ns_idx = 0; continue; } ns_idx += 1; continue; } if !response.answers.is_empty() { let has_target = response.answers.iter().any(|r| r.query_type() == qtype); if has_target || qtype == QueryType::CNAME { cache.write().unwrap().insert(qname, qtype, &response); return Ok(response); } if let Some(cname_target) = extract_cname_target(&response, qname) { if cname_depth >= MAX_CNAME_DEPTH { return Err("max CNAME depth exceeded".into()); } debug!("recursive: chasing CNAME {} -> {}", qname, cname_target); let final_resp = resolve_iterative( &cname_target, qtype, cache, root_hints, srtt, 0, cname_depth + 1, ) .await?; let mut combined = response; combined.answers.extend(final_resp.answers); combined.header.rescode = final_resp.header.rescode; cache.write().unwrap().insert(qname, qtype, &combined); return Ok(combined); } cache.write().unwrap().insert(qname, qtype, &response); return Ok(response); } if response.header.rescode == ResultCode::NXDOMAIN || response.header.rescode == ResultCode::REFUSED { cache.write().unwrap().insert(qname, qtype, &response); return Ok(response); } if let Some(zone) = referral_zone(&response) { current_zone = zone; } let ns_names = extract_ns_names(&response); if ns_names.is_empty() { return Ok(response); } { let mut cache_w = cache.write().unwrap(); cache_ds_from_authority(&mut cache_w, &response); } let mut new_ns_addrs = resolve_ns_addrs_from_glue(&response, &ns_names, cache); if new_ns_addrs.is_empty() { for ns_name in &ns_names { if referral_depth < MAX_REFERRAL_DEPTH { debug!("recursive: resolving glue-less NS {}", ns_name); for qt in [QueryType::A, QueryType::AAAA] { if let Ok(ns_resp) = resolve_iterative( ns_name, qt, cache, root_hints, srtt, referral_depth + 1, cname_depth, ) .await { new_ns_addrs .extend(ns_resp.answers.iter().filter_map(record_to_addr)); } if !new_ns_addrs.is_empty() { break; } } } if !new_ns_addrs.is_empty() { break; } } } if new_ns_addrs.is_empty() { return Err(format!("could not resolve any NS for {}", qname).into()); } srtt.read().unwrap().sort_by_rtt(&mut new_ns_addrs); ns_addrs = new_ns_addrs; ns_idx = 0; } Err(format!("recursive resolution exhausted for {}", qname).into()) }) } /// Find the closest cached NS zone and its resolved addresses. /// Returns (zone_name, ns_addresses). Falls back to (".", root_hints). fn find_closest_ns( qname: &str, cache: &RwLock, root_hints: &[SocketAddr], ) -> (String, Vec) { let guard = cache.read().unwrap(); let mut pos = 0; loop { let zone = &qname[pos..]; if let Some(cached) = guard.lookup(zone, QueryType::NS) { let mut addrs = Vec::new(); let ns_records = if cached .answers .iter() .any(|r| matches!(r, DnsRecord::NS { .. })) { &cached.answers } else { &cached.authorities }; for ns_rec in ns_records { if let DnsRecord::NS { host, .. } = ns_rec { for qt in [QueryType::A, QueryType::AAAA] { if let Some(resp) = guard.lookup(host, qt) { addrs.extend(resp.answers.iter().filter_map(record_to_addr)); } } } } if !addrs.is_empty() { debug!("recursive: starting from cached NS for zone '{}'", zone); return (zone.to_string(), addrs); } } match qname[pos..].find('.') { Some(dot) => pos += dot + 1, None => break, } } drop(guard); debug!( "recursive: starting from root hints ({} servers)", root_hints.len() ); (".".to_string(), root_hints.to_vec()) } /// Extract NS hostnames from any record section (answers or authorities). fn extract_ns_from_records(records: &[DnsRecord]) -> Vec { records .iter() .filter_map(|r| match r { DnsRecord::NS { host, .. } => Some(host.clone()), _ => None, }) .collect() } /// Resolve NS addresses from glue records, then cache fallback. fn resolve_ns_addrs_from_glue( response: &DnsPacket, ns_names: &[String], cache: &RwLock, ) -> Vec { let mut addrs = Vec::new(); { let mut cache_w = cache.write().unwrap(); cache_glue(&mut cache_w, response, ns_names); } for ns_name in ns_names { addrs.extend_from_slice(&glue_addrs_for(response, ns_name)); } if addrs.is_empty() { for ns_name in ns_names { addrs.extend(addrs_from_cache(cache, ns_name)); } } addrs } fn referral_zone(response: &DnsPacket) -> Option { response.authorities.iter().find_map(|r| match r { DnsRecord::NS { domain, .. } => Some(domain.clone()), _ => None, }) } /// RFC 7816 query minimization (conservative): only minimize at root. fn minimize_query<'a>( qname: &'a str, qtype: QueryType, current_zone: &str, ) -> (&'a str, QueryType) { if current_zone != "." { return (qname, qtype); } // At root: extract TLD (last label) match qname.rfind('.') { Some(dot) if dot > 0 => (&qname[dot + 1..], QueryType::NS), _ => (qname, qtype), } } fn addrs_from_cache(cache: &RwLock, name: &str) -> Vec { let guard = cache.read().unwrap(); let mut addrs = Vec::new(); for qt in [QueryType::A, QueryType::AAAA] { if let Some(pkt) = guard.lookup(name, qt) { addrs.extend(pkt.answers.iter().filter_map(record_to_addr)); } } addrs } fn glue_addrs_for(response: &DnsPacket, ns_name: &str) -> Vec { response .resources .iter() .filter(|r| match r { DnsRecord::A { domain, .. } | DnsRecord::AAAA { domain, .. } => { domain.eq_ignore_ascii_case(ns_name) } _ => false, }) .filter_map(record_to_addr) .collect() } fn cache_glue(cache: &mut DnsCache, response: &DnsPacket, ns_names: &[String]) { for ns_name in ns_names { let mut a_pkt: Option = None; let mut aaaa_pkt: Option = None; for r in &response.resources { match r { DnsRecord::A { domain, addr, ttl } if domain.eq_ignore_ascii_case(ns_name) => { a_pkt .get_or_insert_with(make_glue_packet) .answers .push(DnsRecord::A { domain: ns_name.clone(), addr: *addr, ttl: *ttl, }); } DnsRecord::AAAA { domain, addr, ttl } if domain.eq_ignore_ascii_case(ns_name) => { aaaa_pkt .get_or_insert_with(make_glue_packet) .answers .push(DnsRecord::AAAA { domain: ns_name.clone(), addr: *addr, ttl: *ttl, }); } _ => {} } } if let Some(pkt) = a_pkt { cache.insert(ns_name, QueryType::A, &pkt); } if let Some(pkt) = aaaa_pkt { cache.insert(ns_name, QueryType::AAAA, &pkt); } } } /// Cache DS + DS-covering RRSIG records from referral authority sections. fn cache_ds_from_authority(cache: &mut DnsCache, response: &DnsPacket) { let mut ds_by_domain: Vec<(String, DnsPacket)> = Vec::new(); for r in &response.authorities { match r { DnsRecord::DS { domain, .. } => { let key = domain.to_lowercase(); let pkt = match ds_by_domain.iter_mut().find(|(d, _)| *d == key) { Some((_, pkt)) => pkt, None => { ds_by_domain.push((key, make_glue_packet())); &mut ds_by_domain.last_mut().unwrap().1 } }; pkt.answers.push(r.clone()); } DnsRecord::RRSIG { domain, type_covered, .. } if QueryType::from_num(*type_covered) == QueryType::DS => { let key = domain.to_lowercase(); let pkt = match ds_by_domain.iter_mut().find(|(d, _)| *d == key) { Some((_, pkt)) => pkt, None => { ds_by_domain.push((key, make_glue_packet())); &mut ds_by_domain.last_mut().unwrap().1 } }; pkt.answers.push(r.clone()); } _ => {} } } for (domain, pkt) in &ds_by_domain { if !pkt.answers.is_empty() { cache.insert(domain, QueryType::DS, pkt); } } } fn make_glue_packet() -> DnsPacket { let mut pkt = DnsPacket::new(); pkt.header.response = true; pkt.header.rescode = ResultCode::NOERROR; pkt } async fn tcp_with_srtt( query: &DnsPacket, server: SocketAddr, srtt: &RwLock, start: Instant, ) -> crate::Result { match crate::forward::forward_tcp(query, server, TCP_TIMEOUT).await { Ok(resp) => { srtt.write() .unwrap() .record_rtt(server.ip(), start.elapsed().as_millis() as u64, true); Ok(resp) } Err(e) => { srtt.write().unwrap().record_failure(server.ip()); Err(e) } } } async fn send_query( qname: &str, qtype: QueryType, server: SocketAddr, srtt: &RwLock, ) -> crate::Result { let mut query = DnsPacket::query(next_id(), qname, qtype); query.header.recursion_desired = false; query.edns = Some(crate::packet::EdnsOpt { do_bit: true, ..Default::default() }); let start = Instant::now(); // IPv6 forced to TCP — our UDP socket is bound to 0.0.0.0 if server.is_ipv6() { return tcp_with_srtt(&query, server, srtt, start).await; } // UDP detected as blocked — go TCP-first if UDP_DISABLED.load(Ordering::Acquire) { return tcp_with_srtt(&query, server, srtt, start).await; } match forward_udp(&query, server, NS_QUERY_TIMEOUT).await { Ok(resp) if resp.header.truncated_message => { debug!("send_query: truncated from {}, retrying TCP", server); tcp_with_srtt(&query, server, srtt, start).await } Ok(resp) => { UDP_FAILURES.store(0, Ordering::Release); srtt.write().unwrap().record_rtt( server.ip(), start.elapsed().as_millis() as u64, false, ); Ok(resp) } Err(e) => { let fails = UDP_FAILURES.fetch_add(1, Ordering::AcqRel) + 1; if fails >= UDP_FAIL_THRESHOLD && !UDP_DISABLED.load(Ordering::Acquire) { UDP_DISABLED.store(true, Ordering::Release); info!( "send_query: {} consecutive UDP failures — switching to TCP-first", fails ); } debug!("send_query: UDP failed for {}: {}, trying TCP", server, e); tcp_with_srtt(&query, server, srtt, start).await } } } fn extract_cname_target(response: &DnsPacket, qname: &str) -> Option { response.answers.iter().find_map(|r| match r { DnsRecord::CNAME { domain, host, .. } if domain.eq_ignore_ascii_case(qname) => { Some(host.clone()) } _ => None, }) } fn extract_ns_names(response: &DnsPacket) -> Vec { response .authorities .iter() .filter_map(|r| match r { DnsRecord::NS { host, .. } => Some(host.clone()), _ => None, }) .collect() } pub fn parse_root_hints(hints: &[String]) -> Vec { hints .iter() .filter_map(|s| { s.parse::() .map(|ip| SocketAddr::new(ip, 53)) .map_err(|e| log::warn!("invalid root hint '{}': {}", s, e)) .ok() }) .collect() } #[cfg(test)] mod tests { use super::*; use std::net::{Ipv4Addr, Ipv6Addr}; #[test] fn extract_ns_from_authority() { let mut pkt = DnsPacket::new(); pkt.authorities.push(DnsRecord::NS { domain: "example.com".into(), host: "ns1.example.com".into(), ttl: 3600, }); pkt.authorities.push(DnsRecord::NS { domain: "example.com".into(), host: "ns2.example.com".into(), ttl: 3600, }); let names = extract_ns_names(&pkt); assert_eq!(names, vec!["ns1.example.com", "ns2.example.com"]); } #[test] fn glue_extraction_a() { let mut pkt = DnsPacket::new(); pkt.resources.push(DnsRecord::A { domain: "ns1.example.com".into(), addr: Ipv4Addr::new(1, 2, 3, 4), ttl: 3600, }); let addrs = glue_addrs_for(&pkt, "ns1.example.com"); assert_eq!(addrs, vec![dns_addr(Ipv4Addr::new(1, 2, 3, 4))]); assert!(glue_addrs_for(&pkt, "ns3.example.com").is_empty()); } #[test] fn glue_extraction_aaaa() { let mut pkt = DnsPacket::new(); pkt.resources.push(DnsRecord::AAAA { domain: "ns1.example.com".into(), addr: "2001:db8::1".parse().unwrap(), ttl: 3600, }); pkt.resources.push(DnsRecord::A { domain: "ns1.example.com".into(), addr: Ipv4Addr::new(1, 2, 3, 4), ttl: 3600, }); let addrs = glue_addrs_for(&pkt, "ns1.example.com"); assert_eq!(addrs.len(), 2); // AAAA first (order matches resources), then A assert_eq!( addrs[0], dns_addr("2001:db8::1".parse::().unwrap()) ); assert_eq!(addrs[1], dns_addr(Ipv4Addr::new(1, 2, 3, 4))); } #[test] fn cname_extraction() { let mut pkt = DnsPacket::new(); pkt.answers.push(DnsRecord::CNAME { domain: "www.example.com".into(), host: "example.com".into(), ttl: 300, }); assert_eq!( extract_cname_target(&pkt, "www.example.com"), Some("example.com".into()) ); assert_eq!(extract_cname_target(&pkt, "other.com"), None); } #[test] fn parse_root_hints_valid() { let hints = vec!["198.41.0.4".into(), "199.9.14.201".into()]; let addrs = parse_root_hints(&hints); assert_eq!(addrs.len(), 2); assert_eq!(addrs[0], dns_addr(Ipv4Addr::new(198, 41, 0, 4))); } #[test] fn parse_root_hints_skips_invalid() { let hints = vec![ "198.41.0.4".into(), "not-an-ip".into(), "192.33.4.12".into(), ]; let addrs = parse_root_hints(&hints); assert_eq!(addrs.len(), 2); } #[test] fn find_closest_ns_falls_back_to_hints() { let cache = RwLock::new(DnsCache::new(100, 60, 86400)); let hints = vec![ dns_addr(Ipv4Addr::new(198, 41, 0, 4)), dns_addr(Ipv4Addr::new(199, 9, 14, 201)), ]; let (zone, addrs) = find_closest_ns("example.com", &cache, &hints); assert_eq!(zone, "."); assert_eq!(addrs, hints); } #[test] fn find_closest_ns_uses_authority_ns_records() { // Simulate what TLD priming does: cache a referral response where // NS records are in authorities (not answers), with glue in resources. let cache = RwLock::new(DnsCache::new(100, 60, 86400)); let hints = vec![dns_addr(Ipv4Addr::new(198, 41, 0, 4))]; // Build a referral-style response (NS in authorities, glue in resources) let mut referral = DnsPacket::new(); referral.header.response = true; referral.authorities.push(DnsRecord::NS { domain: "com".into(), host: "ns1.com".into(), ttl: 3600, }); referral.resources.push(DnsRecord::A { domain: "ns1.com".into(), addr: Ipv4Addr::new(192, 5, 6, 30), ttl: 3600, }); // Cache the referral under "com" NS (same as prime_tld_cache does) { let mut c = cache.write().unwrap(); c.insert("com", QueryType::NS, &referral); // Cache glue separately (as prime_tld_cache does) let mut glue_pkt = DnsPacket::new(); glue_pkt.header.response = true; glue_pkt.answers.push(DnsRecord::A { domain: "ns1.com".into(), addr: Ipv4Addr::new(192, 5, 6, 30), ttl: 3600, }); c.insert("ns1.com", QueryType::A, &glue_pkt); } // find_closest_ns should find "com" zone from authority NS records let (zone, addrs) = find_closest_ns("www.example.com", &cache, &hints); assert_eq!(zone, "com"); assert_eq!(addrs, vec![dns_addr(Ipv4Addr::new(192, 5, 6, 30))]); } #[test] fn minimize_query_from_root() { // At root, only reveal TLD let (name, qt) = minimize_query("www.example.com", QueryType::A, "."); assert_eq!(name, "com"); assert_eq!(qt, QueryType::NS); } #[test] fn minimize_query_beyond_root_sends_full() { // Beyond root, send full query (conservative minimization) let (name, qt) = minimize_query("www.example.com", QueryType::A, "com"); assert_eq!(name, "www.example.com"); assert_eq!(qt, QueryType::A); let (name, qt) = minimize_query("www.example.com", QueryType::A, "example.com"); assert_eq!(name, "www.example.com"); assert_eq!(qt, QueryType::A); } #[test] fn minimize_query_single_label() { // Single label (e.g., "com") from root — send as-is let (name, qt) = minimize_query("com", QueryType::NS, "."); assert_eq!(name, "com"); assert_eq!(qt, QueryType::NS); } // ---- Mock DNS server (TCP-only) for fallback tests ---- use crate::buffer::BytePacketBuffer; use crate::header::ResultCode; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; /// Spawn a TCP-only DNS server on localhost. Returns the address. /// The handler receives each query and returns a response packet. async fn spawn_tcp_dns_server( handler: impl Fn(&DnsPacket) -> DnsPacket + Send + Sync + 'static, ) -> SocketAddr { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let handler = std::sync::Arc::new(handler); tokio::spawn(async move { loop { let (mut stream, _) = match listener.accept().await { Ok(c) => c, Err(_) => break, }; let handler = handler.clone(); tokio::spawn(async move { let timeout = std::time::Duration::from_secs(5); // Read length-prefixed DNS query let mut len_buf = [0u8; 2]; if tokio::time::timeout(timeout, stream.read_exact(&mut len_buf)) .await .ok() .and_then(|r| r.ok()) .is_none() { return; } let len = u16::from_be_bytes(len_buf) as usize; let mut data = vec![0u8; len]; if tokio::time::timeout(timeout, stream.read_exact(&mut data)) .await .ok() .and_then(|r| r.ok()) .is_none() { return; } let mut buf = BytePacketBuffer::from_bytes(&data); let query = match DnsPacket::from_buffer(&mut buf) { Ok(q) => q, Err(_) => return, }; let response = handler(&query); let mut resp_buf = BytePacketBuffer::new(); if response.write(&mut resp_buf).is_err() { return; } let resp_bytes = resp_buf.filled(); let mut out = Vec::with_capacity(2 + resp_bytes.len()); out.extend_from_slice(&(resp_bytes.len() as u16).to_be_bytes()); out.extend_from_slice(resp_bytes); let _ = stream.write_all(&out).await; }); } }); addr } /// TCP-only server returns authoritative answer directly. /// Verifies: UDP fails → TCP fallback → resolves. #[tokio::test] async fn tcp_fallback_resolves_when_udp_blocked() { UDP_DISABLED.store(false, Ordering::Relaxed); UDP_FAILURES.store(0, Ordering::Release); let server_addr = spawn_tcp_dns_server(|query| { let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR); resp.header.authoritative_answer = true; if let Some(q) = query.questions.first() { if q.qtype == QueryType::A || q.qtype == QueryType::NS { resp.answers.push(DnsRecord::A { domain: q.name.clone(), addr: Ipv4Addr::new(10, 0, 0, 1), ttl: 300, }); } } resp }) .await; let srtt = RwLock::new(SrttCache::new(true)); let result = send_query("test.example.com", QueryType::A, server_addr, &srtt).await; let resp = result.expect("should resolve via TCP fallback"); assert_eq!(resp.header.rescode, ResultCode::NOERROR); assert!(!resp.answers.is_empty()); match &resp.answers[0] { DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::new(10, 0, 0, 1)), other => panic!("expected A record, got {:?}", other), } } /// Full iterative resolution through TCP-only mock: root referral → authoritative answer. /// The mock plays both roles (returns referral for NS queries, answer for A queries). #[tokio::test] async fn tcp_only_iterative_resolution() { UDP_DISABLED.store(true, Ordering::Release); // Skip UDP entirely for speed let server_addr = spawn_tcp_dns_server(|query| { let q = match query.questions.first() { Some(q) => q, None => return DnsPacket::response_from(query, ResultCode::SERVFAIL), }; if q.qtype == QueryType::NS || q.name == "com" { // Return referral — NS points back to ourselves (same IP, port 53 in glue // won't work, but cache will have our address from root_hints) let mut resp = DnsPacket::new(); resp.header.id = query.header.id; resp.header.response = true; resp.header.rescode = ResultCode::NOERROR; resp.questions = query.questions.clone(); resp.authorities.push(DnsRecord::NS { domain: "com".into(), host: "ns1.com".into(), ttl: 3600, }); resp } else { // Return authoritative answer let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR); resp.header.authoritative_answer = true; resp.answers.push(DnsRecord::A { domain: q.name.clone(), addr: Ipv4Addr::new(10, 0, 0, 42), ttl: 300, }); resp } }) .await; let srtt = RwLock::new(SrttCache::new(true)); let result = send_query("hello.example.com", QueryType::A, server_addr, &srtt).await; let resp = result.expect("TCP-only send_query should work"); assert_eq!(resp.header.rescode, ResultCode::NOERROR); match &resp.answers[0] { DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::new(10, 0, 0, 42)), other => panic!("expected A, got {:?}", other), } } #[tokio::test] async fn tcp_fallback_handles_nxdomain() { UDP_DISABLED.store(false, Ordering::Relaxed); UDP_FAILURES.store(0, Ordering::Release); let server_addr = spawn_tcp_dns_server(|query| { let mut resp = DnsPacket::response_from(query, ResultCode::NXDOMAIN); resp.header.authoritative_answer = true; resp }) .await; let cache = RwLock::new(DnsCache::new(100, 60, 86400)); let srtt = RwLock::new(SrttCache::new(true)); let root_hints = vec![server_addr]; let result = resolve_iterative( "nonexistent.test", QueryType::A, &cache, &root_hints, &srtt, 0, 0, ) .await; let resp = result.expect("NXDOMAIN should still return a response"); assert_eq!(resp.header.rescode, ResultCode::NXDOMAIN); assert!(resp.answers.is_empty()); } #[tokio::test] async fn udp_auto_disable_resets() { UDP_DISABLED.store(true, Ordering::Release); UDP_FAILURES.store(5, Ordering::Relaxed); reset_udp_state(); assert!(!UDP_DISABLED.load(Ordering::Acquire)); assert_eq!(UDP_FAILURES.load(Ordering::Relaxed), 0); } /// Test forward_tcp directly — verifies the length-prefixed wire format. #[tokio::test] async fn forward_tcp_wire_format() { let server_addr = spawn_tcp_dns_server(|query| { let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR); resp.header.authoritative_answer = true; if let Some(q) = query.questions.first() { resp.answers.push(DnsRecord::A { domain: q.name.clone(), addr: Ipv4Addr::new(1, 2, 3, 4), ttl: 60, }); } resp }) .await; let query = DnsPacket::query(0xBEEF, "test.com", QueryType::A); let resp = crate::forward::forward_tcp(&query, server_addr, Duration::from_secs(2)) .await .expect("forward_tcp should succeed"); assert_eq!(resp.header.id, 0xBEEF); assert_eq!(resp.header.rescode, ResultCode::NOERROR); assert!(!resp.answers.is_empty()); } /// Strict server: reads with a single read() call, rejecting split writes. /// Simulates Microsoft Azure DNS behavior that caused the early-eof bug. #[tokio::test] async fn forward_tcp_single_segment_write() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); // Single read — if length prefix arrives separately, this gets // only 2 bytes and the parse fails (simulating the Microsoft bug). let mut buf = vec![0u8; 4096]; let n = tokio::io::AsyncReadExt::read(&mut stream, &mut buf) .await .unwrap(); assert!( n >= 2 + 12, // length prefix + DNS header minimum "got only {} bytes in first read — split write bug", n ); let msg_len = u16::from_be_bytes([buf[0], buf[1]]) as usize; assert_eq!(msg_len, n - 2, "length prefix doesn't match payload"); // Parse and respond let mut pkt_buf = BytePacketBuffer::from_bytes(&buf[2..n]); let query = DnsPacket::from_buffer(&mut pkt_buf).unwrap(); let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); resp.answers.push(DnsRecord::A { domain: query.questions[0].name.clone(), addr: Ipv4Addr::new(5, 6, 7, 8), ttl: 60, }); let mut resp_buf = BytePacketBuffer::new(); resp.write(&mut resp_buf).unwrap(); let resp_bytes = resp_buf.filled(); let mut out = Vec::with_capacity(2 + resp_bytes.len()); out.extend_from_slice(&(resp_bytes.len() as u16).to_be_bytes()); out.extend_from_slice(resp_bytes); tokio::io::AsyncWriteExt::write_all(&mut stream, &out) .await .unwrap(); }); let query = DnsPacket::query(0xCAFE, "strict.test", QueryType::A); let resp = crate::forward::forward_tcp(&query, addr, Duration::from_secs(2)) .await .expect("forward_tcp must send length+message in single segment"); assert_eq!(resp.header.id, 0xCAFE); match &resp.answers[0] { DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::new(5, 6, 7, 8)), other => panic!("expected A, got {:?}", other), } } }