feat: SRTT-based nameserver selection (#19)

* feat: SRTT-based nameserver selection for recursive resolver

BIND-style Smoothed RTT (EWMA) tracking per NS IP address. The resolver
learns which nameservers respond fastest and prefers them, eliminating
cascading timeouts from slow/unreachable IPv6 servers.

- New src/srtt.rs: SrttCache with record_rtt, record_failure, sort_by_rtt
- EWMA formula: new = (old * 7 + sample) / 8, 5s failure penalty, 5min decay
- TCP penalty (+100ms) lets SRTT naturally deprioritize IPv6-over-TCP
- Enabled flag embedded in SrttCache (no-op when disabled)
- Batch eviction (64 entries) for O(1) amortized writes at capacity
- Configurable via [upstream] srtt = true/false (default: true)
- Benchmark script: scripts/benchmark.sh (full, cold, warm, compare-all)
- Benchmarks show 12x avg improvement, 0% queries >1s (was 58%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat: show DNSSEC and SRTT status in dashboard + API

Add dnssec and srtt boolean fields to /stats API response.
Display on/off indicators in the dashboard footer.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: apply SRTT decay before EWMA so recovered servers rehabilitate

Without decay-before-EWMA, a server penalized at 5000ms stayed near
that value even after recovery — the stale raw penalty was used as the
EWMA base instead of the decayed estimate. Extract decayed_srtt()
helper and call it in record_rtt() before the smoothing step.

Also restores removed "why" comments in send_query / resolve_recursive.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* docs: add install/upgrade instructions, smarter benchmark priming

README: document `numa install`, `numa service`, Homebrew upgrade,
and `make deploy` workflows. Benchmark: replace fixed `sleep 4` with
`wait_for_priming` that polls cache entry count for stability.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-03-28 23:22:31 +02:00
committed by GitHub
parent 71dbb138bc
commit 06d4e91cd2
11 changed files with 683 additions and 37 deletions

View File

@@ -30,11 +30,34 @@ dig @127.0.0.1 ads.google.com # ✗ blocked → 0.0.0.0
Open the dashboard: **http://numa.numa** (or `http://localhost:5380`)
Or build from source:
### Set as system resolver
```bash
# Point your system DNS to Numa (saves originals for uninstall)
sudo numa install
# Run as a persistent service (auto-starts on boot, restarts if killed)
sudo numa service start
```
To uninstall: `sudo numa service stop` removes the service, `sudo numa uninstall` restores your original DNS.
### Upgrade
```bash
# From Homebrew
brew upgrade numa
# From source
make deploy # builds release, copies binary, re-signs, restarts service
```
### Build from source
```bash
git clone https://github.com/razvandimescu/numa.git && cd numa
cargo build --release
sudo ./target/release/numa
sudo cp target/release/numa /usr/local/bin/numa
```
## Why Numa

306
scripts/benchmark.sh Executable file
View File

@@ -0,0 +1,306 @@
#!/usr/bin/env bash
set -euo pipefail
API="${NUMA_API:-http://127.0.0.1:5380}"
DNS="${NUMA_DNS:-127.0.0.1}"
NUMA_BIN="${NUMA_BIN:-/usr/local/bin/numa}"
LAUNCHD_PLIST="/Library/LaunchDaemons/com.numa.dns.plist"
DOMAINS=(
paypal.com ebay.com zoom.us slack.com discord.com
microsoft.com apple.com meta.com oracle.com ibm.com
docker.com kubernetes.io prometheus.io grafana.com terraform.io
python.org nodejs.org golang.org wikipedia.org reddit.com
stackoverflow.com stripe.com linear.app nytimes.com bbc.co.uk
rust-lang.org fastly.com hetzner.com uber.com airbnb.com
notion.so figma.com netflix.com spotify.com dropbox.com
gitlab.com twitch.tv shopify.com vercel.app mozilla.org
)
stats() {
curl -s "$API/query-log" | python3 -c "
import sys, json
data = json.load(sys.stdin)
rec = [q for q in data if q['path'] == 'RECURSIVE']
if not rec:
print('No recursive queries in log.')
sys.exit()
vals = sorted([q['latency_ms'] for q in rec])
n = len(vals)
print(f'Recursive queries: {n}')
print(f' Avg: {sum(vals)/n:.1f}ms')
print(f' Median: {vals[n//2]:.1f}ms')
print(f' P95: {vals[int(n*0.95)]:.1f}ms')
print(f' P99: {vals[int(n*0.99)]:.1f}ms')
print(f' Min: {min(vals):.1f}ms')
print(f' Max: {max(vals):.1f}ms')
print(f' <100ms: {sum(1 for v in vals if v < 100)}')
print(f' <200ms: {sum(1 for v in vals if v < 200)}')
print(f' <500ms: {sum(1 for v in vals if v < 500)}')
print(f' >1s: {sum(1 for v in vals if v >= 1000)}')
print()
print('Slowest 5:')
for q in sorted(rec, key=lambda q: q['latency_ms'], reverse=True)[:5]:
print(f' {q[\"latency_ms\"]:>8.1f}ms {q[\"query_type\"]:5s} {q[\"domain\"]:35s} {q[\"rescode\"]}')
print()
print('Fastest 5:')
for q in sorted(rec, key=lambda q: q['latency_ms'])[:5]:
print(f' {q[\"latency_ms\"]:>8.1f}ms {q[\"query_type\"]:5s} {q[\"domain\"]:35s} {q[\"rescode\"]}')
"
}
query_all() {
local label="$1"
echo "=== $label ==="
for d in "${DOMAINS[@]}"; do
printf " %-25s " "$d"
dig "@$DNS" "$d" A +noall +stats 2>/dev/null | grep "Query time"
done
echo
}
flush_cache() {
curl -s -X DELETE "$API/cache" > /dev/null
echo "Cache flushed ($(curl -s "$API/stats" | python3 -c "import sys,json; print(json.load(sys.stdin)['cache']['entries'])" 2>/dev/null || echo '?') entries)."
}
wait_for_api() {
local attempts=0
while ! curl -sf "$API/health" > /dev/null 2>&1; do
attempts=$((attempts + 1))
if [ $attempts -ge 20 ]; then
echo "ERROR: API not reachable at $API after 10s" >&2
exit 1
fi
sleep 0.5
done
}
wait_for_priming() {
echo -n "Waiting for TLD priming..."
local prev=0
local stable=0
for _ in $(seq 1 60); do
local entries
entries=$(curl -s "$API/stats" | python3 -c "import sys,json; print(json.load(sys.stdin)['cache']['entries'])" 2>/dev/null || echo 0)
if [ "$entries" -gt 0 ] && [ "$entries" = "$prev" ]; then
stable=$((stable + 1))
if [ $stable -ge 3 ]; then
echo " done ($entries cache entries)."
return
fi
else
stable=0
fi
prev="$entries"
sleep 1
done
echo " timeout (cache: $prev entries)."
}
# restart_numa <config_toml_body>
# Writes config to a temp file, stops numa (launchd or manual), starts with that config.
restart_numa() {
local config_body="$1"
local tmpconf
tmpconf=$(mktemp /tmp/numa-bench-XXXXXX)
mv "$tmpconf" "${tmpconf}.toml"
tmpconf="${tmpconf}.toml"
echo "$config_body" > "$tmpconf"
# Stop launchd-managed numa if active
if sudo launchctl list com.numa.dns &>/dev/null; then
sudo launchctl unload "$LAUNCHD_PLIST" 2>/dev/null || true
sleep 1
fi
# Kill any remaining
sudo killall numa 2>/dev/null || true
sleep 2
sudo "$NUMA_BIN" "$tmpconf" &
wait_for_api
wait_for_priming
echo "numa ready (pid $(pgrep numa | head -1), config: $tmpconf)."
}
# Restore the launchd service
restore_launchd() {
sudo killall numa 2>/dev/null || true
sleep 1
if [ -f "$LAUNCHD_PLIST" ]; then
sudo launchctl load "$LAUNCHD_PLIST" 2>/dev/null || true
echo "Restored launchd service."
fi
}
run_pass() {
local label="$1"
flush_cache
sleep 0.5
query_all "$label"
echo "=== $label — stats ==="
stats
}
case "${1:-full}" in
cold)
echo "--- Cold cache benchmark ---"
run_pass "Cold SRTT + Cold cache"
;;
warm)
echo "--- Warm SRTT benchmark ---"
echo "Priming SRTT..."
for d in "${DOMAINS[@]}"; do dig "@$DNS" "$d" A +short > /dev/null 2>&1; done
run_pass "Warm SRTT + Cold cache"
;;
stats)
stats
;;
compare-srtt)
echo "============================================"
echo " A/B: SRTT OFF vs ON (dnssec off)"
echo "============================================"
echo
restart_numa "$(cat <<'TOML'
[upstream]
mode = "recursive"
srtt = false
TOML
)"
echo
run_pass "SRTT OFF"
echo
echo "--------------------------------------------"
echo
restart_numa "$(cat <<'TOML'
[upstream]
mode = "recursive"
srtt = true
TOML
)"
echo
run_pass "SRTT ON"
echo
restore_launchd
;;
compare-dnssec)
echo "============================================"
echo " A/B: DNSSEC OFF vs ON (srtt on)"
echo "============================================"
echo
restart_numa "$(cat <<'TOML'
[upstream]
mode = "recursive"
srtt = true
[dnssec]
enabled = false
TOML
)"
echo
run_pass "DNSSEC OFF"
echo
echo "--------------------------------------------"
echo
restart_numa "$(cat <<'TOML'
[upstream]
mode = "recursive"
srtt = true
[dnssec]
enabled = true
TOML
)"
echo
run_pass "DNSSEC ON"
echo
restore_launchd
;;
compare-all)
echo "============================================"
echo " Full A/B matrix"
echo " 1. SRTT OFF + DNSSEC OFF (baseline)"
echo " 2. SRTT ON + DNSSEC OFF"
echo " 3. SRTT ON + DNSSEC ON"
echo "============================================"
echo
# --- 1. Baseline ---
restart_numa "$(cat <<'TOML'
[upstream]
mode = "recursive"
srtt = false
[dnssec]
enabled = false
TOML
)"
echo
run_pass "SRTT OFF + DNSSEC OFF"
echo
echo "--------------------------------------------"
echo
# --- 2. SRTT only ---
restart_numa "$(cat <<'TOML'
[upstream]
mode = "recursive"
srtt = true
[dnssec]
enabled = false
TOML
)"
echo
run_pass "SRTT ON + DNSSEC OFF"
echo
echo "--------------------------------------------"
echo
# --- 3. Both ---
restart_numa "$(cat <<'TOML'
[upstream]
mode = "recursive"
srtt = true
[dnssec]
enabled = true
TOML
)"
echo
run_pass "SRTT ON + DNSSEC ON"
echo
restore_launchd
;;
full|*)
echo "--- Full benchmark (cold → warm → SRTT-only) ---"
echo
wait_for_priming
flush_cache
sleep 0.5
query_all "Pass 1: Cold SRTT + Cold cache"
flush_cache
sleep 0.5
query_all "Pass 2: Warm SRTT + Cold cache"
echo "=== Pass 2 stats (SRTT-warm) ==="
stats
;;
esac

