feat: in-flight query coalescing with COALESCED path (#20)

* feat: in-flight query coalescing for recursive resolver

When multiple queries for the same (domain, qtype) arrive concurrently
and all miss the cache, only the first triggers recursive resolution.
Subsequent queries wait on a broadcast channel for the result.

Prevents thundering herd where N concurrent cache misses each
independently walk the full NS chain, compounding timeouts.

Uses InflightGuard (Drop impl) to guarantee map cleanup on
panic/cancellation — prevents permanent SERVFAIL poisoning.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* style: add InflightMap type alias for clippy

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat: add COALESCED query path and coalescing tests

Followers in the inflight coalescing path now log as COALESCED instead
of RECURSIVE, making it visible in the dashboard when queries were
deduplicated vs independently resolved. Adds 10 tests covering
InflightGuard cleanup, broadcast mechanics, and concurrent handle_query
coalescing through a mock TCP DNS server.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* style: cargo fmt

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refactor: extract acquire_inflight, rewrite tests against real code

Move Disposition enum and inflight acquisition logic into a standalone
acquire_inflight() function. Rewrite 4 tests that were exercising tokio
primitives to call the real coalescing code path instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit was merged in pull request #20.
This commit is contained in:
Razvan Dimescu
2026-03-29 10:36:02 +03:00
committed by GitHub
parent 87c321f3d4
commit 7510c8e068
8 changed files with 586 additions and 23 deletions

View File

@@ -10,7 +10,7 @@ keywords = ["dns", "dns-server", "ad-blocking", "reverse-proxy", "developer-tool
categories = ["network-programming", "development-tools"] categories = ["network-programming", "development-tools"]
[dependencies] [dependencies]
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] } tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time", "sync"] }
axum = "0.8" axum = "0.8"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"

View File

