From 1028c1a48d0fddb9e2afff942bfdd221e2f57851 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 29 Mar 2026 08:54:04 +0300 Subject: [PATCH] feat: in-flight query coalescing for recursive resolver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When multiple queries for the same (domain, qtype) arrive concurrently and all miss the cache, only the first triggers recursive resolution. Subsequent queries wait on a broadcast channel for the result. Prevents thundering herd where N concurrent cache misses each independently walk the full NS chain, compounding timeouts. Uses InflightGuard (Drop impl) to guarantee map cleanup on panic/cancellation — prevents permanent SERVFAIL poisoning. Co-Authored-By: Claude Opus 4.6 (1M context) --- Cargo.toml | 2 +- src/api.rs | 1 + src/ctx.rs | 103 ++++++++++++++++++++++++++++++++++++++++++---------- src/main.rs | 1 + 4 files changed, 86 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 62322ff..89c6cfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ keywords = ["dns", "dns-server", "ad-blocking", "reverse-proxy", "developer-tool categories = ["network-programming", "development-tools"] [dependencies] -tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] } +tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time", "sync"] } axum = "0.8" serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/src/api.rs b/src/api.rs index 9826a3a..2fcb277 100644 --- a/src/api.rs +++ b/src/api.rs @@ -953,6 +953,7 @@ mod tests { upstream_mode: crate::config::UpstreamMode::Forward, root_hints: Vec::new(), srtt: RwLock::new(crate::srtt::SrttCache::new(true)), + inflight: Mutex::new(std::collections::HashMap::new()), dnssec_enabled: false, dnssec_strict: false, }) diff --git a/src/ctx.rs b/src/ctx.rs index 2a527ee..c102ab9 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::net::SocketAddr; use std::path::PathBuf; use std::sync::{Mutex, RwLock}; @@ -7,6 +8,7 @@ use arc_swap::ArcSwap; use log::{debug, error, info, warn}; use rustls::ServerConfig; use tokio::net::UdpSocket; +use tokio::sync::broadcast; use crate::blocklist::BlocklistStore; use crate::buffer::BytePacketBuffer; @@ -53,6 +55,7 @@ pub struct ServerCtx { pub upstream_mode: UpstreamMode, pub root_hints: Vec, pub srtt: RwLock, + pub inflight: Mutex>>>, pub dnssec_enabled: bool, pub dnssec_strict: bool, } @@ -172,27 +175,76 @@ pub async fn handle_query( } (resp, QueryPath::Cached, cached_dnssec) } else if ctx.upstream_mode == UpstreamMode::Recursive { - match crate::recursive::resolve_recursive( - &qname, - qtype, - &ctx.cache, - &query, - &ctx.root_hints, - &ctx.srtt, - ) - .await - { - Ok(resp) => (resp, QueryPath::Recursive, DnssecStatus::Indeterminate), - Err(e) => { - error!( - "{} | {:?} {} | RECURSIVE ERROR | {}", - src_addr, qtype, qname, e - ); - ( - DnsPacket::response_from(&query, ResultCode::SERVFAIL), - QueryPath::UpstreamError, - DnssecStatus::Indeterminate, + let key = (qname.clone(), qtype); + + enum Disposition { + Leader(broadcast::Sender>), + Follower(broadcast::Receiver>), + } + + let disposition = { + let mut inflight = ctx.inflight.lock().unwrap(); + if let Some(tx) = inflight.get(&key) { + Disposition::Follower(tx.subscribe()) + } else { + let (tx, _) = broadcast::channel::>(1); + inflight.insert(key.clone(), tx.clone()); + Disposition::Leader(tx) + } + }; + + match disposition { + Disposition::Follower(mut rx) => { + debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname); + match rx.recv().await { + Ok(Some(mut resp)) => { + resp.header.id = query.header.id; + (resp, QueryPath::Recursive, DnssecStatus::Indeterminate) + } + _ => ( + DnsPacket::response_from(&query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + DnssecStatus::Indeterminate, + ), + } + } + Disposition::Leader(tx) => { + // Drop guard: remove inflight entry even on panic/cancellation + let guard = InflightGuard { + inflight: &ctx.inflight, + key: key.clone(), + }; + + let result = crate::recursive::resolve_recursive( + &qname, + qtype, + &ctx.cache, + &query, + &ctx.root_hints, + &ctx.srtt, ) + .await; + + drop(guard); + + match result { + Ok(resp) => { + let _ = tx.send(Some(resp.clone())); + (resp, QueryPath::Recursive, DnssecStatus::Indeterminate) + } + Err(e) => { + let _ = tx.send(None); + error!( + "{} | {:?} {} | RECURSIVE ERROR | {}", + src_addr, qtype, qname, e + ); + ( + DnsPacket::response_from(&query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + DnssecStatus::Indeterminate, + ) + } + } } } } else { @@ -377,6 +429,17 @@ fn is_special_use_domain(qname: &str) -> bool { qname == "local" || qname.ends_with(".local") } +struct InflightGuard<'a> { + inflight: &'a Mutex>>>, + key: (String, QueryType), +} + +impl Drop for InflightGuard<'_> { + fn drop(&mut self) { + self.inflight.lock().unwrap().remove(&self.key); + } +} + fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket { use std::net::{Ipv4Addr, Ipv6Addr}; if qname == "ipv4only.arpa" { diff --git a/src/main.rs b/src/main.rs index 6ba3de2..3066fdd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -202,6 +202,7 @@ async fn main() -> numa::Result<()> { upstream_mode: config.upstream.mode, root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints), srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)), + inflight: std::sync::Mutex::new(std::collections::HashMap::new()), dnssec_enabled: config.dnssec.enabled, dnssec_strict: config.dnssec.strict, });