feat: in-flight query coalescing with COALESCED path #20

Merged
razvandimescu merged 5 commits from feat/query-coalescing into main 2026-03-29 15:36:02 +08:00
4 changed files with 86 additions and 21 deletions
Showing only changes of commit 1028c1a48d - Show all commits

View File

@@ -10,7 +10,7 @@ keywords = ["dns", "dns-server", "ad-blocking", "reverse-proxy", "developer-tool
categories = ["network-programming", "development-tools"] categories = ["network-programming", "development-tools"]
[dependencies] [dependencies]
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] } tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time", "sync"] }
axum = "0.8" axum = "0.8"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"

View File

@@ -953,6 +953,7 @@ mod tests {
upstream_mode: crate::config::UpstreamMode::Forward, upstream_mode: crate::config::UpstreamMode::Forward,
root_hints: Vec::new(), root_hints: Vec::new(),
srtt: RwLock::new(crate::srtt::SrttCache::new(true)), srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
inflight: Mutex::new(std::collections::HashMap::new()),
dnssec_enabled: false, dnssec_enabled: false,
dnssec_strict: false, dnssec_strict: false,
}) })

View File

@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Mutex, RwLock}; use std::sync::{Mutex, RwLock};
@@ -7,6 +8,7 @@ use arc_swap::ArcSwap;
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use rustls::ServerConfig; use rustls::ServerConfig;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::sync::broadcast;
use crate::blocklist::BlocklistStore; use crate::blocklist::BlocklistStore;
use crate::buffer::BytePacketBuffer; use crate::buffer::BytePacketBuffer;
@@ -53,6 +55,7 @@ pub struct ServerCtx {
pub upstream_mode: UpstreamMode, pub upstream_mode: UpstreamMode,
pub root_hints: Vec<SocketAddr>, pub root_hints: Vec<SocketAddr>,
pub srtt: RwLock<SrttCache>, pub srtt: RwLock<SrttCache>,
pub inflight: Mutex<HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>>,
pub dnssec_enabled: bool, pub dnssec_enabled: bool,
pub dnssec_strict: bool, pub dnssec_strict: bool,
} }
@@ -172,7 +175,47 @@ pub async fn handle_query(
} }
(resp, QueryPath::Cached, cached_dnssec) (resp, QueryPath::Cached, cached_dnssec)
} else if ctx.upstream_mode == UpstreamMode::Recursive { } else if ctx.upstream_mode == UpstreamMode::Recursive {
match crate::recursive::resolve_recursive( let key = (qname.clone(), qtype);
enum Disposition {
Leader(broadcast::Sender<Option<DnsPacket>>),
Follower(broadcast::Receiver<Option<DnsPacket>>),
}
let disposition = {
let mut inflight = ctx.inflight.lock().unwrap();
if let Some(tx) = inflight.get(&key) {
Disposition::Follower(tx.subscribe())
} else {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
inflight.insert(key.clone(), tx.clone());
Disposition::Leader(tx)
}
};
match disposition {
Disposition::Follower(mut rx) => {
debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname);
match rx.recv().await {
Ok(Some(mut resp)) => {
resp.header.id = query.header.id;
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
}
_ => (
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
DnssecStatus::Indeterminate,
),
}
}
Disposition::Leader(tx) => {
// Drop guard: remove inflight entry even on panic/cancellation
let guard = InflightGuard {
inflight: &ctx.inflight,
key: key.clone(),
};
let result = crate::recursive::resolve_recursive(
&qname, &qname,
qtype, qtype,
&ctx.cache, &ctx.cache,
@@ -180,10 +223,17 @@ pub async fn handle_query(
&ctx.root_hints, &ctx.root_hints,
&ctx.srtt, &ctx.srtt,
) )
.await .await;
{
Ok(resp) => (resp, QueryPath::Recursive, DnssecStatus::Indeterminate), drop(guard);
match result {
Ok(resp) => {
let _ = tx.send(Some(resp.clone()));
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
}
Err(e) => { Err(e) => {
let _ = tx.send(None);
error!( error!(
"{} | {:?} {} | RECURSIVE ERROR | {}", "{} | {:?} {} | RECURSIVE ERROR | {}",
src_addr, qtype, qname, e src_addr, qtype, qname, e
@@ -195,6 +245,8 @@ pub async fn handle_query(
) )
} }
} }
}
}
} else { } else {
let upstream = let upstream =
match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) { match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) {
@@ -377,6 +429,17 @@ fn is_special_use_domain(qname: &str) -> bool {
qname == "local" || qname.ends_with(".local") qname == "local" || qname.ends_with(".local")
} }
struct InflightGuard<'a> {
inflight: &'a Mutex<HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>>,
key: (String, QueryType),
}
impl Drop for InflightGuard<'_> {
fn drop(&mut self) {
self.inflight.lock().unwrap().remove(&self.key);
}
}
fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket { fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket {
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
if qname == "ipv4only.arpa" { if qname == "ipv4only.arpa" {

View File

@@ -202,6 +202,7 @@ async fn main() -> numa::Result<()> {
upstream_mode: config.upstream.mode, upstream_mode: config.upstream.mode,
root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints), root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints),
srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)), srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)),
inflight: std::sync::Mutex::new(std::collections::HashMap::new()),
dnssec_enabled: config.dnssec.enabled, dnssec_enabled: config.dnssec.enabled,
dnssec_strict: config.dnssec.strict, dnssec_strict: config.dnssec.strict,
}); });