@@ -285,6 +285,7 @@ body {
.path-tag.OVERRIDE { background: rgba(82, 122, 82, 0.12); color: var(--emerald); } .path-tag.OVERRIDE { background: rgba(82, 122, 82, 0.12); color: var(--emerald); }
.path-tag.SERVFAIL { background: rgba(181, 68, 58, 0.12); color: var(--rose); } .path-tag.SERVFAIL { background: rgba(181, 68, 58, 0.12); color: var(--rose); }
.path-tag.BLOCKED { background: rgba(163, 152, 136, 0.15); color: var(--text-dim); } .path-tag.BLOCKED { background: rgba(163, 152, 136, 0.15); color: var(--text-dim); }
.path-tag.COALESCED { background: rgba(138, 104, 158, 0.12); color: var(--violet-dim); }
/* Sidebar panels */ /* Sidebar panels */
.sidebar { .sidebar {
@@ -547,6 +548,8 @@ body {
<select id="logFilterPath" onchange="applyLogFilter()" <select id="logFilterPath" onchange="applyLogFilter()"
style="font-family:var(--font-mono);font-size:0.7rem;padding:0.25rem 0.4rem;border:1px solid var(--border);border-radius:4px;background:var(--bg-surface);color:var(--text-secondary);outline:none;"> style="font-family:var(--font-mono);font-size:0.7rem;padding:0.25rem 0.4rem;border:1px solid var(--border);border-radius:4px;background:var(--bg-surface);color:var(--text-secondary);outline:none;">
<option value="">all paths</option> <option value="">all paths</option>
<option value="RECURSIVE">recursive</option>
<option value="COALESCED">coalesced</option>
<option value="FORWARD">forward</option> <option value="FORWARD">forward</option>
<option value="CACHED">cached</option> <option value="CACHED">cached</option>
<option value="BLOCKED">blocked</option> <option value="BLOCKED">blocked</option>

View File

@@ -182,6 +182,7 @@ struct QueriesStats {
total: u64, total: u64,
forwarded: u64, forwarded: u64,
recursive: u64, recursive: u64,
coalesced: u64,
cached: u64, cached: u64,
local: u64, local: u64,
overridden: u64, overridden: u64,
@@ -499,6 +500,7 @@ async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
total: snap.total, total: snap.total,
forwarded: snap.forwarded, forwarded: snap.forwarded,
recursive: snap.recursive, recursive: snap.recursive,
coalesced: snap.coalesced,
cached: snap.cached, cached: snap.cached,
local: snap.local, local: snap.local,
overridden: snap.overridden, overridden: snap.overridden,
@@ -953,6 +955,7 @@ mod tests {
upstream_mode: crate::config::UpstreamMode::Forward, upstream_mode: crate::config::UpstreamMode::Forward,
root_hints: Vec::new(), root_hints: Vec::new(),
srtt: RwLock::new(crate::srtt::SrttCache::new(true)), srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
inflight: Mutex::new(std::collections::HashMap::new()),
dnssec_enabled: false, dnssec_enabled: false,
dnssec_strict: false, dnssec_strict: false,
}) })

View File

@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Mutex, RwLock}; use std::sync::{Mutex, RwLock};
@@ -7,6 +8,9 @@ use arc_swap::ArcSwap;
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use rustls::ServerConfig; use rustls::ServerConfig;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::sync::broadcast;
type InflightMap = HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>;
use crate::blocklist::BlocklistStore; use crate::blocklist::BlocklistStore;
use crate::buffer::BytePacketBuffer; use crate::buffer::BytePacketBuffer;
@@ -53,6 +57,7 @@ pub struct ServerCtx {
pub upstream_mode: UpstreamMode, pub upstream_mode: UpstreamMode,
pub root_hints: Vec<SocketAddr>, pub root_hints: Vec<SocketAddr>,
pub srtt: RwLock<SrttCache>, pub srtt: RwLock<SrttCache>,
pub inflight: Mutex<InflightMap>,
pub dnssec_enabled: bool, pub dnssec_enabled: bool,
pub dnssec_strict: bool, pub dnssec_strict: bool,
} }
@@ -172,7 +177,32 @@ pub async fn handle_query(
} }
(resp, QueryPath::Cached, cached_dnssec) (resp, QueryPath::Cached, cached_dnssec)
} else if ctx.upstream_mode == UpstreamMode::Recursive { } else if ctx.upstream_mode == UpstreamMode::Recursive {
match crate::recursive::resolve_recursive( let key = (qname.clone(), qtype);
let disposition = acquire_inflight(&ctx.inflight, key.clone());
match disposition {
Disposition::Follower(mut rx) => {
debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname);
match rx.recv().await {
Ok(Some(mut resp)) => {
resp.header.id = query.header.id;
(resp, QueryPath::Coalesced, DnssecStatus::Indeterminate)
}
_ => (
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
DnssecStatus::Indeterminate,
),
}
}
Disposition::Leader(tx) => {
// Drop guard: remove inflight entry even on panic/cancellation
let guard = InflightGuard {
inflight: &ctx.inflight,
key: key.clone(),
};
let result = crate::recursive::resolve_recursive(
&qname, &qname,
qtype, qtype,
&ctx.cache, &ctx.cache,
@@ -180,10 +210,17 @@ pub async fn handle_query(
&ctx.root_hints, &ctx.root_hints,
&ctx.srtt, &ctx.srtt,
) )
.await .await;
{
Ok(resp) => (resp, QueryPath::Recursive, DnssecStatus::Indeterminate), drop(guard);
match result {
Ok(resp) => {
let _ = tx.send(Some(resp.clone()));
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
}
Err(e) => { Err(e) => {
let _ = tx.send(None);
error!( error!(
"{} | {:?} {} | RECURSIVE ERROR | {}", "{} | {:?} {} | RECURSIVE ERROR | {}",
src_addr, qtype, qname, e src_addr, qtype, qname, e
@@ -195,6 +232,8 @@ pub async fn handle_query(
) )
} }
} }
}
}
} else { } else {
let upstream = let upstream =
match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) { match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) {
@@ -377,6 +416,47 @@ fn is_special_use_domain(qname: &str) -> bool {
qname == "local" || qname.ends_with(".local") qname == "local" || qname.ends_with(".local")
} }
enum Disposition {
Leader(broadcast::Sender<Option<DnsPacket>>),
Follower(broadcast::Receiver<Option<DnsPacket>>),
}
fn acquire_inflight(inflight: &Mutex<InflightMap>, key: (String, QueryType)) -> Disposition {
let mut map = inflight.lock().unwrap();
if let Some(tx) = map.get(&key) {
Disposition::Follower(tx.subscribe())
} else {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
map.insert(key, tx.clone());
Disposition::Leader(tx)
}
}
struct InflightGuard<'a> {
inflight: &'a Mutex<InflightMap>,
key: (String, QueryType),
}
impl Drop for InflightGuard<'_> {
fn drop(&mut self) {
self.inflight.lock().unwrap().remove(&self.key);
}
}
/// Build a wire-format DNS query packet for the given domain and type.
#[cfg(test)]
fn build_wire_query(id: u16, domain: &str, qtype: QueryType) -> BytePacketBuffer {
let mut pkt = DnsPacket::new();
pkt.header.id = id;
pkt.header.recursion_desired = true;
pkt.header.questions = 1;
pkt.questions
.push(crate::question::DnsQuestion::new(domain.to_string(), qtype));
let mut buf = BytePacketBuffer::new();
pkt.write(&mut buf).unwrap();
BytePacketBuffer::from_bytes(buf.filled())
}
fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket { fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket {
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
if qname == "ipv4only.arpa" { if qname == "ipv4only.arpa" {
@@ -410,3 +490,368 @@ fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> Dns
DnsPacket::response_from(query, ResultCode::NXDOMAIN) DnsPacket::response_from(query, ResultCode::NXDOMAIN)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::{Arc, Mutex, RwLock};
use tokio::sync::broadcast;
// ---- InflightGuard unit tests ----
#[test]
fn inflight_guard_removes_key_on_drop() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("example.com".to_string(), QueryType::A);
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
map.lock().unwrap().insert(key.clone(), tx);
assert_eq!(map.lock().unwrap().len(), 1);
{
let _guard = InflightGuard {
inflight: &map,
key: key.clone(),
};
} // guard dropped here
assert!(map.lock().unwrap().is_empty());
}
#[test]
fn inflight_guard_only_removes_own_key() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key_a = ("a.com".to_string(), QueryType::A);
let key_b = ("b.com".to_string(), QueryType::A);
let (tx_a, _) = broadcast::channel::<Option<DnsPacket>>(1);
let (tx_b, _) = broadcast::channel::<Option<DnsPacket>>(1);
map.lock().unwrap().insert(key_a.clone(), tx_a);
map.lock().unwrap().insert(key_b.clone(), tx_b);
{
let _guard = InflightGuard {
inflight: &map,
key: key_a,
};
}
let m = map.lock().unwrap();
assert_eq!(m.len(), 1);
assert!(m.contains_key(&key_b));
}
#[test]
fn inflight_guard_same_domain_different_qtype_independent() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key_a = ("example.com".to_string(), QueryType::A);
let key_aaaa = ("example.com".to_string(), QueryType::AAAA);
let (tx_a, _) = broadcast::channel::<Option<DnsPacket>>(1);
let (tx_aaaa, _) = broadcast::channel::<Option<DnsPacket>>(1);
map.lock().unwrap().insert(key_a.clone(), tx_a);
map.lock().unwrap().insert(key_aaaa.clone(), tx_aaaa);
{
let _guard = InflightGuard {
inflight: &map,
key: key_a,
};
}
let m = map.lock().unwrap();
assert_eq!(m.len(), 1);
assert!(m.contains_key(&key_aaaa));
}
// ---- Coalescing disposition tests (via acquire_inflight) ----
#[test]
fn first_caller_becomes_leader() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("test.com".to_string(), QueryType::A);
let d = acquire_inflight(&map, key.clone());
assert!(matches!(d, Disposition::Leader(_)));
assert_eq!(map.lock().unwrap().len(), 1);
}
#[test]
fn second_caller_becomes_follower() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("test.com".to_string(), QueryType::A);
let _leader = acquire_inflight(&map, key.clone());
let follower = acquire_inflight(&map, key);
assert!(matches!(follower, Disposition::Follower(_)));
// Map still has exactly 1 entry — follower subscribes, doesn't insert
assert_eq!(map.lock().unwrap().len(), 1);
}
#[tokio::test]
async fn leader_broadcast_reaches_follower() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("test.com".to_string(), QueryType::A);
let leader = acquire_inflight(&map, key.clone());
let follower = acquire_inflight(&map, key);
let tx = match leader {
Disposition::Leader(tx) => tx,
_ => panic!("expected leader"),
};
let mut rx = match follower {
Disposition::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let mut resp = DnsPacket::new();
resp.header.id = 42;
resp.answers.push(DnsRecord::A {
domain: "test.com".into(),
addr: Ipv4Addr::new(1, 2, 3, 4),
ttl: 300,
});
let _ = tx.send(Some(resp));
let received = rx.recv().await.unwrap().unwrap();
assert_eq!(received.header.id, 42);
assert_eq!(received.answers.len(), 1);
}
#[tokio::test]
async fn leader_none_signals_failure_to_follower() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("test.com".to_string(), QueryType::A);
let leader = acquire_inflight(&map, key.clone());
let follower = acquire_inflight(&map, key);
let tx = match leader {
Disposition::Leader(tx) => tx,
_ => panic!("expected leader"),
};
let mut rx = match follower {
Disposition::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let _ = tx.send(None);
assert!(rx.recv().await.unwrap().is_none());
}
#[tokio::test]
async fn multiple_followers_all_receive_via_acquire() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("multi.com".to_string(), QueryType::A);
let leader = acquire_inflight(&map, key.clone());
let f1 = acquire_inflight(&map, key.clone());
let f2 = acquire_inflight(&map, key.clone());
let f3 = acquire_inflight(&map, key);
let tx = match leader {
Disposition::Leader(tx) => tx,
_ => panic!("expected leader"),
};
let mut resp = DnsPacket::new();
resp.answers.push(DnsRecord::A {
domain: "multi.com".into(),
addr: Ipv4Addr::new(10, 0, 0, 1),
ttl: 60,
});
let _ = tx.send(Some(resp));
for f in [f1, f2, f3] {
let mut rx = match f {
Disposition::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let r = rx.recv().await.unwrap().unwrap();
assert_eq!(r.answers.len(), 1);
}
}
// ---- Integration: concurrent handle_query coalescing ----
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
/// Spawn a slow TCP DNS server that delays `delay` before responding.
/// Returns (addr, query_count) where query_count is an Arc<AtomicU32>
/// tracking how many queries were actually resolved (not coalesced).
async fn spawn_slow_dns_server(
delay: Duration,
) -> (SocketAddr, Arc<std::sync::atomic::AtomicU32>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let count_clone = count.clone();
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(c) => c,
Err(_) => break,
};
let count = count_clone.clone();
let delay = delay;
tokio::spawn(async move {
let mut len_buf = [0u8; 2];
if stream.read_exact(&mut len_buf).await.is_err() {
return;
}
let len = u16::from_be_bytes(len_buf) as usize;
let mut data = vec![0u8; len];
if stream.read_exact(&mut data).await.is_err() {
return;
}
let mut buf = BytePacketBuffer::from_bytes(&data);
let query = match DnsPacket::from_buffer(&mut buf) {
Ok(q) => q,
Err(_) => return,
};
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
// Deliberate delay to create coalescing window
tokio::time::sleep(delay).await;
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.header.authoritative_answer = true;
if let Some(q) = query.questions.first() {
resp.answers.push(DnsRecord::A {
domain: q.name.clone(),
addr: Ipv4Addr::new(10, 0, 0, 1),
ttl: 300,
});
}
let mut resp_buf = BytePacketBuffer::new();
if resp.write(&mut resp_buf).is_err() {
return;
}
let resp_bytes = resp_buf.filled();
let mut out = Vec::with_capacity(2 + resp_bytes.len());
out.extend_from_slice(&(resp_bytes.len() as u16).to_be_bytes());
out.extend_from_slice(resp_bytes);
let _ = stream.write_all(&out).await;
});
}
});
(addr, count)
}
async fn test_recursive_ctx(root_hint: SocketAddr) -> Arc<ServerCtx> {
let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
Arc::new(ServerCtx {
socket,
zone_map: HashMap::new(),
cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)),
stats: Mutex::new(crate::stats::ServerStats::new()),
overrides: RwLock::new(crate::override_store::OverrideStore::new()),
blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()),
query_log: Mutex::new(crate::query_log::QueryLog::new(100)),
services: Mutex::new(crate::service_store::ServiceStore::new()),
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
forwarding_rules: Vec::new(),
upstream: Mutex::new(crate::forward::Upstream::Udp(
"127.0.0.1:53".parse().unwrap(),
)),
upstream_auto: false,
upstream_port: 53,
lan_ip: Mutex::new(Ipv4Addr::LOCALHOST),
timeout: Duration::from_secs(3),
proxy_tld: "numa".to_string(),
proxy_tld_suffix: ".numa".to_string(),
lan_enabled: false,
config_path: "/tmp/test-numa.toml".to_string(),
config_found: false,
config_dir: std::path::PathBuf::from("/tmp"),
data_dir: std::path::PathBuf::from("/tmp"),
tls_config: None,
upstream_mode: crate::config::UpstreamMode::Recursive,
root_hints: vec![root_hint],
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
inflight: Mutex::new(HashMap::new()),
dnssec_enabled: false,
dnssec_strict: false,
})
}
#[tokio::test]
async fn concurrent_queries_coalesce_to_single_resolution() {
// Force TCP-only so mock server works
crate::recursive::UDP_DISABLED.store(true, std::sync::atomic::Ordering::Release);
let (server_addr, query_count) = spawn_slow_dns_server(Duration::from_millis(200)).await;
let ctx = test_recursive_ctx(server_addr).await;
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap();
// Fire 5 concurrent queries for the same (domain, A)
let mut handles = Vec::new();
for i in 0..5u16 {
let ctx = ctx.clone();
let buf = build_wire_query(100 + i, "coalesce-test.example.com", QueryType::A);
handles.push(tokio::spawn(
async move { handle_query(buf, src, &ctx).await },
));
}
for h in handles {
h.await.unwrap().unwrap();
}
// Only 1 resolution should have reached the upstream server
let actual = query_count.load(std::sync::atomic::Ordering::Relaxed);
assert_eq!(actual, 1, "expected 1 upstream query, got {}", actual);
// Inflight map must be empty after all queries complete
assert!(ctx.inflight.lock().unwrap().is_empty());
crate::recursive::reset_udp_state();
}
#[tokio::test]
async fn different_qtypes_not_coalesced() {
crate::recursive::UDP_DISABLED.store(true, std::sync::atomic::Ordering::Release);
let (server_addr, query_count) = spawn_slow_dns_server(Duration::from_millis(100)).await;
let ctx = test_recursive_ctx(server_addr).await;
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap();
// Fire A and AAAA concurrently — should NOT coalesce
let ctx_ref = ctx.clone();
let ctx_ref2 = ctx.clone();
let buf_a = build_wire_query(200, "different-qt.example.com", QueryType::A);
let buf_aaaa = build_wire_query(201, "different-qt.example.com", QueryType::AAAA);
let h1 = tokio::spawn(async move { handle_query(buf_a, src, &ctx_ref).await });
let h2 = tokio::spawn(async move { handle_query(buf_aaaa, src, &ctx_ref2).await });
h1.await.unwrap().unwrap();
h2.await.unwrap().unwrap();
let actual = query_count.load(std::sync::atomic::Ordering::Relaxed);
assert!(
actual >= 2,
"A and AAAA should resolve independently, got {}",
actual
);
assert!(ctx.inflight.lock().unwrap().is_empty());
crate::recursive::reset_udp_state();
}
#[tokio::test]
async fn inflight_map_cleaned_after_upstream_error() {
// Server that rejects everything — no server running at all
let bogus_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let ctx = test_recursive_ctx(bogus_addr).await;
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let buf = build_wire_query(300, "will-fail.example.com", QueryType::A);
let _ = handle_query(buf, src, &ctx).await;
// Map must be clean even after error
assert!(ctx.inflight.lock().unwrap().is_empty());
}
}

