refactor: deduplicate query builders, record extraction, sinkhole records
- Add DnsPacket::query(id, domain, qtype) constructor; replace mock_query, make_query, and 4 inline constructions across ctx/forward/recursive/api - Add record_to_addr() in recursive.rs; replace 4 identical A/AAAA match blocks with filter_map one-liners - Add sinkhole_record() in ctx.rs; consolidate localhost and blocklist A/AAAA branching into single calls - Remove now-unused DnsQuestion imports Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -410,14 +410,8 @@ async fn forward_query_for_diagnose(
|
||||
timeout: std::time::Duration,
|
||||
) -> (bool, String) {
|
||||
use crate::packet::DnsPacket;
|
||||
use crate::question::DnsQuestion;
|
||||
|
||||
let mut query = DnsPacket::new();
|
||||
query.header.id = 0xBEEF;
|
||||
query.header.recursion_desired = true;
|
||||
query
|
||||
.questions
|
||||
.push(DnsQuestion::new(domain.to_string(), QueryType::A));
|
||||
let query = DnsPacket::query(0xBEEF, domain, QueryType::A);
|
||||
|
||||
match forward_query(&query, upstream, timeout).await {
|
||||
Ok(resp) => (
|
||||
|
||||
101
src/ctx.rs
101
src/ctx.rs
@@ -93,18 +93,13 @@ pub async fn handle_query(
|
||||
} else if qname == "localhost" || qname.ends_with(".localhost") {
|
||||
// RFC 6761: .localhost always resolves to loopback
|
||||
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
||||
match qtype {
|
||||
QueryType::AAAA => resp.answers.push(DnsRecord::AAAA {
|
||||
domain: qname.clone(),
|
||||
addr: std::net::Ipv6Addr::LOCALHOST,
|
||||
ttl: 300,
|
||||
}),
|
||||
_ => resp.answers.push(DnsRecord::A {
|
||||
domain: qname.clone(),
|
||||
addr: std::net::Ipv4Addr::LOCALHOST,
|
||||
ttl: 300,
|
||||
}),
|
||||
}
|
||||
resp.answers.push(sinkhole_record(
|
||||
&qname,
|
||||
qtype,
|
||||
std::net::Ipv4Addr::LOCALHOST,
|
||||
std::net::Ipv6Addr::LOCALHOST,
|
||||
300,
|
||||
));
|
||||
(resp, QueryPath::Local, DnssecStatus::Indeterminate)
|
||||
} else if is_special_use_domain(&qname) {
|
||||
// RFC 6761/8880: private PTR, DDR, NAT64 — answer locally
|
||||
@@ -130,38 +125,24 @@ pub async fn handle_query(
|
||||
.unwrap_or(std::net::Ipv4Addr::LOCALHOST)
|
||||
}
|
||||
};
|
||||
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
||||
match qtype {
|
||||
QueryType::AAAA => resp.answers.push(DnsRecord::AAAA {
|
||||
domain: qname.clone(),
|
||||
addr: if resolve_ip == std::net::Ipv4Addr::LOCALHOST {
|
||||
let v6 = if resolve_ip == std::net::Ipv4Addr::LOCALHOST {
|
||||
std::net::Ipv6Addr::LOCALHOST
|
||||
} else {
|
||||
resolve_ip.to_ipv6_mapped()
|
||||
},
|
||||
ttl: 300,
|
||||
}),
|
||||
_ => resp.answers.push(DnsRecord::A {
|
||||
domain: qname.clone(),
|
||||
addr: resolve_ip,
|
||||
ttl: 300,
|
||||
}),
|
||||
}
|
||||
};
|
||||
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
||||
resp.answers
|
||||
.push(sinkhole_record(&qname, qtype, resolve_ip, v6, 300));
|
||||
(resp, QueryPath::Local, DnssecStatus::Indeterminate)
|
||||
} else if ctx.blocklist.read().unwrap().is_blocked(&qname) {
|
||||
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
||||
match qtype {
|
||||
QueryType::AAAA => resp.answers.push(DnsRecord::AAAA {
|
||||
domain: qname.clone(),
|
||||
addr: std::net::Ipv6Addr::UNSPECIFIED,
|
||||
ttl: 60,
|
||||
}),
|
||||
_ => resp.answers.push(DnsRecord::A {
|
||||
domain: qname.clone(),
|
||||
addr: std::net::Ipv4Addr::UNSPECIFIED,
|
||||
ttl: 60,
|
||||
}),
|
||||
}
|
||||
resp.answers.push(sinkhole_record(
|
||||
&qname,
|
||||
qtype,
|
||||
std::net::Ipv4Addr::UNSPECIFIED,
|
||||
std::net::Ipv6Addr::UNSPECIFIED,
|
||||
60,
|
||||
));
|
||||
(resp, QueryPath::Blocked, DnssecStatus::Indeterminate)
|
||||
} else if let Some(records) = ctx.zone_map.get(qname.as_str()).and_then(|m| m.get(&qtype)) {
|
||||
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
||||
@@ -383,6 +364,27 @@ fn is_special_use_domain(qname: &str) -> bool {
|
||||
qname == "local" || qname.ends_with(".local")
|
||||
}
|
||||
|
||||
fn sinkhole_record(
|
||||
domain: &str,
|
||||
qtype: QueryType,
|
||||
v4: std::net::Ipv4Addr,
|
||||
v6: std::net::Ipv6Addr,
|
||||
ttl: u32,
|
||||
) -> DnsRecord {
|
||||
match qtype {
|
||||
QueryType::AAAA => DnsRecord::AAAA {
|
||||
domain: domain.to_string(),
|
||||
addr: v6,
|
||||
ttl,
|
||||
},
|
||||
_ => DnsRecord::A {
|
||||
domain: domain.to_string(),
|
||||
addr: v4,
|
||||
ttl,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
enum Disposition {
|
||||
Leader(broadcast::Sender<Option<DnsPacket>>),
|
||||
Follower(broadcast::Receiver<Option<DnsPacket>>),
|
||||
@@ -675,15 +677,6 @@ mod tests {
|
||||
|
||||
// ---- Integration: resolve_coalesced with mock futures ----
|
||||
|
||||
fn mock_query(id: u16, domain: &str, qtype: QueryType) -> DnsPacket {
|
||||
let mut pkt = DnsPacket::new();
|
||||
pkt.header.id = id;
|
||||
pkt.header.recursion_desired = true;
|
||||
pkt.questions
|
||||
.push(crate::question::DnsQuestion::new(domain.to_string(), qtype));
|
||||
pkt
|
||||
}
|
||||
|
||||
fn mock_response(domain: &str) -> DnsPacket {
|
||||
let mut resp = DnsPacket::new();
|
||||
resp.header.response = true;
|
||||
@@ -706,7 +699,7 @@ mod tests {
|
||||
let count = resolve_count.clone();
|
||||
let inf = inflight.clone();
|
||||
let key = ("coalesce.test".to_string(), QueryType::A);
|
||||
let query = mock_query(100 + i, "coalesce.test", QueryType::A);
|
||||
let query = DnsPacket::query(100 + i, "coalesce.test", QueryType::A);
|
||||
handles.push(tokio::spawn(async move {
|
||||
resolve_coalesced(&inf, key, &query, || async {
|
||||
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
@@ -744,8 +737,8 @@ mod tests {
|
||||
let count1 = resolve_count.clone();
|
||||
let count2 = resolve_count.clone();
|
||||
|
||||
let query_a = mock_query(200, "same.domain", QueryType::A);
|
||||
let query_aaaa = mock_query(201, "same.domain", QueryType::AAAA);
|
||||
let query_a = DnsPacket::query(200, "same.domain", QueryType::A);
|
||||
let query_aaaa = DnsPacket::query(201, "same.domain", QueryType::AAAA);
|
||||
|
||||
let h1 = tokio::spawn(async move {
|
||||
resolve_coalesced(
|
||||
@@ -788,7 +781,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn inflight_map_cleaned_after_error() {
|
||||
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||
let query = mock_query(300, "will-fail.test", QueryType::A);
|
||||
let query = DnsPacket::query(300, "will-fail.test", QueryType::A);
|
||||
|
||||
let (_, path, _) = resolve_coalesced(
|
||||
&inflight,
|
||||
@@ -809,7 +802,7 @@ mod tests {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..3u16 {
|
||||
let inf = inflight.clone();
|
||||
let query = mock_query(400 + i, "fail.test", QueryType::A);
|
||||
let query = DnsPacket::query(400 + i, "fail.test", QueryType::A);
|
||||
handles.push(tokio::spawn(async move {
|
||||
resolve_coalesced(
|
||||
&inf,
|
||||
@@ -849,7 +842,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn servfail_leader_includes_question_section() {
|
||||
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||
let query = mock_query(500, "question.test", QueryType::A);
|
||||
let query = DnsPacket::query(500, "question.test", QueryType::A);
|
||||
|
||||
let (resp, _, _) = resolve_coalesced(
|
||||
&inflight,
|
||||
@@ -873,7 +866,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn leader_error_preserves_message() {
|
||||
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||
let query = mock_query(700, "err-msg.test", QueryType::A);
|
||||
let query = DnsPacket::query(700, "err-msg.test", QueryType::A);
|
||||
|
||||
let (_, path, err) = resolve_coalesced(
|
||||
&inflight,
|
||||
|
||||
@@ -141,7 +141,7 @@ mod tests {
|
||||
use std::future::IntoFuture;
|
||||
|
||||
use crate::header::ResultCode;
|
||||
use crate::question::{DnsQuestion, QueryType};
|
||||
use crate::question::QueryType;
|
||||
use crate::record::DnsRecord;
|
||||
|
||||
#[test]
|
||||
@@ -160,12 +160,7 @@ mod tests {
|
||||
}
|
||||
|
||||
fn make_query() -> DnsPacket {
|
||||
let mut q = DnsPacket::new();
|
||||
q.header.id = 0xABCD;
|
||||
q.header.recursion_desired = true;
|
||||
q.questions
|
||||
.push(DnsQuestion::new("example.com".to_string(), QueryType::A));
|
||||
q
|
||||
DnsPacket::query(0xABCD, "example.com", QueryType::A)
|
||||
}
|
||||
|
||||
fn make_response(query: &DnsPacket) -> DnsPacket {
|
||||
|
||||
@@ -57,6 +57,15 @@ impl DnsPacket {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn query(id: u16, domain: &str, qtype: crate::question::QueryType) -> DnsPacket {
|
||||
let mut pkt = DnsPacket::new();
|
||||
pkt.header.id = id;
|
||||
pkt.header.recursion_desired = true;
|
||||
pkt.questions
|
||||
.push(crate::question::DnsQuestion::new(domain.to_string(), qtype));
|
||||
pkt
|
||||
}
|
||||
|
||||
pub fn response_from(query: &DnsPacket, rescode: crate::header::ResultCode) -> DnsPacket {
|
||||
let mut resp = DnsPacket::new();
|
||||
resp.header.id = query.header.id;
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::cache::DnsCache;
|
||||
use crate::forward::forward_udp;
|
||||
use crate::header::ResultCode;
|
||||
use crate::packet::DnsPacket;
|
||||
use crate::question::{DnsQuestion, QueryType};
|
||||
use crate::question::QueryType;
|
||||
use crate::record::DnsRecord;
|
||||
use crate::srtt::SrttCache;
|
||||
|
||||
@@ -32,6 +32,14 @@ fn dns_addr(ip: impl Into<IpAddr>) -> SocketAddr {
|
||||
SocketAddr::new(ip.into(), 53)
|
||||
}
|
||||
|
||||
fn record_to_addr(rec: &DnsRecord) -> Option<SocketAddr> {
|
||||
match rec {
|
||||
DnsRecord::A { addr, .. } => Some(dns_addr(*addr)),
|
||||
DnsRecord::AAAA { addr, .. } => Some(dns_addr(*addr)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset_udp_state() {
|
||||
UDP_DISABLED.store(false, Ordering::Release);
|
||||
UDP_FAILURES.store(0, Ordering::Release);
|
||||
@@ -46,11 +54,8 @@ pub async fn probe_udp(root_hints: &[SocketAddr]) {
|
||||
Some(h) => *h,
|
||||
None => return,
|
||||
};
|
||||
let mut probe = DnsPacket::new();
|
||||
probe.header.id = next_id();
|
||||
probe
|
||||
.questions
|
||||
.push(DnsQuestion::new(".".to_string(), QueryType::NS));
|
||||
let mut probe = DnsPacket::query(next_id(), ".", QueryType::NS);
|
||||
probe.header.recursion_desired = false;
|
||||
if forward_udp(&probe, hint, Duration::from_millis(1500))
|
||||
.await
|
||||
.is_ok()
|
||||
@@ -296,17 +301,8 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
)
|
||||
.await
|
||||
{
|
||||
for rec in &ns_resp.answers {
|
||||
match rec {
|
||||
DnsRecord::A { addr, .. } => {
|
||||
new_ns_addrs.push(dns_addr(*addr));
|
||||
}
|
||||
DnsRecord::AAAA { addr, .. } => {
|
||||
new_ns_addrs.push(dns_addr(*addr));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
new_ns_addrs
|
||||
.extend(ns_resp.answers.iter().filter_map(record_to_addr));
|
||||
}
|
||||
if !new_ns_addrs.is_empty() {
|
||||
break;
|
||||
@@ -360,13 +356,7 @@ fn find_closest_ns(
|
||||
if let DnsRecord::NS { host, .. } = ns_rec {
|
||||
for qt in [QueryType::A, QueryType::AAAA] {
|
||||
if let Some(resp) = guard.lookup(host, qt) {
|
||||
for rec in &resp.answers {
|
||||
match rec {
|
||||
DnsRecord::A { addr, .. } => addrs.push(dns_addr(*addr)),
|
||||
DnsRecord::AAAA { addr, .. } => addrs.push(dns_addr(*addr)),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
addrs.extend(resp.answers.iter().filter_map(record_to_addr));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -452,13 +442,7 @@ fn addrs_from_cache(cache: &RwLock<DnsCache>, name: &str) -> Vec<SocketAddr> {
|
||||
let mut addrs = Vec::new();
|
||||
for qt in [QueryType::A, QueryType::AAAA] {
|
||||
if let Some(pkt) = guard.lookup(name, qt) {
|
||||
for rec in &pkt.answers {
|
||||
match rec {
|
||||
DnsRecord::A { addr, .. } => addrs.push(dns_addr(*addr)),
|
||||
DnsRecord::AAAA { addr, .. } => addrs.push(dns_addr(*addr)),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
addrs.extend(pkt.answers.iter().filter_map(record_to_addr));
|
||||
}
|
||||
}
|
||||
addrs
|
||||
@@ -468,15 +452,13 @@ fn glue_addrs_for(response: &DnsPacket, ns_name: &str) -> Vec<SocketAddr> {
|
||||
response
|
||||
.resources
|
||||
.iter()
|
||||
.filter_map(|r| match r {
|
||||
DnsRecord::A { domain, addr, .. } if domain.eq_ignore_ascii_case(ns_name) => {
|
||||
Some(dns_addr(*addr))
|
||||
.filter(|r| match r {
|
||||
DnsRecord::A { domain, .. } | DnsRecord::AAAA { domain, .. } => {
|
||||
domain.eq_ignore_ascii_case(ns_name)
|
||||
}
|
||||
DnsRecord::AAAA { domain, addr, .. } if domain.eq_ignore_ascii_case(ns_name) => {
|
||||
Some(dns_addr(*addr))
|
||||
}
|
||||
_ => None,
|
||||
_ => false,
|
||||
})
|
||||
.filter_map(record_to_addr)
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -596,12 +578,8 @@ async fn send_query(
|
||||
server: SocketAddr,
|
||||
srtt: &RwLock<SrttCache>,
|
||||
) -> crate::Result<DnsPacket> {
|
||||
let mut query = DnsPacket::new();
|
||||
query.header.id = next_id();
|
||||
let mut query = DnsPacket::query(next_id(), qname, qtype);
|
||||
query.header.recursion_desired = false;
|
||||
query
|
||||
.questions
|
||||
.push(DnsQuestion::new(qname.to_string(), qtype));
|
||||
query.edns = Some(crate::packet::EdnsOpt {
|
||||
do_bit: true,
|
||||
..Default::default()
|
||||
@@ -1056,11 +1034,7 @@ mod tests {
|
||||
})
|
||||
.await;
|
||||
|
||||
let mut query = DnsPacket::new();
|
||||
query.header.id = 0xBEEF;
|
||||
query
|
||||
.questions
|
||||
.push(DnsQuestion::new("test.com".to_string(), QueryType::A));
|
||||
let query = DnsPacket::query(0xBEEF, "test.com", QueryType::A);
|
||||
|
||||
let resp = crate::forward::forward_tcp(&query, server_addr, Duration::from_secs(2))
|
||||
.await
|
||||
@@ -1120,11 +1094,7 @@ mod tests {
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let mut query = DnsPacket::new();
|
||||
query.header.id = 0xCAFE;
|
||||
query
|
||||
.questions
|
||||
.push(DnsQuestion::new("strict.test".to_string(), QueryType::A));
|
||||
let query = DnsPacket::query(0xCAFE, "strict.test", QueryType::A);
|
||||
|
||||
let resp = crate::forward::forward_tcp(&query, addr, Duration::from_secs(2))
|
||||
.await
|
||||
|
||||
Reference in New Issue
Block a user