feat: in-flight query coalescing with COALESCED path #20
@@ -10,7 +10,7 @@ keywords = ["dns", "dns-server", "ad-blocking", "reverse-proxy", "developer-tool
|
||||
categories = ["network-programming", "development-tools"]
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] }
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time", "sync"] }
|
||||
axum = "0.8"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
@@ -285,6 +285,7 @@ body {
|
||||
.path-tag.OVERRIDE { background: rgba(82, 122, 82, 0.12); color: var(--emerald); }
|
||||
.path-tag.SERVFAIL { background: rgba(181, 68, 58, 0.12); color: var(--rose); }
|
||||
.path-tag.BLOCKED { background: rgba(163, 152, 136, 0.15); color: var(--text-dim); }
|
||||
.path-tag.COALESCED { background: rgba(138, 104, 158, 0.12); color: var(--violet-dim); }
|
||||
|
||||
/* Sidebar panels */
|
||||
.sidebar {
|
||||
@@ -547,6 +548,8 @@ body {
|
||||
<select id="logFilterPath" onchange="applyLogFilter()"
|
||||
style="font-family:var(--font-mono);font-size:0.7rem;padding:0.25rem 0.4rem;border:1px solid var(--border);border-radius:4px;background:var(--bg-surface);color:var(--text-secondary);outline:none;">
|
||||
<option value="">all paths</option>
|
||||
<option value="RECURSIVE">recursive</option>
|
||||
<option value="COALESCED">coalesced</option>
|
||||
<option value="FORWARD">forward</option>
|
||||
<option value="CACHED">cached</option>
|
||||
<option value="BLOCKED">blocked</option>
|
||||
|
||||
@@ -182,6 +182,7 @@ struct QueriesStats {
|
||||
total: u64,
|
||||
forwarded: u64,
|
||||
recursive: u64,
|
||||
coalesced: u64,
|
||||
cached: u64,
|
||||
local: u64,
|
||||
overridden: u64,
|
||||
@@ -499,6 +500,7 @@ async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
|
||||
total: snap.total,
|
||||
forwarded: snap.forwarded,
|
||||
recursive: snap.recursive,
|
||||
coalesced: snap.coalesced,
|
||||
cached: snap.cached,
|
||||
local: snap.local,
|
||||
overridden: snap.overridden,
|
||||
@@ -953,6 +955,7 @@ mod tests {
|
||||
upstream_mode: crate::config::UpstreamMode::Forward,
|
||||
root_hints: Vec::new(),
|
||||
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
|
||||
inflight: Mutex::new(std::collections::HashMap::new()),
|
||||
dnssec_enabled: false,
|
||||
dnssec_strict: false,
|
||||
})
|
||||
|
||||
453
src/ctx.rs
453
src/ctx.rs
@@ -1,3 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Mutex, RwLock};
|
||||
@@ -7,6 +8,9 @@ use arc_swap::ArcSwap;
|
||||
use log::{debug, error, info, warn};
|
||||
use rustls::ServerConfig;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
type InflightMap = HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>;
|
||||
|
||||
use crate::blocklist::BlocklistStore;
|
||||
use crate::buffer::BytePacketBuffer;
|
||||
@@ -53,6 +57,7 @@ pub struct ServerCtx {
|
||||
pub upstream_mode: UpstreamMode,
|
||||
pub root_hints: Vec<SocketAddr>,
|
||||
pub srtt: RwLock<SrttCache>,
|
||||
pub inflight: Mutex<InflightMap>,
|
||||
pub dnssec_enabled: bool,
|
||||
pub dnssec_strict: bool,
|
||||
}
|
||||
@@ -172,7 +177,32 @@ pub async fn handle_query(
|
||||
}
|
||||
(resp, QueryPath::Cached, cached_dnssec)
|
||||
} else if ctx.upstream_mode == UpstreamMode::Recursive {
|
||||
match crate::recursive::resolve_recursive(
|
||||
let key = (qname.clone(), qtype);
|
||||
let disposition = acquire_inflight(&ctx.inflight, key.clone());
|
||||
|
||||
match disposition {
|
||||
Disposition::Follower(mut rx) => {
|
||||
debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname);
|
||||
match rx.recv().await {
|
||||
Ok(Some(mut resp)) => {
|
||||
resp.header.id = query.header.id;
|
||||
(resp, QueryPath::Coalesced, DnssecStatus::Indeterminate)
|
||||
}
|
||||
_ => (
|
||||
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
|
||||
QueryPath::UpstreamError,
|
||||
DnssecStatus::Indeterminate,
|
||||
),
|
||||
}
|
||||
}
|
||||
Disposition::Leader(tx) => {
|
||||
// Drop guard: remove inflight entry even on panic/cancellation
|
||||
let guard = InflightGuard {
|
||||
inflight: &ctx.inflight,
|
||||
key: key.clone(),
|
||||
};
|
||||
|
||||
let result = crate::recursive::resolve_recursive(
|
||||
&qname,
|
||||
qtype,
|
||||
&ctx.cache,
|
||||
@@ -180,10 +210,17 @@ pub async fn handle_query(
|
||||
&ctx.root_hints,
|
||||
&ctx.srtt,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => (resp, QueryPath::Recursive, DnssecStatus::Indeterminate),
|
||||
.await;
|
||||
|
||||
drop(guard);
|
||||
|
||||
match result {
|
||||
Ok(resp) => {
|
||||
let _ = tx.send(Some(resp.clone()));
|
||||
(resp, QueryPath::Recursive, DnssecStatus::Indeterminate)
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(None);
|
||||
error!(
|
||||
"{} | {:?} {} | RECURSIVE ERROR | {}",
|
||||
src_addr, qtype, qname, e
|
||||
@@ -195,6 +232,8 @@ pub async fn handle_query(
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let upstream =
|
||||
match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) {
|
||||
@@ -377,6 +416,47 @@ fn is_special_use_domain(qname: &str) -> bool {
|
||||
qname == "local" || qname.ends_with(".local")
|
||||
}
|
||||
|
||||
enum Disposition {
|
||||
Leader(broadcast::Sender<Option<DnsPacket>>),
|
||||
Follower(broadcast::Receiver<Option<DnsPacket>>),
|
||||
}
|
||||
|
||||
fn acquire_inflight(inflight: &Mutex<InflightMap>, key: (String, QueryType)) -> Disposition {
|
||||
let mut map = inflight.lock().unwrap();
|
||||
if let Some(tx) = map.get(&key) {
|
||||
Disposition::Follower(tx.subscribe())
|
||||
} else {
|
||||
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||
map.insert(key, tx.clone());
|
||||
Disposition::Leader(tx)
|
||||
}
|
||||
}
|
||||
|
||||
struct InflightGuard<'a> {
|
||||
inflight: &'a Mutex<InflightMap>,
|
||||
key: (String, QueryType),
|
||||
}
|
||||
|
||||
impl Drop for InflightGuard<'_> {
|
||||
fn drop(&mut self) {
|
||||
self.inflight.lock().unwrap().remove(&self.key);
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a wire-format DNS query packet for the given domain and type.
|
||||
#[cfg(test)]
|
||||
fn build_wire_query(id: u16, domain: &str, qtype: QueryType) -> BytePacketBuffer {
|
||||
let mut pkt = DnsPacket::new();
|
||||
pkt.header.id = id;
|
||||
pkt.header.recursion_desired = true;
|
||||
pkt.header.questions = 1;
|
||||
pkt.questions
|
||||
.push(crate::question::DnsQuestion::new(domain.to_string(), qtype));
|
||||
let mut buf = BytePacketBuffer::new();
|
||||
pkt.write(&mut buf).unwrap();
|
||||
BytePacketBuffer::from_bytes(buf.filled())
|
||||
}
|
||||
|
||||
fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket {
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
if qname == "ipv4only.arpa" {
|
||||
@@ -410,3 +490,368 @@ fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> Dns
|
||||
DnsPacket::response_from(query, ResultCode::NXDOMAIN)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
// ---- InflightGuard unit tests ----
|
||||
|
||||
#[test]
|
||||
fn inflight_guard_removes_key_on_drop() {
|
||||
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||
let key = ("example.com".to_string(), QueryType::A);
|
||||
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||
map.lock().unwrap().insert(key.clone(), tx);
|
||||
|
||||
assert_eq!(map.lock().unwrap().len(), 1);
|
||||
{
|
||||
let _guard = InflightGuard {
|
||||
inflight: &map,
|
||||
key: key.clone(),
|
||||
};
|
||||
} // guard dropped here
|
||||
assert!(map.lock().unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inflight_guard_only_removes_own_key() {
|
||||
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||
let key_a = ("a.com".to_string(), QueryType::A);
|
||||
let key_b = ("b.com".to_string(), QueryType::A);
|
||||
let (tx_a, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||
let (tx_b, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||
map.lock().unwrap().insert(key_a.clone(), tx_a);
|
||||
map.lock().unwrap().insert(key_b.clone(), tx_b);
|
||||
|
||||
{
|
||||
let _guard = InflightGuard {
|
||||
inflight: &map,
|
||||
key: key_a,
|
||||
};
|
||||
}
|
||||
let m = map.lock().unwrap();
|
||||
assert_eq!(m.len(), 1);
|
||||
assert!(m.contains_key(&key_b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inflight_guard_same_domain_different_qtype_independent() {
|
||||
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||
let key_a = ("example.com".to_string(), QueryType::A);
|
||||
let key_aaaa = ("example.com".to_string(), QueryType::AAAA);
|
||||
let (tx_a, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||
let (tx_aaaa, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
||||
map.lock().unwrap().insert(key_a.clone(), tx_a);
|
||||
map.lock().unwrap().insert(key_aaaa.clone(), tx_aaaa);
|
||||
|
||||
{
|
||||
let _guard = InflightGuard {
|
||||
inflight: &map,
|
||||
key: key_a,
|
||||
};
|
||||
}
|
||||
let m = map.lock().unwrap();
|
||||
assert_eq!(m.len(), 1);
|
||||
assert!(m.contains_key(&key_aaaa));
|
||||
}
|
||||
|
||||
// ---- Coalescing disposition tests (via acquire_inflight) ----
|
||||
|
||||
#[test]
|
||||
fn first_caller_becomes_leader() {
|
||||
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||
let key = ("test.com".to_string(), QueryType::A);
|
||||
|
||||
let d = acquire_inflight(&map, key.clone());
|
||||
assert!(matches!(d, Disposition::Leader(_)));
|
||||
assert_eq!(map.lock().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn second_caller_becomes_follower() {
|
||||
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||
let key = ("test.com".to_string(), QueryType::A);
|
||||
|
||||
let _leader = acquire_inflight(&map, key.clone());
|
||||
let follower = acquire_inflight(&map, key);
|
||||
assert!(matches!(follower, Disposition::Follower(_)));
|
||||
// Map still has exactly 1 entry — follower subscribes, doesn't insert
|
||||
assert_eq!(map.lock().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn leader_broadcast_reaches_follower() {
|
||||
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||
let key = ("test.com".to_string(), QueryType::A);
|
||||
|
||||
let leader = acquire_inflight(&map, key.clone());
|
||||
let follower = acquire_inflight(&map, key);
|
||||
|
||||
let tx = match leader {
|
||||
Disposition::Leader(tx) => tx,
|
||||
_ => panic!("expected leader"),
|
||||
};
|
||||
let mut rx = match follower {
|
||||
Disposition::Follower(rx) => rx,
|
||||
_ => panic!("expected follower"),
|
||||
};
|
||||
|
||||
let mut resp = DnsPacket::new();
|
||||
resp.header.id = 42;
|
||||
resp.answers.push(DnsRecord::A {
|
||||
domain: "test.com".into(),
|
||||
addr: Ipv4Addr::new(1, 2, 3, 4),
|
||||
ttl: 300,
|
||||
});
|
||||
let _ = tx.send(Some(resp));
|
||||
|
||||
let received = rx.recv().await.unwrap().unwrap();
|
||||
assert_eq!(received.header.id, 42);
|
||||
assert_eq!(received.answers.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn leader_none_signals_failure_to_follower() {
|
||||
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||
let key = ("test.com".to_string(), QueryType::A);
|
||||
|
||||
let leader = acquire_inflight(&map, key.clone());
|
||||
let follower = acquire_inflight(&map, key);
|
||||
|
||||
let tx = match leader {
|
||||
Disposition::Leader(tx) => tx,
|
||||
_ => panic!("expected leader"),
|
||||
};
|
||||
let mut rx = match follower {
|
||||
Disposition::Follower(rx) => rx,
|
||||
_ => panic!("expected follower"),
|
||||
};
|
||||
|
||||
let _ = tx.send(None);
|
||||
assert!(rx.recv().await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiple_followers_all_receive_via_acquire() {
|
||||
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
||||
let key = ("multi.com".to_string(), QueryType::A);
|
||||
|
||||
let leader = acquire_inflight(&map, key.clone());
|
||||
let f1 = acquire_inflight(&map, key.clone());
|
||||
let f2 = acquire_inflight(&map, key.clone());
|
||||
let f3 = acquire_inflight(&map, key);
|
||||
|
||||
let tx = match leader {
|
||||
Disposition::Leader(tx) => tx,
|
||||
_ => panic!("expected leader"),
|
||||
};
|
||||
|
||||
let mut resp = DnsPacket::new();
|
||||
resp.answers.push(DnsRecord::A {
|
||||
domain: "multi.com".into(),
|
||||
addr: Ipv4Addr::new(10, 0, 0, 1),
|
||||
ttl: 60,
|
||||
});
|
||||
let _ = tx.send(Some(resp));
|
||||
|
||||
for f in [f1, f2, f3] {
|
||||
let mut rx = match f {
|
||||
Disposition::Follower(rx) => rx,
|
||||
_ => panic!("expected follower"),
|
||||
};
|
||||
let r = rx.recv().await.unwrap().unwrap();
|
||||
assert_eq!(r.answers.len(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Integration: concurrent handle_query coalescing ----
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Spawn a slow TCP DNS server that delays `delay` before responding.
|
||||
/// Returns (addr, query_count) where query_count is an Arc<AtomicU32>
|
||||
/// tracking how many queries were actually resolved (not coalesced).
|
||||
async fn spawn_slow_dns_server(
|
||||
delay: Duration,
|
||||
) -> (SocketAddr, Arc<std::sync::atomic::AtomicU32>) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let count = Arc::new(std::sync::atomic::AtomicU32::new(0));
|
||||
let count_clone = count.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(c) => c,
|
||||
Err(_) => break,
|
||||
};
|
||||
let count = count_clone.clone();
|
||||
let delay = delay;
|
||||
tokio::spawn(async move {
|
||||
let mut len_buf = [0u8; 2];
|
||||
if stream.read_exact(&mut len_buf).await.is_err() {
|
||||
return;
|
||||
}
|
||||
let len = u16::from_be_bytes(len_buf) as usize;
|
||||
let mut data = vec![0u8; len];
|
||||
if stream.read_exact(&mut data).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut buf = BytePacketBuffer::from_bytes(&data);
|
||||
let query = match DnsPacket::from_buffer(&mut buf) {
|
||||
Ok(q) => q,
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
// Deliberate delay to create coalescing window
|
||||
tokio::time::sleep(delay).await;
|
||||
|
||||
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
||||
resp.header.authoritative_answer = true;
|
||||
if let Some(q) = query.questions.first() {
|
||||
resp.answers.push(DnsRecord::A {
|
||||
domain: q.name.clone(),
|
||||
addr: Ipv4Addr::new(10, 0, 0, 1),
|
||||
ttl: 300,
|
||||
});
|
||||
}
|
||||
|
||||
let mut resp_buf = BytePacketBuffer::new();
|
||||
if resp.write(&mut resp_buf).is_err() {
|
||||
return;
|
||||
}
|
||||
let resp_bytes = resp_buf.filled();
|
||||
let mut out = Vec::with_capacity(2 + resp_bytes.len());
|
||||
out.extend_from_slice(&(resp_bytes.len() as u16).to_be_bytes());
|
||||
out.extend_from_slice(resp_bytes);
|
||||
let _ = stream.write_all(&out).await;
|
||||
});
|
||||
}
|
||||
});
|
||||
(addr, count)
|
||||
}
|
||||
|
||||
async fn test_recursive_ctx(root_hint: SocketAddr) -> Arc<ServerCtx> {
|
||||
let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
Arc::new(ServerCtx {
|
||||
socket,
|
||||
zone_map: HashMap::new(),
|
||||
cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)),
|
||||
stats: Mutex::new(crate::stats::ServerStats::new()),
|
||||
overrides: RwLock::new(crate::override_store::OverrideStore::new()),
|
||||
blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()),
|
||||
query_log: Mutex::new(crate::query_log::QueryLog::new(100)),
|
||||
services: Mutex::new(crate::service_store::ServiceStore::new()),
|
||||
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
|
||||
forwarding_rules: Vec::new(),
|
||||
upstream: Mutex::new(crate::forward::Upstream::Udp(
|
||||
"127.0.0.1:53".parse().unwrap(),
|
||||
)),
|
||||
upstream_auto: false,
|
||||
upstream_port: 53,
|
||||
lan_ip: Mutex::new(Ipv4Addr::LOCALHOST),
|
||||
timeout: Duration::from_secs(3),
|
||||
proxy_tld: "numa".to_string(),
|
||||
proxy_tld_suffix: ".numa".to_string(),
|
||||
lan_enabled: false,
|
||||
config_path: "/tmp/test-numa.toml".to_string(),
|
||||
config_found: false,
|
||||
config_dir: std::path::PathBuf::from("/tmp"),
|
||||
data_dir: std::path::PathBuf::from("/tmp"),
|
||||
tls_config: None,
|
||||
upstream_mode: crate::config::UpstreamMode::Recursive,
|
||||
root_hints: vec![root_hint],
|
||||
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
|
||||
inflight: Mutex::new(HashMap::new()),
|
||||
dnssec_enabled: false,
|
||||
dnssec_strict: false,
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_queries_coalesce_to_single_resolution() {
|
||||
// Force TCP-only so mock server works
|
||||
crate::recursive::UDP_DISABLED.store(true, std::sync::atomic::Ordering::Release);
|
||||
|
||||
let (server_addr, query_count) = spawn_slow_dns_server(Duration::from_millis(200)).await;
|
||||
let ctx = test_recursive_ctx(server_addr).await;
|
||||
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||
|
||||
// Fire 5 concurrent queries for the same (domain, A)
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..5u16 {
|
||||
let ctx = ctx.clone();
|
||||
let buf = build_wire_query(100 + i, "coalesce-test.example.com", QueryType::A);
|
||||
handles.push(tokio::spawn(
|
||||
async move { handle_query(buf, src, &ctx).await },
|
||||
));
|
||||
}
|
||||
|
||||
for h in handles {
|
||||
h.await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
// Only 1 resolution should have reached the upstream server
|
||||
let actual = query_count.load(std::sync::atomic::Ordering::Relaxed);
|
||||
assert_eq!(actual, 1, "expected 1 upstream query, got {}", actual);
|
||||
|
||||
// Inflight map must be empty after all queries complete
|
||||
assert!(ctx.inflight.lock().unwrap().is_empty());
|
||||
|
||||
crate::recursive::reset_udp_state();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn different_qtypes_not_coalesced() {
|
||||
crate::recursive::UDP_DISABLED.store(true, std::sync::atomic::Ordering::Release);
|
||||
|
||||
let (server_addr, query_count) = spawn_slow_dns_server(Duration::from_millis(100)).await;
|
||||
let ctx = test_recursive_ctx(server_addr).await;
|
||||
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||
|
||||
// Fire A and AAAA concurrently — should NOT coalesce
|
||||
let ctx_ref = ctx.clone();
|
||||
let ctx_ref2 = ctx.clone();
|
||||
let buf_a = build_wire_query(200, "different-qt.example.com", QueryType::A);
|
||||
let buf_aaaa = build_wire_query(201, "different-qt.example.com", QueryType::AAAA);
|
||||
|
||||
let h1 = tokio::spawn(async move { handle_query(buf_a, src, &ctx_ref).await });
|
||||
let h2 = tokio::spawn(async move { handle_query(buf_aaaa, src, &ctx_ref2).await });
|
||||
|
||||
h1.await.unwrap().unwrap();
|
||||
h2.await.unwrap().unwrap();
|
||||
|
||||
let actual = query_count.load(std::sync::atomic::Ordering::Relaxed);
|
||||
assert!(
|
||||
actual >= 2,
|
||||
"A and AAAA should resolve independently, got {}",
|
||||
actual
|
||||
);
|
||||
assert!(ctx.inflight.lock().unwrap().is_empty());
|
||||
|
||||
crate::recursive::reset_udp_state();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn inflight_map_cleaned_after_upstream_error() {
|
||||
// Server that rejects everything — no server running at all
|
||||
let bogus_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||
let ctx = test_recursive_ctx(bogus_addr).await;
|
||||
let src: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||
|
||||
let buf = build_wire_query(300, "will-fail.example.com", QueryType::A);
|
||||
let _ = handle_query(buf, src, &ctx).await;
|
||||
|
||||
// Map must be clean even after error
|
||||
assert!(ctx.inflight.lock().unwrap().is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,6 +202,7 @@ async fn main() -> numa::Result<()> {
|
||||
upstream_mode: config.upstream.mode,
|
||||
root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints),
|
||||
srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)),
|
||||
inflight: std::sync::Mutex::new(std::collections::HashMap::new()),
|
||||
dnssec_enabled: config.dnssec.enabled,
|
||||
dnssec_strict: config.dnssec.strict,
|
||||
});
|
||||
|
||||
@@ -21,7 +21,8 @@ const UDP_FAIL_THRESHOLD: u8 = 3;
|
||||
|
||||
static QUERY_ID: AtomicU16 = AtomicU16::new(1);
|
||||
static UDP_FAILURES: std::sync::atomic::AtomicU8 = std::sync::atomic::AtomicU8::new(0);
|
||||
static UDP_DISABLED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
||||
pub(crate) static UDP_DISABLED: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
fn next_id() -> u16 {
|
||||
QUERY_ID.fetch_add(1, Ordering::Relaxed)
|
||||
|
||||
100
src/srtt.rs
100
src/srtt.rs
@@ -108,6 +108,13 @@ impl SrttCache {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn set_updated_at(&mut self, ip: IpAddr, at: Instant) {
|
||||
if let Some(entry) = self.entries.get_mut(&ip) {
|
||||
entry.updated_at = at;
|
||||
}
|
||||
}
|
||||
|
||||
fn maybe_evict(&mut self) {
|
||||
if self.entries.len() < MAX_ENTRIES {
|
||||
return;
|
||||
@@ -203,6 +210,99 @@ mod tests {
|
||||
assert_eq!(addrs, original);
|
||||
}
|
||||
|
||||
fn age(secs: u64) -> Instant {
|
||||
Instant::now() - std::time::Duration::from_secs(secs)
|
||||
}
|
||||
|
||||
/// Cache with ip(1) saturated at FAILURE_PENALTY_MS
|
||||
fn saturated_penalty_cache() -> SrttCache {
|
||||
let mut cache = SrttCache::new(true);
|
||||
for _ in 0..30 {
|
||||
cache.record_rtt(ip(1), FAILURE_PENALTY_MS, false);
|
||||
}
|
||||
cache
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_decay_within_threshold() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
cache.record_rtt(ip(1), 5000, false);
|
||||
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS));
|
||||
assert_eq!(cache.get(ip(1)), cache.entries[&ip(1)].srtt_ms);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_decay_period() {
|
||||
let mut cache = saturated_penalty_cache();
|
||||
let raw = cache.entries[&ip(1)].srtt_ms;
|
||||
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS + 1));
|
||||
let expected = (raw + INITIAL_SRTT_MS) / 2;
|
||||
assert_eq!(cache.get(ip(1)), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_decay_periods() {
|
||||
let mut cache = saturated_penalty_cache();
|
||||
let raw = cache.entries[&ip(1)].srtt_ms;
|
||||
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 4 + 1));
|
||||
let mut expected = raw;
|
||||
for _ in 0..4 {
|
||||
expected = (expected + INITIAL_SRTT_MS) / 2;
|
||||
}
|
||||
assert_eq!(cache.get(ip(1)), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decay_caps_at_8_periods() {
|
||||
// 9 periods and 100 periods should produce the same result (capped at 8)
|
||||
let mut cache_a = saturated_penalty_cache();
|
||||
let mut cache_b = saturated_penalty_cache();
|
||||
cache_a.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 9 + 1));
|
||||
cache_b.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 100));
|
||||
assert_eq!(cache_a.get(ip(1)), cache_b.get(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decay_converges_toward_initial() {
|
||||
let mut cache = saturated_penalty_cache();
|
||||
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 100));
|
||||
let decayed = cache.get(ip(1));
|
||||
let diff = decayed.abs_diff(INITIAL_SRTT_MS);
|
||||
assert!(
|
||||
diff < 25,
|
||||
"expected near INITIAL_SRTT_MS, got {} (diff={})",
|
||||
decayed,
|
||||
diff
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_rtt_applies_decay_before_ewma() {
|
||||
let mut cache = saturated_penalty_cache();
|
||||
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 8));
|
||||
cache.record_rtt(ip(1), 50, false);
|
||||
let srtt = cache.get(ip(1));
|
||||
// Without decay-before-EWMA, result would be ~(5000*7+50)/8 ≈ 4381
|
||||
assert!(srtt < 500, "expected decay before EWMA, got srtt={}", srtt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decay_reranks_stale_failures() {
|
||||
let mut cache = saturated_penalty_cache();
|
||||
for _ in 0..30 {
|
||||
cache.record_rtt(ip(2), 300, false);
|
||||
}
|
||||
let mut addrs = vec![sock(1), sock(2)];
|
||||
cache.sort_by_rtt(&mut addrs);
|
||||
assert_eq!(addrs, vec![sock(2), sock(1)]);
|
||||
|
||||
// Age server 1 so it decays toward INITIAL (200ms) — below server 2's 300ms
|
||||
cache.set_updated_at(ip(1), age(DECAY_AFTER_SECS * 100));
|
||||
let mut addrs = vec![sock(1), sock(2)];
|
||||
cache.sort_by_rtt(&mut addrs);
|
||||
assert_eq!(addrs, vec![sock(1), sock(2)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eviction_removes_oldest() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
|
||||
12
src/stats.rs
12
src/stats.rs
@@ -4,6 +4,7 @@ pub struct ServerStats {
|
||||
queries_total: u64,
|
||||
queries_forwarded: u64,
|
||||
queries_recursive: u64,
|
||||
queries_coalesced: u64,
|
||||
queries_cached: u64,
|
||||
queries_blocked: u64,
|
||||
queries_local: u64,
|
||||
@@ -18,6 +19,7 @@ pub enum QueryPath {
|
||||
Cached,
|
||||
Forwarded,
|
||||
Recursive,
|
||||
Coalesced,
|
||||
Blocked,
|
||||
Overridden,
|
||||
UpstreamError,
|
||||
@@ -30,6 +32,7 @@ impl QueryPath {
|
||||
QueryPath::Cached => "CACHED",
|
||||
QueryPath::Forwarded => "FORWARD",
|
||||
QueryPath::Recursive => "RECURSIVE",
|
||||
QueryPath::Coalesced => "COALESCED",
|
||||
QueryPath::Blocked => "BLOCKED",
|
||||
QueryPath::Overridden => "OVERRIDE",
|
||||
QueryPath::UpstreamError => "SERVFAIL",
|
||||
@@ -45,6 +48,8 @@ impl QueryPath {
|
||||
Some(QueryPath::Forwarded)
|
||||
} else if s.eq_ignore_ascii_case("RECURSIVE") {
|
||||
Some(QueryPath::Recursive)
|
||||
} else if s.eq_ignore_ascii_case("COALESCED") {
|
||||
Some(QueryPath::Coalesced)
|
||||
} else if s.eq_ignore_ascii_case("BLOCKED") {
|
||||
Some(QueryPath::Blocked)
|
||||
} else if s.eq_ignore_ascii_case("OVERRIDE") {
|
||||
@@ -69,6 +74,7 @@ impl ServerStats {
|
||||
queries_total: 0,
|
||||
queries_forwarded: 0,
|
||||
queries_recursive: 0,
|
||||
queries_coalesced: 0,
|
||||
queries_cached: 0,
|
||||
queries_blocked: 0,
|
||||
queries_local: 0,
|
||||
@@ -85,6 +91,7 @@ impl ServerStats {
|
||||
QueryPath::Cached => self.queries_cached += 1,
|
||||
QueryPath::Forwarded => self.queries_forwarded += 1,
|
||||
QueryPath::Recursive => self.queries_recursive += 1,
|
||||
QueryPath::Coalesced => self.queries_coalesced += 1,
|
||||
QueryPath::Blocked => self.queries_blocked += 1,
|
||||
QueryPath::Overridden => self.queries_overridden += 1,
|
||||
QueryPath::UpstreamError => self.upstream_errors += 1,
|
||||
@@ -106,6 +113,7 @@ impl ServerStats {
|
||||
total: self.queries_total,
|
||||
forwarded: self.queries_forwarded,
|
||||
recursive: self.queries_recursive,
|
||||
coalesced: self.queries_coalesced,
|
||||
cached: self.queries_cached,
|
||||
local: self.queries_local,
|
||||
overridden: self.queries_overridden,
|
||||
@@ -121,11 +129,12 @@ impl ServerStats {
|
||||
let secs = uptime.as_secs() % 60;
|
||||
|
||||
log::info!(
|
||||
"STATS | uptime {}h{}m{}s | total {} | fwd {} | recursive {} | cached {} | local {} | override {} | blocked {} | errors {}",
|
||||
"STATS | uptime {}h{}m{}s | total {} | fwd {} | recursive {} | coalesced {} | cached {} | local {} | override {} | blocked {} | errors {}",
|
||||
hours, mins, secs,
|
||||
self.queries_total,
|
||||
self.queries_forwarded,
|
||||
self.queries_recursive,
|
||||
self.queries_coalesced,
|
||||
self.queries_cached,
|
||||
self.queries_local,
|
||||
self.queries_overridden,
|
||||
@@ -140,6 +149,7 @@ pub struct StatsSnapshot {
|
||||
pub total: u64,
|
||||
pub forwarded: u64,
|
||||
pub recursive: u64,
|
||||
pub coalesced: u64,
|
||||
pub cached: u64,
|
||||
pub local: u64,
|
||||
pub overridden: u64,
|
||||
|
||||
Reference in New Issue
Block a user