View File

@@ -202,6 +202,7 @@ async fn main() -> numa::Result<()> {
upstream_mode: config.upstream.mode, upstream_mode: config.upstream.mode,
root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints), root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints),
srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)), srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)),
inflight: std::sync::Mutex::new(std::collections::HashMap::new()),
dnssec_enabled: config.dnssec.enabled, dnssec_enabled: config.dnssec.enabled,
dnssec_strict: config.dnssec.strict, dnssec_strict: config.dnssec.strict,
}); });

View File

@@ -21,7 +21,8 @@ const UDP_FAIL_THRESHOLD: u8 = 3;
static QUERY_ID: AtomicU16 = AtomicU16::new(1); static QUERY_ID: AtomicU16 = AtomicU16::new(1);
static UDP_FAILURES: std::sync::atomic::AtomicU8 = std::sync::atomic::AtomicU8::new(0); static UDP_FAILURES: std::sync::atomic::AtomicU8 = std::sync::atomic::AtomicU8::new(0);
static UDP_DISABLED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); pub(crate) static UDP_DISABLED: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
fn next_id() -> u16 { fn next_id() -> u16 {
QUERY_ID.fetch_add(1, Ordering::Relaxed) QUERY_ID.fetch_add(1, Ordering::Relaxed)

View File

@@ -108,6 +108,13 @@ impl SrttCache {
self.entries.is_empty() self.entries.is_empty()
} }
#[cfg(test)]
fn set_updated_at(&mut self, ip: IpAddr, at: Instant) {
if let Some(entry) = self.entries.get_mut(&ip) {
entry.updated_at = at;
}
}
fn maybe_evict(&mut self) { fn maybe_evict(&mut self) {
if self.entries.len() < MAX_ENTRIES { if self.entries.len() < MAX_ENTRIES {
return; return;
@@ -203,6 +210,99 @@ mod tests {
assert_eq!(addrs, original); assert_eq!(addrs, original);
} }
fn age(secs: u64) -> Instant {
Instant::now() - std::time::Duration::from_secs(secs)
}
/// Cache with ip(1) saturated at FAILURE_PENALTY_MS
fn saturated_penalty_cache() -> SrttCache {
let mut cache = SrttCache::new(true);
for _ in 0..30 {
cache.record_rtt(ip(1), FAILURE_PENALTY_MS, false);
}
cache
}
#[test]
fn no_decay_within_threshold() {
let mut cache = SrttCache::new(true);
cache.record_rtt(ip(1), 5000, false);
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS));
assert_eq!(cache.get(ip(1)), cache.entries[&ip(1)].srtt_ms);
}
#[test]
fn one_decay_period() {
let mut cache = saturated_penalty_cache();
let raw = cache.entries[&ip(1)].srtt_ms;
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS + 1));
let expected = (raw + INITIAL_SRTT_MS) / 2;
assert_eq!(cache.get(ip(1)), expected);
}
#[test]
fn multiple_decay_periods() {
let mut cache = saturated_penalty_cache();
let raw = cache.entries[&ip(1)].srtt_ms;
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 4 + 1));
let mut expected = raw;
for _ in 0..4 {
expected = (expected + INITIAL_SRTT_MS) / 2;
}
assert_eq!(cache.get(ip(1)), expected);
}
#[test]
fn decay_caps_at_8_periods() {
// 9 periods and 100 periods should produce the same result (capped at 8)
let mut cache_a = saturated_penalty_cache();
let mut cache_b = saturated_penalty_cache();
cache_a.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 9 + 1));
cache_b.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 100));
assert_eq!(cache_a.get(ip(1)), cache_b.get(ip(1)));
}
#[test]
fn decay_converges_toward_initial() {
let mut cache = saturated_penalty_cache();
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 100));
let decayed = cache.get(ip(1));
let diff = decayed.abs_diff(INITIAL_SRTT_MS);
assert!(
diff < 25,
"expected near INITIAL_SRTT_MS, got {} (diff={})",
decayed,
diff
);
}
#[test]
fn record_rtt_applies_decay_before_ewma() {
let mut cache = saturated_penalty_cache();
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 8));
cache.record_rtt(ip(1), 50, false);
let srtt = cache.get(ip(1));
// Without decay-before-EWMA, result would be ~(5000*7+50)/8 ≈ 4381
assert!(srtt < 500, "expected decay before EWMA, got srtt={}", srtt);
}
#[test]
fn decay_reranks_stale_failures() {
let mut cache = saturated_penalty_cache();
for _ in 0..30 {
cache.record_rtt(ip(2), 300, false);
}
let mut addrs = vec![sock(1), sock(2)];
cache.sort_by_rtt(&mut addrs);
assert_eq!(addrs, vec![sock(2), sock(1)]);
// Age server 1 so it decays toward INITIAL (200ms) — below server 2's 300ms
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 100));
let mut addrs = vec![sock(1), sock(2)];
cache.sort_by_rtt(&mut addrs);
assert_eq!(addrs, vec![sock(1), sock(2)]);
}
#[test] #[test]
fn eviction_removes_oldest() { fn eviction_removes_oldest() {
let mut cache = SrttCache::new(true); let mut cache = SrttCache::new(true);

View File

@@ -4,6 +4,7 @@ pub struct ServerStats {
queries_total: u64, queries_total: u64,
queries_forwarded: u64, queries_forwarded: u64,
queries_recursive: u64, queries_recursive: u64,
queries_coalesced: u64,
queries_cached: u64, queries_cached: u64,
queries_blocked: u64, queries_blocked: u64,
queries_local: u64, queries_local: u64,
@@ -18,6 +19,7 @@ pub enum QueryPath {
Cached, Cached,
Forwarded, Forwarded,
Recursive, Recursive,
Coalesced,
Blocked, Blocked,
Overridden, Overridden,
UpstreamError, UpstreamError,
@@ -30,6 +32,7 @@ impl QueryPath {
QueryPath::Cached => "CACHED", QueryPath::Cached => "CACHED",
QueryPath::Forwarded => "FORWARD", QueryPath::Forwarded => "FORWARD",
QueryPath::Recursive => "RECURSIVE", QueryPath::Recursive => "RECURSIVE",
QueryPath::Coalesced => "COALESCED",
QueryPath::Blocked => "BLOCKED", QueryPath::Blocked => "BLOCKED",
QueryPath::Overridden => "OVERRIDE", QueryPath::Overridden => "OVERRIDE",
QueryPath::UpstreamError => "SERVFAIL", QueryPath::UpstreamError => "SERVFAIL",
@@ -45,6 +48,8 @@ impl QueryPath {
Some(QueryPath::Forwarded) Some(QueryPath::Forwarded)
} else if s.eq_ignore_ascii_case("RECURSIVE") { } else if s.eq_ignore_ascii_case("RECURSIVE") {
Some(QueryPath::Recursive) Some(QueryPath::Recursive)
} else if s.eq_ignore_ascii_case("COALESCED") {
Some(QueryPath::Coalesced)
} else if s.eq_ignore_ascii_case("BLOCKED") { } else if s.eq_ignore_ascii_case("BLOCKED") {
Some(QueryPath::Blocked) Some(QueryPath::Blocked)
} else if s.eq_ignore_ascii_case("OVERRIDE") { } else if s.eq_ignore_ascii_case("OVERRIDE") {
@@ -69,6 +74,7 @@ impl ServerStats {
queries_total: 0, queries_total: 0,
queries_forwarded: 0, queries_forwarded: 0,
queries_recursive: 0, queries_recursive: 0,
queries_coalesced: 0,
queries_cached: 0, queries_cached: 0,
queries_blocked: 0, queries_blocked: 0,
queries_local: 0, queries_local: 0,
@@ -85,6 +91,7 @@ impl ServerStats {
QueryPath::Cached => self.queries_cached += 1, QueryPath::Cached => self.queries_cached += 1,
QueryPath::Forwarded => self.queries_forwarded += 1, QueryPath::Forwarded => self.queries_forwarded += 1,
QueryPath::Recursive => self.queries_recursive += 1, QueryPath::Recursive => self.queries_recursive += 1,
QueryPath::Coalesced => self.queries_coalesced += 1,
QueryPath::Blocked => self.queries_blocked += 1, QueryPath::Blocked => self.queries_blocked += 1,
QueryPath::Overridden => self.queries_overridden += 1, QueryPath::Overridden => self.queries_overridden += 1,
QueryPath::UpstreamError => self.upstream_errors += 1, QueryPath::UpstreamError => self.upstream_errors += 1,
@@ -106,6 +113,7 @@ impl ServerStats {
total: self.queries_total, total: self.queries_total,
forwarded: self.queries_forwarded, forwarded: self.queries_forwarded,
recursive: self.queries_recursive, recursive: self.queries_recursive,
coalesced: self.queries_coalesced,
cached: self.queries_cached, cached: self.queries_cached,
local: self.queries_local, local: self.queries_local,
overridden: self.queries_overridden, overridden: self.queries_overridden,
@@ -121,11 +129,12 @@ impl ServerStats {
let secs = uptime.as_secs() % 60; let secs = uptime.as_secs() % 60;
log::info!( log::info!(
"STATS | uptime {}h{}m{}s | total {} | fwd {} | recursive {} | cached {} | local {} | override {} | blocked {} | errors {}", "STATS | uptime {}h{}m{}s | total {} | fwd {} | recursive {} | coalesced {} | cached {} | local {} | override {} | blocked {} | errors {}",
hours, mins, secs, hours, mins, secs,
self.queries_total, self.queries_total,
self.queries_forwarded, self.queries_forwarded,
self.queries_recursive, self.queries_recursive,
self.queries_coalesced,
self.queries_cached, self.queries_cached,
self.queries_local, self.queries_local,
self.queries_overridden, self.queries_overridden,
@@ -140,6 +149,7 @@ pub struct StatsSnapshot {
pub total: u64, pub total: u64,
pub forwarded: u64, pub forwarded: u64,
pub recursive: u64, pub recursive: u64,
pub coalesced: u64,
pub cached: u64, pub cached: u64,
pub local: u64, pub local: u64,
pub overridden: u64, pub overridden: u64,