refactor: deduplicate query/record/sinkhole helpers #22

Merged
razvandimescu merged 1 commits from refactor/dedup-helpers into main 2026-03-29 19:22:07 +08:00
5 changed files with 84 additions and 123 deletions

View File

@@ -410,14 +410,8 @@ async fn forward_query_for_diagnose(
timeout: std::time::Duration, timeout: std::time::Duration,
) -> (bool, String) { ) -> (bool, String) {
use crate::packet::DnsPacket; use crate::packet::DnsPacket;
use crate::question::DnsQuestion;
let mut query = DnsPacket::new(); let query = DnsPacket::query(0xBEEF, domain, QueryType::A);
query.header.id = 0xBEEF;
query.header.recursion_desired = true;
query
.questions
.push(DnsQuestion::new(domain.to_string(), QueryType::A));
match forward_query(&query, upstream, timeout).await { match forward_query(&query, upstream, timeout).await {
Ok(resp) => ( Ok(resp) => (

View File

@@ -93,18 +93,13 @@ pub async fn handle_query(
} else if qname == "localhost" || qname.ends_with(".localhost") { } else if qname == "localhost" || qname.ends_with(".localhost") {
// RFC 6761: .localhost always resolves to loopback // RFC 6761: .localhost always resolves to loopback
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
match qtype { resp.answers.push(sinkhole_record(
QueryType::AAAA => resp.answers.push(DnsRecord::AAAA { &qname,
domain: qname.clone(), qtype,
addr: std::net::Ipv6Addr::LOCALHOST, std::net::Ipv4Addr::LOCALHOST,
ttl: 300, std::net::Ipv6Addr::LOCALHOST,
}), 300,
_ => resp.answers.push(DnsRecord::A { ));
domain: qname.clone(),
addr: std::net::Ipv4Addr::LOCALHOST,
ttl: 300,
}),
}
(resp, QueryPath::Local, DnssecStatus::Indeterminate) (resp, QueryPath::Local, DnssecStatus::Indeterminate)
} else if is_special_use_domain(&qname) { } else if is_special_use_domain(&qname) {
// RFC 6761/8880: private PTR, DDR, NAT64 — answer locally // RFC 6761/8880: private PTR, DDR, NAT64 — answer locally
@@ -130,38 +125,24 @@ pub async fn handle_query(
.unwrap_or(std::net::Ipv4Addr::LOCALHOST) .unwrap_or(std::net::Ipv4Addr::LOCALHOST)
} }
}; };
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); let v6 = if resolve_ip == std::net::Ipv4Addr::LOCALHOST {
match qtype {
QueryType::AAAA => resp.answers.push(DnsRecord::AAAA {
domain: qname.clone(),
addr: if resolve_ip == std::net::Ipv4Addr::LOCALHOST {
std::net::Ipv6Addr::LOCALHOST std::net::Ipv6Addr::LOCALHOST
} else { } else {
resolve_ip.to_ipv6_mapped() resolve_ip.to_ipv6_mapped()
}, };
ttl: 300, let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
}), resp.answers
_ => resp.answers.push(DnsRecord::A { .push(sinkhole_record(&qname, qtype, resolve_ip, v6, 300));
domain: qname.clone(),
addr: resolve_ip,
ttl: 300,
}),
}
(resp, QueryPath::Local, DnssecStatus::Indeterminate) (resp, QueryPath::Local, DnssecStatus::Indeterminate)
} else if ctx.blocklist.read().unwrap().is_blocked(&qname) { } else if ctx.blocklist.read().unwrap().is_blocked(&qname) {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
match qtype { resp.answers.push(sinkhole_record(
QueryType::AAAA => resp.answers.push(DnsRecord::AAAA { &qname,
domain: qname.clone(), qtype,
addr: std::net::Ipv6Addr::UNSPECIFIED, std::net::Ipv4Addr::UNSPECIFIED,
ttl: 60, std::net::Ipv6Addr::UNSPECIFIED,
}), 60,
_ => resp.answers.push(DnsRecord::A { ));
domain: qname.clone(),
addr: std::net::Ipv4Addr::UNSPECIFIED,
ttl: 60,
}),
}
(resp, QueryPath::Blocked, DnssecStatus::Indeterminate) (resp, QueryPath::Blocked, DnssecStatus::Indeterminate)
} else if let Some(records) = ctx.zone_map.get(qname.as_str()).and_then(|m| m.get(&qtype)) { } else if let Some(records) = ctx.zone_map.get(qname.as_str()).and_then(|m| m.get(&qtype)) {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
@@ -383,6 +364,27 @@ fn is_special_use_domain(qname: &str) -> bool {
qname == "local" || qname.ends_with(".local") qname == "local" || qname.ends_with(".local")
} }
fn sinkhole_record(
domain: &str,
qtype: QueryType,
v4: std::net::Ipv4Addr,
v6: std::net::Ipv6Addr,
ttl: u32,
) -> DnsRecord {
match qtype {
QueryType::AAAA => DnsRecord::AAAA {
domain: domain.to_string(),
addr: v6,
ttl,
},
_ => DnsRecord::A {
domain: domain.to_string(),
addr: v4,
ttl,
},
}
}
enum Disposition { enum Disposition {
Leader(broadcast::Sender<Option<DnsPacket>>), Leader(broadcast::Sender<Option<DnsPacket>>),
Follower(broadcast::Receiver<Option<DnsPacket>>), Follower(broadcast::Receiver<Option<DnsPacket>>),
@@ -675,15 +677,6 @@ mod tests {
// ---- Integration: resolve_coalesced with mock futures ---- // ---- Integration: resolve_coalesced with mock futures ----
fn mock_query(id: u16, domain: &str, qtype: QueryType) -> DnsPacket {
let mut pkt = DnsPacket::new();
pkt.header.id = id;
pkt.header.recursion_desired = true;
pkt.questions
.push(crate::question::DnsQuestion::new(domain.to_string(), qtype));
pkt
}
fn mock_response(domain: &str) -> DnsPacket { fn mock_response(domain: &str) -> DnsPacket {
let mut resp = DnsPacket::new(); let mut resp = DnsPacket::new();
resp.header.response = true; resp.header.response = true;
@@ -706,7 +699,7 @@ mod tests {
let count = resolve_count.clone(); let count = resolve_count.clone();
let inf = inflight.clone(); let inf = inflight.clone();
let key = ("coalesce.test".to_string(), QueryType::A); let key = ("coalesce.test".to_string(), QueryType::A);
let query = mock_query(100 + i, "coalesce.test", QueryType::A); let query = DnsPacket::query(100 + i, "coalesce.test", QueryType::A);
handles.push(tokio::spawn(async move { handles.push(tokio::spawn(async move {
resolve_coalesced(&inf, key, &query, || async { resolve_coalesced(&inf, key, &query, || async {
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
@@ -744,8 +737,8 @@ mod tests {
let count1 = resolve_count.clone(); let count1 = resolve_count.clone();
let count2 = resolve_count.clone(); let count2 = resolve_count.clone();
let query_a = mock_query(200, "same.domain", QueryType::A); let query_a = DnsPacket::query(200, "same.domain", QueryType::A);
let query_aaaa = mock_query(201, "same.domain", QueryType::AAAA); let query_aaaa = DnsPacket::query(201, "same.domain", QueryType::AAAA);
let h1 = tokio::spawn(async move { let h1 = tokio::spawn(async move {
resolve_coalesced( resolve_coalesced(
@@ -788,7 +781,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn inflight_map_cleaned_after_error() { async fn inflight_map_cleaned_after_error() {
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new()); let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
let query = mock_query(300, "will-fail.test", QueryType::A); let query = DnsPacket::query(300, "will-fail.test", QueryType::A);
let (_, path, _) = resolve_coalesced( let (_, path, _) = resolve_coalesced(
&inflight, &inflight,
@@ -809,7 +802,7 @@ mod tests {
let mut handles = Vec::new(); let mut handles = Vec::new();
for i in 0..3u16 { for i in 0..3u16 {
let inf = inflight.clone(); let inf = inflight.clone();
let query = mock_query(400 + i, "fail.test", QueryType::A); let query = DnsPacket::query(400 + i, "fail.test", QueryType::A);
handles.push(tokio::spawn(async move { handles.push(tokio::spawn(async move {
resolve_coalesced( resolve_coalesced(
&inf, &inf,
@@ -849,7 +842,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn servfail_leader_includes_question_section() { async fn servfail_leader_includes_question_section() {
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new()); let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
let query = mock_query(500, "question.test", QueryType::A); let query = DnsPacket::query(500, "question.test", QueryType::A);
let (resp, _, _) = resolve_coalesced( let (resp, _, _) = resolve_coalesced(
&inflight, &inflight,
@@ -873,7 +866,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn leader_error_preserves_message() { async fn leader_error_preserves_message() {
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new()); let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
let query = mock_query(700, "err-msg.test", QueryType::A); let query = DnsPacket::query(700, "err-msg.test", QueryType::A);
let (_, path, err) = resolve_coalesced( let (_, path, err) = resolve_coalesced(
&inflight, &inflight,

View File

@@ -141,7 +141,7 @@ mod tests {
use std::future::IntoFuture; use std::future::IntoFuture;
use crate::header::ResultCode; use crate::header::ResultCode;
use crate::question::{DnsQuestion, QueryType}; use crate::question::QueryType;
use crate::record::DnsRecord; use crate::record::DnsRecord;
#[test] #[test]
@@ -160,12 +160,7 @@ mod tests {
} }
fn make_query() -> DnsPacket { fn make_query() -> DnsPacket {
let mut q = DnsPacket::new(); DnsPacket::query(0xABCD, "example.com", QueryType::A)
q.header.id = 0xABCD;
q.header.recursion_desired = true;
q.questions
.push(DnsQuestion::new("example.com".to_string(), QueryType::A));
q
} }
fn make_response(query: &DnsPacket) -> DnsPacket { fn make_response(query: &DnsPacket) -> DnsPacket {

View File

@@ -57,6 +57,15 @@ impl DnsPacket {
} }
} }
pub fn query(id: u16, domain: &str, qtype: crate::question::QueryType) -> DnsPacket {
let mut pkt = DnsPacket::new();
pkt.header.id = id;
pkt.header.recursion_desired = true;
pkt.questions
.push(crate::question::DnsQuestion::new(domain.to_string(), qtype));
pkt
}
pub fn response_from(query: &DnsPacket, rescode: crate::header::ResultCode) -> DnsPacket { pub fn response_from(query: &DnsPacket, rescode: crate::header::ResultCode) -> DnsPacket {
let mut resp = DnsPacket::new(); let mut resp = DnsPacket::new();
resp.header.id = query.header.id; resp.header.id = query.header.id;

View File

@@ -9,7 +9,7 @@ use crate::cache::DnsCache;
use crate::forward::forward_udp; use crate::forward::forward_udp;
use crate::header::ResultCode; use crate::header::ResultCode;
use crate::packet::DnsPacket; use crate::packet::DnsPacket;
use crate::question::{DnsQuestion, QueryType}; use crate::question::QueryType;
use crate::record::DnsRecord; use crate::record::DnsRecord;
use crate::srtt::SrttCache; use crate::srtt::SrttCache;
@@ -32,6 +32,14 @@ fn dns_addr(ip: impl Into<IpAddr>) -> SocketAddr {
SocketAddr::new(ip.into(), 53) SocketAddr::new(ip.into(), 53)
} }
fn record_to_addr(rec: &DnsRecord) -> Option<SocketAddr> {
match rec {
DnsRecord::A { addr, .. } => Some(dns_addr(*addr)),
DnsRecord::AAAA { addr, .. } => Some(dns_addr(*addr)),
_ => None,
}
}
pub fn reset_udp_state() { pub fn reset_udp_state() {
UDP_DISABLED.store(false, Ordering::Release); UDP_DISABLED.store(false, Ordering::Release);
UDP_FAILURES.store(0, Ordering::Release); UDP_FAILURES.store(0, Ordering::Release);
@@ -46,11 +54,8 @@ pub async fn probe_udp(root_hints: &[SocketAddr]) {
Some(h) => *h, Some(h) => *h,
None => return, None => return,
}; };
let mut probe = DnsPacket::new(); let mut probe = DnsPacket::query(next_id(), ".", QueryType::NS);
probe.header.id = next_id(); probe.header.recursion_desired = false;
probe
.questions
.push(DnsQuestion::new(".".to_string(), QueryType::NS));
if forward_udp(&probe, hint, Duration::from_millis(1500)) if forward_udp(&probe, hint, Duration::from_millis(1500))
.await .await
.is_ok() .is_ok()
@@ -296,17 +301,8 @@ pub(crate) fn resolve_iterative<'a>(
) )
.await .await
{ {
for rec in &ns_resp.answers { new_ns_addrs
match rec { .extend(ns_resp.answers.iter().filter_map(record_to_addr));
DnsRecord::A { addr, .. } => {
new_ns_addrs.push(dns_addr(*addr));
}
DnsRecord::AAAA { addr, .. } => {
new_ns_addrs.push(dns_addr(*addr));
}
_ => {}
}
}
} }
if !new_ns_addrs.is_empty() { if !new_ns_addrs.is_empty() {
break; break;
@@ -360,13 +356,7 @@ fn find_closest_ns(
if let DnsRecord::NS { host, .. } = ns_rec { if let DnsRecord::NS { host, .. } = ns_rec {
for qt in [QueryType::A, QueryType::AAAA] { for qt in [QueryType::A, QueryType::AAAA] {
if let Some(resp) = guard.lookup(host, qt) { if let Some(resp) = guard.lookup(host, qt) {
for rec in &resp.answers { addrs.extend(resp.answers.iter().filter_map(record_to_addr));
match rec {
DnsRecord::A { addr, .. } => addrs.push(dns_addr(*addr)),
DnsRecord::AAAA { addr, .. } => addrs.push(dns_addr(*addr)),
_ => {}
}
}
} }
} }
} }
@@ -452,13 +442,7 @@ fn addrs_from_cache(cache: &RwLock<DnsCache>, name: &str) -> Vec<SocketAddr> {
let mut addrs = Vec::new(); let mut addrs = Vec::new();
for qt in [QueryType::A, QueryType::AAAA] { for qt in [QueryType::A, QueryType::AAAA] {
if let Some(pkt) = guard.lookup(name, qt) { if let Some(pkt) = guard.lookup(name, qt) {
for rec in &pkt.answers { addrs.extend(pkt.answers.iter().filter_map(record_to_addr));
match rec {
DnsRecord::A { addr, .. } => addrs.push(dns_addr(*addr)),
DnsRecord::AAAA { addr, .. } => addrs.push(dns_addr(*addr)),
_ => {}
}
}
} }
} }
addrs addrs
@@ -468,15 +452,13 @@ fn glue_addrs_for(response: &DnsPacket, ns_name: &str) -> Vec<SocketAddr> {
response response
.resources .resources
.iter() .iter()
.filter_map(|r| match r { .filter(|r| match r {
DnsRecord::A { domain, addr, .. } if domain.eq_ignore_ascii_case(ns_name) => { DnsRecord::A { domain, .. } | DnsRecord::AAAA { domain, .. } => {
Some(dns_addr(*addr)) domain.eq_ignore_ascii_case(ns_name)
} }
DnsRecord::AAAA { domain, addr, .. } if domain.eq_ignore_ascii_case(ns_name) => { _ => false,
Some(dns_addr(*addr))
}
_ => None,
}) })
.filter_map(record_to_addr)
.collect() .collect()
} }
@@ -596,12 +578,8 @@ async fn send_query(
server: SocketAddr, server: SocketAddr,
srtt: &RwLock<SrttCache>, srtt: &RwLock<SrttCache>,
) -> crate::Result<DnsPacket> { ) -> crate::Result<DnsPacket> {
let mut query = DnsPacket::new(); let mut query = DnsPacket::query(next_id(), qname, qtype);
query.header.id = next_id();
query.header.recursion_desired = false; query.header.recursion_desired = false;
query
.questions
.push(DnsQuestion::new(qname.to_string(), qtype));
query.edns = Some(crate::packet::EdnsOpt { query.edns = Some(crate::packet::EdnsOpt {
do_bit: true, do_bit: true,
..Default::default() ..Default::default()
@@ -1056,11 +1034,7 @@ mod tests {
}) })
.await; .await;
let mut query = DnsPacket::new(); let query = DnsPacket::query(0xBEEF, "test.com", QueryType::A);
query.header.id = 0xBEEF;
query
.questions
.push(DnsQuestion::new("test.com".to_string(), QueryType::A));
let resp = crate::forward::forward_tcp(&query, server_addr, Duration::from_secs(2)) let resp = crate::forward::forward_tcp(&query, server_addr, Duration::from_secs(2))
.await .await
@@ -1120,11 +1094,7 @@ mod tests {
.unwrap(); .unwrap();
}); });
let mut query = DnsPacket::new(); let query = DnsPacket::query(0xCAFE, "strict.test", QueryType::A);
query.header.id = 0xCAFE;
query
.questions
.push(DnsQuestion::new("strict.test".to_string(), QueryType::A));
let resp = crate::forward::forward_tcp(&query, addr, Duration::from_secs(2)) let resp = crate::forward::forward_tcp(&query, addr, Duration::from_secs(2))
.await .await