feat: in-flight query coalescing with COALESCED path #20
@@ -10,7 +10,7 @@ keywords = ["dns", "dns-server", "ad-blocking", "reverse-proxy", "developer-tool
|
|||||||
categories = ["network-programming", "development-tools"]
|
categories = ["network-programming", "development-tools"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] }
|
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time", "sync"] }
|
||||||
axum = "0.8"
|
axum = "0.8"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
|
|||||||
@@ -285,6 +285,7 @@ body {
|
|||||||
.path-tag.OVERRIDE { background: rgba(82, 122, 82, 0.12); color: var(--emerald); }
|
.path-tag.OVERRIDE { background: rgba(82, 122, 82, 0.12); color: var(--emerald); }
|
||||||
.path-tag.SERVFAIL { background: rgba(181, 68, 58, 0.12); color: var(--rose); }
|
.path-tag.SERVFAIL { background: rgba(181, 68, 58, 0.12); color: var(--rose); }
|
||||||
.path-tag.BLOCKED { background: rgba(163, 152, 136, 0.15); color: var(--text-dim); }
|
.path-tag.BLOCKED { background: rgba(163, 152, 136, 0.15); color: var(--text-dim); }
|
||||||
|
.path-tag.COALESCED { background: rgba(138, 104, 158, 0.12); color: var(--violet-dim); }
|
||||||
|
|
||||||
/* Sidebar panels */
|
/* Sidebar panels */
|
||||||
.sidebar {
|
.sidebar {
|
||||||
@@ -547,6 +548,8 @@ body {
|
|||||||
<select id="logFilterPath" onchange="applyLogFilter()"
|
<select id="logFilterPath" onchange="applyLogFilter()"
|
||||||
style="font-family:var(--font-mono);font-size:0.7rem;padding:0.25rem 0.4rem;border:1px solid var(--border);border-radius:4px;background:var(--bg-surface);color:var(--text-secondary);outline:none;">
|
style="font-family:var(--font-mono);font-size:0.7rem;padding:0.25rem 0.4rem;border:1px solid var(--border);border-radius:4px;background:var(--bg-surface);color:var(--text-secondary);outline:none;">
|
||||||
<option value="">all paths</option>
|
<option value="">all paths</option>
|
||||||
|
<option value="RECURSIVE">recursive</option>
|
||||||
|
<option value="COALESCED">coalesced</option>
|
||||||
<option value="FORWARD">forward</option>
|
<option value="FORWARD">forward</option>
|
||||||
<option value="CACHED">cached</option>
|
<option value="CACHED">cached</option>
|
||||||
<option value="BLOCKED">blocked</option>
|
<option value="BLOCKED">blocked</option>
|
||||||
|
|||||||
@@ -182,6 +182,7 @@ struct QueriesStats {
|
|||||||
total: u64,
|
total: u64,
|
||||||
forwarded: u64,
|
forwarded: u64,
|
||||||
recursive: u64,
|
recursive: u64,
|
||||||
|
coalesced: u64,
|
||||||
cached: u64,
|
cached: u64,
|
||||||
local: u64,
|
local: u64,
|
||||||
overridden: u64,
|
overridden: u64,
|
||||||
@@ -499,6 +500,7 @@ async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
|
|||||||
total: snap.total,
|
total: snap.total,
|
||||||
forwarded: snap.forwarded,
|
forwarded: snap.forwarded,
|
||||||
recursive: snap.recursive,
|
recursive: snap.recursive,
|
||||||
|
coalesced: snap.coalesced,
|
||||||
cached: snap.cached,
|
cached: snap.cached,
|
||||||
local: snap.local,
|
local: snap.local,
|
||||||
overridden: snap.overridden,
|
overridden: snap.overridden,
|
||||||
@@ -953,6 +955,7 @@ mod tests {
|
|||||||
upstream_mode: crate::config::UpstreamMode::Forward,
|
upstream_mode: crate::config::UpstreamMode::Forward,
|
||||||
root_hints: Vec::new(),
|
root_hints: Vec::new(),
|
||||||
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
|
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
|
||||||
|
inflight: Mutex::new(std::collections::HashMap::new()),
|
||||||
dnssec_enabled: false,
|
dnssec_enabled: false,
|
||||||
dnssec_strict: false,
|
dnssec_strict: false,
|
||||||
})
|
})
|
||||||
|
|||||||
453
src/ctx.rs
453
src/ctx.rs
@@ -1,3 +1,4 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::{Mutex, RwLock};
|
use std::sync::{Mutex, RwLock};
|
||||||
@@ -7,6 +8,9 @@ use arc_swap::ArcSwap;
|
|||||||
use log::{debug, error, info, warn};
|
use log::{debug, error, info, warn};
|
||||||
use rustls::ServerConfig;
|
use rustls::ServerConfig;
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
|
use tokio::sync::broadcast;
|
||||||
|
|
||||||
|
type InflightMap = HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>;
|
||||||
|
|
||||||
use crate::blocklist::BlocklistStore;
|
use crate::blocklist::BlocklistStore;
|
||||||
use crate::buffer::BytePacketBuffer;
|
use crate::buffer::BytePacketBuffer;
|
||||||
@@ -53,6 +57,7 @@ pub struct ServerCtx {
|
|||||||
pub upstream_mode: UpstreamMode,
|
pub upstream_mode: UpstreamMode,
|
||||||
pub root_hints: Vec<SocketAddr>,
|
pub root_hints: Vec<SocketAddr>,
|
||||||
pub srtt: RwLock<SrttCache>,
|
pub srtt: RwLock<SrttCache>,
|
||||||
|
pub inflight: Mutex<InflightMap>,
|
||||||
pub dnssec_enabled: bool,
|
pub dnssec_enabled: bool,
|
||||||
pub dnssec_strict: bool,
|
pub dnssec_strict: bool,
|
||||||
}
|
}
|
||||||
@@ -172,7 +177,32 @@ pub async fn handle_query(
|
|||||||
}
|
}
|
||||||
(resp, QueryPath::Cached, cached_dnssec)
|
(resp, QueryPath::Cached, cached_dnssec)
|
||||||
} else if ctx.upstream_mode == UpstreamMode::Recursive {
|
} else if ctx.upstream_mode == UpstreamMode::Recursive {
|
||||||
match crate::recursive::resolve_recursive(
|
let key = (qname.clone(), qtype);
|
||||||
|
let disposition = acquire_inflight(&ctx.inflight, key.clone());
|
||||||
|
|
||||||
|
match disposition {
|
||||||
|
Disposition::Follower(mut rx) => {
|
||||||
|
debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname);
|
||||||
|
match rx.recv().await {
|
||||||
|
Ok(Some(mut resp)) => {
|
||||||
|
resp.header.id = query.header.id;
|
||||||
|
(resp, QueryPath::Coalesced, DnssecStatus::Indeterminate)
|
||||||
|
}
|
||||||
|
_ => (
|
||||||
|
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
|
||||||
|
QueryPath::UpstreamError,
|
||||||
|
DnssecStatus::Indeterminate,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Disposition::Leader(tx) => {
|
||||||
|
// Drop guard: remove inflight entry even on panic/cancellation
|
||||||
|
let guard = InflightGuard {
|
||||||
|
inflight: &ctx.inflight,
|
||||||
|
key: key.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = crate::recursive::resolve_recursive(
|
||||||
&qname,
|
&qname,
|
||||||
qtype,
|
qtype,
|
||||||
&ctx.cache,
|
&ctx.cache,
|
||||||
@@ -180,10 +210,17 @@ pub async fn handle_query(
|
|||||||
&ctx.root_hints,
|
&ctx.root_hints,
|
||||||
&ctx.srtt,
|
&ctx.srtt,
|
||||||
)
|
)
|
||||||
.await
|
.await;
|
||||||
{
|
|
||||||
Ok(resp) => (resp, QueryPath::Recursive, DnssecStatus::Indeterminate),
|
drop(guard);
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(resp) => {
|
||||||
|
let _ = tx.send(Some(resp.clone()));
|
||||||
|
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
let _ = tx.send(None);
|
||||||
error!(
|
error!(
|
||||||
"{} | {:?} {} | RECURSIVE ERROR | {}",
|
"{} | {:?} {} | RECURSIVE ERROR | {}",
|
||||||
src_addr, qtype, qname, e
|
src_addr, qtype, qname, e
|
||||||
@@ -195,6 +232,8 @@ pub async fn handle_query(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
let upstream =
|
let upstream =
|
||||||
match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) {
|
match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) {
|
||||||
@@ -377,6 +416,47 @@ fn is_special_use_domain(qname: &str) -> bool {
|
|||||||
qname == "local" || qname.ends_with(".local")
|
qname == "local" || qname.ends_with(".local")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum Disposition {
|
||||||
|
Leader(broadcast::Sender<Option<DnsPacket>>),
|
||||||
|
Follower(broadcast::Receiver<Option<DnsPacket>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn acquire_inflight(inflight: &Mutex<InflightMap>, key: (String, QueryType)) -> Disposition {
|
||||||
|
let mut map = inflight.lock().unwrap();
|
||||||
|
if let Some(tx) = map.get(&key) {
|
||||||
|
Disposition::Follower(tx.subscribe())
|
||||||
|
} else {
|
||||||
|
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||||
|
map.insert(key, tx.clone());
|
||||||
|
Disposition::Leader(tx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct InflightGuard<'a> {
|
||||||
|
inflight: &'a Mutex<InflightMap>,
|
||||||
|
key: (String, QueryType),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for InflightGuard<'_> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.inflight.lock().unwrap().remove(&self.key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a wire-format DNS query packet for the given domain and type.
|
||||||
|
#[cfg(test)]
|
||||||
|
fn build_wire_query(id: u16, domain: &str, qtype: QueryType) -> BytePacketBuffer {
|
||||||
|
let mut pkt = DnsPacket::new();
|
||||||
|
pkt.header.id = id;
|
||||||
|
pkt.header.recursion_desired = true;
|
||||||
|
pkt.header.questions = 1;
|
||||||
|
pkt.questions
|
||||||
|
.push(crate::question::DnsQuestion::new(domain.to_string(), qtype));
|
||||||
|
let mut buf = BytePacketBuffer::new();
|
||||||
|
pkt.write(&mut buf).unwrap();
|
||||||
|
BytePacketBuffer::from_bytes(buf.filled())
|
||||||
|
}
|
||||||
|
|
||||||
fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket {
|
fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket {
|
||||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||||
if qname == "ipv4only.arpa" {
|
if qname == "ipv4only.arpa" {
|
||||||
@@ -410,3 +490,368 @@ fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> Dns
|
|||||||
DnsPacket::response_from(query, ResultCode::NXDOMAIN)
|
DnsPacket::response_from(query, ResultCode::NXDOMAIN)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::{Ipv4Addr, SocketAddr};
|
||||||
|
use std::sync::{Arc, Mutex, RwLock};
|
||||||
|
use tokio::sync::broadcast;
|
||||||
|
|
||||||
|
// ---- InflightGuard unit tests ----
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inflight_guard_removes_key_on_drop() {
|
||||||
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||||
|
let key = ("example.com".to_string(), QueryType::A);
|
||||||
|
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||||
|
map.lock().unwrap().insert(key.clone(), tx);
|
||||||
|
|
||||||
|
assert_eq!(map.lock().unwrap().len(), 1);
|
||||||
|
{
|
||||||
|
let _guard = InflightGuard {
|
||||||
|
inflight: &map,
|
||||||
|
key: key.clone(),
|
||||||
|
};
|
||||||
|
} // guard dropped here
|
||||||
|
assert!(map.lock().unwrap().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inflight_guard_only_removes_own_key() {
|
||||||
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||||
|
let key_a = ("a.com".to_string(), QueryType::A);
|
||||||
|
let key_b = ("b.com".to_string(), QueryType::A);
|
||||||
|
let (tx_a, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||||
|
let (tx_b, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||||
|
map.lock().unwrap().insert(key_a.clone(), tx_a);
|
||||||
|
map.lock().unwrap().insert(key_b.clone(), tx_b);
|
||||||
|
|
||||||
|
{
|
||||||
|
let _guard = InflightGuard {
|
||||||
|
inflight: &map,
|
||||||
|
key: key_a,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
let m = map.lock().unwrap();
|
||||||
|
assert_eq!(m.len(), 1);
|
||||||
|
assert!(m.contains_key(&key_b));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inflight_guard_same_domain_different_qtype_independent() {
|
||||||
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||||
|
let key_a = ("example.com".to_string(), QueryType::A);
|
||||||
|
let key_aaaa = ("example.com".to_string(), QueryType::AAAA);
|
||||||
|
let (tx_a, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||||
|
let (tx_aaaa, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||||
|
map.lock().unwrap().insert(key_a.clone(), tx_a);
|
||||||
|
map.lock().unwrap().insert(key_aaaa.clone(), tx_aaaa);
|
||||||
|
|
||||||
|
{
|
||||||
|
let _guard = InflightGuard {
|
||||||
|
inflight: &map,
|
||||||
|
key: key_a,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
let m = map.lock().unwrap();
|
||||||
|
assert_eq!(m.len(), 1);
|
||||||
|
assert!(m.contains_key(&key_aaaa));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Coalescing disposition tests (via acquire_inflight) ----
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn first_caller_becomes_leader() {
|
||||||
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||||
|
let key = ("test.com".to_string(), QueryType::A);
|
||||||
|
|
||||||
|
let d = acquire_inflight(&map, key.clone());
|
||||||
|
assert!(matches!(d, Disposition::Leader(_)));
|
||||||
|
assert_eq!(map.lock().unwrap().len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn second_caller_becomes_follower() {
|
||||||
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||||
|
let key = ("test.com".to_string(), QueryType::A);
|
||||||
|
|
||||||
|
let _leader = acquire_inflight(&map, key.clone());
|
||||||
|
let follower = acquire_inflight(&map, key);
|
||||||
|
assert!(matches!(follower, Disposition::Follower(_)));
|
||||||
|
// Map still has exactly 1 entry — follower subscribes, doesn't insert
|
||||||
|
assert_eq!(map.lock().unwrap().len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn leader_broadcast_reaches_follower() {
|
||||||
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||||
|
let key = ("test.com".to_string(), QueryType::A);
|
||||||
|
|
||||||
|
let leader = acquire_inflight(&map, key.clone());
|
||||||
|
let follower = acquire_inflight(&map, key);
|
||||||
|
|
||||||
|
let tx = match leader {
|
||||||
|
Disposition::Leader(tx) => tx,
|
||||||
|
_ => panic!("expected leader"),
|
||||||
|
};
|
||||||
|
let mut rx = match follower {
|
||||||
|
Disposition::Follower(rx) => rx,
|
||||||
|
_ => panic!("expected follower"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut resp = DnsPacket::new();
|
||||||
|
resp.header.id = 42;
|
||||||
|
resp.answers.push(DnsRecord::A {
|
||||||
|
domain: "test.com".into(),
|
||||||
|
addr: Ipv4Addr::new(1, 2, 3, 4),
|
||||||
|
ttl: 300,
|
||||||
|
});
|
||||||
|
let _ = tx.send(Some(resp));
|
||||||
|
|
||||||
|
let received = rx.recv().await.unwrap().unwrap();
|
||||||
|
assert_eq!(received.header.id, 42);
|
||||||
|
assert_eq!(received.answers.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn leader_none_signals_failure_to_follower() {
|
||||||
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||||
|
let key = ("test.com".to_string(), QueryType::A);
|
||||||
|
|
||||||
|
let leader = acquire_inflight(&map, key.clone());
|
||||||
|
let follower = acquire_inflight(&map, key);
|
||||||
|
|
||||||
|
let tx = match leader {
|
||||||
|
Disposition::Leader(tx) => tx,
|
||||||
|
_ => panic!("expected leader"),
|
||||||
|
};
|
||||||
|
let mut rx = match follower {
|
||||||
|
Disposition::Follower(rx) => rx,
|
||||||
|
_ => panic!("expected follower"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let _ = tx.send(None);
|
||||||
|
assert!(rx.recv().await.unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn multiple_followers_all_receive_via_acquire() {
|
||||||
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||||
|
let key = ("multi.com".to_string(), QueryType::A);
|
||||||
|
|
||||||
|
let leader = acquire_inflight(&map, key.clone());
|
||||||
|
let f1 = acquire_inflight(&map, key.clone());
|
||||||
|
let f2 = acquire_inflight(&map, key.clone());
|
||||||
|
let f3 = acquire_inflight(&map, key);
|
||||||
|
|
||||||
|
let tx = match leader {
|
||||||
|
Disposition::Leader(tx) => tx,
|
||||||
|
_ => panic!("expected leader"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut resp = DnsPacket::new();
|
||||||
|
resp.answers.push(DnsRecord::A {
|
||||||
|
domain: "multi.com".into(),
|
||||||
|
addr: Ipv4Addr::new(10, 0, 0, 1),
|
||||||
|
ttl: 60,
|
||||||
|
});
|
||||||
|
let _ = tx.send(Some(resp));
|
||||||
|
|
||||||
|
for f in [f1, f2, f3] {
|
||||||
|
let mut rx = match f {
|
||||||
|
Disposition::Follower(rx) => rx,
|
||||||
|
_ => panic!("expected follower"),
|
||||||
|
};
|
||||||
|
let r = rx.recv().await.unwrap().unwrap();
|
||||||
|
assert_eq!(r.answers.len(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Integration: concurrent handle_query coalescing ----
|
||||||
|
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
/// Spawn a slow TCP DNS server that delays `delay` before responding.
|
||||||
|
/// Returns (addr, query_count) where query_count is an Arc<AtomicU32>
|
||||||
|
/// tracking how many queries were actually resolved (not coalesced).
|
||||||
|
async fn spawn_slow_dns_server(
|
||||||
|
delay: Duration,
|
||||||
|
) -> (SocketAddr, Arc<std::sync::atomic::AtomicU32>) {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let count = Arc::new(std::sync::atomic::AtomicU32::new(0));
|
||||||
|
let count_clone = count.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
let (mut stream, _) = match listener.accept().await {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(_) => break,
|
||||||
|
};
|
||||||
|
let count = count_clone.clone();
|
||||||
|
let delay = delay;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut len_buf = [0u8; 2];
|
||||||
|
if stream.read_exact(&mut len_buf).await.is_err() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let len = u16::from_be_bytes(len_buf) as usize;
|
||||||
|
let mut data = vec![0u8; len];
|
||||||
|
if stream.read_exact(&mut data).await.is_err() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut buf = BytePacketBuffer::from_bytes(&data);
|
||||||
|
let query = match DnsPacket::from_buffer(&mut buf) {
|
||||||
|
Ok(q) => q,
|
||||||
|
Err(_) => return,
|
||||||
|
};
|
||||||
|
|
||||||
|
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
|
||||||
|
// Deliberate delay to create coalescing window
|
||||||
|
tokio::time::sleep(delay).await;
|
||||||
|
|
||||||
|
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
||||||
|
resp.header.authoritative_answer = true;
|
||||||
|
if let Some(q) = query.questions.first() {
|
||||||
|
resp.answers.push(DnsRecord::A {
|
||||||
|
domain: q.name.clone(),
|
||||||
|
addr: Ipv4Addr::new(10, 0, 0, 1),
|
||||||
|
ttl: 300,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut resp_buf = BytePacketBuffer::new();
|
||||||
|
if resp.write(&mut resp_buf).is_err() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let resp_bytes = resp_buf.filled();
|
||||||
|
let mut out = Vec::with_capacity(2 + resp_bytes.len());
|
||||||
|
out.extend_from_slice(&(resp_bytes.len() as u16).to_be_bytes());
|
||||||
|
out.extend_from_slice(resp_bytes);
|
||||||
|
let _ = stream.write_all(&out).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
(addr, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn test_recursive_ctx(root_hint: SocketAddr) -> Arc<ServerCtx> {
|
||||||
|
let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
Arc::new(ServerCtx {
|
||||||
|
socket,
|
||||||
|
zone_map: HashMap::new(),
|
||||||
|
cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)),
|
||||||
|
stats: Mutex::new(crate::stats::ServerStats::new()),
|
||||||
|
overrides: RwLock::new(crate::override_store::OverrideStore::new()),
|
||||||
|
blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()),
|
||||||
|
query_log: Mutex::new(crate::query_log::QueryLog::new(100)),
|
||||||
|
services: Mutex::new(crate::service_store::ServiceStore::new()),
|
||||||
|
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
|
||||||
|
forwarding_rules: Vec::new(),
|
||||||
|
upstream: Mutex::new(crate::forward::Upstream::Udp(
|
||||||
|
"127.0.0.1:53".parse().unwrap(),
|
||||||
|
)),
|
||||||
|
upstream_auto: false,
|
||||||
|
upstream_port: 53,
|
||||||
|
lan_ip: Mutex::new(Ipv4Addr::LOCALHOST),
|
||||||
|
timeout: Duration::from_secs(3),
|
||||||
|
proxy_tld: "numa".to_string(),
|
||||||
|
proxy_tld_suffix: ".numa".to_string(),
|
||||||
|
lan_enabled: false,
|
||||||
|
config_path: "/tmp/test-numa.toml".to_string(),
|
||||||
|
config_found: false,
|
||||||
|
config_dir: std::path::PathBuf::from("/tmp"),
|
||||||
|
data_dir: std::path::PathBuf::from("/tmp"),
|
||||||
|
tls_config: None,
|
||||||
|
upstream_mode: crate::config::UpstreamMode::Recursive,
|
||||||
|
root_hints: vec![root_hint],
|
||||||
|
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
|
||||||
|
inflight: Mutex::new(HashMap::new()),
|
||||||
|
dnssec_enabled: false,
|
||||||
|
dnssec_strict: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn concurrent_queries_coalesce_to_single_resolution() {
|
||||||
|
// Force TCP-only so mock server works
|
||||||
|
crate::recursive::UDP_DISABLED.store(true, std::sync::atomic::Ordering::Release);
|
||||||
|
|
||||||
|
let (server_addr, query_count) = spawn_slow_dns_server(Duration::from_millis(200)).await;
|
||||||
|
let ctx = test_recursive_ctx(server_addr).await;
|
||||||
|
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||||
|
|
||||||
|
// Fire 5 concurrent queries for the same (domain, A)
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
for i in 0..5u16 {
|
||||||
|
let ctx = ctx.clone();
|
||||||
|
let buf = build_wire_query(100 + i, "coalesce-test.example.com", QueryType::A);
|
||||||
|
handles.push(tokio::spawn(
|
||||||
|
async move { handle_query(buf, src, &ctx).await },
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
for h in handles {
|
||||||
|
h.await.unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only 1 resolution should have reached the upstream server
|
||||||
|
let actual = query_count.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
|
assert_eq!(actual, 1, "expected 1 upstream query, got {}", actual);
|
||||||
|
|
||||||
|
// Inflight map must be empty after all queries complete
|
||||||
|
assert!(ctx.inflight.lock().unwrap().is_empty());
|
||||||
|
|
||||||
|
crate::recursive::reset_udp_state();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn different_qtypes_not_coalesced() {
|
||||||
|
crate::recursive::UDP_DISABLED.store(true, std::sync::atomic::Ordering::Release);
|
||||||
|
|
||||||
|
let (server_addr, query_count) = spawn_slow_dns_server(Duration::from_millis(100)).await;
|
||||||
|
let ctx = test_recursive_ctx(server_addr).await;
|
||||||
|
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||||
|
|
||||||
|
// Fire A and AAAA concurrently — should NOT coalesce
|
||||||
|
let ctx_ref = ctx.clone();
|
||||||
|
let ctx_ref2 = ctx.clone();
|
||||||
|
let buf_a = build_wire_query(200, "different-qt.example.com", QueryType::A);
|
||||||
|
let buf_aaaa = build_wire_query(201, "different-qt.example.com", QueryType::AAAA);
|
||||||
|
|
||||||
|
let h1 = tokio::spawn(async move { handle_query(buf_a, src, &ctx_ref).await });
|
||||||
|
let h2 = tokio::spawn(async move { handle_query(buf_aaaa, src, &ctx_ref2).await });
|
||||||
|
|
||||||
|
h1.await.unwrap().unwrap();
|
||||||
|
h2.await.unwrap().unwrap();
|
||||||
|
|
||||||
|
let actual = query_count.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
|
assert!(
|
||||||
|
actual >= 2,
|
||||||
|
"A and AAAA should resolve independently, got {}",
|
||||||
|
actual
|
||||||
|
);
|
||||||
|
assert!(ctx.inflight.lock().unwrap().is_empty());
|
||||||
|
|
||||||
|
crate::recursive::reset_udp_state();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn inflight_map_cleaned_after_upstream_error() {
|
||||||
|
// Server that rejects everything — no server running at all
|
||||||
|
let bogus_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||||
|
let ctx = test_recursive_ctx(bogus_addr).await;
|
||||||
|
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||||
|
|
||||||
|
let buf = build_wire_query(300, "will-fail.example.com", QueryType::A);
|
||||||
|
let _ = handle_query(buf, src, &ctx).await;
|
||||||
|
|
||||||
|
// Map must be clean even after error
|
||||||
|
assert!(ctx.inflight.lock().unwrap().is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -202,6 +202,7 @@ async fn main() -> numa::Result<()> {
|
|||||||
upstream_mode: config.upstream.mode,
|
upstream_mode: config.upstream.mode,
|
||||||
root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints),
|
root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints),
|
||||||
srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)),
|
srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)),
|
||||||
|
inflight: std::sync::Mutex::new(std::collections::HashMap::new()),
|
||||||
dnssec_enabled: config.dnssec.enabled,
|
dnssec_enabled: config.dnssec.enabled,
|
||||||
dnssec_strict: config.dnssec.strict,
|
dnssec_strict: config.dnssec.strict,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ const UDP_FAIL_THRESHOLD: u8 = 3;
|
|||||||
|
|
||||||
static QUERY_ID: AtomicU16 = AtomicU16::new(1);
|
static QUERY_ID: AtomicU16 = AtomicU16::new(1);
|
||||||
static UDP_FAILURES: std::sync::atomic::AtomicU8 = std::sync::atomic::AtomicU8::new(0);
|
static UDP_FAILURES: std::sync::atomic::AtomicU8 = std::sync::atomic::AtomicU8::new(0);
|
||||||
static UDP_DISABLED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
pub(crate) static UDP_DISABLED: std::sync::atomic::AtomicBool =
|
||||||
|
std::sync::atomic::AtomicBool::new(false);
|
||||||
|
|
||||||
fn next_id() -> u16 {
|
fn next_id() -> u16 {
|
||||||
QUERY_ID.fetch_add(1, Ordering::Relaxed)
|
QUERY_ID.fetch_add(1, Ordering::Relaxed)
|
||||||
|
|||||||
100
src/srtt.rs
100
src/srtt.rs
@@ -108,6 +108,13 @@ impl SrttCache {
|
|||||||
self.entries.is_empty()
|
self.entries.is_empty()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn set_updated_at(&mut self, ip: IpAddr, at: Instant) {
|
||||||
|
if let Some(entry) = self.entries.get_mut(&ip) {
|
||||||
|
entry.updated_at = at;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn maybe_evict(&mut self) {
|
fn maybe_evict(&mut self) {
|
||||||
if self.entries.len() < MAX_ENTRIES {
|
if self.entries.len() < MAX_ENTRIES {
|
||||||
return;
|
return;
|
||||||
@@ -203,6 +210,99 @@ mod tests {
|
|||||||
assert_eq!(addrs, original);
|
assert_eq!(addrs, original);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn age(secs: u64) -> Instant {
|
||||||
|
Instant::now() - std::time::Duration::from_secs(secs)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cache with ip(1) saturated at FAILURE_PENALTY_MS
|
||||||
|
fn saturated_penalty_cache() -> SrttCache {
|
||||||
|
let mut cache = SrttCache::new(true);
|
||||||
|
for _ in 0..30 {
|
||||||
|
cache.record_rtt(ip(1), FAILURE_PENALTY_MS, false);
|
||||||
|
}
|
||||||
|
cache
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_decay_within_threshold() {
|
||||||
|
let mut cache = SrttCache::new(true);
|
||||||
|
cache.record_rtt(ip(1), 5000, false);
|
||||||
|
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS));
|
||||||
|
assert_eq!(cache.get(ip(1)), cache.entries[&ip(1)].srtt_ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn one_decay_period() {
|
||||||
|
let mut cache = saturated_penalty_cache();
|
||||||
|
let raw = cache.entries[&ip(1)].srtt_ms;
|
||||||
|
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS + 1));
|
||||||
|
let expected = (raw + INITIAL_SRTT_MS) / 2;
|
||||||
|
assert_eq!(cache.get(ip(1)), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn multiple_decay_periods() {
|
||||||
|
let mut cache = saturated_penalty_cache();
|
||||||
|
let raw = cache.entries[&ip(1)].srtt_ms;
|
||||||
|
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 4 + 1));
|
||||||
|
let mut expected = raw;
|
||||||
|
for _ in 0..4 {
|
||||||
|
expected = (expected + INITIAL_SRTT_MS) / 2;
|
||||||
|
}
|
||||||
|
assert_eq!(cache.get(ip(1)), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decay_caps_at_8_periods() {
|
||||||
|
// 9 periods and 100 periods should produce the same result (capped at 8)
|
||||||
|
let mut cache_a = saturated_penalty_cache();
|
||||||
|
let mut cache_b = saturated_penalty_cache();
|
||||||
|
cache_a.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 9 + 1));
|
||||||
|
cache_b.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 100));
|
||||||
|
assert_eq!(cache_a.get(ip(1)), cache_b.get(ip(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decay_converges_toward_initial() {
|
||||||
|
let mut cache = saturated_penalty_cache();
|
||||||
|
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 100));
|
||||||
|
let decayed = cache.get(ip(1));
|
||||||
|
let diff = decayed.abs_diff(INITIAL_SRTT_MS);
|
||||||
|
assert!(
|
||||||
|
diff < 25,
|
||||||
|
"expected near INITIAL_SRTT_MS, got {} (diff={})",
|
||||||
|
decayed,
|
||||||
|
diff
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn record_rtt_applies_decay_before_ewma() {
|
||||||
|
let mut cache = saturated_penalty_cache();
|
||||||
|
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 8));
|
||||||
|
cache.record_rtt(ip(1), 50, false);
|
||||||
|
let srtt = cache.get(ip(1));
|
||||||
|
// Without decay-before-EWMA, result would be ~(5000*7+50)/8 ≈ 4381
|
||||||
|
assert!(srtt < 500, "expected decay before EWMA, got srtt={}", srtt);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decay_reranks_stale_failures() {
|
||||||
|
let mut cache = saturated_penalty_cache();
|
||||||
|
for _ in 0..30 {
|
||||||
|
cache.record_rtt(ip(2), 300, false);
|
||||||
|
}
|
||||||
|
let mut addrs = vec![sock(1), sock(2)];
|
||||||
|
cache.sort_by_rtt(&mut addrs);
|
||||||
|
assert_eq!(addrs, vec![sock(2), sock(1)]);
|
||||||
|
|
||||||
|
// Age server 1 so it decays toward INITIAL (200ms) — below server 2's 300ms
|
||||||
|
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 100));
|
||||||
|
let mut addrs = vec![sock(1), sock(2)];
|
||||||
|
cache.sort_by_rtt(&mut addrs);
|
||||||
|
assert_eq!(addrs, vec![sock(1), sock(2)]);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn eviction_removes_oldest() {
|
fn eviction_removes_oldest() {
|
||||||
let mut cache = SrttCache::new(true);
|
let mut cache = SrttCache::new(true);
|
||||||
|
|||||||
12
src/stats.rs
12
src/stats.rs
@@ -4,6 +4,7 @@ pub struct ServerStats {
|
|||||||
queries_total: u64,
|
queries_total: u64,
|
||||||
queries_forwarded: u64,
|
queries_forwarded: u64,
|
||||||
queries_recursive: u64,
|
queries_recursive: u64,
|
||||||
|
queries_coalesced: u64,
|
||||||
queries_cached: u64,
|
queries_cached: u64,
|
||||||
queries_blocked: u64,
|
queries_blocked: u64,
|
||||||
queries_local: u64,
|
queries_local: u64,
|
||||||
@@ -18,6 +19,7 @@ pub enum QueryPath {
|
|||||||
Cached,
|
Cached,
|
||||||
Forwarded,
|
Forwarded,
|
||||||
Recursive,
|
Recursive,
|
||||||
|
Coalesced,
|
||||||
Blocked,
|
Blocked,
|
||||||
Overridden,
|
Overridden,
|
||||||
UpstreamError,
|
UpstreamError,
|
||||||
@@ -30,6 +32,7 @@ impl QueryPath {
|
|||||||
QueryPath::Cached => "CACHED",
|
QueryPath::Cached => "CACHED",
|
||||||
QueryPath::Forwarded => "FORWARD",
|
QueryPath::Forwarded => "FORWARD",
|
||||||
QueryPath::Recursive => "RECURSIVE",
|
QueryPath::Recursive => "RECURSIVE",
|
||||||
|
QueryPath::Coalesced => "COALESCED",
|
||||||
QueryPath::Blocked => "BLOCKED",
|
QueryPath::Blocked => "BLOCKED",
|
||||||
QueryPath::Overridden => "OVERRIDE",
|
QueryPath::Overridden => "OVERRIDE",
|
||||||
QueryPath::UpstreamError => "SERVFAIL",
|
QueryPath::UpstreamError => "SERVFAIL",
|
||||||
@@ -45,6 +48,8 @@ impl QueryPath {
|
|||||||
Some(QueryPath::Forwarded)
|
Some(QueryPath::Forwarded)
|
||||||
} else if s.eq_ignore_ascii_case("RECURSIVE") {
|
} else if s.eq_ignore_ascii_case("RECURSIVE") {
|
||||||
Some(QueryPath::Recursive)
|
Some(QueryPath::Recursive)
|
||||||
|
} else if s.eq_ignore_ascii_case("COALESCED") {
|
||||||
|
Some(QueryPath::Coalesced)
|
||||||
} else if s.eq_ignore_ascii_case("BLOCKED") {
|
} else if s.eq_ignore_ascii_case("BLOCKED") {
|
||||||
Some(QueryPath::Blocked)
|
Some(QueryPath::Blocked)
|
||||||
} else if s.eq_ignore_ascii_case("OVERRIDE") {
|
} else if s.eq_ignore_ascii_case("OVERRIDE") {
|
||||||
@@ -69,6 +74,7 @@ impl ServerStats {
|
|||||||
queries_total: 0,
|
queries_total: 0,
|
||||||
queries_forwarded: 0,
|
queries_forwarded: 0,
|
||||||
queries_recursive: 0,
|
queries_recursive: 0,
|
||||||
|
queries_coalesced: 0,
|
||||||
queries_cached: 0,
|
queries_cached: 0,
|
||||||
queries_blocked: 0,
|
queries_blocked: 0,
|
||||||
queries_local: 0,
|
queries_local: 0,
|
||||||
@@ -85,6 +91,7 @@ impl ServerStats {
|
|||||||
QueryPath::Cached => self.queries_cached += 1,
|
QueryPath::Cached => self.queries_cached += 1,
|
||||||
QueryPath::Forwarded => self.queries_forwarded += 1,
|
QueryPath::Forwarded => self.queries_forwarded += 1,
|
||||||
QueryPath::Recursive => self.queries_recursive += 1,
|
QueryPath::Recursive => self.queries_recursive += 1,
|
||||||
|
QueryPath::Coalesced => self.queries_coalesced += 1,
|
||||||
QueryPath::Blocked => self.queries_blocked += 1,
|
QueryPath::Blocked => self.queries_blocked += 1,
|
||||||
QueryPath::Overridden => self.queries_overridden += 1,
|
QueryPath::Overridden => self.queries_overridden += 1,
|
||||||
QueryPath::UpstreamError => self.upstream_errors += 1,
|
QueryPath::UpstreamError => self.upstream_errors += 1,
|
||||||
@@ -106,6 +113,7 @@ impl ServerStats {
|
|||||||
total: self.queries_total,
|
total: self.queries_total,
|
||||||
forwarded: self.queries_forwarded,
|
forwarded: self.queries_forwarded,
|
||||||
recursive: self.queries_recursive,
|
recursive: self.queries_recursive,
|
||||||
|
coalesced: self.queries_coalesced,
|
||||||
cached: self.queries_cached,
|
cached: self.queries_cached,
|
||||||
local: self.queries_local,
|
local: self.queries_local,
|
||||||
overridden: self.queries_overridden,
|
overridden: self.queries_overridden,
|
||||||
@@ -121,11 +129,12 @@ impl ServerStats {
|
|||||||
let secs = uptime.as_secs() % 60;
|
let secs = uptime.as_secs() % 60;
|
||||||
|
|
||||||
log::info!(
|
log::info!(
|
||||||
"STATS | uptime {}h{}m{}s | total {} | fwd {} | recursive {} | cached {} | local {} | override {} | blocked {} | errors {}",
|
"STATS | uptime {}h{}m{}s | total {} | fwd {} | recursive {} | coalesced {} | cached {} | local {} | override {} | blocked {} | errors {}",
|
||||||
hours, mins, secs,
|
hours, mins, secs,
|
||||||
self.queries_total,
|
self.queries_total,
|
||||||
self.queries_forwarded,
|
self.queries_forwarded,
|
||||||
self.queries_recursive,
|
self.queries_recursive,
|
||||||
|
self.queries_coalesced,
|
||||||
self.queries_cached,
|
self.queries_cached,
|
||||||
self.queries_local,
|
self.queries_local,
|
||||||
self.queries_overridden,
|
self.queries_overridden,
|
||||||
@@ -140,6 +149,7 @@ pub struct StatsSnapshot {
|
|||||||
pub total: u64,
|
pub total: u64,
|
||||||
pub forwarded: u64,
|
pub forwarded: u64,
|
||||||
pub recursive: u64,
|
pub recursive: u64,
|
||||||
|
pub coalesced: u64,
|
||||||
pub cached: u64,
|
pub cached: u64,
|
||||||
pub local: u64,
|
pub local: u64,
|
||||||
pub overridden: u64,
|
pub overridden: u64,
|
||||||
|
|||||||
Reference in New Issue
Block a user