diff --git a/src/ctx.rs b/src/ctx.rs index 5aee946..fbddb15 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -178,62 +178,29 @@ pub async fn handle_query( (resp, QueryPath::Cached, cached_dnssec) } else if ctx.upstream_mode == UpstreamMode::Recursive { let key = (qname.clone(), qtype); - let disposition = acquire_inflight(&ctx.inflight, key.clone()); - - match disposition { - Disposition::Follower(mut rx) => { - debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname); - match rx.recv().await { - Ok(Some(mut resp)) => { - resp.header.id = query.header.id; - (resp, QueryPath::Coalesced, DnssecStatus::Indeterminate) - } - _ => ( - DnsPacket::response_from(&query, ResultCode::SERVFAIL), - QueryPath::UpstreamError, - DnssecStatus::Indeterminate, - ), - } - } - Disposition::Leader(tx) => { - // Drop guard: remove inflight entry even on panic/cancellation - let guard = InflightGuard { - inflight: &ctx.inflight, - key: key.clone(), - }; - - let result = crate::recursive::resolve_recursive( - &qname, - qtype, - &ctx.cache, - &query, - &ctx.root_hints, - &ctx.srtt, - ) - .await; - - drop(guard); - - match result { - Ok(resp) => { - let _ = tx.send(Some(resp.clone())); - (resp, QueryPath::Recursive, DnssecStatus::Indeterminate) - } - Err(e) => { - let _ = tx.send(None); - error!( - "{} | {:?} {} | RECURSIVE ERROR | {}", - src_addr, qtype, qname, e - ); - ( - DnsPacket::response_from(&query, ResultCode::SERVFAIL), - QueryPath::UpstreamError, - DnssecStatus::Indeterminate, - ) - } - } - } + let (resp, path, err) = resolve_coalesced(&ctx.inflight, key, &query, || { + crate::recursive::resolve_recursive( + &qname, + qtype, + &ctx.cache, + &query, + &ctx.root_hints, + &ctx.srtt, + ) + }) + .await; + if path == QueryPath::Coalesced { + debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname); + } else if path == QueryPath::UpstreamError { + error!( + "{} | {:?} {} | RECURSIVE ERROR | {}", + src_addr, + qtype, + qname, + err.as_deref().unwrap_or("leader failed") + ); } + (resp, path, DnssecStatus::Indeterminate) } else { let upstream = match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) { @@ -432,6 +399,57 @@ fn acquire_inflight(inflight: &Mutex, key: (String, QueryType)) -> } } +/// Run a resolve function with in-flight coalescing. Multiple concurrent calls +/// for the same key share a single resolution — the first caller (leader) +/// executes `resolve_fn`, and followers wait for the broadcast result. +async fn resolve_coalesced( + inflight: &Mutex, + key: (String, QueryType), + query: &DnsPacket, + resolve_fn: F, +) -> (DnsPacket, QueryPath, Option) +where + F: FnOnce() -> Fut, + Fut: std::future::Future>, +{ + let disposition = acquire_inflight(inflight, key.clone()); + + match disposition { + Disposition::Follower(mut rx) => match rx.recv().await { + Ok(Some(mut resp)) => { + resp.header.id = query.header.id; + (resp, QueryPath::Coalesced, None) + } + _ => ( + DnsPacket::response_from(query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + None, + ), + }, + Disposition::Leader(tx) => { + let guard = InflightGuard { inflight, key }; + let result = resolve_fn().await; + drop(guard); + + match result { + Ok(resp) => { + let _ = tx.send(Some(resp.clone())); + (resp, QueryPath::Recursive, None) + } + Err(e) => { + let _ = tx.send(None); + let err_msg = e.to_string(); + ( + DnsPacket::response_from(query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + Some(err_msg), + ) + } + } + } + } +} + struct InflightGuard<'a> { inflight: &'a Mutex, key: (String, QueryType), @@ -443,20 +461,6 @@ impl Drop for InflightGuard<'_> { } } -/// Build a wire-format DNS query packet for the given domain and type. -#[cfg(test)] -fn build_wire_query(id: u16, domain: &str, qtype: QueryType) -> BytePacketBuffer { - let mut pkt = DnsPacket::new(); - pkt.header.id = id; - pkt.header.recursion_desired = true; - pkt.header.questions = 1; - pkt.questions - .push(crate::question::DnsQuestion::new(domain.to_string(), qtype)); - let mut buf = BytePacketBuffer::new(); - pkt.write(&mut buf).unwrap(); - BytePacketBuffer::from_bytes(buf.filled()) -} - fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket { use std::net::{Ipv4Addr, Ipv6Addr}; if qname == "ipv4only.arpa" { @@ -495,8 +499,8 @@ fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> Dns mod tests { use super::*; use std::collections::HashMap; - use std::net::{Ipv4Addr, SocketAddr}; - use std::sync::{Arc, Mutex, RwLock}; + use std::net::Ipv4Addr; + use std::sync::{Arc, Mutex}; use tokio::sync::broadcast; // ---- InflightGuard unit tests ---- @@ -669,189 +673,221 @@ mod tests { } } - // ---- Integration: concurrent handle_query coalescing ---- + // ---- Integration: resolve_coalesced with mock futures ---- - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use tokio::net::TcpListener; - - /// Spawn a slow TCP DNS server that delays `delay` before responding. - /// Returns (addr, query_count) where query_count is an Arc - /// tracking how many queries were actually resolved (not coalesced). - async fn spawn_slow_dns_server( - delay: Duration, - ) -> (SocketAddr, Arc) { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let count = Arc::new(std::sync::atomic::AtomicU32::new(0)); - let count_clone = count.clone(); - - tokio::spawn(async move { - loop { - let (mut stream, _) = match listener.accept().await { - Ok(c) => c, - Err(_) => break, - }; - let count = count_clone.clone(); - let delay = delay; - tokio::spawn(async move { - let mut len_buf = [0u8; 2]; - if stream.read_exact(&mut len_buf).await.is_err() { - return; - } - let len = u16::from_be_bytes(len_buf) as usize; - let mut data = vec![0u8; len]; - if stream.read_exact(&mut data).await.is_err() { - return; - } - - let mut buf = BytePacketBuffer::from_bytes(&data); - let query = match DnsPacket::from_buffer(&mut buf) { - Ok(q) => q, - Err(_) => return, - }; - - count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - - // Deliberate delay to create coalescing window - tokio::time::sleep(delay).await; - - let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); - resp.header.authoritative_answer = true; - if let Some(q) = query.questions.first() { - resp.answers.push(DnsRecord::A { - domain: q.name.clone(), - addr: Ipv4Addr::new(10, 0, 0, 1), - ttl: 300, - }); - } - - let mut resp_buf = BytePacketBuffer::new(); - if resp.write(&mut resp_buf).is_err() { - return; - } - let resp_bytes = resp_buf.filled(); - let mut out = Vec::with_capacity(2 + resp_bytes.len()); - out.extend_from_slice(&(resp_bytes.len() as u16).to_be_bytes()); - out.extend_from_slice(resp_bytes); - let _ = stream.write_all(&out).await; - }); - } - }); - (addr, count) + fn mock_query(id: u16, domain: &str, qtype: QueryType) -> DnsPacket { + let mut pkt = DnsPacket::new(); + pkt.header.id = id; + pkt.header.recursion_desired = true; + pkt.questions + .push(crate::question::DnsQuestion::new(domain.to_string(), qtype)); + pkt } - async fn test_recursive_ctx(root_hint: SocketAddr) -> Arc { - let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap(); - Arc::new(ServerCtx { - socket, - zone_map: HashMap::new(), - cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)), - stats: Mutex::new(crate::stats::ServerStats::new()), - overrides: RwLock::new(crate::override_store::OverrideStore::new()), - blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()), - query_log: Mutex::new(crate::query_log::QueryLog::new(100)), - services: Mutex::new(crate::service_store::ServiceStore::new()), - lan_peers: Mutex::new(crate::lan::PeerStore::new(90)), - forwarding_rules: Vec::new(), - upstream: Mutex::new(crate::forward::Upstream::Udp( - "127.0.0.1:53".parse().unwrap(), - )), - upstream_auto: false, - upstream_port: 53, - lan_ip: Mutex::new(Ipv4Addr::LOCALHOST), - timeout: Duration::from_secs(3), - proxy_tld: "numa".to_string(), - proxy_tld_suffix: ".numa".to_string(), - lan_enabled: false, - config_path: "/tmp/test-numa.toml".to_string(), - config_found: false, - config_dir: std::path::PathBuf::from("/tmp"), - data_dir: std::path::PathBuf::from("/tmp"), - tls_config: None, - upstream_mode: crate::config::UpstreamMode::Recursive, - root_hints: vec![root_hint], - srtt: RwLock::new(crate::srtt::SrttCache::new(true)), - inflight: Mutex::new(HashMap::new()), - dnssec_enabled: false, - dnssec_strict: false, - }) + fn mock_response(domain: &str) -> DnsPacket { + let mut resp = DnsPacket::new(); + resp.header.response = true; + resp.header.rescode = ResultCode::NOERROR; + resp.answers.push(DnsRecord::A { + domain: domain.to_string(), + addr: Ipv4Addr::new(10, 0, 0, 1), + ttl: 300, + }); + resp } #[tokio::test] async fn concurrent_queries_coalesce_to_single_resolution() { - // Force TCP-only so mock server works - crate::recursive::UDP_DISABLED.store(true, std::sync::atomic::Ordering::Release); + let inflight = Arc::new(Mutex::new(HashMap::new())); + let resolve_count = Arc::new(std::sync::atomic::AtomicU32::new(0)); - let (server_addr, query_count) = spawn_slow_dns_server(Duration::from_millis(200)).await; - let ctx = test_recursive_ctx(server_addr).await; - let src: SocketAddr = "127.0.0.1:9999".parse().unwrap(); - - // Fire 5 concurrent queries for the same (domain, A) let mut handles = Vec::new(); for i in 0..5u16 { - let ctx = ctx.clone(); - let buf = build_wire_query(100 + i, "coalesce-test.example.com", QueryType::A); - handles.push(tokio::spawn( - async move { handle_query(buf, src, &ctx).await }, - )); + let count = resolve_count.clone(); + let inf = inflight.clone(); + let key = ("coalesce.test".to_string(), QueryType::A); + let query = mock_query(100 + i, "coalesce.test", QueryType::A); + handles.push(tokio::spawn(async move { + resolve_coalesced(&inf, key, &query, || async { + count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + tokio::time::sleep(Duration::from_millis(200)).await; + Ok(mock_response("coalesce.test")) + }) + .await + })); } + let mut paths = Vec::new(); for h in handles { - h.await.unwrap().unwrap(); + let (_, path, _) = h.await.unwrap(); + paths.push(path); } - // Only 1 resolution should have reached the upstream server - let actual = query_count.load(std::sync::atomic::Ordering::Relaxed); - assert_eq!(actual, 1, "expected 1 upstream query, got {}", actual); + let actual = resolve_count.load(std::sync::atomic::Ordering::Relaxed); + assert_eq!(actual, 1, "expected 1 resolution, got {}", actual); - // Inflight map must be empty after all queries complete - assert!(ctx.inflight.lock().unwrap().is_empty()); + let recursive = paths.iter().filter(|p| **p == QueryPath::Recursive).count(); + let coalesced = paths.iter().filter(|p| **p == QueryPath::Coalesced).count(); + assert_eq!(recursive, 1, "expected 1 RECURSIVE, got {}", recursive); + assert_eq!(coalesced, 4, "expected 4 COALESCED, got {}", coalesced); - crate::recursive::reset_udp_state(); + assert!(inflight.lock().unwrap().is_empty()); } #[tokio::test] async fn different_qtypes_not_coalesced() { - crate::recursive::UDP_DISABLED.store(true, std::sync::atomic::Ordering::Release); + let inflight = Arc::new(Mutex::new(HashMap::new())); + let resolve_count = Arc::new(std::sync::atomic::AtomicU32::new(0)); - let (server_addr, query_count) = spawn_slow_dns_server(Duration::from_millis(100)).await; - let ctx = test_recursive_ctx(server_addr).await; - let src: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + let inf1 = inflight.clone(); + let inf2 = inflight.clone(); + let count1 = resolve_count.clone(); + let count2 = resolve_count.clone(); - // Fire A and AAAA concurrently — should NOT coalesce - let ctx_ref = ctx.clone(); - let ctx_ref2 = ctx.clone(); - let buf_a = build_wire_query(200, "different-qt.example.com", QueryType::A); - let buf_aaaa = build_wire_query(201, "different-qt.example.com", QueryType::AAAA); + let query_a = mock_query(200, "same.domain", QueryType::A); + let query_aaaa = mock_query(201, "same.domain", QueryType::AAAA); - let h1 = tokio::spawn(async move { handle_query(buf_a, src, &ctx_ref).await }); - let h2 = tokio::spawn(async move { handle_query(buf_aaaa, src, &ctx_ref2).await }); + let h1 = tokio::spawn(async move { + resolve_coalesced( + &inf1, + ("same.domain".to_string(), QueryType::A), + &query_a, + || async { + count1.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + tokio::time::sleep(Duration::from_millis(100)).await; + Ok(mock_response("same.domain")) + }, + ) + .await + }); + let h2 = tokio::spawn(async move { + resolve_coalesced( + &inf2, + ("same.domain".to_string(), QueryType::AAAA), + &query_aaaa, + || async { + count2.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + tokio::time::sleep(Duration::from_millis(100)).await; + Ok(mock_response("same.domain")) + }, + ) + .await + }); - h1.await.unwrap().unwrap(); - h2.await.unwrap().unwrap(); + let (_, path1, _) = h1.await.unwrap(); + let (_, path2, _) = h2.await.unwrap(); - let actual = query_count.load(std::sync::atomic::Ordering::Relaxed); - assert!( - actual >= 2, - "A and AAAA should resolve independently, got {}", - actual - ); - assert!(ctx.inflight.lock().unwrap().is_empty()); + let actual = resolve_count.load(std::sync::atomic::Ordering::Relaxed); + assert_eq!(actual, 2, "A and AAAA should each resolve, got {}", actual); + assert_eq!(path1, QueryPath::Recursive); + assert_eq!(path2, QueryPath::Recursive); - crate::recursive::reset_udp_state(); + assert!(inflight.lock().unwrap().is_empty()); } #[tokio::test] - async fn inflight_map_cleaned_after_upstream_error() { - // Server that rejects everything — no server running at all - let bogus_addr: SocketAddr = "127.0.0.1:1".parse().unwrap(); - let ctx = test_recursive_ctx(bogus_addr).await; - let src: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + async fn inflight_map_cleaned_after_error() { + let inflight: Mutex = Mutex::new(HashMap::new()); + let query = mock_query(300, "will-fail.test", QueryType::A); - let buf = build_wire_query(300, "will-fail.example.com", QueryType::A); - let _ = handle_query(buf, src, &ctx).await; + let (_, path, _) = resolve_coalesced( + &inflight, + ("will-fail.test".to_string(), QueryType::A), + &query, + || async { Err::("upstream timeout".into()) }, + ) + .await; - // Map must be clean even after error - assert!(ctx.inflight.lock().unwrap().is_empty()); + assert_eq!(path, QueryPath::UpstreamError); + assert!(inflight.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn follower_gets_servfail_when_leader_fails() { + let inflight = Arc::new(Mutex::new(HashMap::new())); + + let mut handles = Vec::new(); + for i in 0..3u16 { + let inf = inflight.clone(); + let query = mock_query(400 + i, "fail.test", QueryType::A); + handles.push(tokio::spawn(async move { + resolve_coalesced( + &inf, + ("fail.test".to_string(), QueryType::A), + &query, + || async { + tokio::time::sleep(Duration::from_millis(200)).await; + Err::("upstream error".into()) + }, + ) + .await + })); + } + + let mut paths = Vec::new(); + for h in handles { + let (resp, path, _) = h.await.unwrap(); + assert_eq!(resp.header.rescode, ResultCode::SERVFAIL); + assert_eq!( + resp.questions.len(), + 1, + "SERVFAIL must echo question section" + ); + assert_eq!(resp.questions[0].name, "fail.test"); + paths.push(path); + } + + let errors = paths + .iter() + .filter(|p| **p == QueryPath::UpstreamError) + .count(); + assert_eq!(errors, 3, "all 3 should be UpstreamError, got {}", errors); + + assert!(inflight.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn servfail_leader_includes_question_section() { + let inflight: Mutex = Mutex::new(HashMap::new()); + let query = mock_query(500, "question.test", QueryType::A); + + let (resp, _, _) = resolve_coalesced( + &inflight, + ("question.test".to_string(), QueryType::A), + &query, + || async { Err::("fail".into()) }, + ) + .await; + + assert_eq!(resp.header.rescode, ResultCode::SERVFAIL); + assert_eq!( + resp.questions.len(), + 1, + "SERVFAIL must echo question section" + ); + assert_eq!(resp.questions[0].name, "question.test"); + assert_eq!(resp.questions[0].qtype, QueryType::A); + assert_eq!(resp.header.id, 500); + } + + #[tokio::test] + async fn leader_error_preserves_message() { + let inflight: Mutex = Mutex::new(HashMap::new()); + let query = mock_query(700, "err-msg.test", QueryType::A); + + let (_, path, err) = resolve_coalesced( + &inflight, + ("err-msg.test".to_string(), QueryType::A), + &query, + || async { Err::("connection refused by upstream".into()) }, + ) + .await; + + assert_eq!(path, QueryPath::UpstreamError); + assert_eq!( + err.as_deref(), + Some("connection refused by upstream"), + "error message must be preserved for logging" + ); } } diff --git a/src/stats.rs b/src/stats.rs index ea438e9..67ac56d 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -13,7 +13,7 @@ pub struct ServerStats { started_at: Instant, } -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum QueryPath { Local, Cached,