From 32cd8624b49b90b239b6f11e6897682f9299fff1 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 29 Mar 2026 14:22:07 +0300 Subject: [PATCH] refactor: deduplicate query builders, record extraction, sinkhole records (#22) - Add DnsPacket::query(id, domain, qtype) constructor; replace mock_query, make_query, and 4 inline constructions across ctx/forward/recursive/api - Add record_to_addr() in recursive.rs; replace 4 identical A/AAAA match blocks with filter_map one-liners - Add sinkhole_record() in ctx.rs; consolidate localhost and blocklist A/AAAA branching into single calls - Remove now-unused DnsQuestion imports Co-authored-by: Claude Opus 4.6 (1M context) --- src/api.rs | 8 +--- src/ctx.rs | 105 ++++++++++++++++++++++------------------------- src/forward.rs | 9 +--- src/packet.rs | 9 ++++ src/recursive.rs | 76 +++++++++++----------------------- 5 files changed, 84 insertions(+), 123 deletions(-) diff --git a/src/api.rs b/src/api.rs index c3dc324..04a81bf 100644 --- a/src/api.rs +++ b/src/api.rs @@ -410,14 +410,8 @@ async fn forward_query_for_diagnose( timeout: std::time::Duration, ) -> (bool, String) { use crate::packet::DnsPacket; - use crate::question::DnsQuestion; - let mut query = DnsPacket::new(); - query.header.id = 0xBEEF; - query.header.recursion_desired = true; - query - .questions - .push(DnsQuestion::new(domain.to_string(), QueryType::A)); + let query = DnsPacket::query(0xBEEF, domain, QueryType::A); match forward_query(&query, upstream, timeout).await { Ok(resp) => ( diff --git a/src/ctx.rs b/src/ctx.rs index fbddb15..b21e20b 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -93,18 +93,13 @@ pub async fn handle_query( } else if qname == "localhost" || qname.ends_with(".localhost") { // RFC 6761: .localhost always resolves to loopback let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); - match qtype { - QueryType::AAAA => resp.answers.push(DnsRecord::AAAA { - domain: qname.clone(), - addr: std::net::Ipv6Addr::LOCALHOST, - ttl: 300, - }), - _ => resp.answers.push(DnsRecord::A { - domain: qname.clone(), - addr: std::net::Ipv4Addr::LOCALHOST, - ttl: 300, - }), - } + resp.answers.push(sinkhole_record( + &qname, + qtype, + std::net::Ipv4Addr::LOCALHOST, + std::net::Ipv6Addr::LOCALHOST, + 300, + )); (resp, QueryPath::Local, DnssecStatus::Indeterminate) } else if is_special_use_domain(&qname) { // RFC 6761/8880: private PTR, DDR, NAT64 — answer locally @@ -130,38 +125,24 @@ pub async fn handle_query( .unwrap_or(std::net::Ipv4Addr::LOCALHOST) } }; + let v6 = if resolve_ip == std::net::Ipv4Addr::LOCALHOST { + std::net::Ipv6Addr::LOCALHOST + } else { + resolve_ip.to_ipv6_mapped() + }; let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); - match qtype { - QueryType::AAAA => resp.answers.push(DnsRecord::AAAA { - domain: qname.clone(), - addr: if resolve_ip == std::net::Ipv4Addr::LOCALHOST { - std::net::Ipv6Addr::LOCALHOST - } else { - resolve_ip.to_ipv6_mapped() - }, - ttl: 300, - }), - _ => resp.answers.push(DnsRecord::A { - domain: qname.clone(), - addr: resolve_ip, - ttl: 300, - }), - } + resp.answers + .push(sinkhole_record(&qname, qtype, resolve_ip, v6, 300)); (resp, QueryPath::Local, DnssecStatus::Indeterminate) } else if ctx.blocklist.read().unwrap().is_blocked(&qname) { let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); - match qtype { - QueryType::AAAA => resp.answers.push(DnsRecord::AAAA { - domain: qname.clone(), - addr: std::net::Ipv6Addr::UNSPECIFIED, - ttl: 60, - }), - _ => resp.answers.push(DnsRecord::A { - domain: qname.clone(), - addr: std::net::Ipv4Addr::UNSPECIFIED, - ttl: 60, - }), - } + resp.answers.push(sinkhole_record( + &qname, + qtype, + std::net::Ipv4Addr::UNSPECIFIED, + std::net::Ipv6Addr::UNSPECIFIED, + 60, + )); (resp, QueryPath::Blocked, DnssecStatus::Indeterminate) } else if let Some(records) = ctx.zone_map.get(qname.as_str()).and_then(|m| m.get(&qtype)) { let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); @@ -383,6 +364,27 @@ fn is_special_use_domain(qname: &str) -> bool { qname == "local" || qname.ends_with(".local") } +fn sinkhole_record( + domain: &str, + qtype: QueryType, + v4: std::net::Ipv4Addr, + v6: std::net::Ipv6Addr, + ttl: u32, +) -> DnsRecord { + match qtype { + QueryType::AAAA => DnsRecord::AAAA { + domain: domain.to_string(), + addr: v6, + ttl, + }, + _ => DnsRecord::A { + domain: domain.to_string(), + addr: v4, + ttl, + }, + } +} + enum Disposition { Leader(broadcast::Sender>), Follower(broadcast::Receiver>), @@ -675,15 +677,6 @@ mod tests { // ---- Integration: resolve_coalesced with mock futures ---- - fn mock_query(id: u16, domain: &str, qtype: QueryType) -> DnsPacket { - let mut pkt = DnsPacket::new(); - pkt.header.id = id; - pkt.header.recursion_desired = true; - pkt.questions - .push(crate::question::DnsQuestion::new(domain.to_string(), qtype)); - pkt - } - fn mock_response(domain: &str) -> DnsPacket { let mut resp = DnsPacket::new(); resp.header.response = true; @@ -706,7 +699,7 @@ mod tests { let count = resolve_count.clone(); let inf = inflight.clone(); let key = ("coalesce.test".to_string(), QueryType::A); - let query = mock_query(100 + i, "coalesce.test", QueryType::A); + let query = DnsPacket::query(100 + i, "coalesce.test", QueryType::A); handles.push(tokio::spawn(async move { resolve_coalesced(&inf, key, &query, || async { count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); @@ -744,8 +737,8 @@ mod tests { let count1 = resolve_count.clone(); let count2 = resolve_count.clone(); - let query_a = mock_query(200, "same.domain", QueryType::A); - let query_aaaa = mock_query(201, "same.domain", QueryType::AAAA); + let query_a = DnsPacket::query(200, "same.domain", QueryType::A); + let query_aaaa = DnsPacket::query(201, "same.domain", QueryType::AAAA); let h1 = tokio::spawn(async move { resolve_coalesced( @@ -788,7 +781,7 @@ mod tests { #[tokio::test] async fn inflight_map_cleaned_after_error() { let inflight: Mutex = Mutex::new(HashMap::new()); - let query = mock_query(300, "will-fail.test", QueryType::A); + let query = DnsPacket::query(300, "will-fail.test", QueryType::A); let (_, path, _) = resolve_coalesced( &inflight, @@ -809,7 +802,7 @@ mod tests { let mut handles = Vec::new(); for i in 0..3u16 { let inf = inflight.clone(); - let query = mock_query(400 + i, "fail.test", QueryType::A); + let query = DnsPacket::query(400 + i, "fail.test", QueryType::A); handles.push(tokio::spawn(async move { resolve_coalesced( &inf, @@ -849,7 +842,7 @@ mod tests { #[tokio::test] async fn servfail_leader_includes_question_section() { let inflight: Mutex = Mutex::new(HashMap::new()); - let query = mock_query(500, "question.test", QueryType::A); + let query = DnsPacket::query(500, "question.test", QueryType::A); let (resp, _, _) = resolve_coalesced( &inflight, @@ -873,7 +866,7 @@ mod tests { #[tokio::test] async fn leader_error_preserves_message() { let inflight: Mutex = Mutex::new(HashMap::new()); - let query = mock_query(700, "err-msg.test", QueryType::A); + let query = DnsPacket::query(700, "err-msg.test", QueryType::A); let (_, path, err) = resolve_coalesced( &inflight, diff --git a/src/forward.rs b/src/forward.rs index 1d4ed25..ea2b03e 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -141,7 +141,7 @@ mod tests { use std::future::IntoFuture; use crate::header::ResultCode; - use crate::question::{DnsQuestion, QueryType}; + use crate::question::QueryType; use crate::record::DnsRecord; #[test] @@ -160,12 +160,7 @@ mod tests { } fn make_query() -> DnsPacket { - let mut q = DnsPacket::new(); - q.header.id = 0xABCD; - q.header.recursion_desired = true; - q.questions - .push(DnsQuestion::new("example.com".to_string(), QueryType::A)); - q + DnsPacket::query(0xABCD, "example.com", QueryType::A) } fn make_response(query: &DnsPacket) -> DnsPacket { diff --git a/src/packet.rs b/src/packet.rs index e273ba8..bf89ea3 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -57,6 +57,15 @@ impl DnsPacket { } } + pub fn query(id: u16, domain: &str, qtype: crate::question::QueryType) -> DnsPacket { + let mut pkt = DnsPacket::new(); + pkt.header.id = id; + pkt.header.recursion_desired = true; + pkt.questions + .push(crate::question::DnsQuestion::new(domain.to_string(), qtype)); + pkt + } + pub fn response_from(query: &DnsPacket, rescode: crate::header::ResultCode) -> DnsPacket { let mut resp = DnsPacket::new(); resp.header.id = query.header.id; diff --git a/src/recursive.rs b/src/recursive.rs index 54a9625..82f9879 100644 --- a/src/recursive.rs +++ b/src/recursive.rs @@ -9,7 +9,7 @@ use crate::cache::DnsCache; use crate::forward::forward_udp; use crate::header::ResultCode; use crate::packet::DnsPacket; -use crate::question::{DnsQuestion, QueryType}; +use crate::question::QueryType; use crate::record::DnsRecord; use crate::srtt::SrttCache; @@ -32,6 +32,14 @@ fn dns_addr(ip: impl Into) -> SocketAddr { SocketAddr::new(ip.into(), 53) } +fn record_to_addr(rec: &DnsRecord) -> Option { + match rec { + DnsRecord::A { addr, .. } => Some(dns_addr(*addr)), + DnsRecord::AAAA { addr, .. } => Some(dns_addr(*addr)), + _ => None, + } +} + pub fn reset_udp_state() { UDP_DISABLED.store(false, Ordering::Release); UDP_FAILURES.store(0, Ordering::Release); @@ -46,11 +54,8 @@ pub async fn probe_udp(root_hints: &[SocketAddr]) { Some(h) => *h, None => return, }; - let mut probe = DnsPacket::new(); - probe.header.id = next_id(); - probe - .questions - .push(DnsQuestion::new(".".to_string(), QueryType::NS)); + let mut probe = DnsPacket::query(next_id(), ".", QueryType::NS); + probe.header.recursion_desired = false; if forward_udp(&probe, hint, Duration::from_millis(1500)) .await .is_ok() @@ -296,17 +301,8 @@ pub(crate) fn resolve_iterative<'a>( ) .await { - for rec in &ns_resp.answers { - match rec { - DnsRecord::A { addr, .. } => { - new_ns_addrs.push(dns_addr(*addr)); - } - DnsRecord::AAAA { addr, .. } => { - new_ns_addrs.push(dns_addr(*addr)); - } - _ => {} - } - } + new_ns_addrs + .extend(ns_resp.answers.iter().filter_map(record_to_addr)); } if !new_ns_addrs.is_empty() { break; @@ -360,13 +356,7 @@ fn find_closest_ns( if let DnsRecord::NS { host, .. } = ns_rec { for qt in [QueryType::A, QueryType::AAAA] { if let Some(resp) = guard.lookup(host, qt) { - for rec in &resp.answers { - match rec { - DnsRecord::A { addr, .. } => addrs.push(dns_addr(*addr)), - DnsRecord::AAAA { addr, .. } => addrs.push(dns_addr(*addr)), - _ => {} - } - } + addrs.extend(resp.answers.iter().filter_map(record_to_addr)); } } } @@ -452,13 +442,7 @@ fn addrs_from_cache(cache: &RwLock, name: &str) -> Vec { let mut addrs = Vec::new(); for qt in [QueryType::A, QueryType::AAAA] { if let Some(pkt) = guard.lookup(name, qt) { - for rec in &pkt.answers { - match rec { - DnsRecord::A { addr, .. } => addrs.push(dns_addr(*addr)), - DnsRecord::AAAA { addr, .. } => addrs.push(dns_addr(*addr)), - _ => {} - } - } + addrs.extend(pkt.answers.iter().filter_map(record_to_addr)); } } addrs @@ -468,15 +452,13 @@ fn glue_addrs_for(response: &DnsPacket, ns_name: &str) -> Vec { response .resources .iter() - .filter_map(|r| match r { - DnsRecord::A { domain, addr, .. } if domain.eq_ignore_ascii_case(ns_name) => { - Some(dns_addr(*addr)) + .filter(|r| match r { + DnsRecord::A { domain, .. } | DnsRecord::AAAA { domain, .. } => { + domain.eq_ignore_ascii_case(ns_name) } - DnsRecord::AAAA { domain, addr, .. } if domain.eq_ignore_ascii_case(ns_name) => { - Some(dns_addr(*addr)) - } - _ => None, + _ => false, }) + .filter_map(record_to_addr) .collect() } @@ -596,12 +578,8 @@ async fn send_query( server: SocketAddr, srtt: &RwLock, ) -> crate::Result { - let mut query = DnsPacket::new(); - query.header.id = next_id(); + let mut query = DnsPacket::query(next_id(), qname, qtype); query.header.recursion_desired = false; - query - .questions - .push(DnsQuestion::new(qname.to_string(), qtype)); query.edns = Some(crate::packet::EdnsOpt { do_bit: true, ..Default::default() @@ -1056,11 +1034,7 @@ mod tests { }) .await; - let mut query = DnsPacket::new(); - query.header.id = 0xBEEF; - query - .questions - .push(DnsQuestion::new("test.com".to_string(), QueryType::A)); + let query = DnsPacket::query(0xBEEF, "test.com", QueryType::A); let resp = crate::forward::forward_tcp(&query, server_addr, Duration::from_secs(2)) .await @@ -1120,11 +1094,7 @@ mod tests { .unwrap(); }); - let mut query = DnsPacket::new(); - query.header.id = 0xCAFE; - query - .questions - .push(DnsQuestion::new("strict.test".to_string(), QueryType::A)); + let query = DnsPacket::query(0xCAFE, "strict.test", QueryType::A); let resp = crate::forward::forward_tcp(&query, addr, Duration::from_secs(2)) .await