refactor: extract resolve_coalesced, rewrite tests against real code

Extract Disposition enum, acquire_inflight(), and resolve_coalesced()
from handle_query so coalescing logic is independently testable. Rewrite
integration tests to call resolve_coalesced directly with mock futures
instead of fighting the iterative resolver's NS chain. All 12 coalescing
tests now exercise production code paths, not tokio primitives.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-03-29 10:55:46 +03:00
parent 882508297e
commit e0b0c2bda9
2 changed files with 192 additions and 225 deletions

View File

@@ -178,62 +178,23 @@ pub async fn handle_query(
(resp, QueryPath::Cached, cached_dnssec) (resp, QueryPath::Cached, cached_dnssec)
} else if ctx.upstream_mode == UpstreamMode::Recursive { } else if ctx.upstream_mode == UpstreamMode::Recursive {
let key = (qname.clone(), qtype); let key = (qname.clone(), qtype);
let disposition = acquire_inflight(&ctx.inflight, key.clone()); let (resp, path) = resolve_coalesced(&ctx.inflight, key, query.header.id, || {
crate::recursive::resolve_recursive(
match disposition { &qname,
Disposition::Follower(mut rx) => { qtype,
debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname); &ctx.cache,
match rx.recv().await { &query,
Ok(Some(mut resp)) => { &ctx.root_hints,
resp.header.id = query.header.id; &ctx.srtt,
(resp, QueryPath::Coalesced, DnssecStatus::Indeterminate) )
} })
_ => ( .await;
DnsPacket::response_from(&query, ResultCode::SERVFAIL), if path == QueryPath::Coalesced {
QueryPath::UpstreamError, debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname);
DnssecStatus::Indeterminate, } else if path == QueryPath::UpstreamError {
), error!("{} | {:?} {} | RECURSIVE ERROR", src_addr, qtype, qname);
}
}
Disposition::Leader(tx) => {
// Drop guard: remove inflight entry even on panic/cancellation
let guard = InflightGuard {
inflight: &ctx.inflight,
key: key.clone(),
};
let result = crate::recursive::resolve_recursive(
&qname,
qtype,
&ctx.cache,
&query,
&ctx.root_hints,
&ctx.srtt,
)
.await;
drop(guard);
match result {
Ok(resp) => {
let _ = tx.send(Some(resp.clone()));
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
}
Err(e) => {
let _ = tx.send(None);
error!(
"{} | {:?} {} | RECURSIVE ERROR | {}",
src_addr, qtype, qname, e
);
(
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
DnssecStatus::Indeterminate,
)
}
}
}
} }
(resp, path, DnssecStatus::Indeterminate)
} else { } else {
let upstream = let upstream =
match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) { match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) {
@@ -432,6 +393,58 @@ fn acquire_inflight(inflight: &Mutex<InflightMap>, key: (String, QueryType)) ->
} }
} }
/// Run a resolve function with in-flight coalescing. Multiple concurrent calls
/// for the same key share a single resolution — the first caller (leader)
/// executes `resolve_fn`, and followers wait for the broadcast result.
async fn resolve_coalesced<F, Fut>(
inflight: &Mutex<InflightMap>,
key: (String, QueryType),
query_id: u16,
resolve_fn: F,
) -> (DnsPacket, QueryPath)
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = crate::Result<DnsPacket>>,
{
let disposition = acquire_inflight(inflight, key.clone());
match disposition {
Disposition::Follower(mut rx) => match rx.recv().await {
Ok(Some(mut resp)) => {
resp.header.id = query_id;
(resp, QueryPath::Coalesced)
}
_ => {
let mut resp = DnsPacket::new();
resp.header.id = query_id;
resp.header.response = true;
resp.header.rescode = ResultCode::SERVFAIL;
(resp, QueryPath::UpstreamError)
}
},
Disposition::Leader(tx) => {
let guard = InflightGuard { inflight, key };
let result = resolve_fn().await;
drop(guard);
match result {
Ok(resp) => {
let _ = tx.send(Some(resp.clone()));
(resp, QueryPath::Recursive)
}
Err(_) => {
let _ = tx.send(None);
let mut resp = DnsPacket::new();
resp.header.id = query_id;
resp.header.response = true;
resp.header.rescode = ResultCode::SERVFAIL;
(resp, QueryPath::UpstreamError)
}
}
}
}
}
struct InflightGuard<'a> { struct InflightGuard<'a> {
inflight: &'a Mutex<InflightMap>, inflight: &'a Mutex<InflightMap>,
key: (String, QueryType), key: (String, QueryType),
@@ -443,20 +456,6 @@ impl Drop for InflightGuard<'_> {
} }
} }
/// Build a wire-format DNS query packet for the given domain and type.
#[cfg(test)]
fn build_wire_query(id: u16, domain: &str, qtype: QueryType) -> BytePacketBuffer {
let mut pkt = DnsPacket::new();
pkt.header.id = id;
pkt.header.recursion_desired = true;
pkt.header.questions = 1;
pkt.questions
.push(crate::question::DnsQuestion::new(domain.to_string(), qtype));
let mut buf = BytePacketBuffer::new();
pkt.write(&mut buf).unwrap();
BytePacketBuffer::from_bytes(buf.filled())
}
fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket { fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket {
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
if qname == "ipv4only.arpa" { if qname == "ipv4only.arpa" {
@@ -495,8 +494,8 @@ fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> Dns
mod tests { mod tests {
use super::*; use super::*;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{Ipv4Addr, SocketAddr}; use std::net::Ipv4Addr;
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex};
use tokio::sync::broadcast; use tokio::sync::broadcast;
// ---- InflightGuard unit tests ---- // ---- InflightGuard unit tests ----
@@ -669,189 +668,157 @@ mod tests {
} }
} }
// ---- Integration: concurrent handle_query coalescing ---- // ---- Integration: resolve_coalesced with mock futures ----
use tokio::io::{AsyncReadExt, AsyncWriteExt}; // ---- Integration: resolve_coalesced with mock futures ----
use tokio::net::TcpListener;
/// Spawn a slow TCP DNS server that delays `delay` before responding. fn mock_response(domain: &str) -> DnsPacket {
/// Returns (addr, query_count) where query_count is an Arc<AtomicU32> let mut resp = DnsPacket::new();
/// tracking how many queries were actually resolved (not coalesced). resp.header.response = true;
async fn spawn_slow_dns_server( resp.header.rescode = ResultCode::NOERROR;
delay: Duration, resp.answers.push(DnsRecord::A {
) -> (SocketAddr, Arc<std::sync::atomic::AtomicU32>) { domain: domain.to_string(),
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); addr: Ipv4Addr::new(10, 0, 0, 1),
let addr = listener.local_addr().unwrap(); ttl: 300,
let count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let count_clone = count.clone();
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(c) => c,
Err(_) => break,
};
let count = count_clone.clone();
let delay = delay;
tokio::spawn(async move {
let mut len_buf = [0u8; 2];
if stream.read_exact(&mut len_buf).await.is_err() {
return;
}
let len = u16::from_be_bytes(len_buf) as usize;
let mut data = vec![0u8; len];
if stream.read_exact(&mut data).await.is_err() {
return;
}
let mut buf = BytePacketBuffer::from_bytes(&data);
let query = match DnsPacket::from_buffer(&mut buf) {
Ok(q) => q,
Err(_) => return,
};
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
// Deliberate delay to create coalescing window
tokio::time::sleep(delay).await;
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.header.authoritative_answer = true;
if let Some(q) = query.questions.first() {
resp.answers.push(DnsRecord::A {
domain: q.name.clone(),
addr: Ipv4Addr::new(10, 0, 0, 1),
ttl: 300,
});
}
let mut resp_buf = BytePacketBuffer::new();
if resp.write(&mut resp_buf).is_err() {
return;
}
let resp_bytes = resp_buf.filled();
let mut out = Vec::with_capacity(2 + resp_bytes.len());
out.extend_from_slice(&(resp_bytes.len() as u16).to_be_bytes());
out.extend_from_slice(resp_bytes);
let _ = stream.write_all(&out).await;
});
}
}); });
(addr, count) resp
}
async fn test_recursive_ctx(root_hint: SocketAddr) -> Arc<ServerCtx> {
let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
Arc::new(ServerCtx {
socket,
zone_map: HashMap::new(),
cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)),
stats: Mutex::new(crate::stats::ServerStats::new()),
overrides: RwLock::new(crate::override_store::OverrideStore::new()),
blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()),
query_log: Mutex::new(crate::query_log::QueryLog::new(100)),
services: Mutex::new(crate::service_store::ServiceStore::new()),
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
forwarding_rules: Vec::new(),
upstream: Mutex::new(crate::forward::Upstream::Udp(
"127.0.0.1:53".parse().unwrap(),
)),
upstream_auto: false,
upstream_port: 53,
lan_ip: Mutex::new(Ipv4Addr::LOCALHOST),
timeout: Duration::from_secs(3),
proxy_tld: "numa".to_string(),
proxy_tld_suffix: ".numa".to_string(),
lan_enabled: false,
config_path: "/tmp/test-numa.toml".to_string(),
config_found: false,
config_dir: std::path::PathBuf::from("/tmp"),
data_dir: std::path::PathBuf::from("/tmp"),
tls_config: None,
upstream_mode: crate::config::UpstreamMode::Recursive,
root_hints: vec![root_hint],
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
inflight: Mutex::new(HashMap::new()),
dnssec_enabled: false,
dnssec_strict: false,
})
} }
#[tokio::test] #[tokio::test]
async fn concurrent_queries_coalesce_to_single_resolution() { async fn concurrent_queries_coalesce_to_single_resolution() {
// Force TCP-only so mock server works let inflight = Arc::new(Mutex::new(HashMap::new()));
crate::recursive::UDP_DISABLED.store(true, std::sync::atomic::Ordering::Release); let resolve_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let (server_addr, query_count) = spawn_slow_dns_server(Duration::from_millis(200)).await;
let ctx = test_recursive_ctx(server_addr).await;
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap();
// Fire 5 concurrent queries for the same (domain, A)
let mut handles = Vec::new(); let mut handles = Vec::new();
for i in 0..5u16 { for i in 0..5u16 {
let ctx = ctx.clone(); let count = resolve_count.clone();
let buf = build_wire_query(100 + i, "coalesce-test.example.com", QueryType::A); let inf = inflight.clone();
handles.push(tokio::spawn( let key = ("coalesce.test".to_string(), QueryType::A);
async move { handle_query(buf, src, &ctx).await }, handles.push(tokio::spawn(async move {
)); resolve_coalesced(&inf, key, 100 + i, || async {
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(200)).await;
Ok(mock_response("coalesce.test"))
})
.await
}));
} }
let mut paths = Vec::new();
for h in handles { for h in handles {
h.await.unwrap().unwrap(); let (_, path) = h.await.unwrap();
paths.push(path);
} }
// Only 1 resolution should have reached the upstream server let actual = resolve_count.load(std::sync::atomic::Ordering::Relaxed);
let actual = query_count.load(std::sync::atomic::Ordering::Relaxed); assert_eq!(actual, 1, "expected 1 resolution, got {}", actual);
assert_eq!(actual, 1, "expected 1 upstream query, got {}", actual);
// Inflight map must be empty after all queries complete let recursive = paths.iter().filter(|p| **p == QueryPath::Recursive).count();
assert!(ctx.inflight.lock().unwrap().is_empty()); let coalesced = paths.iter().filter(|p| **p == QueryPath::Coalesced).count();
assert_eq!(recursive, 1, "expected 1 RECURSIVE, got {}", recursive);
assert_eq!(coalesced, 4, "expected 4 COALESCED, got {}", coalesced);
crate::recursive::reset_udp_state(); assert!(inflight.lock().unwrap().is_empty());
} }
#[tokio::test] #[tokio::test]
async fn different_qtypes_not_coalesced() { async fn different_qtypes_not_coalesced() {
crate::recursive::UDP_DISABLED.store(true, std::sync::atomic::Ordering::Release); let inflight = Arc::new(Mutex::new(HashMap::new()));
let resolve_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let (server_addr, query_count) = spawn_slow_dns_server(Duration::from_millis(100)).await; let inf1 = inflight.clone();
let ctx = test_recursive_ctx(server_addr).await; let inf2 = inflight.clone();
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap(); let count1 = resolve_count.clone();
let count2 = resolve_count.clone();
// Fire A and AAAA concurrently — should NOT coalesce let h1 = tokio::spawn(async move {
let ctx_ref = ctx.clone(); resolve_coalesced(
let ctx_ref2 = ctx.clone(); &inf1,
let buf_a = build_wire_query(200, "different-qt.example.com", QueryType::A); ("same.domain".to_string(), QueryType::A),
let buf_aaaa = build_wire_query(201, "different-qt.example.com", QueryType::AAAA); 200,
|| async {
count1.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(mock_response("same.domain"))
},
)
.await
});
let h2 = tokio::spawn(async move {
resolve_coalesced(
&inf2,
("same.domain".to_string(), QueryType::AAAA),
201,
|| async {
count2.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(mock_response("same.domain"))
},
)
.await
});
let h1 = tokio::spawn(async move { handle_query(buf_a, src, &ctx_ref).await }); let (_, path1) = h1.await.unwrap();
let h2 = tokio::spawn(async move { handle_query(buf_aaaa, src, &ctx_ref2).await }); let (_, path2) = h2.await.unwrap();
h1.await.unwrap().unwrap(); let actual = resolve_count.load(std::sync::atomic::Ordering::Relaxed);
h2.await.unwrap().unwrap(); assert_eq!(actual, 2, "A and AAAA should each resolve, got {}", actual);
assert_eq!(path1, QueryPath::Recursive);
assert_eq!(path2, QueryPath::Recursive);
let actual = query_count.load(std::sync::atomic::Ordering::Relaxed); assert!(inflight.lock().unwrap().is_empty());
assert!(
actual >= 2,
"A and AAAA should resolve independently, got {}",
actual
);
assert!(ctx.inflight.lock().unwrap().is_empty());
crate::recursive::reset_udp_state();
} }
#[tokio::test] #[tokio::test]
async fn inflight_map_cleaned_after_upstream_error() { async fn inflight_map_cleaned_after_error() {
// Server that rejects everything — no server running at all let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
let bogus_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let ctx = test_recursive_ctx(bogus_addr).await;
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let buf = build_wire_query(300, "will-fail.example.com", QueryType::A); let (_, path) = resolve_coalesced(
let _ = handle_query(buf, src, &ctx).await; &inflight,
("will-fail.test".to_string(), QueryType::A),
300,
|| async { Err::<DnsPacket, _>("upstream timeout".into()) },
)
.await;
// Map must be clean even after error assert_eq!(path, QueryPath::UpstreamError);
assert!(ctx.inflight.lock().unwrap().is_empty()); assert!(inflight.lock().unwrap().is_empty());
}
#[tokio::test]
async fn follower_gets_servfail_when_leader_fails() {
let inflight = Arc::new(Mutex::new(HashMap::new()));
let mut handles = Vec::new();
for i in 0..3u16 {
let inf = inflight.clone();
handles.push(tokio::spawn(async move {
resolve_coalesced(
&inf,
("fail.test".to_string(), QueryType::A),
400 + i,
|| async {
tokio::time::sleep(Duration::from_millis(200)).await;
Err::<DnsPacket, _>("upstream error".into())
},
)
.await
}));
}
let mut paths = Vec::new();
for h in handles {
let (resp, path) = h.await.unwrap();
assert_eq!(resp.header.rescode, ResultCode::SERVFAIL);
paths.push(path);
}
let errors = paths
.iter()
.filter(|p| **p == QueryPath::UpstreamError)
.count();
assert_eq!(errors, 3, "all 3 should be UpstreamError, got {}", errors);
assert!(inflight.lock().unwrap().is_empty());
} }
} }

View File

@@ -13,7 +13,7 @@ pub struct ServerStats {
started_at: Instant, started_at: Instant,
} }
#[derive(Clone, Copy, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum QueryPath { pub enum QueryPath {
Local, Local,
Cached, Cached,