feat: in-flight query coalescing with COALESCED path #20

Merged
razvandimescu merged 5 commits from feat/query-coalescing into main 2026-03-29 15:36:02 +08:00
4 changed files with 86 additions and 21 deletions
Showing only changes of commit 1028c1a48d - Show all commits

View File

@@ -10,7 +10,7 @@ keywords = ["dns", "dns-server", "ad-blocking", "reverse-proxy", "developer-tool
categories = ["network-programming", "development-tools"]
[dependencies]
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] }
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time", "sync"] }
axum = "0.8"
serde = { version = "1", features = ["derive"] }
serde_json = "1"

View File

@@ -953,6 +953,7 @@ mod tests {
upstream_mode: crate::config::UpstreamMode::Forward,
root_hints: Vec::new(),
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
inflight: Mutex::new(std::collections::HashMap::new()),
dnssec_enabled: false,
dnssec_strict: false,
})

View File

@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::{Mutex, RwLock};
@@ -7,6 +8,7 @@ use arc_swap::ArcSwap;
use log::{debug, error, info, warn};
use rustls::ServerConfig;
use tokio::net::UdpSocket;
use tokio::sync::broadcast;
use crate::blocklist::BlocklistStore;
use crate::buffer::BytePacketBuffer;
@@ -53,6 +55,7 @@ pub struct ServerCtx {
pub upstream_mode: UpstreamMode,
pub root_hints: Vec<SocketAddr>,
pub srtt: RwLock<SrttCache>,
pub inflight: Mutex<HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>>,
pub dnssec_enabled: bool,
pub dnssec_strict: bool,
}
@@ -172,27 +175,76 @@ pub async fn handle_query(
}
(resp, QueryPath::Cached, cached_dnssec)
} else if ctx.upstream_mode == UpstreamMode::Recursive {
match crate::recursive::resolve_recursive(
&qname,
qtype,
&ctx.cache,
&query,
&ctx.root_hints,
&ctx.srtt,
)
.await
{
Ok(resp) => (resp, QueryPath::Recursive, DnssecStatus::Indeterminate),
Err(e) => {
error!(
"{} | {:?} {} | RECURSIVE ERROR | {}",
src_addr, qtype, qname, e
);
(
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
DnssecStatus::Indeterminate,
let key = (qname.clone(), qtype);
enum Disposition {
Leader(broadcast::Sender<Option<DnsPacket>>),
Follower(broadcast::Receiver<Option<DnsPacket>>),
}
let disposition = {
let mut inflight = ctx.inflight.lock().unwrap();
if let Some(tx) = inflight.get(&key) {
Disposition::Follower(tx.subscribe())
} else {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
inflight.insert(key.clone(), tx.clone());
Disposition::Leader(tx)
}
};
match disposition {
Disposition::Follower(mut rx) => {
debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname);
match rx.recv().await {
Ok(Some(mut resp)) => {
resp.header.id = query.header.id;
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
}
_ => (
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
DnssecStatus::Indeterminate,
),
}
}
Disposition::Leader(tx) => {
// Drop guard: remove inflight entry even on panic/cancellation
let guard = InflightGuard {
inflight: &ctx.inflight,
key: key.clone(),
};
let result = crate::recursive::resolve_recursive(
&qname,
qtype,
&ctx.cache,
&query,
&ctx.root_hints,
&ctx.srtt,
)
.await;
drop(guard);
match result {
Ok(resp) => {
let _ = tx.send(Some(resp.clone()));
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
}
Err(e) => {
let _ = tx.send(None);
error!(
"{} | {:?} {} | RECURSIVE ERROR | {}",
src_addr, qtype, qname, e
);
(
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
DnssecStatus::Indeterminate,
)
}
}
}
}
} else {
@@ -377,6 +429,17 @@ fn is_special_use_domain(qname: &str) -> bool {
qname == "local" || qname.ends_with(".local")
}
struct InflightGuard<'a> {
inflight: &'a Mutex<HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>>,
key: (String, QueryType),
}
impl Drop for InflightGuard<'_> {
fn drop(&mut self) {
self.inflight.lock().unwrap().remove(&self.key);
}
}
fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket {
use std::net::{Ipv4Addr, Ipv6Addr};
if qname == "ipv4only.arpa" {

View File

@@ -202,6 +202,7 @@ async fn main() -> numa::Result<()> {
upstream_mode: config.upstream.mode,
root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints),
srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)),
inflight: std::sync::Mutex::new(std::collections::HashMap::new()),
dnssec_enabled: config.dnssec.enabled,
dnssec_strict: config.dnssec.strict,
});