View File

@@ -879,6 +879,10 @@ async function refresh() {
document.getElementById('footerUpstream').textContent = stats.upstream || '';
document.getElementById('footerConfig').textContent = stats.config_path || '';
document.getElementById('footerData').textContent = stats.data_dir || '';
document.getElementById('footerDnssec').textContent = stats.dnssec ? 'on' : 'off';
document.getElementById('footerDnssec').style.color = stats.dnssec ? 'var(--emerald)' : 'var(--text-dim)';
document.getElementById('footerSrtt').textContent = stats.srtt ? 'on' : 'off';
document.getElementById('footerSrtt').style.color = stats.srtt ? 'var(--emerald)' : 'var(--text-dim)';
// LAN status indicator
const lanEl = document.getElementById('lanToggle');
@@ -1229,6 +1233,8 @@ setInterval(refresh, 2000);
Config: <span id="footerConfig" style="user-select:all;color:var(--emerald);"></span>
· Data: <span id="footerData" style="user-select:all;color:var(--emerald);"></span>
· Upstream: <span id="footerUpstream" style="user-select:all;color:var(--emerald);"></span>
· DNSSEC: <span id="footerDnssec" style="color:var(--text-dim);"></span>
· SRTT: <span id="footerSrtt" style="color:var(--text-dim);"></span>
· Logs: <span style="user-select:all;color:var(--emerald);">macOS: /usr/local/var/log/numa.log · Linux: journalctl -u numa -f</span>
· <a href="https://github.com/razvandimescu/numa" target="_blank" rel="noopener" style="color:var(--amber);text-decoration:none;">GitHub</a>
</div>

View File

@@ -162,6 +162,8 @@ struct StatsResponse {
upstream: String,
config_path: String,
data_dir: String,
dnssec: bool,
srtt: bool,
queries: QueriesStats,
cache: CacheStats,
overrides: OverrideStats,
@@ -491,6 +493,8 @@ async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
upstream,
config_path: ctx.config_path.clone(),
data_dir: ctx.data_dir.to_string_lossy().to_string(),
dnssec: ctx.dnssec_enabled,
srtt: ctx.srtt.read().unwrap().is_enabled(),
queries: QueriesStats {
total: snap.total,
forwarded: snap.forwarded,
@@ -948,6 +952,7 @@ mod tests {
tls_config: None,
upstream_mode: crate::config::UpstreamMode::Forward,
root_hints: Vec::new(),
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
dnssec_enabled: false,
dnssec_strict: false,
})

View File

@@ -85,6 +85,8 @@ pub struct UpstreamConfig {
pub root_hints: Vec<String>,
#[serde(default = "default_prime_tlds")]
pub prime_tlds: Vec<String>,
#[serde(default = "default_srtt")]
pub srtt: bool,
}
impl Default for UpstreamConfig {
@@ -96,10 +98,15 @@ impl Default for UpstreamConfig {
timeout_ms: default_timeout_ms(),
root_hints: default_root_hints(),
prime_tlds: default_prime_tlds(),
srtt: default_srtt(),
}
}
}
fn default_srtt() -> bool {
true
}
fn default_prime_tlds() -> Vec<String> {
vec![
// gTLDs

View File

@@ -21,6 +21,7 @@ use crate::query_log::{QueryLog, QueryLogEntry};
use crate::question::QueryType;
use crate::record::DnsRecord;
use crate::service_store::ServiceStore;
use crate::srtt::SrttCache;
use crate::stats::{QueryPath, ServerStats};
use crate::system_dns::ForwardingRule;
@@ -51,6 +52,7 @@ pub struct ServerCtx {
pub tls_config: Option<ArcSwap<ServerConfig>>,
pub upstream_mode: UpstreamMode,
pub root_hints: Vec<SocketAddr>,
pub srtt: RwLock<SrttCache>,
pub dnssec_enabled: bool,
pub dnssec_strict: bool,
}
@@ -176,6 +178,7 @@ pub async fn handle_query(
&ctx.cache,
&query,
&ctx.root_hints,
&ctx.srtt,
)
.await
{
@@ -226,7 +229,8 @@ pub async fn handle_query(
let mut dnssec = dnssec;
if ctx.dnssec_enabled && path == QueryPath::Recursive {
let (status, vstats) =
crate::dnssec::validate_response(&response, &ctx.cache, &ctx.root_hints).await;
crate::dnssec::validate_response(&response, &ctx.cache, &ctx.root_hints, &ctx.srtt)
.await;
debug!(
"DNSSEC | {} | {:?} | {}ms | dnskey_hit={} dnskey_fetch={} ds_hit={} ds_fetch={}",

View File

@@ -9,6 +9,7 @@ use crate::cache::{DnsCache, DnssecStatus};
use crate::packet::DnsPacket;
use crate::question::QueryType;
use crate::record::DnsRecord;
use crate::srtt::SrttCache;
#[derive(Debug, Default)]
pub struct ValidationStats {
@@ -64,6 +65,7 @@ pub async fn validate_response(
response: &DnsPacket,
cache: &RwLock<DnsCache>,
root_hints: &[std::net::SocketAddr],
srtt: &RwLock<SrttCache>,
) -> (DnssecStatus, ValidationStats) {
let start = Instant::now();
let stats = Mutex::new(ValidationStats::default());
@@ -95,7 +97,7 @@ pub async fn validate_response(
}
}
for zone in &signer_zones {
fetch_dnskeys(zone, cache, root_hints, &stats).await;
fetch_dnskeys(zone, cache, root_hints, srtt, &stats).await;
}
// Group answer records into RRsets (by domain + type, excluding RRSIGs)
@@ -132,7 +134,8 @@ pub async fn validate_response(
..
} = rrsig
{
let dnskey_response = fetch_dnskeys(signer_name, cache, root_hints, &stats).await;
let dnskey_response =
fetch_dnskeys(signer_name, cache, root_hints, srtt, &stats).await;
let dnskeys: Vec<&DnsRecord> = dnskey_response
.iter()
.filter(|r| matches!(r, DnsRecord::DNSKEY { .. }))
@@ -206,6 +209,7 @@ pub async fn validate_response(
&dnskey_response,
cache,
root_hints,
srtt,
trust_anchors,
0,
&stats,
@@ -276,11 +280,13 @@ pub async fn validate_response(
/// Walk the chain of trust from zone DNSKEY up to root trust anchor.
/// `zone_records` contains both DNSKEY and RRSIG records from the DNSKEY response.
#[allow(clippy::too_many_arguments)]
fn validate_chain<'a>(
zone: &'a str,
zone_records: &'a [DnsRecord],
cache: &'a RwLock<DnsCache>,
root_hints: &'a [std::net::SocketAddr],
srtt: &'a RwLock<SrttCache>,
trust_anchors: &'a [DnsRecord],
depth: u8,
stats: &'a Mutex<ValidationStats>,
@@ -343,7 +349,7 @@ fn validate_chain<'a>(
return DnssecStatus::Indeterminate;
}
let parent = parent_zone(zone);
let ds_records = fetch_ds(zone, cache, root_hints, stats).await;
let ds_records = fetch_ds(zone, cache, root_hints, srtt, stats).await;
if ds_records.is_empty() {
debug!("dnssec: no DS for zone '{}' at parent '{}'", zone, parent);
@@ -377,7 +383,7 @@ fn validate_chain<'a>(
// Walk up: validate the parent's DNSKEY
trace!("dnssec: fetching parent DNSKEY for '{}'", parent);
let parent_records = fetch_dnskeys(&parent, cache, root_hints, stats).await;
let parent_records = fetch_dnskeys(&parent, cache, root_hints, srtt, stats).await;
if parent_records.is_empty() {
debug!("dnssec: no parent DNSKEY for '{}' — Indeterminate", parent);
return DnssecStatus::Indeterminate;
@@ -388,6 +394,7 @@ fn validate_chain<'a>(
&parent_records,
cache,
root_hints,
srtt,
trust_anchors,
depth + 1,
stats,
@@ -460,6 +467,7 @@ async fn fetch_dnskeys(
zone: &str,
cache: &RwLock<DnsCache>,
root_hints: &[std::net::SocketAddr],
srtt: &RwLock<SrttCache>,
stats: &Mutex<ValidationStats>,
) -> Vec<DnsRecord> {
if let Some(pkt) = cache.read().unwrap().lookup(zone, QueryType::DNSKEY) {
@@ -475,7 +483,8 @@ async fn fetch_dnskeys(
trace!("dnssec: fetch_dnskeys('{}') cache miss — resolving", zone);
stats.lock().unwrap().dnskey_fetches += 1;
if let Ok(pkt) =
crate::recursive::resolve_iterative(zone, QueryType::DNSKEY, cache, root_hints, 0, 0).await
crate::recursive::resolve_iterative(zone, QueryType::DNSKEY, cache, root_hints, srtt, 0, 0)
.await
{
cache.write().unwrap().insert(zone, QueryType::DNSKEY, &pkt);
return pkt.answers;
@@ -488,6 +497,7 @@ async fn fetch_ds(
child: &str,
cache: &RwLock<DnsCache>,
root_hints: &[std::net::SocketAddr],
srtt: &RwLock<SrttCache>,
stats: &Mutex<ValidationStats>,
) -> Vec<DnsRecord> {
if let Some(pkt) = cache.read().unwrap().lookup(child, QueryType::DS) {
@@ -501,7 +511,8 @@ async fn fetch_ds(
stats.lock().unwrap().ds_fetches += 1;
if let Ok(pkt) =
crate::recursive::resolve_iterative(child, QueryType::DS, cache, root_hints, 0, 0).await
crate::recursive::resolve_iterative(child, QueryType::DS, cache, root_hints, srtt, 0, 0)
.await
{
cache.write().unwrap().insert(child, QueryType::DS, &pkt);
return pkt

View File

@@ -16,6 +16,7 @@ pub mod question;
pub mod record;
pub mod recursive;
pub mod service_store;
pub mod srtt;
pub mod stats;
pub mod system_dns;
pub mod tls;

View File

@@ -201,6 +201,7 @@ async fn main() -> numa::Result<()> {
tls_config: initial_tls,
upstream_mode: config.upstream.mode,
root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints),
srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)),
dnssec_enabled: config.dnssec.enabled,
dnssec_strict: config.dnssec.strict,
});
@@ -353,8 +354,13 @@ async fn main() -> numa::Result<()> {
let prime_ctx = Arc::clone(&ctx);
let prime_tlds = config.upstream.prime_tlds;
tokio::spawn(async move {
numa::recursive::prime_tld_cache(&prime_ctx.cache, &prime_ctx.root_hints, &prime_tlds)
.await;
numa::recursive::prime_tld_cache(
&prime_ctx.cache,
&prime_ctx.root_hints,
&prime_tlds,
&prime_ctx.srtt,
)
.await;
});
}

View File

@@ -1,7 +1,7 @@
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::RwLock;
use std::time::Duration;
use std::time::{Duration, Instant};
use log::{debug, info};
@@ -11,6 +11,7 @@ use crate::header::ResultCode;
use crate::packet::DnsPacket;
use crate::question::{DnsQuestion, QueryType};
use crate::record::DnsRecord;
use crate::srtt::SrttCache;
const MAX_REFERRAL_DEPTH: u8 = 10;
const MAX_CNAME_DEPTH: u8 = 8;
@@ -58,7 +59,12 @@ pub async fn probe_udp(root_hints: &[SocketAddr]) {
}
}
pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr], tlds: &[String]) {
pub async fn prime_tld_cache(
cache: &RwLock<DnsCache>,
root_hints: &[SocketAddr],
tlds: &[String],
srtt: &RwLock<SrttCache>,
) {
if root_hints.is_empty() || tlds.is_empty() {
return;
}
@@ -66,7 +72,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
let mut root_addr = root_hints[0];
for hint in root_hints {
info!("prime: probing root {}", hint);
match send_query(".", QueryType::NS, *hint).await {
match send_query(".", QueryType::NS, *hint, srtt).await {
Ok(_) => {
info!("prime: root {} reachable", hint);
root_addr = *hint;
@@ -79,7 +85,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
}
// Fetch root DNSKEY (needed for DNSSEC chain-of-trust terminus)
if let Ok(root_dnskey) = send_query(".", QueryType::DNSKEY, root_addr).await {
if let Ok(root_dnskey) = send_query(".", QueryType::DNSKEY, root_addr, srtt).await {
cache
.write()
.unwrap()
@@ -91,7 +97,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
for tld in tlds {
// Fetch NS referral (includes DS in authority section from root)
let response = match send_query(tld, QueryType::NS, root_addr).await {
let response = match send_query(tld, QueryType::NS, root_addr, srtt).await {
Ok(r) => r,
Err(e) => {
debug!("prime: failed to query NS for .{}: {}", tld, e);
@@ -108,7 +114,6 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
let mut cache_w = cache.write().unwrap();
cache_w.insert(tld, QueryType::NS, &response);
cache_glue(&mut cache_w, &response, &ns_names);
// Cache DS records from referral authority section
cache_ds_from_authority(&mut cache_w, &response);
}
@@ -116,7 +121,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
let first_ns_name = ns_names.first().map(|s| s.as_str()).unwrap_or("");
let first_ns = glue_addrs_for(&response, first_ns_name);
if let Some(ns_addr) = first_ns.first() {
if let Ok(dnskey_resp) = send_query(tld, QueryType::DNSKEY, *ns_addr).await {
if let Ok(dnskey_resp) = send_query(tld, QueryType::DNSKEY, *ns_addr, srtt).await {
cache
.write()
.unwrap()
@@ -140,10 +145,11 @@ pub async fn resolve_recursive(
cache: &RwLock<DnsCache>,
original_query: &DnsPacket,
root_hints: &[SocketAddr],
srtt: &RwLock<SrttCache>,
) -> crate::Result<DnsPacket> {
// No overall timeout — each hop is bounded by NS_QUERY_TIMEOUT (UDP + TCP fallback),
// and MAX_REFERRAL_DEPTH caps the chain length.
let mut resp = resolve_iterative(qname, qtype, cache, root_hints, 0, 0).await?;
let mut resp = resolve_iterative(qname, qtype, cache, root_hints, srtt, 0, 0).await?;
resp.header.id = original_query.header.id;
resp.header.recursion_available = true;
@@ -157,6 +163,7 @@ pub(crate) fn resolve_iterative<'a>(
qtype: QueryType,
cache: &'a RwLock<DnsCache>,
root_hints: &'a [SocketAddr],
srtt: &'a RwLock<SrttCache>,
referral_depth: u8,
cname_depth: u8,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<DnsPacket>> + Send + 'a>> {
@@ -170,6 +177,7 @@ pub(crate) fn resolve_iterative<'a>(
}
let (mut current_zone, mut ns_addrs) = find_closest_ns(qname, cache, root_hints);
srtt.read().unwrap().sort_by_rtt(&mut ns_addrs);
let mut ns_idx = 0;
for _ in 0..MAX_REFERRAL_DEPTH {
@@ -185,7 +193,7 @@ pub(crate) fn resolve_iterative<'a>(
ns_addr, q_type, q_name, current_zone, referral_depth
);
let response = match send_query(q_name, q_type, ns_addr).await {
let response = match send_query(q_name, q_type, ns_addr, srtt).await {
Ok(r) => r,
Err(e) => {
debug!("recursive: NS {} failed: {}", ns_addr, e);
@@ -194,7 +202,6 @@ pub(crate) fn resolve_iterative<'a>(
}
};
// Minimized query response — treat as referral, not final answer
if (q_type != qtype || !q_name.eq_ignore_ascii_case(qname))
&& (!response.authorities.is_empty() || !response.answers.is_empty())
{
@@ -205,8 +212,9 @@ pub(crate) fn resolve_iterative<'a>(
if all_ns.is_empty() {
all_ns = extract_ns_names(&response);
}
let new_addrs = resolve_ns_addrs_from_glue(&response, &all_ns, cache);
let mut new_addrs = resolve_ns_addrs_from_glue(&response, &all_ns, cache);
if !new_addrs.is_empty() {
srtt.read().unwrap().sort_by_rtt(&mut new_addrs);
ns_addrs = new_addrs;
ns_idx = 0;
continue;
@@ -233,6 +241,7 @@ pub(crate) fn resolve_iterative<'a>(
qtype,
cache,
root_hints,
srtt,
0,
cname_depth + 1,
)
@@ -256,8 +265,6 @@ pub(crate) fn resolve_iterative<'a>(
return Ok(response);
}
// Referral — extract NS + glue, cache glue, resolve NS addresses
// Update zone for query minimization
if let Some(zone) = referral_zone(&response) {
current_zone = zone;
}
@@ -276,13 +283,13 @@ pub(crate) fn resolve_iterative<'a>(
for ns_name in &ns_names {
if referral_depth < MAX_REFERRAL_DEPTH {
debug!("recursive: resolving glue-less NS {}", ns_name);
// Try A first, then AAAA
for qt in [QueryType::A, QueryType::AAAA] {
if let Ok(ns_resp) = resolve_iterative(
ns_name,
qt,
cache,
root_hints,
srtt,
referral_depth + 1,
cname_depth,
)
@@ -316,6 +323,7 @@ pub(crate) fn resolve_iterative<'a>(
return Err(format!("could not resolve any NS for {}", qname).into());
}
srtt.read().unwrap().sort_by_rtt(&mut new_ns_addrs);
ns_addrs = new_ns_addrs;
ns_idx = 0;
}
@@ -561,7 +569,32 @@ fn make_glue_packet() -> DnsPacket {
pkt
}
async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate::Result<DnsPacket> {
async fn tcp_with_srtt(
query: &DnsPacket,
server: SocketAddr,
srtt: &RwLock<SrttCache>,
start: Instant,
) -> crate::Result<DnsPacket> {
match crate::forward::forward_tcp(query, server, TCP_TIMEOUT).await {
Ok(resp) => {
srtt.write()
.unwrap()
.record_rtt(server.ip(), start.elapsed().as_millis() as u64, true);
Ok(resp)
}
Err(e) => {
srtt.write().unwrap().record_failure(server.ip());
Err(e)
}
}
}
async fn send_query(
qname: &str,
qtype: QueryType,
server: SocketAddr,
srtt: &RwLock<SrttCache>,
) -> crate::Result<DnsPacket> {
let mut query = DnsPacket::new();
query.header.id = next_id();
query.header.recursion_desired = false;
@@ -573,24 +606,30 @@ async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate:
..Default::default()
});
// Skip IPv6 if the socket can't handle it (bound to 0.0.0.0)
let start = Instant::now();
// IPv6 forced to TCP — our UDP socket is bound to 0.0.0.0
if server.is_ipv6() {
return crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await;
return tcp_with_srtt(&query, server, srtt, start).await;
}
// If UDP has been detected as blocked, go TCP-first
// UDP detected as blocked go TCP-first
if UDP_DISABLED.load(Ordering::Acquire) {
return crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await;
return tcp_with_srtt(&query, server, srtt, start).await;
}
match forward_udp(&query, server, NS_QUERY_TIMEOUT).await {
Ok(resp) if resp.header.truncated_message => {
debug!("send_query: truncated from {}, retrying TCP", server);
crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await
tcp_with_srtt(&query, server, srtt, start).await
}
Ok(resp) => {
// UDP works — reset failure counter
UDP_FAILURES.store(0, Ordering::Release);
srtt.write().unwrap().record_rtt(
server.ip(),
start.elapsed().as_millis() as u64,
false,
);
Ok(resp)
}
Err(e) => {
@@ -603,7 +642,7 @@ async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate:
);
}
debug!("send_query: UDP failed for {}: {}, trying TCP", server, e);
crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await
tcp_with_srtt(&query, server, srtt, start).await
}
}
}
@@ -894,7 +933,8 @@ mod tests {
})
.await;
let result = send_query("test.example.com", QueryType::A, server_addr).await;
let srtt = RwLock::new(SrttCache::new(true));
let result = send_query("test.example.com", QueryType::A, server_addr, &srtt).await;
let resp = result.expect("should resolve via TCP fallback");
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
@@ -945,7 +985,8 @@ mod tests {
})
.await;
let result = send_query("hello.example.com", QueryType::A, server_addr).await;
let srtt = RwLock::new(SrttCache::new(true));
let result = send_query("hello.example.com", QueryType::A, server_addr, &srtt).await;
let resp = result.expect("TCP-only send_query should work");
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
match &resp.answers[0] {
@@ -967,10 +1008,19 @@ mod tests {
.await;
let cache = RwLock::new(DnsCache::new(100, 60, 86400));
let srtt = RwLock::new(SrttCache::new(true));
let root_hints = vec![server_addr];
let result =
resolve_iterative("nonexistent.test", QueryType::A, &cache, &root_hints, 0, 0).await;
let result = resolve_iterative(
"nonexistent.test",
QueryType::A,
&cache,
&root_hints,
&srtt,
0,
0,
)
.await;
let resp = result.expect("NXDOMAIN should still return a response");
assert_eq!(resp.header.rescode, ResultCode::NXDOMAIN);

227
src/srtt.rs Normal file
View File

@@ -0,0 +1,227 @@
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::time::Instant;
const INITIAL_SRTT_MS: u64 = 200;
const FAILURE_PENALTY_MS: u64 = 5000;
const TCP_PENALTY_MS: u64 = 100;
const DECAY_AFTER_SECS: u64 = 300;
const MAX_ENTRIES: usize = 4096;
const EVICT_BATCH: usize = 64;
struct SrttEntry {
srtt_ms: u64,
updated_at: Instant,
}
pub struct SrttCache {
entries: HashMap<IpAddr, SrttEntry>,
enabled: bool,
}
impl Default for SrttCache {
fn default() -> Self {
Self::new(true)
}
}
impl SrttCache {
pub fn new(enabled: bool) -> Self {
Self {
entries: HashMap::new(),
enabled,
}
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
/// Get current SRTT for an IP, applying decay if stale. Returns INITIAL for unknown.
pub fn get(&self, ip: IpAddr) -> u64 {
match self.entries.get(&ip) {
Some(entry) => Self::decayed_srtt(entry),
None => INITIAL_SRTT_MS,
}
}
/// Apply time-based decay: each DECAY_AFTER_SECS period halves distance to INITIAL.
fn decayed_srtt(entry: &SrttEntry) -> u64 {
let age_secs = entry.updated_at.elapsed().as_secs();
if age_secs > DECAY_AFTER_SECS {
let periods = (age_secs / DECAY_AFTER_SECS).min(8);
let mut srtt = entry.srtt_ms;
for _ in 0..periods {
srtt = (srtt + INITIAL_SRTT_MS) / 2;
}
srtt
} else {
entry.srtt_ms
}
}
/// Record a successful query RTT. No-op when disabled.
pub fn record_rtt(&mut self, ip: IpAddr, rtt_ms: u64, tcp: bool) {
if !self.enabled {
return;
}
let effective = if tcp { rtt_ms + TCP_PENALTY_MS } else { rtt_ms };
self.maybe_evict();
let entry = self.entries.entry(ip).or_insert(SrttEntry {
srtt_ms: effective,
updated_at: Instant::now(),
});
// Apply decay before EWMA so recovered servers aren't stuck at stale penalties
let base = Self::decayed_srtt(entry);
// BIND EWMA: new = (old * 7 + sample) / 8
entry.srtt_ms = (base * 7 + effective) / 8;
entry.updated_at = Instant::now();
}
/// Record a failure (timeout or error). No-op when disabled.
pub fn record_failure(&mut self, ip: IpAddr) {
if !self.enabled {
return;
}
self.maybe_evict();
let entry = self.entries.entry(ip).or_insert(SrttEntry {
srtt_ms: FAILURE_PENALTY_MS,
updated_at: Instant::now(),
});
entry.srtt_ms = FAILURE_PENALTY_MS;
entry.updated_at = Instant::now();
}
/// Sort addresses by SRTT ascending (lowest/fastest first). No-op when disabled.
pub fn sort_by_rtt(&self, addrs: &mut [SocketAddr]) {
if !self.enabled {
return;
}
addrs.sort_by_key(|a| self.get(a.ip()));
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
fn maybe_evict(&mut self) {
if self.entries.len() < MAX_ENTRIES {
return;
}
// Batch eviction: remove the oldest EVICT_BATCH entries at once
let mut by_age: Vec<IpAddr> = self.entries.keys().copied().collect();
by_age.sort_by_key(|ip| self.entries[ip].updated_at);
for ip in by_age.into_iter().take(EVICT_BATCH) {
self.entries.remove(&ip);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
fn ip(last: u8) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(192, 0, 2, last))
}
fn sock(last: u8) -> SocketAddr {
SocketAddr::new(ip(last), 53)
}
#[test]
fn unknown_returns_initial() {
let cache = SrttCache::new(true);
assert_eq!(cache.get(ip(1)), INITIAL_SRTT_MS);
}
#[test]
fn ewma_converges() {
let mut cache = SrttCache::new(true);
for _ in 0..20 {
cache.record_rtt(ip(1), 100, false);
}
let srtt = cache.get(ip(1));
assert!(srtt >= 98 && srtt <= 102, "srtt={}", srtt);
}
#[test]
fn failure_sets_penalty() {
let mut cache = SrttCache::new(true);
cache.record_rtt(ip(1), 50, false);
cache.record_failure(ip(1));
assert_eq!(cache.get(ip(1)), FAILURE_PENALTY_MS);
}
#[test]
fn tcp_penalty_added() {
let mut cache = SrttCache::new(true);
for _ in 0..20 {
cache.record_rtt(ip(1), 50, true);
}
let srtt = cache.get(ip(1));
assert!(srtt >= 148 && srtt <= 152, "srtt={}", srtt);
}
#[test]
fn sort_by_rtt_orders_correctly() {
let mut cache = SrttCache::new(true);
for _ in 0..20 {
cache.record_rtt(ip(1), 500, false);
cache.record_rtt(ip(2), 100, false);
cache.record_rtt(ip(3), 10, false);
}
let mut addrs = vec![sock(1), sock(2), sock(3)];
cache.sort_by_rtt(&mut addrs);
assert_eq!(addrs, vec![sock(3), sock(2), sock(1)]);
}
#[test]
fn unknown_servers_sort_equal() {
let cache = SrttCache::new(true);
let mut addrs = vec![sock(1), sock(2), sock(3)];
let original = addrs.clone();
cache.sort_by_rtt(&mut addrs);
assert_eq!(addrs, original);
}
#[test]
fn disabled_is_noop() {
let mut cache = SrttCache::new(false);
cache.record_rtt(ip(1), 50, false);
cache.record_failure(ip(2));
assert_eq!(cache.len(), 0);
let mut addrs = vec![sock(2), sock(1)];
let original = addrs.clone();
cache.sort_by_rtt(&mut addrs);
assert_eq!(addrs, original);
}
#[test]
fn eviction_removes_oldest() {
let mut cache = SrttCache::new(true);
for i in 0..MAX_ENTRIES {
let octets = [
10,
((i >> 16) & 0xFF) as u8,
((i >> 8) & 0xFF) as u8,
(i & 0xFF) as u8,
];
cache.record_rtt(
IpAddr::V4(Ipv4Addr::new(octets[0], octets[1], octets[2], octets[3])),
100,
false,
);
}
assert_eq!(cache.len(), MAX_ENTRIES);
cache.record_rtt(ip(1), 100, false);
// Batch eviction removes EVICT_BATCH entries
assert!(cache.len() <= MAX_ENTRIES - EVICT_BATCH + 1);
}
}