feat: SRTT-based nameserver selection (#19)
* feat: SRTT-based nameserver selection for recursive resolver BIND-style Smoothed RTT (EWMA) tracking per NS IP address. The resolver learns which nameservers respond fastest and prefers them, eliminating cascading timeouts from slow/unreachable IPv6 servers. - New src/srtt.rs: SrttCache with record_rtt, record_failure, sort_by_rtt - EWMA formula: new = (old * 7 + sample) / 8, 5s failure penalty, 5min decay - TCP penalty (+100ms) lets SRTT naturally deprioritize IPv6-over-TCP - Enabled flag embedded in SrttCache (no-op when disabled) - Batch eviction (64 entries) for O(1) amortized writes at capacity - Configurable via [upstream] srtt = true/false (default: true) - Benchmark script: scripts/benchmark.sh (full, cold, warm, compare-all) - Benchmarks show 12x avg improvement, 0% queries >1s (was 58%) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: show DNSSEC and SRTT status in dashboard + API Add dnssec and srtt boolean fields to /stats API response. Display on/off indicators in the dashboard footer. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: apply SRTT decay before EWMA so recovered servers rehabilitate Without decay-before-EWMA, a server penalized at 5000ms stayed near that value even after recovery — the stale raw penalty was used as the EWMA base instead of the decayed estimate. Extract decayed_srtt() helper and call it in record_rtt() before the smoothing step. Also restores removed "why" comments in send_query / resolve_recursive. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * docs: add install/upgrade instructions, smarter benchmark priming README: document `numa install`, `numa service`, Homebrew upgrade, and `make deploy` workflows. Benchmark: replace fixed `sleep 4` with `wait_for_priming` that polls cache entry count for stability. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -162,6 +162,8 @@ struct StatsResponse {
|
||||
upstream: String,
|
||||
config_path: String,
|
||||
data_dir: String,
|
||||
dnssec: bool,
|
||||
srtt: bool,
|
||||
queries: QueriesStats,
|
||||
cache: CacheStats,
|
||||
overrides: OverrideStats,
|
||||
@@ -491,6 +493,8 @@ async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
|
||||
upstream,
|
||||
config_path: ctx.config_path.clone(),
|
||||
data_dir: ctx.data_dir.to_string_lossy().to_string(),
|
||||
dnssec: ctx.dnssec_enabled,
|
||||
srtt: ctx.srtt.read().unwrap().is_enabled(),
|
||||
queries: QueriesStats {
|
||||
total: snap.total,
|
||||
forwarded: snap.forwarded,
|
||||
@@ -948,6 +952,7 @@ mod tests {
|
||||
tls_config: None,
|
||||
upstream_mode: crate::config::UpstreamMode::Forward,
|
||||
root_hints: Vec::new(),
|
||||
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
|
||||
dnssec_enabled: false,
|
||||
dnssec_strict: false,
|
||||
})
|
||||
|
||||
@@ -85,6 +85,8 @@ pub struct UpstreamConfig {
|
||||
pub root_hints: Vec<String>,
|
||||
#[serde(default = "default_prime_tlds")]
|
||||
pub prime_tlds: Vec<String>,
|
||||
#[serde(default = "default_srtt")]
|
||||
pub srtt: bool,
|
||||
}
|
||||
|
||||
impl Default for UpstreamConfig {
|
||||
@@ -96,10 +98,15 @@ impl Default for UpstreamConfig {
|
||||
timeout_ms: default_timeout_ms(),
|
||||
root_hints: default_root_hints(),
|
||||
prime_tlds: default_prime_tlds(),
|
||||
srtt: default_srtt(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_srtt() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_prime_tlds() -> Vec<String> {
|
||||
vec![
|
||||
// gTLDs
|
||||
|
||||
@@ -21,6 +21,7 @@ use crate::query_log::{QueryLog, QueryLogEntry};
|
||||
use crate::question::QueryType;
|
||||
use crate::record::DnsRecord;
|
||||
use crate::service_store::ServiceStore;
|
||||
use crate::srtt::SrttCache;
|
||||
use crate::stats::{QueryPath, ServerStats};
|
||||
use crate::system_dns::ForwardingRule;
|
||||
|
||||
@@ -51,6 +52,7 @@ pub struct ServerCtx {
|
||||
pub tls_config: Option<ArcSwap<ServerConfig>>,
|
||||
pub upstream_mode: UpstreamMode,
|
||||
pub root_hints: Vec<SocketAddr>,
|
||||
pub srtt: RwLock<SrttCache>,
|
||||
pub dnssec_enabled: bool,
|
||||
pub dnssec_strict: bool,
|
||||
}
|
||||
@@ -176,6 +178,7 @@ pub async fn handle_query(
|
||||
&ctx.cache,
|
||||
&query,
|
||||
&ctx.root_hints,
|
||||
&ctx.srtt,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -226,7 +229,8 @@ pub async fn handle_query(
|
||||
let mut dnssec = dnssec;
|
||||
if ctx.dnssec_enabled && path == QueryPath::Recursive {
|
||||
let (status, vstats) =
|
||||
crate::dnssec::validate_response(&response, &ctx.cache, &ctx.root_hints).await;
|
||||
crate::dnssec::validate_response(&response, &ctx.cache, &ctx.root_hints, &ctx.srtt)
|
||||
.await;
|
||||
|
||||
debug!(
|
||||
"DNSSEC | {} | {:?} | {}ms | dnskey_hit={} dnskey_fetch={} ds_hit={} ds_fetch={}",
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::cache::{DnsCache, DnssecStatus};
|
||||
use crate::packet::DnsPacket;
|
||||
use crate::question::QueryType;
|
||||
use crate::record::DnsRecord;
|
||||
use crate::srtt::SrttCache;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ValidationStats {
|
||||
@@ -64,6 +65,7 @@ pub async fn validate_response(
|
||||
response: &DnsPacket,
|
||||
cache: &RwLock<DnsCache>,
|
||||
root_hints: &[std::net::SocketAddr],
|
||||
srtt: &RwLock<SrttCache>,
|
||||
) -> (DnssecStatus, ValidationStats) {
|
||||
let start = Instant::now();
|
||||
let stats = Mutex::new(ValidationStats::default());
|
||||
@@ -95,7 +97,7 @@ pub async fn validate_response(
|
||||
}
|
||||
}
|
||||
for zone in &signer_zones {
|
||||
fetch_dnskeys(zone, cache, root_hints, &stats).await;
|
||||
fetch_dnskeys(zone, cache, root_hints, srtt, &stats).await;
|
||||
}
|
||||
|
||||
// Group answer records into RRsets (by domain + type, excluding RRSIGs)
|
||||
@@ -132,7 +134,8 @@ pub async fn validate_response(
|
||||
..
|
||||
} = rrsig
|
||||
{
|
||||
let dnskey_response = fetch_dnskeys(signer_name, cache, root_hints, &stats).await;
|
||||
let dnskey_response =
|
||||
fetch_dnskeys(signer_name, cache, root_hints, srtt, &stats).await;
|
||||
let dnskeys: Vec<&DnsRecord> = dnskey_response
|
||||
.iter()
|
||||
.filter(|r| matches!(r, DnsRecord::DNSKEY { .. }))
|
||||
@@ -206,6 +209,7 @@ pub async fn validate_response(
|
||||
&dnskey_response,
|
||||
cache,
|
||||
root_hints,
|
||||
srtt,
|
||||
trust_anchors,
|
||||
0,
|
||||
&stats,
|
||||
@@ -276,11 +280,13 @@ pub async fn validate_response(
|
||||
|
||||
/// Walk the chain of trust from zone DNSKEY up to root trust anchor.
|
||||
/// `zone_records` contains both DNSKEY and RRSIG records from the DNSKEY response.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn validate_chain<'a>(
|
||||
zone: &'a str,
|
||||
zone_records: &'a [DnsRecord],
|
||||
cache: &'a RwLock<DnsCache>,
|
||||
root_hints: &'a [std::net::SocketAddr],
|
||||
srtt: &'a RwLock<SrttCache>,
|
||||
trust_anchors: &'a [DnsRecord],
|
||||
depth: u8,
|
||||
stats: &'a Mutex<ValidationStats>,
|
||||
@@ -343,7 +349,7 @@ fn validate_chain<'a>(
|
||||
return DnssecStatus::Indeterminate;
|
||||
}
|
||||
let parent = parent_zone(zone);
|
||||
let ds_records = fetch_ds(zone, cache, root_hints, stats).await;
|
||||
let ds_records = fetch_ds(zone, cache, root_hints, srtt, stats).await;
|
||||
|
||||
if ds_records.is_empty() {
|
||||
debug!("dnssec: no DS for zone '{}' at parent '{}'", zone, parent);
|
||||
@@ -377,7 +383,7 @@ fn validate_chain<'a>(
|
||||
|
||||
// Walk up: validate the parent's DNSKEY
|
||||
trace!("dnssec: fetching parent DNSKEY for '{}'", parent);
|
||||
let parent_records = fetch_dnskeys(&parent, cache, root_hints, stats).await;
|
||||
let parent_records = fetch_dnskeys(&parent, cache, root_hints, srtt, stats).await;
|
||||
if parent_records.is_empty() {
|
||||
debug!("dnssec: no parent DNSKEY for '{}' — Indeterminate", parent);
|
||||
return DnssecStatus::Indeterminate;
|
||||
@@ -388,6 +394,7 @@ fn validate_chain<'a>(
|
||||
&parent_records,
|
||||
cache,
|
||||
root_hints,
|
||||
srtt,
|
||||
trust_anchors,
|
||||
depth + 1,
|
||||
stats,
|
||||
@@ -460,6 +467,7 @@ async fn fetch_dnskeys(
|
||||
zone: &str,
|
||||
cache: &RwLock<DnsCache>,
|
||||
root_hints: &[std::net::SocketAddr],
|
||||
srtt: &RwLock<SrttCache>,
|
||||
stats: &Mutex<ValidationStats>,
|
||||
) -> Vec<DnsRecord> {
|
||||
if let Some(pkt) = cache.read().unwrap().lookup(zone, QueryType::DNSKEY) {
|
||||
@@ -475,7 +483,8 @@ async fn fetch_dnskeys(
|
||||
trace!("dnssec: fetch_dnskeys('{}') cache miss — resolving", zone);
|
||||
stats.lock().unwrap().dnskey_fetches += 1;
|
||||
if let Ok(pkt) =
|
||||
crate::recursive::resolve_iterative(zone, QueryType::DNSKEY, cache, root_hints, 0, 0).await
|
||||
crate::recursive::resolve_iterative(zone, QueryType::DNSKEY, cache, root_hints, srtt, 0, 0)
|
||||
.await
|
||||
{
|
||||
cache.write().unwrap().insert(zone, QueryType::DNSKEY, &pkt);
|
||||
return pkt.answers;
|
||||
@@ -488,6 +497,7 @@ async fn fetch_ds(
|
||||
child: &str,
|
||||
cache: &RwLock<DnsCache>,
|
||||
root_hints: &[std::net::SocketAddr],
|
||||
srtt: &RwLock<SrttCache>,
|
||||
stats: &Mutex<ValidationStats>,
|
||||
) -> Vec<DnsRecord> {
|
||||
if let Some(pkt) = cache.read().unwrap().lookup(child, QueryType::DS) {
|
||||
@@ -501,7 +511,8 @@ async fn fetch_ds(
|
||||
|
||||
stats.lock().unwrap().ds_fetches += 1;
|
||||
if let Ok(pkt) =
|
||||
crate::recursive::resolve_iterative(child, QueryType::DS, cache, root_hints, 0, 0).await
|
||||
crate::recursive::resolve_iterative(child, QueryType::DS, cache, root_hints, srtt, 0, 0)
|
||||
.await
|
||||
{
|
||||
cache.write().unwrap().insert(child, QueryType::DS, &pkt);
|
||||
return pkt
|
||||
|
||||
@@ -16,6 +16,7 @@ pub mod question;
|
||||
pub mod record;
|
||||
pub mod recursive;
|
||||
pub mod service_store;
|
||||
pub mod srtt;
|
||||
pub mod stats;
|
||||
pub mod system_dns;
|
||||
pub mod tls;
|
||||
|
||||
10
src/main.rs
10
src/main.rs
@@ -201,6 +201,7 @@ async fn main() -> numa::Result<()> {
|
||||
tls_config: initial_tls,
|
||||
upstream_mode: config.upstream.mode,
|
||||
root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints),
|
||||
srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)),
|
||||
dnssec_enabled: config.dnssec.enabled,
|
||||
dnssec_strict: config.dnssec.strict,
|
||||
});
|
||||
@@ -353,8 +354,13 @@ async fn main() -> numa::Result<()> {
|
||||
let prime_ctx = Arc::clone(&ctx);
|
||||
let prime_tlds = config.upstream.prime_tlds;
|
||||
tokio::spawn(async move {
|
||||
numa::recursive::prime_tld_cache(&prime_ctx.cache, &prime_ctx.root_hints, &prime_tlds)
|
||||
.await;
|
||||
numa::recursive::prime_tld_cache(
|
||||
&prime_ctx.cache,
|
||||
&prime_ctx.root_hints,
|
||||
&prime_tlds,
|
||||
&prime_ctx.srtt,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
102
src/recursive.rs
102
src/recursive.rs
@@ -1,7 +1,7 @@
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::atomic::{AtomicU16, Ordering};
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use log::{debug, info};
|
||||
|
||||
@@ -11,6 +11,7 @@ use crate::header::ResultCode;
|
||||
use crate::packet::DnsPacket;
|
||||
use crate::question::{DnsQuestion, QueryType};
|
||||
use crate::record::DnsRecord;
|
||||
use crate::srtt::SrttCache;
|
||||
|
||||
const MAX_REFERRAL_DEPTH: u8 = 10;
|
||||
const MAX_CNAME_DEPTH: u8 = 8;
|
||||
@@ -58,7 +59,12 @@ pub async fn probe_udp(root_hints: &[SocketAddr]) {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr], tlds: &[String]) {
|
||||
pub async fn prime_tld_cache(
|
||||
cache: &RwLock<DnsCache>,
|
||||
root_hints: &[SocketAddr],
|
||||
tlds: &[String],
|
||||
srtt: &RwLock<SrttCache>,
|
||||
) {
|
||||
if root_hints.is_empty() || tlds.is_empty() {
|
||||
return;
|
||||
}
|
||||
@@ -66,7 +72,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
|
||||
let mut root_addr = root_hints[0];
|
||||
for hint in root_hints {
|
||||
info!("prime: probing root {}", hint);
|
||||
match send_query(".", QueryType::NS, *hint).await {
|
||||
match send_query(".", QueryType::NS, *hint, srtt).await {
|
||||
Ok(_) => {
|
||||
info!("prime: root {} reachable", hint);
|
||||
root_addr = *hint;
|
||||
@@ -79,7 +85,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
|
||||
}
|
||||
|
||||
// Fetch root DNSKEY (needed for DNSSEC chain-of-trust terminus)
|
||||
if let Ok(root_dnskey) = send_query(".", QueryType::DNSKEY, root_addr).await {
|
||||
if let Ok(root_dnskey) = send_query(".", QueryType::DNSKEY, root_addr, srtt).await {
|
||||
cache
|
||||
.write()
|
||||
.unwrap()
|
||||
@@ -91,7 +97,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
|
||||
|
||||
for tld in tlds {
|
||||
// Fetch NS referral (includes DS in authority section from root)
|
||||
let response = match send_query(tld, QueryType::NS, root_addr).await {
|
||||
let response = match send_query(tld, QueryType::NS, root_addr, srtt).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
debug!("prime: failed to query NS for .{}: {}", tld, e);
|
||||
@@ -108,7 +114,6 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
|
||||
let mut cache_w = cache.write().unwrap();
|
||||
cache_w.insert(tld, QueryType::NS, &response);
|
||||
cache_glue(&mut cache_w, &response, &ns_names);
|
||||
// Cache DS records from referral authority section
|
||||
cache_ds_from_authority(&mut cache_w, &response);
|
||||
}
|
||||
|
||||
@@ -116,7 +121,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
|
||||
let first_ns_name = ns_names.first().map(|s| s.as_str()).unwrap_or("");
|
||||
let first_ns = glue_addrs_for(&response, first_ns_name);
|
||||
if let Some(ns_addr) = first_ns.first() {
|
||||
if let Ok(dnskey_resp) = send_query(tld, QueryType::DNSKEY, *ns_addr).await {
|
||||
if let Ok(dnskey_resp) = send_query(tld, QueryType::DNSKEY, *ns_addr, srtt).await {
|
||||
cache
|
||||
.write()
|
||||
.unwrap()
|
||||
@@ -140,10 +145,11 @@ pub async fn resolve_recursive(
|
||||
cache: &RwLock<DnsCache>,
|
||||
original_query: &DnsPacket,
|
||||
root_hints: &[SocketAddr],
|
||||
srtt: &RwLock<SrttCache>,
|
||||
) -> crate::Result<DnsPacket> {
|
||||
// No overall timeout — each hop is bounded by NS_QUERY_TIMEOUT (UDP + TCP fallback),
|
||||
// and MAX_REFERRAL_DEPTH caps the chain length.
|
||||
let mut resp = resolve_iterative(qname, qtype, cache, root_hints, 0, 0).await?;
|
||||
let mut resp = resolve_iterative(qname, qtype, cache, root_hints, srtt, 0, 0).await?;
|
||||
|
||||
resp.header.id = original_query.header.id;
|
||||
resp.header.recursion_available = true;
|
||||
@@ -157,6 +163,7 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
qtype: QueryType,
|
||||
cache: &'a RwLock<DnsCache>,
|
||||
root_hints: &'a [SocketAddr],
|
||||
srtt: &'a RwLock<SrttCache>,
|
||||
referral_depth: u8,
|
||||
cname_depth: u8,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<DnsPacket>> + Send + 'a>> {
|
||||
@@ -170,6 +177,7 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
}
|
||||
|
||||
let (mut current_zone, mut ns_addrs) = find_closest_ns(qname, cache, root_hints);
|
||||
srtt.read().unwrap().sort_by_rtt(&mut ns_addrs);
|
||||
let mut ns_idx = 0;
|
||||
|
||||
for _ in 0..MAX_REFERRAL_DEPTH {
|
||||
@@ -185,7 +193,7 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
ns_addr, q_type, q_name, current_zone, referral_depth
|
||||
);
|
||||
|
||||
let response = match send_query(q_name, q_type, ns_addr).await {
|
||||
let response = match send_query(q_name, q_type, ns_addr, srtt).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
debug!("recursive: NS {} failed: {}", ns_addr, e);
|
||||
@@ -194,7 +202,6 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
}
|
||||
};
|
||||
|
||||
// Minimized query response — treat as referral, not final answer
|
||||
if (q_type != qtype || !q_name.eq_ignore_ascii_case(qname))
|
||||
&& (!response.authorities.is_empty() || !response.answers.is_empty())
|
||||
{
|
||||
@@ -205,8 +212,9 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
if all_ns.is_empty() {
|
||||
all_ns = extract_ns_names(&response);
|
||||
}
|
||||
let new_addrs = resolve_ns_addrs_from_glue(&response, &all_ns, cache);
|
||||
let mut new_addrs = resolve_ns_addrs_from_glue(&response, &all_ns, cache);
|
||||
if !new_addrs.is_empty() {
|
||||
srtt.read().unwrap().sort_by_rtt(&mut new_addrs);
|
||||
ns_addrs = new_addrs;
|
||||
ns_idx = 0;
|
||||
continue;
|
||||
@@ -233,6 +241,7 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
qtype,
|
||||
cache,
|
||||
root_hints,
|
||||
srtt,
|
||||
0,
|
||||
cname_depth + 1,
|
||||
)
|
||||
@@ -256,8 +265,6 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
// Referral — extract NS + glue, cache glue, resolve NS addresses
|
||||
// Update zone for query minimization
|
||||
if let Some(zone) = referral_zone(&response) {
|
||||
current_zone = zone;
|
||||
}
|
||||
@@ -276,13 +283,13 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
for ns_name in &ns_names {
|
||||
if referral_depth < MAX_REFERRAL_DEPTH {
|
||||
debug!("recursive: resolving glue-less NS {}", ns_name);
|
||||
// Try A first, then AAAA
|
||||
for qt in [QueryType::A, QueryType::AAAA] {
|
||||
if let Ok(ns_resp) = resolve_iterative(
|
||||
ns_name,
|
||||
qt,
|
||||
cache,
|
||||
root_hints,
|
||||
srtt,
|
||||
referral_depth + 1,
|
||||
cname_depth,
|
||||
)
|
||||
@@ -316,6 +323,7 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
return Err(format!("could not resolve any NS for {}", qname).into());
|
||||
}
|
||||
|
||||
srtt.read().unwrap().sort_by_rtt(&mut new_ns_addrs);
|
||||
ns_addrs = new_ns_addrs;
|
||||
ns_idx = 0;
|
||||
}
|
||||
@@ -561,7 +569,32 @@ fn make_glue_packet() -> DnsPacket {
|
||||
pkt
|
||||
}
|
||||
|
||||
async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate::Result<DnsPacket> {
|
||||
async fn tcp_with_srtt(
|
||||
query: &DnsPacket,
|
||||
server: SocketAddr,
|
||||
srtt: &RwLock<SrttCache>,
|
||||
start: Instant,
|
||||
) -> crate::Result<DnsPacket> {
|
||||
match crate::forward::forward_tcp(query, server, TCP_TIMEOUT).await {
|
||||
Ok(resp) => {
|
||||
srtt.write()
|
||||
.unwrap()
|
||||
.record_rtt(server.ip(), start.elapsed().as_millis() as u64, true);
|
||||
Ok(resp)
|
||||
}
|
||||
Err(e) => {
|
||||
srtt.write().unwrap().record_failure(server.ip());
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_query(
|
||||
qname: &str,
|
||||
qtype: QueryType,
|
||||
server: SocketAddr,
|
||||
srtt: &RwLock<SrttCache>,
|
||||
) -> crate::Result<DnsPacket> {
|
||||
let mut query = DnsPacket::new();
|
||||
query.header.id = next_id();
|
||||
query.header.recursion_desired = false;
|
||||
@@ -573,24 +606,30 @@ async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate:
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Skip IPv6 if the socket can't handle it (bound to 0.0.0.0)
|
||||
let start = Instant::now();
|
||||
|
||||
// IPv6 forced to TCP — our UDP socket is bound to 0.0.0.0
|
||||
if server.is_ipv6() {
|
||||
return crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await;
|
||||
return tcp_with_srtt(&query, server, srtt, start).await;
|
||||
}
|
||||
|
||||
// If UDP has been detected as blocked, go TCP-first
|
||||
// UDP detected as blocked — go TCP-first
|
||||
if UDP_DISABLED.load(Ordering::Acquire) {
|
||||
return crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await;
|
||||
return tcp_with_srtt(&query, server, srtt, start).await;
|
||||
}
|
||||
|
||||
match forward_udp(&query, server, NS_QUERY_TIMEOUT).await {
|
||||
Ok(resp) if resp.header.truncated_message => {
|
||||
debug!("send_query: truncated from {}, retrying TCP", server);
|
||||
crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await
|
||||
tcp_with_srtt(&query, server, srtt, start).await
|
||||
}
|
||||
Ok(resp) => {
|
||||
// UDP works — reset failure counter
|
||||
UDP_FAILURES.store(0, Ordering::Release);
|
||||
srtt.write().unwrap().record_rtt(
|
||||
server.ip(),
|
||||
start.elapsed().as_millis() as u64,
|
||||
false,
|
||||
);
|
||||
Ok(resp)
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -603,7 +642,7 @@ async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate:
|
||||
);
|
||||
}
|
||||
debug!("send_query: UDP failed for {}: {}, trying TCP", server, e);
|
||||
crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await
|
||||
tcp_with_srtt(&query, server, srtt, start).await
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -894,7 +933,8 @@ mod tests {
|
||||
})
|
||||
.await;
|
||||
|
||||
let result = send_query("test.example.com", QueryType::A, server_addr).await;
|
||||
let srtt = RwLock::new(SrttCache::new(true));
|
||||
let result = send_query("test.example.com", QueryType::A, server_addr, &srtt).await;
|
||||
|
||||
let resp = result.expect("should resolve via TCP fallback");
|
||||
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
|
||||
@@ -945,7 +985,8 @@ mod tests {
|
||||
})
|
||||
.await;
|
||||
|
||||
let result = send_query("hello.example.com", QueryType::A, server_addr).await;
|
||||
let srtt = RwLock::new(SrttCache::new(true));
|
||||
let result = send_query("hello.example.com", QueryType::A, server_addr, &srtt).await;
|
||||
let resp = result.expect("TCP-only send_query should work");
|
||||
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
|
||||
match &resp.answers[0] {
|
||||
@@ -967,10 +1008,19 @@ mod tests {
|
||||
.await;
|
||||
|
||||
let cache = RwLock::new(DnsCache::new(100, 60, 86400));
|
||||
let srtt = RwLock::new(SrttCache::new(true));
|
||||
let root_hints = vec![server_addr];
|
||||
|
||||
let result =
|
||||
resolve_iterative("nonexistent.test", QueryType::A, &cache, &root_hints, 0, 0).await;
|
||||
let result = resolve_iterative(
|
||||
"nonexistent.test",
|
||||
QueryType::A,
|
||||
&cache,
|
||||
&root_hints,
|
||||
&srtt,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
|
||||
let resp = result.expect("NXDOMAIN should still return a response");
|
||||
assert_eq!(resp.header.rescode, ResultCode::NXDOMAIN);
|
||||
|
||||
227
src/srtt.rs
Normal file
227
src/srtt.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::time::Instant;
|
||||
|
||||
const INITIAL_SRTT_MS: u64 = 200;
|
||||
const FAILURE_PENALTY_MS: u64 = 5000;
|
||||
const TCP_PENALTY_MS: u64 = 100;
|
||||
const DECAY_AFTER_SECS: u64 = 300;
|
||||
const MAX_ENTRIES: usize = 4096;
|
||||
const EVICT_BATCH: usize = 64;
|
||||
|
||||
struct SrttEntry {
|
||||
srtt_ms: u64,
|
||||
updated_at: Instant,
|
||||
}
|
||||
|
||||
pub struct SrttCache {
|
||||
entries: HashMap<IpAddr, SrttEntry>,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl Default for SrttCache {
|
||||
fn default() -> Self {
|
||||
Self::new(true)
|
||||
}
|
||||
}
|
||||
|
||||
impl SrttCache {
|
||||
pub fn new(enabled: bool) -> Self {
|
||||
Self {
|
||||
entries: HashMap::new(),
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
|
||||
/// Get current SRTT for an IP, applying decay if stale. Returns INITIAL for unknown.
|
||||
pub fn get(&self, ip: IpAddr) -> u64 {
|
||||
match self.entries.get(&ip) {
|
||||
Some(entry) => Self::decayed_srtt(entry),
|
||||
None => INITIAL_SRTT_MS,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply time-based decay: each DECAY_AFTER_SECS period halves distance to INITIAL.
|
||||
fn decayed_srtt(entry: &SrttEntry) -> u64 {
|
||||
let age_secs = entry.updated_at.elapsed().as_secs();
|
||||
if age_secs > DECAY_AFTER_SECS {
|
||||
let periods = (age_secs / DECAY_AFTER_SECS).min(8);
|
||||
let mut srtt = entry.srtt_ms;
|
||||
for _ in 0..periods {
|
||||
srtt = (srtt + INITIAL_SRTT_MS) / 2;
|
||||
}
|
||||
srtt
|
||||
} else {
|
||||
entry.srtt_ms
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a successful query RTT. No-op when disabled.
|
||||
pub fn record_rtt(&mut self, ip: IpAddr, rtt_ms: u64, tcp: bool) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
let effective = if tcp { rtt_ms + TCP_PENALTY_MS } else { rtt_ms };
|
||||
self.maybe_evict();
|
||||
let entry = self.entries.entry(ip).or_insert(SrttEntry {
|
||||
srtt_ms: effective,
|
||||
updated_at: Instant::now(),
|
||||
});
|
||||
// Apply decay before EWMA so recovered servers aren't stuck at stale penalties
|
||||
let base = Self::decayed_srtt(entry);
|
||||
// BIND EWMA: new = (old * 7 + sample) / 8
|
||||
entry.srtt_ms = (base * 7 + effective) / 8;
|
||||
entry.updated_at = Instant::now();
|
||||
}
|
||||
|
||||
/// Record a failure (timeout or error). No-op when disabled.
|
||||
pub fn record_failure(&mut self, ip: IpAddr) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
self.maybe_evict();
|
||||
let entry = self.entries.entry(ip).or_insert(SrttEntry {
|
||||
srtt_ms: FAILURE_PENALTY_MS,
|
||||
updated_at: Instant::now(),
|
||||
});
|
||||
entry.srtt_ms = FAILURE_PENALTY_MS;
|
||||
entry.updated_at = Instant::now();
|
||||
}
|
||||
|
||||
/// Sort addresses by SRTT ascending (lowest/fastest first). No-op when disabled.
|
||||
pub fn sort_by_rtt(&self, addrs: &mut [SocketAddr]) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
addrs.sort_by_key(|a| self.get(a.ip()));
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
|
||||
fn maybe_evict(&mut self) {
|
||||
if self.entries.len() < MAX_ENTRIES {
|
||||
return;
|
||||
}
|
||||
// Batch eviction: remove the oldest EVICT_BATCH entries at once
|
||||
let mut by_age: Vec<IpAddr> = self.entries.keys().copied().collect();
|
||||
by_age.sort_by_key(|ip| self.entries[ip].updated_at);
|
||||
for ip in by_age.into_iter().take(EVICT_BATCH) {
|
||||
self.entries.remove(&ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
fn ip(last: u8) -> IpAddr {
|
||||
IpAddr::V4(Ipv4Addr::new(192, 0, 2, last))
|
||||
}
|
||||
|
||||
fn sock(last: u8) -> SocketAddr {
|
||||
SocketAddr::new(ip(last), 53)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_returns_initial() {
|
||||
let cache = SrttCache::new(true);
|
||||
assert_eq!(cache.get(ip(1)), INITIAL_SRTT_MS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewma_converges() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
for _ in 0..20 {
|
||||
cache.record_rtt(ip(1), 100, false);
|
||||
}
|
||||
let srtt = cache.get(ip(1));
|
||||
assert!(srtt >= 98 && srtt <= 102, "srtt={}", srtt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn failure_sets_penalty() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
cache.record_rtt(ip(1), 50, false);
|
||||
cache.record_failure(ip(1));
|
||||
assert_eq!(cache.get(ip(1)), FAILURE_PENALTY_MS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tcp_penalty_added() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
for _ in 0..20 {
|
||||
cache.record_rtt(ip(1), 50, true);
|
||||
}
|
||||
let srtt = cache.get(ip(1));
|
||||
assert!(srtt >= 148 && srtt <= 152, "srtt={}", srtt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sort_by_rtt_orders_correctly() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
for _ in 0..20 {
|
||||
cache.record_rtt(ip(1), 500, false);
|
||||
cache.record_rtt(ip(2), 100, false);
|
||||
cache.record_rtt(ip(3), 10, false);
|
||||
}
|
||||
let mut addrs = vec![sock(1), sock(2), sock(3)];
|
||||
cache.sort_by_rtt(&mut addrs);
|
||||
assert_eq!(addrs, vec![sock(3), sock(2), sock(1)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_servers_sort_equal() {
|
||||
let cache = SrttCache::new(true);
|
||||
let mut addrs = vec![sock(1), sock(2), sock(3)];
|
||||
let original = addrs.clone();
|
||||
cache.sort_by_rtt(&mut addrs);
|
||||
assert_eq!(addrs, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disabled_is_noop() {
|
||||
let mut cache = SrttCache::new(false);
|
||||
cache.record_rtt(ip(1), 50, false);
|
||||
cache.record_failure(ip(2));
|
||||
assert_eq!(cache.len(), 0);
|
||||
|
||||
let mut addrs = vec![sock(2), sock(1)];
|
||||
let original = addrs.clone();
|
||||
cache.sort_by_rtt(&mut addrs);
|
||||
assert_eq!(addrs, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eviction_removes_oldest() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
for i in 0..MAX_ENTRIES {
|
||||
let octets = [
|
||||
10,
|
||||
((i >> 16) & 0xFF) as u8,
|
||||
((i >> 8) & 0xFF) as u8,
|
||||
(i & 0xFF) as u8,
|
||||
];
|
||||
cache.record_rtt(
|
||||
IpAddr::V4(Ipv4Addr::new(octets[0], octets[1], octets[2], octets[3])),
|
||||
100,
|
||||
false,
|
||||
);
|
||||
}
|
||||
assert_eq!(cache.len(), MAX_ENTRIES);
|
||||
cache.record_rtt(ip(1), 100, false);
|
||||
// Batch eviction removes EVICT_BATCH entries
|
||||
assert!(cache.len() <= MAX_ENTRIES - EVICT_BATCH + 1);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user