feat: SRTT-based nameserver selection (#19)
* feat: SRTT-based nameserver selection for recursive resolver BIND-style Smoothed RTT (EWMA) tracking per NS IP address. The resolver learns which nameservers respond fastest and prefers them, eliminating cascading timeouts from slow/unreachable IPv6 servers. - New src/srtt.rs: SrttCache with record_rtt, record_failure, sort_by_rtt - EWMA formula: new = (old * 7 + sample) / 8, 5s failure penalty, 5min decay - TCP penalty (+100ms) lets SRTT naturally deprioritize IPv6-over-TCP - Enabled flag embedded in SrttCache (no-op when disabled) - Batch eviction (64 entries) for O(1) amortized writes at capacity - Configurable via [upstream] srtt = true/false (default: true) - Benchmark script: scripts/benchmark.sh (full, cold, warm, compare-all) - Benchmarks show 12x avg improvement, 0% queries >1s (was 58%) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: show DNSSEC and SRTT status in dashboard + API Add dnssec and srtt boolean fields to /stats API response. Display on/off indicators in the dashboard footer. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: apply SRTT decay before EWMA so recovered servers rehabilitate Without decay-before-EWMA, a server penalized at 5000ms stayed near that value even after recovery — the stale raw penalty was used as the EWMA base instead of the decayed estimate. Extract decayed_srtt() helper and call it in record_rtt() before the smoothing step. Also restores removed "why" comments in send_query / resolve_recursive. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * docs: add install/upgrade instructions, smarter benchmark priming README: document `numa install`, `numa service`, Homebrew upgrade, and `make deploy` workflows. Benchmark: replace fixed `sleep 4` with `wait_for_priming` that polls cache entry count for stability. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
27
README.md
27
README.md
@@ -30,11 +30,34 @@ dig @127.0.0.1 ads.google.com # ✗ blocked → 0.0.0.0
|
||||
|
||||
Open the dashboard: **http://numa.numa** (or `http://localhost:5380`)
|
||||
|
||||
Or build from source:
|
||||
### Set as system resolver
|
||||
|
||||
```bash
|
||||
# Point your system DNS to Numa (saves originals for uninstall)
|
||||
sudo numa install
|
||||
|
||||
# Run as a persistent service (auto-starts on boot, restarts if killed)
|
||||
sudo numa service start
|
||||
```
|
||||
|
||||
To uninstall: `sudo numa service stop` removes the service, `sudo numa uninstall` restores your original DNS.
|
||||
|
||||
### Upgrade
|
||||
|
||||
```bash
|
||||
# From Homebrew
|
||||
brew upgrade numa
|
||||
|
||||
# From source
|
||||
make deploy # builds release, copies binary, re-signs, restarts service
|
||||
```
|
||||
|
||||
### Build from source
|
||||
|
||||
```bash
|
||||
git clone https://github.com/razvandimescu/numa.git && cd numa
|
||||
cargo build --release
|
||||
sudo ./target/release/numa
|
||||
sudo cp target/release/numa /usr/local/bin/numa
|
||||
```
|
||||
|
||||
## Why Numa
|
||||
|
||||
306
scripts/benchmark.sh
Executable file
306
scripts/benchmark.sh
Executable file
@@ -0,0 +1,306 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
API="${NUMA_API:-http://127.0.0.1:5380}"
|
||||
DNS="${NUMA_DNS:-127.0.0.1}"
|
||||
NUMA_BIN="${NUMA_BIN:-/usr/local/bin/numa}"
|
||||
LAUNCHD_PLIST="/Library/LaunchDaemons/com.numa.dns.plist"
|
||||
|
||||
DOMAINS=(
|
||||
paypal.com ebay.com zoom.us slack.com discord.com
|
||||
microsoft.com apple.com meta.com oracle.com ibm.com
|
||||
docker.com kubernetes.io prometheus.io grafana.com terraform.io
|
||||
python.org nodejs.org golang.org wikipedia.org reddit.com
|
||||
stackoverflow.com stripe.com linear.app nytimes.com bbc.co.uk
|
||||
rust-lang.org fastly.com hetzner.com uber.com airbnb.com
|
||||
notion.so figma.com netflix.com spotify.com dropbox.com
|
||||
gitlab.com twitch.tv shopify.com vercel.app mozilla.org
|
||||
)
|
||||
|
||||
stats() {
|
||||
curl -s "$API/query-log" | python3 -c "
|
||||
import sys, json
|
||||
|
||||
data = json.load(sys.stdin)
|
||||
rec = [q for q in data if q['path'] == 'RECURSIVE']
|
||||
if not rec:
|
||||
print('No recursive queries in log.')
|
||||
sys.exit()
|
||||
|
||||
vals = sorted([q['latency_ms'] for q in rec])
|
||||
n = len(vals)
|
||||
|
||||
print(f'Recursive queries: {n}')
|
||||
print(f' Avg: {sum(vals)/n:.1f}ms')
|
||||
print(f' Median: {vals[n//2]:.1f}ms')
|
||||
print(f' P95: {vals[int(n*0.95)]:.1f}ms')
|
||||
print(f' P99: {vals[int(n*0.99)]:.1f}ms')
|
||||
print(f' Min: {min(vals):.1f}ms')
|
||||
print(f' Max: {max(vals):.1f}ms')
|
||||
print(f' <100ms: {sum(1 for v in vals if v < 100)}')
|
||||
print(f' <200ms: {sum(1 for v in vals if v < 200)}')
|
||||
print(f' <500ms: {sum(1 for v in vals if v < 500)}')
|
||||
print(f' >1s: {sum(1 for v in vals if v >= 1000)}')
|
||||
print()
|
||||
print('Slowest 5:')
|
||||
for q in sorted(rec, key=lambda q: q['latency_ms'], reverse=True)[:5]:
|
||||
print(f' {q[\"latency_ms\"]:>8.1f}ms {q[\"query_type\"]:5s} {q[\"domain\"]:35s} {q[\"rescode\"]}')
|
||||
print()
|
||||
print('Fastest 5:')
|
||||
for q in sorted(rec, key=lambda q: q['latency_ms'])[:5]:
|
||||
print(f' {q[\"latency_ms\"]:>8.1f}ms {q[\"query_type\"]:5s} {q[\"domain\"]:35s} {q[\"rescode\"]}')
|
||||
"
|
||||
}
|
||||
|
||||
query_all() {
|
||||
local label="$1"
|
||||
echo "=== $label ==="
|
||||
for d in "${DOMAINS[@]}"; do
|
||||
printf " %-25s " "$d"
|
||||
dig "@$DNS" "$d" A +noall +stats 2>/dev/null | grep "Query time"
|
||||
done
|
||||
echo
|
||||
}
|
||||
|
||||
flush_cache() {
|
||||
curl -s -X DELETE "$API/cache" > /dev/null
|
||||
echo "Cache flushed ($(curl -s "$API/stats" | python3 -c "import sys,json; print(json.load(sys.stdin)['cache']['entries'])" 2>/dev/null || echo '?') entries)."
|
||||
}
|
||||
|
||||
wait_for_api() {
|
||||
local attempts=0
|
||||
while ! curl -sf "$API/health" > /dev/null 2>&1; do
|
||||
attempts=$((attempts + 1))
|
||||
if [ $attempts -ge 20 ]; then
|
||||
echo "ERROR: API not reachable at $API after 10s" >&2
|
||||
exit 1
|
||||
fi
|
||||
sleep 0.5
|
||||
done
|
||||
}
|
||||
|
||||
wait_for_priming() {
|
||||
echo -n "Waiting for TLD priming..."
|
||||
local prev=0
|
||||
local stable=0
|
||||
for _ in $(seq 1 60); do
|
||||
local entries
|
||||
entries=$(curl -s "$API/stats" | python3 -c "import sys,json; print(json.load(sys.stdin)['cache']['entries'])" 2>/dev/null || echo 0)
|
||||
if [ "$entries" -gt 0 ] && [ "$entries" = "$prev" ]; then
|
||||
stable=$((stable + 1))
|
||||
if [ $stable -ge 3 ]; then
|
||||
echo " done ($entries cache entries)."
|
||||
return
|
||||
fi
|
||||
else
|
||||
stable=0
|
||||
fi
|
||||
prev="$entries"
|
||||
sleep 1
|
||||
done
|
||||
echo " timeout (cache: $prev entries)."
|
||||
}
|
||||
|
||||
# restart_numa <config_toml_body>
|
||||
# Writes config to a temp file, stops numa (launchd or manual), starts with that config.
|
||||
restart_numa() {
|
||||
local config_body="$1"
|
||||
local tmpconf
|
||||
tmpconf=$(mktemp /tmp/numa-bench-XXXXXX)
|
||||
mv "$tmpconf" "${tmpconf}.toml"
|
||||
tmpconf="${tmpconf}.toml"
|
||||
echo "$config_body" > "$tmpconf"
|
||||
|
||||
# Stop launchd-managed numa if active
|
||||
if sudo launchctl list com.numa.dns &>/dev/null; then
|
||||
sudo launchctl unload "$LAUNCHD_PLIST" 2>/dev/null || true
|
||||
sleep 1
|
||||
fi
|
||||
|
||||
# Kill any remaining
|
||||
sudo killall numa 2>/dev/null || true
|
||||
sleep 2
|
||||
|
||||
sudo "$NUMA_BIN" "$tmpconf" &
|
||||
wait_for_api
|
||||
wait_for_priming
|
||||
echo "numa ready (pid $(pgrep numa | head -1), config: $tmpconf)."
|
||||
}
|
||||
|
||||
# Restore the launchd service
|
||||
restore_launchd() {
|
||||
sudo killall numa 2>/dev/null || true
|
||||
sleep 1
|
||||
if [ -f "$LAUNCHD_PLIST" ]; then
|
||||
sudo launchctl load "$LAUNCHD_PLIST" 2>/dev/null || true
|
||||
echo "Restored launchd service."
|
||||
fi
|
||||
}
|
||||
|
||||
run_pass() {
|
||||
local label="$1"
|
||||
flush_cache
|
||||
sleep 0.5
|
||||
query_all "$label"
|
||||
echo "=== $label — stats ==="
|
||||
stats
|
||||
}
|
||||
|
||||
case "${1:-full}" in
|
||||
cold)
|
||||
echo "--- Cold cache benchmark ---"
|
||||
run_pass "Cold SRTT + Cold cache"
|
||||
;;
|
||||
warm)
|
||||
echo "--- Warm SRTT benchmark ---"
|
||||
echo "Priming SRTT..."
|
||||
for d in "${DOMAINS[@]}"; do dig "@$DNS" "$d" A +short > /dev/null 2>&1; done
|
||||
run_pass "Warm SRTT + Cold cache"
|
||||
;;
|
||||
stats)
|
||||
stats
|
||||
;;
|
||||
compare-srtt)
|
||||
echo "============================================"
|
||||
echo " A/B: SRTT OFF vs ON (dnssec off)"
|
||||
echo "============================================"
|
||||
echo
|
||||
|
||||
restart_numa "$(cat <<'TOML'
|
||||
[upstream]
|
||||
mode = "recursive"
|
||||
srtt = false
|
||||
TOML
|
||||
)"
|
||||
echo
|
||||
run_pass "SRTT OFF"
|
||||
|
||||
echo
|
||||
echo "--------------------------------------------"
|
||||
echo
|
||||
|
||||
restart_numa "$(cat <<'TOML'
|
||||
[upstream]
|
||||
mode = "recursive"
|
||||
srtt = true
|
||||
TOML
|
||||
)"
|
||||
echo
|
||||
run_pass "SRTT ON"
|
||||
|
||||
echo
|
||||
restore_launchd
|
||||
;;
|
||||
compare-dnssec)
|
||||
echo "============================================"
|
||||
echo " A/B: DNSSEC OFF vs ON (srtt on)"
|
||||
echo "============================================"
|
||||
echo
|
||||
|
||||
restart_numa "$(cat <<'TOML'
|
||||
[upstream]
|
||||
mode = "recursive"
|
||||
srtt = true
|
||||
|
||||
[dnssec]
|
||||
enabled = false
|
||||
TOML
|
||||
)"
|
||||
echo
|
||||
run_pass "DNSSEC OFF"
|
||||
|
||||
echo
|
||||
echo "--------------------------------------------"
|
||||
echo
|
||||
|
||||
restart_numa "$(cat <<'TOML'
|
||||
[upstream]
|
||||
mode = "recursive"
|
||||
srtt = true
|
||||
|
||||
[dnssec]
|
||||
enabled = true
|
||||
TOML
|
||||
)"
|
||||
echo
|
||||
run_pass "DNSSEC ON"
|
||||
|
||||
echo
|
||||
restore_launchd
|
||||
;;
|
||||
compare-all)
|
||||
echo "============================================"
|
||||
echo " Full A/B matrix"
|
||||
echo " 1. SRTT OFF + DNSSEC OFF (baseline)"
|
||||
echo " 2. SRTT ON + DNSSEC OFF"
|
||||
echo " 3. SRTT ON + DNSSEC ON"
|
||||
echo "============================================"
|
||||
echo
|
||||
|
||||
# --- 1. Baseline ---
|
||||
restart_numa "$(cat <<'TOML'
|
||||
[upstream]
|
||||
mode = "recursive"
|
||||
srtt = false
|
||||
|
||||
[dnssec]
|
||||
enabled = false
|
||||
TOML
|
||||
)"
|
||||
echo
|
||||
run_pass "SRTT OFF + DNSSEC OFF"
|
||||
|
||||
echo
|
||||
echo "--------------------------------------------"
|
||||
echo
|
||||
|
||||
# --- 2. SRTT only ---
|
||||
restart_numa "$(cat <<'TOML'
|
||||
[upstream]
|
||||
mode = "recursive"
|
||||
srtt = true
|
||||
|
||||
[dnssec]
|
||||
enabled = false
|
||||
TOML
|
||||
)"
|
||||
echo
|
||||
run_pass "SRTT ON + DNSSEC OFF"
|
||||
|
||||
echo
|
||||
echo "--------------------------------------------"
|
||||
echo
|
||||
|
||||
# --- 3. Both ---
|
||||
restart_numa "$(cat <<'TOML'
|
||||
[upstream]
|
||||
mode = "recursive"
|
||||
srtt = true
|
||||
|
||||
[dnssec]
|
||||
enabled = true
|
||||
TOML
|
||||
)"
|
||||
echo
|
||||
run_pass "SRTT ON + DNSSEC ON"
|
||||
|
||||
echo
|
||||
restore_launchd
|
||||
;;
|
||||
full|*)
|
||||
echo "--- Full benchmark (cold → warm → SRTT-only) ---"
|
||||
echo
|
||||
|
||||
wait_for_priming
|
||||
flush_cache
|
||||
sleep 0.5
|
||||
query_all "Pass 1: Cold SRTT + Cold cache"
|
||||
|
||||
flush_cache
|
||||
sleep 0.5
|
||||
query_all "Pass 2: Warm SRTT + Cold cache"
|
||||
|
||||
echo "=== Pass 2 stats (SRTT-warm) ==="
|
||||
stats
|
||||
;;
|
||||
esac
|
||||
@@ -879,6 +879,10 @@ async function refresh() {
|
||||
document.getElementById('footerUpstream').textContent = stats.upstream || '';
|
||||
document.getElementById('footerConfig').textContent = stats.config_path || '';
|
||||
document.getElementById('footerData').textContent = stats.data_dir || '';
|
||||
document.getElementById('footerDnssec').textContent = stats.dnssec ? 'on' : 'off';
|
||||
document.getElementById('footerDnssec').style.color = stats.dnssec ? 'var(--emerald)' : 'var(--text-dim)';
|
||||
document.getElementById('footerSrtt').textContent = stats.srtt ? 'on' : 'off';
|
||||
document.getElementById('footerSrtt').style.color = stats.srtt ? 'var(--emerald)' : 'var(--text-dim)';
|
||||
|
||||
// LAN status indicator
|
||||
const lanEl = document.getElementById('lanToggle');
|
||||
@@ -1229,6 +1233,8 @@ setInterval(refresh, 2000);
|
||||
Config: <span id="footerConfig" style="user-select:all;color:var(--emerald);"></span>
|
||||
· Data: <span id="footerData" style="user-select:all;color:var(--emerald);"></span>
|
||||
· Upstream: <span id="footerUpstream" style="user-select:all;color:var(--emerald);"></span>
|
||||
· DNSSEC: <span id="footerDnssec" style="color:var(--text-dim);">—</span>
|
||||
· SRTT: <span id="footerSrtt" style="color:var(--text-dim);">—</span>
|
||||
· Logs: <span style="user-select:all;color:var(--emerald);">macOS: /usr/local/var/log/numa.log · Linux: journalctl -u numa -f</span>
|
||||
· <a href="https://github.com/razvandimescu/numa" target="_blank" rel="noopener" style="color:var(--amber);text-decoration:none;">GitHub</a>
|
||||
</div>
|
||||
|
||||
@@ -162,6 +162,8 @@ struct StatsResponse {
|
||||
upstream: String,
|
||||
config_path: String,
|
||||
data_dir: String,
|
||||
dnssec: bool,
|
||||
srtt: bool,
|
||||
queries: QueriesStats,
|
||||
cache: CacheStats,
|
||||
overrides: OverrideStats,
|
||||
@@ -491,6 +493,8 @@ async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
|
||||
upstream,
|
||||
config_path: ctx.config_path.clone(),
|
||||
data_dir: ctx.data_dir.to_string_lossy().to_string(),
|
||||
dnssec: ctx.dnssec_enabled,
|
||||
srtt: ctx.srtt.read().unwrap().is_enabled(),
|
||||
queries: QueriesStats {
|
||||
total: snap.total,
|
||||
forwarded: snap.forwarded,
|
||||
@@ -948,6 +952,7 @@ mod tests {
|
||||
tls_config: None,
|
||||
upstream_mode: crate::config::UpstreamMode::Forward,
|
||||
root_hints: Vec::new(),
|
||||
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
|
||||
dnssec_enabled: false,
|
||||
dnssec_strict: false,
|
||||
})
|
||||
|
||||
@@ -85,6 +85,8 @@ pub struct UpstreamConfig {
|
||||
pub root_hints: Vec<String>,
|
||||
#[serde(default = "default_prime_tlds")]
|
||||
pub prime_tlds: Vec<String>,
|
||||
#[serde(default = "default_srtt")]
|
||||
pub srtt: bool,
|
||||
}
|
||||
|
||||
impl Default for UpstreamConfig {
|
||||
@@ -96,10 +98,15 @@ impl Default for UpstreamConfig {
|
||||
timeout_ms: default_timeout_ms(),
|
||||
root_hints: default_root_hints(),
|
||||
prime_tlds: default_prime_tlds(),
|
||||
srtt: default_srtt(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_srtt() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_prime_tlds() -> Vec<String> {
|
||||
vec![
|
||||
// gTLDs
|
||||
|
||||
@@ -21,6 +21,7 @@ use crate::query_log::{QueryLog, QueryLogEntry};
|
||||
use crate::question::QueryType;
|
||||
use crate::record::DnsRecord;
|
||||
use crate::service_store::ServiceStore;
|
||||
use crate::srtt::SrttCache;
|
||||
use crate::stats::{QueryPath, ServerStats};
|
||||
use crate::system_dns::ForwardingRule;
|
||||
|
||||
@@ -51,6 +52,7 @@ pub struct ServerCtx {
|
||||
pub tls_config: Option<ArcSwap<ServerConfig>>,
|
||||
pub upstream_mode: UpstreamMode,
|
||||
pub root_hints: Vec<SocketAddr>,
|
||||
pub srtt: RwLock<SrttCache>,
|
||||
pub dnssec_enabled: bool,
|
||||
pub dnssec_strict: bool,
|
||||
}
|
||||
@@ -176,6 +178,7 @@ pub async fn handle_query(
|
||||
&ctx.cache,
|
||||
&query,
|
||||
&ctx.root_hints,
|
||||
&ctx.srtt,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -226,7 +229,8 @@ pub async fn handle_query(
|
||||
let mut dnssec = dnssec;
|
||||
if ctx.dnssec_enabled && path == QueryPath::Recursive {
|
||||
let (status, vstats) =
|
||||
crate::dnssec::validate_response(&response, &ctx.cache, &ctx.root_hints).await;
|
||||
crate::dnssec::validate_response(&response, &ctx.cache, &ctx.root_hints, &ctx.srtt)
|
||||
.await;
|
||||
|
||||
debug!(
|
||||
"DNSSEC | {} | {:?} | {}ms | dnskey_hit={} dnskey_fetch={} ds_hit={} ds_fetch={}",
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::cache::{DnsCache, DnssecStatus};
|
||||
use crate::packet::DnsPacket;
|
||||
use crate::question::QueryType;
|
||||
use crate::record::DnsRecord;
|
||||
use crate::srtt::SrttCache;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ValidationStats {
|
||||
@@ -64,6 +65,7 @@ pub async fn validate_response(
|
||||
response: &DnsPacket,
|
||||
cache: &RwLock<DnsCache>,
|
||||
root_hints: &[std::net::SocketAddr],
|
||||
srtt: &RwLock<SrttCache>,
|
||||
) -> (DnssecStatus, ValidationStats) {
|
||||
let start = Instant::now();
|
||||
let stats = Mutex::new(ValidationStats::default());
|
||||
@@ -95,7 +97,7 @@ pub async fn validate_response(
|
||||
}
|
||||
}
|
||||
for zone in &signer_zones {
|
||||
fetch_dnskeys(zone, cache, root_hints, &stats).await;
|
||||
fetch_dnskeys(zone, cache, root_hints, srtt, &stats).await;
|
||||
}
|
||||
|
||||
// Group answer records into RRsets (by domain + type, excluding RRSIGs)
|
||||
@@ -132,7 +134,8 @@ pub async fn validate_response(
|
||||
..
|
||||
} = rrsig
|
||||
{
|
||||
let dnskey_response = fetch_dnskeys(signer_name, cache, root_hints, &stats).await;
|
||||
let dnskey_response =
|
||||
fetch_dnskeys(signer_name, cache, root_hints, srtt, &stats).await;
|
||||
let dnskeys: Vec<&DnsRecord> = dnskey_response
|
||||
.iter()
|
||||
.filter(|r| matches!(r, DnsRecord::DNSKEY { .. }))
|
||||
@@ -206,6 +209,7 @@ pub async fn validate_response(
|
||||
&dnskey_response,
|
||||
cache,
|
||||
root_hints,
|
||||
srtt,
|
||||
trust_anchors,
|
||||
0,
|
||||
&stats,
|
||||
@@ -276,11 +280,13 @@ pub async fn validate_response(
|
||||
|
||||
/// Walk the chain of trust from zone DNSKEY up to root trust anchor.
|
||||
/// `zone_records` contains both DNSKEY and RRSIG records from the DNSKEY response.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn validate_chain<'a>(
|
||||
zone: &'a str,
|
||||
zone_records: &'a [DnsRecord],
|
||||
cache: &'a RwLock<DnsCache>,
|
||||
root_hints: &'a [std::net::SocketAddr],
|
||||
srtt: &'a RwLock<SrttCache>,
|
||||
trust_anchors: &'a [DnsRecord],
|
||||
depth: u8,
|
||||
stats: &'a Mutex<ValidationStats>,
|
||||
@@ -343,7 +349,7 @@ fn validate_chain<'a>(
|
||||
return DnssecStatus::Indeterminate;
|
||||
}
|
||||
let parent = parent_zone(zone);
|
||||
let ds_records = fetch_ds(zone, cache, root_hints, stats).await;
|
||||
let ds_records = fetch_ds(zone, cache, root_hints, srtt, stats).await;
|
||||
|
||||
if ds_records.is_empty() {
|
||||
debug!("dnssec: no DS for zone '{}' at parent '{}'", zone, parent);
|
||||
@@ -377,7 +383,7 @@ fn validate_chain<'a>(
|
||||
|
||||
// Walk up: validate the parent's DNSKEY
|
||||
trace!("dnssec: fetching parent DNSKEY for '{}'", parent);
|
||||
let parent_records = fetch_dnskeys(&parent, cache, root_hints, stats).await;
|
||||
let parent_records = fetch_dnskeys(&parent, cache, root_hints, srtt, stats).await;
|
||||
if parent_records.is_empty() {
|
||||
debug!("dnssec: no parent DNSKEY for '{}' — Indeterminate", parent);
|
||||
return DnssecStatus::Indeterminate;
|
||||
@@ -388,6 +394,7 @@ fn validate_chain<'a>(
|
||||
&parent_records,
|
||||
cache,
|
||||
root_hints,
|
||||
srtt,
|
||||
trust_anchors,
|
||||
depth + 1,
|
||||
stats,
|
||||
@@ -460,6 +467,7 @@ async fn fetch_dnskeys(
|
||||
zone: &str,
|
||||
cache: &RwLock<DnsCache>,
|
||||
root_hints: &[std::net::SocketAddr],
|
||||
srtt: &RwLock<SrttCache>,
|
||||
stats: &Mutex<ValidationStats>,
|
||||
) -> Vec<DnsRecord> {
|
||||
if let Some(pkt) = cache.read().unwrap().lookup(zone, QueryType::DNSKEY) {
|
||||
@@ -475,7 +483,8 @@ async fn fetch_dnskeys(
|
||||
trace!("dnssec: fetch_dnskeys('{}') cache miss — resolving", zone);
|
||||
stats.lock().unwrap().dnskey_fetches += 1;
|
||||
if let Ok(pkt) =
|
||||
crate::recursive::resolve_iterative(zone, QueryType::DNSKEY, cache, root_hints, 0, 0).await
|
||||
crate::recursive::resolve_iterative(zone, QueryType::DNSKEY, cache, root_hints, srtt, 0, 0)
|
||||
.await
|
||||
{
|
||||
cache.write().unwrap().insert(zone, QueryType::DNSKEY, &pkt);
|
||||
return pkt.answers;
|
||||
@@ -488,6 +497,7 @@ async fn fetch_ds(
|
||||
child: &str,
|
||||
cache: &RwLock<DnsCache>,
|
||||
root_hints: &[std::net::SocketAddr],
|
||||
srtt: &RwLock<SrttCache>,
|
||||
stats: &Mutex<ValidationStats>,
|
||||
) -> Vec<DnsRecord> {
|
||||
if let Some(pkt) = cache.read().unwrap().lookup(child, QueryType::DS) {
|
||||
@@ -501,7 +511,8 @@ async fn fetch_ds(
|
||||
|
||||
stats.lock().unwrap().ds_fetches += 1;
|
||||
if let Ok(pkt) =
|
||||
crate::recursive::resolve_iterative(child, QueryType::DS, cache, root_hints, 0, 0).await
|
||||
crate::recursive::resolve_iterative(child, QueryType::DS, cache, root_hints, srtt, 0, 0)
|
||||
.await
|
||||
{
|
||||
cache.write().unwrap().insert(child, QueryType::DS, &pkt);
|
||||
return pkt
|
||||
|
||||
@@ -16,6 +16,7 @@ pub mod question;
|
||||
pub mod record;
|
||||
pub mod recursive;
|
||||
pub mod service_store;
|
||||
pub mod srtt;
|
||||
pub mod stats;
|
||||
pub mod system_dns;
|
||||
pub mod tls;
|
||||
|
||||
@@ -201,6 +201,7 @@ async fn main() -> numa::Result<()> {
|
||||
tls_config: initial_tls,
|
||||
upstream_mode: config.upstream.mode,
|
||||
root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints),
|
||||
srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)),
|
||||
dnssec_enabled: config.dnssec.enabled,
|
||||
dnssec_strict: config.dnssec.strict,
|
||||
});
|
||||
@@ -353,7 +354,12 @@ async fn main() -> numa::Result<()> {
|
||||
let prime_ctx = Arc::clone(&ctx);
|
||||
let prime_tlds = config.upstream.prime_tlds;
|
||||
tokio::spawn(async move {
|
||||
numa::recursive::prime_tld_cache(&prime_ctx.cache, &prime_ctx.root_hints, &prime_tlds)
|
||||
numa::recursive::prime_tld_cache(
|
||||
&prime_ctx.cache,
|
||||
&prime_ctx.root_hints,
|
||||
&prime_tlds,
|
||||
&prime_ctx.srtt,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
102
src/recursive.rs
102
src/recursive.rs
@@ -1,7 +1,7 @@
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::atomic::{AtomicU16, Ordering};
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use log::{debug, info};
|
||||
|
||||
@@ -11,6 +11,7 @@ use crate::header::ResultCode;
|
||||
use crate::packet::DnsPacket;
|
||||
use crate::question::{DnsQuestion, QueryType};
|
||||
use crate::record::DnsRecord;
|
||||
use crate::srtt::SrttCache;
|
||||
|
||||
const MAX_REFERRAL_DEPTH: u8 = 10;
|
||||
const MAX_CNAME_DEPTH: u8 = 8;
|
||||
@@ -58,7 +59,12 @@ pub async fn probe_udp(root_hints: &[SocketAddr]) {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr], tlds: &[String]) {
|
||||
pub async fn prime_tld_cache(
|
||||
cache: &RwLock<DnsCache>,
|
||||
root_hints: &[SocketAddr],
|
||||
tlds: &[String],
|
||||
srtt: &RwLock<SrttCache>,
|
||||
) {
|
||||
if root_hints.is_empty() || tlds.is_empty() {
|
||||
return;
|
||||
}
|
||||
@@ -66,7 +72,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
|
||||
let mut root_addr = root_hints[0];
|
||||
for hint in root_hints {
|
||||
info!("prime: probing root {}", hint);
|
||||
match send_query(".", QueryType::NS, *hint).await {
|
||||
match send_query(".", QueryType::NS, *hint, srtt).await {
|
||||
Ok(_) => {
|
||||
info!("prime: root {} reachable", hint);
|
||||
root_addr = *hint;
|
||||
@@ -79,7 +85,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
|
||||
}
|
||||
|
||||
// Fetch root DNSKEY (needed for DNSSEC chain-of-trust terminus)
|
||||
if let Ok(root_dnskey) = send_query(".", QueryType::DNSKEY, root_addr).await {
|
||||
if let Ok(root_dnskey) = send_query(".", QueryType::DNSKEY, root_addr, srtt).await {
|
||||
cache
|
||||
.write()
|
||||
.unwrap()
|
||||
@@ -91,7 +97,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
|
||||
|
||||
for tld in tlds {
|
||||
// Fetch NS referral (includes DS in authority section from root)
|
||||
let response = match send_query(tld, QueryType::NS, root_addr).await {
|
||||
let response = match send_query(tld, QueryType::NS, root_addr, srtt).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
debug!("prime: failed to query NS for .{}: {}", tld, e);
|
||||
@@ -108,7 +114,6 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
|
||||
let mut cache_w = cache.write().unwrap();
|
||||
cache_w.insert(tld, QueryType::NS, &response);
|
||||
cache_glue(&mut cache_w, &response, &ns_names);
|
||||
// Cache DS records from referral authority section
|
||||
cache_ds_from_authority(&mut cache_w, &response);
|
||||
}
|
||||
|
||||
@@ -116,7 +121,7 @@ pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr]
|
||||
let first_ns_name = ns_names.first().map(|s| s.as_str()).unwrap_or("");
|
||||
let first_ns = glue_addrs_for(&response, first_ns_name);
|
||||
if let Some(ns_addr) = first_ns.first() {
|
||||
if let Ok(dnskey_resp) = send_query(tld, QueryType::DNSKEY, *ns_addr).await {
|
||||
if let Ok(dnskey_resp) = send_query(tld, QueryType::DNSKEY, *ns_addr, srtt).await {
|
||||
cache
|
||||
.write()
|
||||
.unwrap()
|
||||
@@ -140,10 +145,11 @@ pub async fn resolve_recursive(
|
||||
cache: &RwLock<DnsCache>,
|
||||
original_query: &DnsPacket,
|
||||
root_hints: &[SocketAddr],
|
||||
srtt: &RwLock<SrttCache>,
|
||||
) -> crate::Result<DnsPacket> {
|
||||
// No overall timeout — each hop is bounded by NS_QUERY_TIMEOUT (UDP + TCP fallback),
|
||||
// and MAX_REFERRAL_DEPTH caps the chain length.
|
||||
let mut resp = resolve_iterative(qname, qtype, cache, root_hints, 0, 0).await?;
|
||||
let mut resp = resolve_iterative(qname, qtype, cache, root_hints, srtt, 0, 0).await?;
|
||||
|
||||
resp.header.id = original_query.header.id;
|
||||
resp.header.recursion_available = true;
|
||||
@@ -157,6 +163,7 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
qtype: QueryType,
|
||||
cache: &'a RwLock<DnsCache>,
|
||||
root_hints: &'a [SocketAddr],
|
||||
srtt: &'a RwLock<SrttCache>,
|
||||
referral_depth: u8,
|
||||
cname_depth: u8,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<DnsPacket>> + Send + 'a>> {
|
||||
@@ -170,6 +177,7 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
}
|
||||
|
||||
let (mut current_zone, mut ns_addrs) = find_closest_ns(qname, cache, root_hints);
|
||||
srtt.read().unwrap().sort_by_rtt(&mut ns_addrs);
|
||||
let mut ns_idx = 0;
|
||||
|
||||
for _ in 0..MAX_REFERRAL_DEPTH {
|
||||
@@ -185,7 +193,7 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
ns_addr, q_type, q_name, current_zone, referral_depth
|
||||
);
|
||||
|
||||
let response = match send_query(q_name, q_type, ns_addr).await {
|
||||
let response = match send_query(q_name, q_type, ns_addr, srtt).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
debug!("recursive: NS {} failed: {}", ns_addr, e);
|
||||
@@ -194,7 +202,6 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
}
|
||||
};
|
||||
|
||||
// Minimized query response — treat as referral, not final answer
|
||||
if (q_type != qtype || !q_name.eq_ignore_ascii_case(qname))
|
||||
&& (!response.authorities.is_empty() || !response.answers.is_empty())
|
||||
{
|
||||
@@ -205,8 +212,9 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
if all_ns.is_empty() {
|
||||
all_ns = extract_ns_names(&response);
|
||||
}
|
||||
let new_addrs = resolve_ns_addrs_from_glue(&response, &all_ns, cache);
|
||||
let mut new_addrs = resolve_ns_addrs_from_glue(&response, &all_ns, cache);
|
||||
if !new_addrs.is_empty() {
|
||||
srtt.read().unwrap().sort_by_rtt(&mut new_addrs);
|
||||
ns_addrs = new_addrs;
|
||||
ns_idx = 0;
|
||||
continue;
|
||||
@@ -233,6 +241,7 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
qtype,
|
||||
cache,
|
||||
root_hints,
|
||||
srtt,
|
||||
0,
|
||||
cname_depth + 1,
|
||||
)
|
||||
@@ -256,8 +265,6 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
// Referral — extract NS + glue, cache glue, resolve NS addresses
|
||||
// Update zone for query minimization
|
||||
if let Some(zone) = referral_zone(&response) {
|
||||
current_zone = zone;
|
||||
}
|
||||
@@ -276,13 +283,13 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
for ns_name in &ns_names {
|
||||
if referral_depth < MAX_REFERRAL_DEPTH {
|
||||
debug!("recursive: resolving glue-less NS {}", ns_name);
|
||||
// Try A first, then AAAA
|
||||
for qt in [QueryType::A, QueryType::AAAA] {
|
||||
if let Ok(ns_resp) = resolve_iterative(
|
||||
ns_name,
|
||||
qt,
|
||||
cache,
|
||||
root_hints,
|
||||
srtt,
|
||||
referral_depth + 1,
|
||||
cname_depth,
|
||||
)
|
||||
@@ -316,6 +323,7 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
return Err(format!("could not resolve any NS for {}", qname).into());
|
||||
}
|
||||
|
||||
srtt.read().unwrap().sort_by_rtt(&mut new_ns_addrs);
|
||||
ns_addrs = new_ns_addrs;
|
||||
ns_idx = 0;
|
||||
}
|
||||
@@ -561,7 +569,32 @@ fn make_glue_packet() -> DnsPacket {
|
||||
pkt
|
||||
}
|
||||
|
||||
async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate::Result<DnsPacket> {
|
||||
async fn tcp_with_srtt(
|
||||
query: &DnsPacket,
|
||||
server: SocketAddr,
|
||||
srtt: &RwLock<SrttCache>,
|
||||
start: Instant,
|
||||
) -> crate::Result<DnsPacket> {
|
||||
match crate::forward::forward_tcp(query, server, TCP_TIMEOUT).await {
|
||||
Ok(resp) => {
|
||||
srtt.write()
|
||||
.unwrap()
|
||||
.record_rtt(server.ip(), start.elapsed().as_millis() as u64, true);
|
||||
Ok(resp)
|
||||
}
|
||||
Err(e) => {
|
||||
srtt.write().unwrap().record_failure(server.ip());
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_query(
|
||||
qname: &str,
|
||||
qtype: QueryType,
|
||||
server: SocketAddr,
|
||||
srtt: &RwLock<SrttCache>,
|
||||
) -> crate::Result<DnsPacket> {
|
||||
let mut query = DnsPacket::new();
|
||||
query.header.id = next_id();
|
||||
query.header.recursion_desired = false;
|
||||
@@ -573,24 +606,30 @@ async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate:
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Skip IPv6 if the socket can't handle it (bound to 0.0.0.0)
|
||||
let start = Instant::now();
|
||||
|
||||
// IPv6 forced to TCP — our UDP socket is bound to 0.0.0.0
|
||||
if server.is_ipv6() {
|
||||
return crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await;
|
||||
return tcp_with_srtt(&query, server, srtt, start).await;
|
||||
}
|
||||
|
||||
// If UDP has been detected as blocked, go TCP-first
|
||||
// UDP detected as blocked — go TCP-first
|
||||
if UDP_DISABLED.load(Ordering::Acquire) {
|
||||
return crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await;
|
||||
return tcp_with_srtt(&query, server, srtt, start).await;
|
||||
}
|
||||
|
||||
match forward_udp(&query, server, NS_QUERY_TIMEOUT).await {
|
||||
Ok(resp) if resp.header.truncated_message => {
|
||||
debug!("send_query: truncated from {}, retrying TCP", server);
|
||||
crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await
|
||||
tcp_with_srtt(&query, server, srtt, start).await
|
||||
}
|
||||
Ok(resp) => {
|
||||
// UDP works — reset failure counter
|
||||
UDP_FAILURES.store(0, Ordering::Release);
|
||||
srtt.write().unwrap().record_rtt(
|
||||
server.ip(),
|
||||
start.elapsed().as_millis() as u64,
|
||||
false,
|
||||
);
|
||||
Ok(resp)
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -603,7 +642,7 @@ async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate:
|
||||
);
|
||||
}
|
||||
debug!("send_query: UDP failed for {}: {}, trying TCP", server, e);
|
||||
crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await
|
||||
tcp_with_srtt(&query, server, srtt, start).await
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -894,7 +933,8 @@ mod tests {
|
||||
})
|
||||
.await;
|
||||
|
||||
let result = send_query("test.example.com", QueryType::A, server_addr).await;
|
||||
let srtt = RwLock::new(SrttCache::new(true));
|
||||
let result = send_query("test.example.com", QueryType::A, server_addr, &srtt).await;
|
||||
|
||||
let resp = result.expect("should resolve via TCP fallback");
|
||||
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
|
||||
@@ -945,7 +985,8 @@ mod tests {
|
||||
})
|
||||
.await;
|
||||
|
||||
let result = send_query("hello.example.com", QueryType::A, server_addr).await;
|
||||
let srtt = RwLock::new(SrttCache::new(true));
|
||||
let result = send_query("hello.example.com", QueryType::A, server_addr, &srtt).await;
|
||||
let resp = result.expect("TCP-only send_query should work");
|
||||
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
|
||||
match &resp.answers[0] {
|
||||
@@ -967,10 +1008,19 @@ mod tests {
|
||||
.await;
|
||||
|
||||
let cache = RwLock::new(DnsCache::new(100, 60, 86400));
|
||||
let srtt = RwLock::new(SrttCache::new(true));
|
||||
let root_hints = vec![server_addr];
|
||||
|
||||
let result =
|
||||
resolve_iterative("nonexistent.test", QueryType::A, &cache, &root_hints, 0, 0).await;
|
||||
let result = resolve_iterative(
|
||||
"nonexistent.test",
|
||||
QueryType::A,
|
||||
&cache,
|
||||
&root_hints,
|
||||
&srtt,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
|
||||
let resp = result.expect("NXDOMAIN should still return a response");
|
||||
assert_eq!(resp.header.rescode, ResultCode::NXDOMAIN);
|
||||
|
||||
227
src/srtt.rs
Normal file
227
src/srtt.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::time::Instant;
|
||||
|
||||
const INITIAL_SRTT_MS: u64 = 200;
|
||||
const FAILURE_PENALTY_MS: u64 = 5000;
|
||||
const TCP_PENALTY_MS: u64 = 100;
|
||||
const DECAY_AFTER_SECS: u64 = 300;
|
||||
const MAX_ENTRIES: usize = 4096;
|
||||
const EVICT_BATCH: usize = 64;
|
||||
|
||||
struct SrttEntry {
|
||||
srtt_ms: u64,
|
||||
updated_at: Instant,
|
||||
}
|
||||
|
||||
pub struct SrttCache {
|
||||
entries: HashMap<IpAddr, SrttEntry>,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl Default for SrttCache {
|
||||
fn default() -> Self {
|
||||
Self::new(true)
|
||||
}
|
||||
}
|
||||
|
||||
impl SrttCache {
|
||||
pub fn new(enabled: bool) -> Self {
|
||||
Self {
|
||||
entries: HashMap::new(),
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
|
||||
/// Get current SRTT for an IP, applying decay if stale. Returns INITIAL for unknown.
|
||||
pub fn get(&self, ip: IpAddr) -> u64 {
|
||||
match self.entries.get(&ip) {
|
||||
Some(entry) => Self::decayed_srtt(entry),
|
||||
None => INITIAL_SRTT_MS,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply time-based decay: each DECAY_AFTER_SECS period halves distance to INITIAL.
|
||||
fn decayed_srtt(entry: &SrttEntry) -> u64 {
|
||||
let age_secs = entry.updated_at.elapsed().as_secs();
|
||||
if age_secs > DECAY_AFTER_SECS {
|
||||
let periods = (age_secs / DECAY_AFTER_SECS).min(8);
|
||||
let mut srtt = entry.srtt_ms;
|
||||
for _ in 0..periods {
|
||||
srtt = (srtt + INITIAL_SRTT_MS) / 2;
|
||||
}
|
||||
srtt
|
||||
} else {
|
||||
entry.srtt_ms
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a successful query RTT. No-op when disabled.
|
||||
pub fn record_rtt(&mut self, ip: IpAddr, rtt_ms: u64, tcp: bool) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
let effective = if tcp { rtt_ms + TCP_PENALTY_MS } else { rtt_ms };
|
||||
self.maybe_evict();
|
||||
let entry = self.entries.entry(ip).or_insert(SrttEntry {
|
||||
srtt_ms: effective,
|
||||
updated_at: Instant::now(),
|
||||
});
|
||||
// Apply decay before EWMA so recovered servers aren't stuck at stale penalties
|
||||
let base = Self::decayed_srtt(entry);
|
||||
// BIND EWMA: new = (old * 7 + sample) / 8
|
||||
entry.srtt_ms = (base * 7 + effective) / 8;
|
||||
entry.updated_at = Instant::now();
|
||||
}
|
||||
|
||||
/// Record a failure (timeout or error). No-op when disabled.
|
||||
pub fn record_failure(&mut self, ip: IpAddr) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
self.maybe_evict();
|
||||
let entry = self.entries.entry(ip).or_insert(SrttEntry {
|
||||
srtt_ms: FAILURE_PENALTY_MS,
|
||||
updated_at: Instant::now(),
|
||||
});
|
||||
entry.srtt_ms = FAILURE_PENALTY_MS;
|
||||
entry.updated_at = Instant::now();
|
||||
}
|
||||
|
||||
/// Sort addresses by SRTT ascending (lowest/fastest first). No-op when disabled.
|
||||
pub fn sort_by_rtt(&self, addrs: &mut [SocketAddr]) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
addrs.sort_by_key(|a| self.get(a.ip()));
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
|
||||
fn maybe_evict(&mut self) {
|
||||
if self.entries.len() < MAX_ENTRIES {
|
||||
return;
|
||||
}
|
||||
// Batch eviction: remove the oldest EVICT_BATCH entries at once
|
||||
let mut by_age: Vec<IpAddr> = self.entries.keys().copied().collect();
|
||||
by_age.sort_by_key(|ip| self.entries[ip].updated_at);
|
||||
for ip in by_age.into_iter().take(EVICT_BATCH) {
|
||||
self.entries.remove(&ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
fn ip(last: u8) -> IpAddr {
|
||||
IpAddr::V4(Ipv4Addr::new(192, 0, 2, last))
|
||||
}
|
||||
|
||||
fn sock(last: u8) -> SocketAddr {
|
||||
SocketAddr::new(ip(last), 53)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_returns_initial() {
|
||||
let cache = SrttCache::new(true);
|
||||
assert_eq!(cache.get(ip(1)), INITIAL_SRTT_MS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewma_converges() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
for _ in 0..20 {
|
||||
cache.record_rtt(ip(1), 100, false);
|
||||
}
|
||||
let srtt = cache.get(ip(1));
|
||||
assert!(srtt >= 98 && srtt <= 102, "srtt={}", srtt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn failure_sets_penalty() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
cache.record_rtt(ip(1), 50, false);
|
||||
cache.record_failure(ip(1));
|
||||
assert_eq!(cache.get(ip(1)), FAILURE_PENALTY_MS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tcp_penalty_added() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
for _ in 0..20 {
|
||||
cache.record_rtt(ip(1), 50, true);
|
||||
}
|
||||
let srtt = cache.get(ip(1));
|
||||
assert!(srtt >= 148 && srtt <= 152, "srtt={}", srtt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sort_by_rtt_orders_correctly() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
for _ in 0..20 {
|
||||
cache.record_rtt(ip(1), 500, false);
|
||||
cache.record_rtt(ip(2), 100, false);
|
||||
cache.record_rtt(ip(3), 10, false);
|
||||
}
|
||||
let mut addrs = vec![sock(1), sock(2), sock(3)];
|
||||
cache.sort_by_rtt(&mut addrs);
|
||||
assert_eq!(addrs, vec![sock(3), sock(2), sock(1)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_servers_sort_equal() {
|
||||
let cache = SrttCache::new(true);
|
||||
let mut addrs = vec![sock(1), sock(2), sock(3)];
|
||||
let original = addrs.clone();
|
||||
cache.sort_by_rtt(&mut addrs);
|
||||
assert_eq!(addrs, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disabled_is_noop() {
|
||||
let mut cache = SrttCache::new(false);
|
||||
cache.record_rtt(ip(1), 50, false);
|
||||
cache.record_failure(ip(2));
|
||||
assert_eq!(cache.len(), 0);
|
||||
|
||||
let mut addrs = vec![sock(2), sock(1)];
|
||||
let original = addrs.clone();
|
||||
cache.sort_by_rtt(&mut addrs);
|
||||
assert_eq!(addrs, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eviction_removes_oldest() {
|
||||
let mut cache = SrttCache::new(true);
|
||||
for i in 0..MAX_ENTRIES {
|
||||
let octets = [
|
||||
10,
|
||||
((i >> 16) & 0xFF) as u8,
|
||||
((i >> 8) & 0xFF) as u8,
|
||||
(i & 0xFF) as u8,
|
||||
];
|
||||
cache.record_rtt(
|
||||
IpAddr::V4(Ipv4Addr::new(octets[0], octets[1], octets[2], octets[3])),
|
||||
100,
|
||||
false,
|
||||
);
|
||||
}
|
||||
assert_eq!(cache.len(), MAX_ENTRIES);
|
||||
cache.record_rtt(ip(1), 100, false);
|
||||
// Batch eviction removes EVICT_BATCH entries
|
||||
assert!(cache.len() <= MAX_ENTRIES - EVICT_BATCH + 1);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user