feat: in-flight query coalescing with COALESCED path #20
@@ -10,7 +10,7 @@ keywords = ["dns", "dns-server", "ad-blocking", "reverse-proxy", "developer-tool
|
||||
categories = ["network-programming", "development-tools"]
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] }
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time", "sync"] }
|
||||
axum = "0.8"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
@@ -953,6 +953,7 @@ mod tests {
|
||||
upstream_mode: crate::config::UpstreamMode::Forward,
|
||||
root_hints: Vec::new(),
|
||||
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
|
||||
inflight: Mutex::new(std::collections::HashMap::new()),
|
||||
dnssec_enabled: false,
|
||||
dnssec_strict: false,
|
||||
})
|
||||
|
||||
103
src/ctx.rs
103
src/ctx.rs
@@ -1,3 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Mutex, RwLock};
|
||||
@@ -7,6 +8,7 @@ use arc_swap::ArcSwap;
|
||||
use log::{debug, error, info, warn};
|
||||
use rustls::ServerConfig;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::blocklist::BlocklistStore;
|
||||
use crate::buffer::BytePacketBuffer;
|
||||
@@ -53,6 +55,7 @@ pub struct ServerCtx {
|
||||
pub upstream_mode: UpstreamMode,
|
||||
pub root_hints: Vec<SocketAddr>,
|
||||
pub srtt: RwLock<SrttCache>,
|
||||
pub inflight: Mutex<HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>>,
|
||||
pub dnssec_enabled: bool,
|
||||
pub dnssec_strict: bool,
|
||||
}
|
||||
@@ -172,27 +175,76 @@ pub async fn handle_query(
|
||||
}
|
||||
(resp, QueryPath::Cached, cached_dnssec)
|
||||
} else if ctx.upstream_mode == UpstreamMode::Recursive {
|
||||
match crate::recursive::resolve_recursive(
|
||||
&qname,
|
||||
qtype,
|
||||
&ctx.cache,
|
||||
&query,
|
||||
&ctx.root_hints,
|
||||
&ctx.srtt,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => (resp, QueryPath::Recursive, DnssecStatus::Indeterminate),
|
||||
Err(e) => {
|
||||
error!(
|
||||
"{} | {:?} {} | RECURSIVE ERROR | {}",
|
||||
src_addr, qtype, qname, e
|
||||
);
|
||||
(
|
||||
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
|
||||
QueryPath::UpstreamError,
|
||||
DnssecStatus::Indeterminate,
|
||||
let key = (qname.clone(), qtype);
|
||||
|
||||
enum Disposition {
|
||||
Leader(broadcast::Sender<Option<DnsPacket>>),
|
||||
Follower(broadcast::Receiver<Option<DnsPacket>>),
|
||||
}
|
||||
|
||||
let disposition = {
|
||||
let mut inflight = ctx.inflight.lock().unwrap();
|
||||
if let Some(tx) = inflight.get(&key) {
|
||||
Disposition::Follower(tx.subscribe())
|
||||
} else {
|
||||
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||
inflight.insert(key.clone(), tx.clone());
|
||||
Disposition::Leader(tx)
|
||||
}
|
||||
};
|
||||
|
||||
match disposition {
|
||||
Disposition::Follower(mut rx) => {
|
||||
debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname);
|
||||
match rx.recv().await {
|
||||
Ok(Some(mut resp)) => {
|
||||
resp.header.id = query.header.id;
|
||||
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
|
||||
}
|
||||
_ => (
|
||||
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
|
||||
QueryPath::UpstreamError,
|
||||
DnssecStatus::Indeterminate,
|
||||
),
|
||||
}
|
||||
}
|
||||
Disposition::Leader(tx) => {
|
||||
// Drop guard: remove inflight entry even on panic/cancellation
|
||||
let guard = InflightGuard {
|
||||
inflight: &ctx.inflight,
|
||||
key: key.clone(),
|
||||
};
|
||||
|
||||
let result = crate::recursive::resolve_recursive(
|
||||
&qname,
|
||||
qtype,
|
||||
&ctx.cache,
|
||||
&query,
|
||||
&ctx.root_hints,
|
||||
&ctx.srtt,
|
||||
)
|
||||
.await;
|
||||
|
||||
drop(guard);
|
||||
|
||||
match result {
|
||||
Ok(resp) => {
|
||||
let _ = tx.send(Some(resp.clone()));
|
||||
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(None);
|
||||
error!(
|
||||
"{} | {:?} {} | RECURSIVE ERROR | {}",
|
||||
src_addr, qtype, qname, e
|
||||
);
|
||||
(
|
||||
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
|
||||
QueryPath::UpstreamError,
|
||||
DnssecStatus::Indeterminate,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -377,6 +429,17 @@ fn is_special_use_domain(qname: &str) -> bool {
|
||||
qname == "local" || qname.ends_with(".local")
|
||||
}
|
||||
|
||||
struct InflightGuard<'a> {
|
||||
inflight: &'a Mutex<HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>>,
|
||||
key: (String, QueryType),
|
||||
}
|
||||
|
||||
impl Drop for InflightGuard<'_> {
|
||||
fn drop(&mut self) {
|
||||
self.inflight.lock().unwrap().remove(&self.key);
|
||||
}
|
||||
}
|
||||
|
||||
fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket {
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
if qname == "ipv4only.arpa" {
|
||||
|
||||
@@ -202,6 +202,7 @@ async fn main() -> numa::Result<()> {
|
||||
upstream_mode: config.upstream.mode,
|
||||
root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints),
|
||||
srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)),
|
||||
inflight: std::sync::Mutex::new(std::collections::HashMap::new()),
|
||||
dnssec_enabled: config.dnssec.enabled,
|
||||
dnssec_strict: config.dnssec.strict,
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user