feat: in-flight query coalescing with COALESCED path #20

Merged
razvandimescu merged 5 commits from feat/query-coalescing into main 2026-03-29 15:36:02 +08:00
4 changed files with 86 additions and 21 deletions
Showing only changes of commit 1028c1a48d - Show all commits

View File

@@ -10,7 +10,7 @@ keywords = ["dns", "dns-server", "ad-blocking", "reverse-proxy", "developer-tool
categories = ["network-programming", "development-tools"] categories = ["network-programming", "development-tools"]
[dependencies] [dependencies]
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] } tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time", "sync"] }
axum = "0.8" axum = "0.8"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"

View File

@@ -953,6 +953,7 @@ mod tests {
upstream_mode: crate::config::UpstreamMode::Forward, upstream_mode: crate::config::UpstreamMode::Forward,
root_hints: Vec::new(), root_hints: Vec::new(),
srtt: RwLock::new(crate::srtt::SrttCache::new(true)), srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
inflight: Mutex::new(std::collections::HashMap::new()),
dnssec_enabled: false, dnssec_enabled: false,
dnssec_strict: false, dnssec_strict: false,
}) })

View File

@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Mutex, RwLock}; use std::sync::{Mutex, RwLock};
@@ -7,6 +8,7 @@ use arc_swap::ArcSwap;
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use rustls::ServerConfig; use rustls::ServerConfig;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::sync::broadcast;
use crate::blocklist::BlocklistStore; use crate::blocklist::BlocklistStore;
use crate::buffer::BytePacketBuffer; use crate::buffer::BytePacketBuffer;
@@ -53,6 +55,7 @@ pub struct ServerCtx {
pub upstream_mode: UpstreamMode, pub upstream_mode: UpstreamMode,
pub root_hints: Vec<SocketAddr>, pub root_hints: Vec<SocketAddr>,
pub srtt: RwLock<SrttCache>, pub srtt: RwLock<SrttCache>,
pub inflight: Mutex<HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>>,
pub dnssec_enabled: bool, pub dnssec_enabled: bool,
pub dnssec_strict: bool, pub dnssec_strict: bool,
} }
@@ -172,27 +175,76 @@ pub async fn handle_query(
} }
(resp, QueryPath::Cached, cached_dnssec) (resp, QueryPath::Cached, cached_dnssec)
} else if ctx.upstream_mode == UpstreamMode::Recursive { } else if ctx.upstream_mode == UpstreamMode::Recursive {
match crate::recursive::resolve_recursive( let key = (qname.clone(), qtype);
&qname,
qtype, enum Disposition {
&ctx.cache, Leader(broadcast::Sender<Option<DnsPacket>>),
&query, Follower(broadcast::Receiver<Option<DnsPacket>>),
&ctx.root_hints, }
&ctx.srtt,
) let disposition = {
.await let mut inflight = ctx.inflight.lock().unwrap();
{ if let Some(tx) = inflight.get(&key) {
Ok(resp) => (resp, QueryPath::Recursive, DnssecStatus::Indeterminate), Disposition::Follower(tx.subscribe())
Err(e) => { } else {
error!( let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
"{} | {:?} {} | RECURSIVE ERROR | {}", inflight.insert(key.clone(), tx.clone());
src_addr, qtype, qname, e Disposition::Leader(tx)
); }
( };
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError, match disposition {
DnssecStatus::Indeterminate, Disposition::Follower(mut rx) => {
debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname);
match rx.recv().await {
Ok(Some(mut resp)) => {
resp.header.id = query.header.id;
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
}
_ => (
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
DnssecStatus::Indeterminate,
),
}
}
Disposition::Leader(tx) => {
// Drop guard: remove inflight entry even on panic/cancellation
let guard = InflightGuard {
inflight: &ctx.inflight,
key: key.clone(),
};
let result = crate::recursive::resolve_recursive(
&qname,
qtype,
&ctx.cache,
&query,
&ctx.root_hints,
&ctx.srtt,
) )
.await;
drop(guard);
match result {
Ok(resp) => {
let _ = tx.send(Some(resp.clone()));
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
}
Err(e) => {
let _ = tx.send(None);
error!(
"{} | {:?} {} | RECURSIVE ERROR | {}",
src_addr, qtype, qname, e
);
(
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
DnssecStatus::Indeterminate,
)
}
}
} }
} }
} else { } else {
@@ -377,6 +429,17 @@ fn is_special_use_domain(qname: &str) -> bool {
qname == "local" || qname.ends_with(".local") qname == "local" || qname.ends_with(".local")
} }
struct InflightGuard<'a> {
inflight: &'a Mutex<HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>>,
key: (String, QueryType),
}
impl Drop for InflightGuard<'_> {
fn drop(&mut self) {
self.inflight.lock().unwrap().remove(&self.key);
}
}
fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket { fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket {
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
if qname == "ipv4only.arpa" { if qname == "ipv4only.arpa" {

View File

@@ -202,6 +202,7 @@ async fn main() -> numa::Result<()> {
upstream_mode: config.upstream.mode, upstream_mode: config.upstream.mode,
root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints), root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints),
srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)), srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)),
inflight: std::sync::Mutex::new(std::collections::HashMap::new()),
dnssec_enabled: config.dnssec.enabled, dnssec_enabled: config.dnssec.enabled,
dnssec_strict: config.dnssec.strict, dnssec_strict: config.dnssec.strict,
}); });