From 7efac85836bacd483e174e5509fad03bac3f548f Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 04:20:18 +0300 Subject: [PATCH 01/21] feat: wire-level forwarding, cache, request hedging, and DoH keepalive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wire-level forwarding path skips DnsPacket parse/serialize on the hot path. Cache stores raw wire bytes with pre-scanned TTL offsets — patches ID + TTLs in-place on lookup instead of cloning parsed packets. Request hedging (Dean & Barroso "Tail at Scale") fires a second parallel request after a configurable delay (default 10ms) when the primary upstream stalls. DoH keepalive loop prevents idle HTTP/2 + TLS connection teardown. Recursive resolver now hedges across multiple NS addresses and caches NS delegation records to skip TLD re-queries. Integration test harness polls /blocking/stats instead of fixed sleep, eliminating the blocklist-download race condition. --- Cargo.lock | 458 +++++++++- Cargo.toml | 6 + benches/numa-bench.toml | 25 + benches/recursive_compare.rs | 1649 ++++++++++++++++++++++++++++++++++ scripts/bench-recursive.sh | 115 +++ src/api.rs | 1 + src/cache.rs | 177 ++-- src/config.rs | 6 + src/ctx.rs | 47 +- src/doh.rs | 11 +- src/dot.rs | 6 +- src/forward.rs | 186 +++- src/lib.rs | 1 + src/main.rs | 26 +- src/recursive.rs | 123 ++- src/srtt.rs | 5 + src/wire.rs | 1347 +++++++++++++++++++++++++++ tests/integration.sh | 12 +- 18 files changed, 4091 insertions(+), 110 deletions(-) create mode 100644 benches/numa-bench.toml create mode 100644 benches/recursive_compare.rs create mode 100755 scripts/bench-recursive.sh create mode 100644 src/wire.rs diff --git a/Cargo.lock b/Cargo.lock index c7cd38b..eaba214 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -82,6 +82,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + [[package]] name = "arc-swap" version = "1.9.0" @@ -142,6 +148,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -410,6 +427,21 @@ dependencies = [ "itertools", ] +[[package]] +name = "critical-section" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -493,6 +525,18 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "enum-as-inner" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "env_filter" version = "1.0.1" @@ -554,6 +598,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -679,11 +729,24 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "r-efi", + "r-efi 5.3.0", "wasip2", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + [[package]] name = "h2" version = "0.4.13" @@ -714,12 +777,82 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + [[package]] name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hickory-proto" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8a6fe56c0038198998a6f217ca4e7ef3a5e51f46163bd6dd60b5c71ca6c6502" +dependencies = [ + "async-trait", + "bytes", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-channel", + "futures-io", + "futures-util", + "h2", + "http", + "idna", + "ipnet", + "once_cell", + "rand", + "ring", + "rustls", + "thiserror", + "tinyvec", + "tokio", + "tokio-rustls", + "tracing", + "url", + "webpki-roots 0.26.11", +] + +[[package]] +name = "hickory-resolver" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc62a9a99b0bfb44d2ab95a7208ac952d31060efc16241c87eaf36406fecf87a" +dependencies = [ + "cfg-if", + "futures-util", + "hickory-proto", + "ipconfig", + "moka", + "once_cell", + "parking_lot", + "rand", + "resolv-conf", + "rustls", + "smallvec", + "thiserror", + "tokio", + "tokio-rustls", + "tracing", + "webpki-roots 0.26.11", +] + [[package]] name = "http" version = "1.4.0" @@ -802,7 +935,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] @@ -909,6 +1042,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "idna" version = "1.1.0" @@ -937,7 +1076,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "ipconfig" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d40460c0ce33d6ce4b0630ad68ff63d6661961c48b6dba35e5a4d81cfb48222" +dependencies = [ + "socket2", + "widestring", + "windows-registry", + "windows-result", + "windows-sys 0.61.2", ] [[package]] @@ -1029,6 +1183,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "libc" version = "0.2.183" @@ -1041,6 +1201,15 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.29" @@ -1098,6 +1267,23 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "moka" +version = "0.12.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046" +dependencies = [ + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "parking_lot", + "portable-atomic", + "smallvec", + "tagptr", + "uuid", +] + [[package]] name = "nom" version = "7.1.3" @@ -1151,6 +1337,8 @@ dependencies = [ "criterion", "env_logger", "futures", + "hickory-proto", + "hickory-resolver", "http", "http-body-util", "hyper", @@ -1187,6 +1375,10 @@ name = "once_cell" version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +dependencies = [ + "critical-section", + "portable-atomic", +] [[package]] name = "once_cell_polyfill" @@ -1210,6 +1402,29 @@ dependencies = [ "winapi", ] +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "pem" version = "3.0.6" @@ -1305,6 +1520,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.106" @@ -1390,6 +1615,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + [[package]] name = "rand" version = "0.9.2" @@ -1453,6 +1684,15 @@ dependencies = [ "yasna", ] +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.12.3" @@ -1518,9 +1758,15 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots", + "webpki-roots 1.0.6", ] +[[package]] +name = "resolv-conf" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e061d1b48cb8d38042de4ae0a7a6401009d6143dc80d2e2d6f31f0bdd6470c7" + [[package]] name = "ring" version = "0.17.14" @@ -1618,6 +1864,18 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + [[package]] name = "serde" version = "1.0.228" @@ -1780,6 +2038,12 @@ dependencies = [ "syn", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "thiserror" version = "2.0.18" @@ -2038,6 +2302,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "untrusted" version = "0.9.0" @@ -2068,6 +2338,17 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ac8b6f42ead25368cf5b098aeb3dc8a1a2c05a3eee8a9a1a68c640edbfc79d9" +dependencies = [ + "getrandom 0.4.2", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "walkdir" version = "2.5.0" @@ -2102,6 +2383,15 @@ dependencies = [ "wit-bindgen", ] +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasm-bindgen" version = "0.2.115" @@ -2157,6 +2447,40 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + [[package]] name = "web-sys" version = "0.3.92" @@ -2177,6 +2501,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.6", +] + [[package]] name = "webpki-roots" version = "1.0.6" @@ -2186,6 +2519,12 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "widestring" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72069c3113ab32ab29e5584db3c6ec55d416895e60715417b5b883a357c3e471" + [[package]] name = "winapi" version = "0.3.9" @@ -2223,6 +2562,35 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -2390,6 +2758,88 @@ name = "wit-bindgen" version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] [[package]] name = "writeable" diff --git a/Cargo.toml b/Cargo.toml index c5d5e1d..d7f6f9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,8 @@ webpki-roots = "1" criterion = { version = "0.8", features = ["html_reports"] } tower = { version = "0.5", features = ["util"] } http = "1" +hickory-resolver = { version = "0.25", features = ["https-ring", "webpki-roots"] } +hickory-proto = "0.25" [[bench]] name = "hot_path" @@ -49,3 +51,7 @@ harness = false [[bench]] name = "dnssec" harness = false + +[[bench]] +name = "recursive_compare" +harness = false diff --git a/benches/numa-bench.toml b/benches/numa-bench.toml new file mode 100644 index 0000000..0e058af --- /dev/null +++ b/benches/numa-bench.toml @@ -0,0 +1,25 @@ +[server] +bind_addr = "127.0.0.1:5454" +api_port = 5381 +api_bind_addr = "127.0.0.1" +data_dir = "/tmp/numa-bench" + +[upstream] +mode = "recursive" +timeout_ms = 10000 + +[cache] +min_ttl = 60 +max_ttl = 3600 + +[blocking] +enabled = false + +[dot] +enabled = false + +[mobile] +enabled = false + +[lan] +enabled = false diff --git a/benches/recursive_compare.rs b/benches/recursive_compare.rs new file mode 100644 index 0000000..e35768c --- /dev/null +++ b/benches/recursive_compare.rs @@ -0,0 +1,1649 @@ +//! DoH forwarding benchmark: Numa vs hickory-resolver. +//! +//! Both forward to the same DoH upstream (Quad9). +//! Measures end-to-end resolution time through each implementation. +//! +//! Fairness: +//! - Both reuse a single TLS connection (Numa via persistent server, +//! Hickory via a shared resolver instance with cache_size=0). +//! - Measurement order is alternated each round to cancel order bias. +//! - Numa cache is flushed before each query. +//! - 100 domains × 10 rounds for statistical confidence. +//! +//! Setup: +//! 1. Start a bench Numa instance: +//! cargo run -- benches/numa-bench.toml +//! 2. Run: +//! cargo bench --bench recursive_compare + +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +const DOH_UPSTREAM: &str = "https://9.9.9.9/dns-query"; +const NUMA_BENCH: &str = "127.0.0.1:5454"; +const NUMA_API: u16 = 5381; + +const DOMAINS: &[&str] = &[ + "example.com", + "rust-lang.org", + "kernel.org", + "signal.org", + "archlinux.org", + "openbsd.org", + "git-scm.com", + "sqlite.org", + "wireguard.com", + "mozilla.org", + "cloudflare.com", + "google.com", + "github.com", + "stackoverflow.com", + "wikipedia.org", + "reddit.com", + "amazon.com", + "apple.com", + "microsoft.com", + "facebook.com", + "twitter.com", + "linkedin.com", + "netflix.com", + "spotify.com", + "discord.com", + "twitch.tv", + "youtube.com", + "instagram.com", + "whatsapp.com", + "telegram.org", + "debian.org", + "ubuntu.com", + "fedoraproject.org", + "nixos.org", + "gentoo.org", + "freebsd.org", + "netbsd.org", + "dragonflybsd.org", + "illumos.org", + "haiku-os.org", + "python.org", + "golang.org", + "nodejs.org", + "ruby-lang.org", + "php.net", + "swift.org", + "kotlinlang.org", + "scala-lang.org", + "haskell.org", + "elixir-lang.org", + "erlang.org", + "clojure.org", + "julialang.org", + "ziglang.org", + "nim-lang.org", + "dlang.org", + "vlang.io", + "crystal-lang.org", + "racket-lang.org", + "ocaml.org", + "crates.io", + "npmjs.com", + "pypi.org", + "rubygems.org", + "packagist.org", + "nuget.org", + "maven.apache.org", + "hex.pm", + "hackage.haskell.org", + "pkg.go.dev", + "docker.com", + "kubernetes.io", + "prometheus.io", + "grafana.com", + "elastic.co", + "datadog.com", + "sentry.io", + "pagerduty.com", + "atlassian.com", + "jetbrains.com", + "gitlab.com", + "bitbucket.org", + "sourcehut.org", + "codeberg.org", + "launchpad.net", + "savannah.gnu.org", + "letsencrypt.org", + "eff.org", + "torproject.org", + "privacyguides.org", + "matrix.org", + "element.io", + "jitsi.org", + "nextcloud.com", + "syncthing.net", + "tailscale.com", + "mullvad.net", + "proton.me", + "duckduckgo.com", + "brave.com", + "vivaldi.com", +]; + +const ROUNDS: usize = 10; + +fn main() { + let diag = std::env::args().any(|a| a == "--diag"); + let direct = std::env::args().any(|a| a == "--direct"); + + let rt = tokio::runtime::Runtime::new().unwrap(); + + if diag { + run_diag(&rt); + return; + } + + if direct { + run_direct(&rt); + return; + } + + if std::env::args().any(|a| a == "--diag-clients") { + run_diag_clients(&rt); + return; + } + + if std::env::args().any(|a| a == "--spike-trace") { + run_spike_trace(&rt); + return; + } + + if std::env::args().any(|a| a == "--spike-phases") { + run_spike_phases(&rt); + return; + } + + if std::env::args().any(|a| a == "--spike-heartbeat") { + run_spike_heartbeat(&rt); + return; + } + + if std::env::args().any(|a| a == "--hedge") { + run_hedge(&rt); + return; + } + + if std::env::args().any(|a| a == "--hedge-5x") { + run_hedge_multi(&rt, 5); + return; + } + + if std::env::args().any(|a| a == "--vs-dnscrypt") { + run_vs_dnscrypt(&rt, 5); + return; + } + + if std::env::args().any(|a| a == "--vs-unbound") { + run_vs_unbound(&rt, 5); + return; + } + + let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); + + println!("DoH Forwarding Benchmark: Numa vs hickory-resolver"); + println!("Both forwarding to {DOH_UPSTREAM}"); + println!("{} domains × {ROUNDS} rounds", DOMAINS.len()); + println!(); + + // Verify bench Numa is reachable + if rt.block_on(query_udp(numa_addr, "example.com")).is_none() { + eprintln!("Bench Numa not responding on {numa_addr}"); + eprintln!(); + eprintln!("Start it with:"); + eprintln!(" cargo run -- benches/numa-bench.toml"); + std::process::exit(1); + } + + // Build a shared Hickory resolver (reuses TLS connection, like Numa does) + let resolver = rt.block_on(build_hickory_resolver()); + + // Warm up both paths (TLS handshake, connection establishment) + println!("Warming up connections..."); + for _ in 0..3 { + rt.block_on(query_udp(numa_addr, "example.com")); + rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + flush_cache(); + + println!( + "{:<30} {:>10} {:>10} {:>10} {:>8} {:>8}", + "Domain", "Numa (ms)", "Hickory", "Delta", "σ Numa", "σ Hick" + ); + println!("{}", "-".repeat(92)); + + let mut numa_all = Vec::new(); + let mut hickory_all = Vec::new(); + let mut per_domain: Vec<(&str, f64, f64, f64, f64, f64)> = Vec::new(); + + for domain in DOMAINS { + let mut numa_times = Vec::with_capacity(ROUNDS); + let mut hickory_times = Vec::with_capacity(ROUNDS); + + for round in 0..ROUNDS { + flush_cache(); + std::thread::sleep(Duration::from_millis(10)); + + // Alternate measurement order each round to cancel systematic bias + if round % 2 == 0 { + // Numa first + let t = measure(&rt, || rt.block_on(query_udp(numa_addr, domain))); + numa_times.push(t); + let t = measure(&rt, || rt.block_on(query_hickory_doh(&resolver, domain))); + hickory_times.push(t); + } else { + // Hickory first + let t = measure(&rt, || rt.block_on(query_hickory_doh(&resolver, domain))); + hickory_times.push(t); + flush_cache(); + std::thread::sleep(Duration::from_millis(10)); + let t = measure(&rt, || rt.block_on(query_udp(numa_addr, domain))); + numa_times.push(t); + } + } + + let numa_avg = mean(&numa_times); + let hickory_avg = mean(&hickory_times); + let numa_sd = stddev(&numa_times); + let hickory_sd = stddev(&hickory_times); + let delta = numa_avg - hickory_avg; + + numa_all.extend_from_slice(&numa_times); + hickory_all.extend_from_slice(&hickory_times); + per_domain.push((domain, numa_avg, hickory_avg, delta, numa_sd, hickory_sd)); + + let delta_str = format_delta(delta); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", + domain, numa_avg, hickory_avg, delta_str, numa_sd, hickory_sd + ); + } + + println!("{}", "-".repeat(92)); + + let numa_mean = mean(&numa_all); + let hickory_mean = mean(&hickory_all); + let delta_mean = numa_mean - hickory_mean; + + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", + "OVERALL MEAN", + numa_mean, + hickory_mean, + format_delta(delta_mean), + stddev(&numa_all), + stddev(&hickory_all), + ); + + // Median + let numa_med = median(&mut numa_all); + let hickory_med = median(&mut hickory_all); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", + "MEDIAN", + numa_med, + hickory_med, + format_delta(numa_med - hickory_med), + ); + + // P95 + let numa_p95 = percentile(&numa_all, 95.0); + let hickory_p95 = percentile(&hickory_all, 95.0); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", + "P95", + numa_p95, + hickory_p95, + format_delta(numa_p95 - hickory_p95), + ); + + println!(); + let total_queries = DOMAINS.len() * ROUNDS; + if numa_mean < hickory_mean { + let pct = ((hickory_mean - numa_mean) / hickory_mean * 100.0).round(); + println!("Numa is ~{pct}% faster (mean over {total_queries} queries)."); + } else if hickory_mean < numa_mean { + let pct = ((numa_mean - hickory_mean) / numa_mean * 100.0).round(); + println!("Hickory is ~{pct}% faster (mean over {total_queries} queries)."); + } else { + println!("Both are equal (mean over {total_queries} queries)."); + } + + println!(); + println!("Methodology:"); + println!(" - Both forward to {DOH_UPSTREAM} over a reused TLS connection."); + println!(" - Numa cache flushed before each query. Hickory cache disabled."); + println!(" - Measurement order alternates each round to cancel order bias."); + println!(" - {} domains × {ROUNDS} rounds = {total_queries} queries per resolver.", DOMAINS.len()); +} + +fn run_diag(rt: &tokio::runtime::Runtime) { + println!("Hickory connection reuse diagnostic"); + println!("20 sequential queries to {DOH_UPSTREAM} via one shared resolver"); + println!("If conn is reused: query 1 slow (TLS handshake), rest fast.\n"); + + let resolver = rt.block_on(build_hickory_resolver()); + + let domains = [ + "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", + "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", + "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", + "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", + ]; + + println!("{:>3} {:<20} {:>10}", "#", "Domain", "Time (ms)"); + println!("{}", "-".repeat(40)); + + for (i, domain) in domains.iter().enumerate() { + use hickory_resolver::proto::rr::RecordType; + let start = Instant::now(); + let result = rt.block_on(resolver.lookup(*domain, RecordType::A)); + let ms = start.elapsed().as_secs_f64() * 1000.0; + match &result { + Ok(lookup) => { + let first = lookup.iter().next().map(|r| format!("{r}")).unwrap_or_default(); + println!("{:>3} {:<20} {:>7.1} ms OK {}", i + 1, domain, ms, first); + } + Err(e) => { + println!("{:>3} {:<20} {:>7.1} ms ERR {}", i + 1, domain, ms, e); + } + } + } +} + +/// Library-to-library comparison: Numa's forward_query_raw vs Hickory's resolver.lookup(). +/// No UDP, no server pipeline — just the DoH forwarding call. +fn run_direct(rt: &tokio::runtime::Runtime) { + println!("Direct DoH Forwarding: Numa forward_query_raw vs Hickory resolver.lookup()"); + println!("Both forwarding to {DOH_UPSTREAM} — no UDP, no server pipeline"); + println!("{} domains × {ROUNDS} rounds", DOMAINS.len()); + println!(); + + // Build Numa's upstream (shared reqwest client, reuses HTTP/2 connection) + let numa_upstream = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); + let timeout = Duration::from_secs(10); + + // Build Hickory's resolver (shared, reuses HTTP/2 connection) + let resolver = rt.block_on(build_hickory_resolver()); + + // Warm up both + println!("Warming up connections..."); + for _ in 0..3 { + let wire = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&wire, &numa_upstream, timeout)); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + println!( + "{:<30} {:>10} {:>10} {:>10} {:>8} {:>8}", + "Domain", "Numa (ms)", "Hickory", "Delta", "σ Numa", "σ Hick" + ); + println!("{}", "-".repeat(92)); + + let mut numa_all = Vec::new(); + let mut hickory_all = Vec::new(); + + for domain in DOMAINS { + let mut numa_times = Vec::with_capacity(ROUNDS); + let mut hickory_times = Vec::with_capacity(ROUNDS); + + for round in 0..ROUNDS { + let wire = build_query_vec(domain); + + if round % 2 == 0 { + let w = wire.clone(); + let t = measure(rt, || { + rt.block_on(numa::forward::forward_query_raw(&w, &numa_upstream, timeout)) + }); + numa_times.push(t); + let t = measure(rt, || rt.block_on(query_hickory_doh(&resolver, domain))); + hickory_times.push(t); + } else { + let t = measure(rt, || rt.block_on(query_hickory_doh(&resolver, domain))); + hickory_times.push(t); + let w = wire.clone(); + let t = measure(rt, || { + rt.block_on(numa::forward::forward_query_raw(&w, &numa_upstream, timeout)) + }); + numa_times.push(t); + } + } + + let numa_avg = mean(&numa_times); + let hickory_avg = mean(&hickory_times); + let numa_sd = stddev(&numa_times); + let hickory_sd = stddev(&hickory_times); + let delta = numa_avg - hickory_avg; + + numa_all.extend_from_slice(&numa_times); + hickory_all.extend_from_slice(&hickory_times); + + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", + domain, numa_avg, hickory_avg, format_delta(delta), numa_sd, hickory_sd + ); + } + + println!("{}", "-".repeat(92)); + let numa_mean = mean(&numa_all); + let hickory_mean = mean(&hickory_all); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", + "OVERALL MEAN", numa_mean, hickory_mean, format_delta(numa_mean - hickory_mean), + stddev(&numa_all), stddev(&hickory_all), + ); + let numa_med = median(&mut numa_all); + let hickory_med = median(&mut hickory_all); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", + "MEDIAN", numa_med, hickory_med, format_delta(numa_med - hickory_med), + ); + let numa_p95 = percentile(&numa_all, 95.0); + let hickory_p95 = percentile(&hickory_all, 95.0); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", + "P95", numa_p95, hickory_p95, format_delta(numa_p95 - hickory_p95), + ); + + println!(); + let total_queries = DOMAINS.len() * ROUNDS; + if numa_mean < hickory_mean { + let pct = ((hickory_mean - numa_mean) / hickory_mean * 100.0).round(); + println!("Numa is ~{pct}% faster (mean over {total_queries} queries)."); + } else if hickory_mean < numa_mean { + let pct = ((numa_mean - hickory_mean) / numa_mean * 100.0).round(); + println!("Hickory is ~{pct}% faster (mean over {total_queries} queries)."); + } else { + println!("Both are equal (mean over {total_queries} queries)."); + } + + println!(); + println!("Methodology:"); + println!(" - Both forward to {DOH_UPSTREAM} over a reused TLS/HTTP2 connection."); + println!(" - No UDP, no server pipeline, no cache — pure DoH forwarding."); + println!(" - Numa: forward_query_raw (reqwest). Hickory: resolver.lookup (h2)."); + println!(" - {} domains × {ROUNDS} rounds = {total_queries} queries per implementation.", DOMAINS.len()); +} + +/// Per-query timing diagnostic: 20 queries each through reqwest and Hickory. +/// Shows whether reqwest has connection reuse issues or per-request overhead. +fn run_diag_clients(rt: &tokio::runtime::Runtime) { + println!("Client diagnostic: reqwest vs Hickory per-query timing"); + println!("20 queries each to {DOH_UPSTREAM}\n"); + + let upstream = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); + let resolver = rt.block_on(build_hickory_resolver()); + let timeout = Duration::from_secs(10); + + // Warm both + for _ in 0..3 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + let domains = [ + "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", + "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", + "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", + "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", + ]; + + println!("{:>3} {:<20} {:>12} {:>12}", "#", "Domain", "reqwest", "Hickory"); + println!("{}", "-".repeat(55)); + + for (i, domain) in domains.iter().enumerate() { + let wire = build_query_vec(domain); + + let start = Instant::now(); + let r_result = rt.block_on(numa::forward::forward_query_raw(&wire, &upstream, timeout)); + let r_ms = start.elapsed().as_secs_f64() * 1000.0; + let r_ok = if r_result.is_ok() { "OK" } else { "FAIL" }; + + let start = Instant::now(); + let h_result = rt.block_on(query_hickory_doh(&resolver, domain)); + let h_ms = start.elapsed().as_secs_f64() * 1000.0; + let h_ok = if h_result.is_some() { "OK" } else { "FAIL" }; + + println!( + "{:>3} {:<20} {:>7.1} ms {} {:>7.1} ms {}", + i + 1, domain, r_ms, r_ok, h_ms, h_ok + ); + } +} + +/// Spike trace: fire 200 sequential queries through reqwest and log every one +/// with a timestamp. Analyze the distribution and find spike clusters. +fn run_spike_trace(rt: &tokio::runtime::Runtime) { + println!("Spike trace: 200 sequential reqwest DoH queries"); + println!("Target: {DOH_UPSTREAM}\n"); + + let upstream = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); + let timeout = Duration::from_secs(10); + + // Warm + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); + } + + // Run the entire 200-query loop inside ONE block_on to eliminate + // per-query runtime re-entry overhead. + let samples: Vec<(u128, f64)> = rt.block_on(async { + let test_start = Instant::now(); + let mut s = Vec::with_capacity(200); + for i in 0..200 { + let domain = match i % 5 { + 0 => "example.com", + 1 => "google.com", + 2 => "github.com", + 3 => "rust-lang.org", + _ => "cloudflare.com", + }; + let wire = build_query_vec(domain); + let req_start = Instant::now(); + let t_from_start_us = test_start.elapsed().as_micros(); + let _ = numa::forward::forward_query_raw(&wire, &upstream, timeout).await; + let ms = req_start.elapsed().as_secs_f64() * 1000.0; + s.push((t_from_start_us, ms)); + } + s + }); + + // Compute stats + let mut sorted_times: Vec = samples.iter().map(|(_, t)| *t).collect(); + sorted_times.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let n = sorted_times.len(); + let median = sorted_times[n / 2]; + let p90 = sorted_times[(n * 90) / 100]; + let p95 = sorted_times[(n * 95) / 100]; + let p99 = sorted_times[(n * 99) / 100]; + let max = sorted_times[n - 1]; + let mean: f64 = sorted_times.iter().sum::() / n as f64; + + println!("Distribution (n={}):", n); + println!(" mean: {:.1} ms", mean); + println!(" median: {:.1} ms", median); + println!(" p90: {:.1} ms", p90); + println!(" p95: {:.1} ms", p95); + println!(" p99: {:.1} ms", p99); + println!(" max: {:.1} ms", max); + println!(); + + // Define spike threshold as 3x median + let spike_threshold = median * 3.0; + let spikes: Vec<(usize, u128, f64)> = samples + .iter() + .enumerate() + .filter(|(_, (_, t))| *t > spike_threshold) + .map(|(i, (ts, t))| (i, *ts, *t)) + .collect(); + + println!("Spikes (> {:.1}ms, which is 3x median):", spike_threshold); + println!(" count: {}", spikes.len()); + if spikes.is_empty() { + return; + } + + // Inter-spike gaps (time between spikes) + let mut gaps_ms: Vec = Vec::new(); + for w in spikes.windows(2) { + let gap_us = w[1].1 - w[0].1; + gaps_ms.push(gap_us as f64 / 1000.0); + } + + println!(); + println!(" {:>4} {:>12} {:>10} {:>12}", "idx", "at (ms)", "latency", "gap from prev"); + for (i, ((idx, ts, latency), gap)) in spikes.iter().zip( + std::iter::once(&0.0).chain(gaps_ms.iter()) + ).enumerate() { + let _ = i; + let gap_str = if *gap > 0.0 { + format!("{:.0} ms", gap) + } else { + "-".to_string() + }; + println!(" {:>4} {:>9.1} {:>6.1} ms {:>12}", idx, *ts as f64 / 1000.0, latency, gap_str); + } + + if !gaps_ms.is_empty() { + let gap_mean: f64 = gaps_ms.iter().sum::() / gaps_ms.len() as f64; + let mut gap_sorted = gaps_ms.clone(); + gap_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let gap_median = gap_sorted[gap_sorted.len() / 2]; + println!(); + println!(" Inter-spike gap: mean={:.0}ms, median={:.0}ms", gap_mean, gap_median); + } +} + +/// Spike phases: time each step of the reqwest DoH call to find which phase +/// is slow during a spike. Reports (build+send, send->resp headers, body read). +fn run_spike_phases(rt: &tokio::runtime::Runtime) { + println!("Spike phases: timing each phase of reqwest DoH call"); + println!("Target: {DOH_UPSTREAM}\n"); + + // Build the same tuned client our forward_doh uses + let client = reqwest::Client::builder() + .use_rustls_tls() + .http2_initial_stream_window_size(65_535) + .http2_initial_connection_window_size(65_535) + .http2_keep_alive_interval(Duration::from_secs(15)) + .http2_keep_alive_while_idle(true) + .http2_keep_alive_timeout(Duration::from_secs(10)) + .pool_idle_timeout(Duration::from_secs(300)) + .pool_max_idle_per_host(1) + .build() + .unwrap(); + + // Warm up + for _ in 0..5 { + let wire = build_query_vec("example.com"); + let _ = rt.block_on(async { + client + .post(DOH_UPSTREAM) + .header("content-type", "application/dns-message") + .header("accept", "application/dns-message") + .body(wire) + .send() + .await + .ok()? + .bytes() + .await + .ok() + }); + } + + println!("{:>4} {:>8} {:>8} {:>8} {:>8}", "idx", "total", "build", "send", "body"); + println!("{}", "-".repeat(50)); + + let samples: Vec<(f64, f64, f64, f64)> = rt.block_on(async { + let mut s = Vec::with_capacity(200); + for i in 0..200 { + let domain = match i % 5 { + 0 => "example.com", + 1 => "google.com", + 2 => "github.com", + 3 => "rust-lang.org", + _ => "cloudflare.com", + }; + let wire = build_query_vec(domain); + + let t0 = Instant::now(); + // Phase 1: build the request + let req = client + .post(DOH_UPSTREAM) + .header("content-type", "application/dns-message") + .header("accept", "application/dns-message") + .body(wire); + let t1 = Instant::now(); + // Phase 2: send() — this is the dispatch channel + round trip to headers + let resp_result = req.send().await; + let t2 = Instant::now(); + // Phase 3: read body + let body_result = match resp_result { + Ok(r) => r.bytes().await.ok().map(|b| b.len()), + Err(_) => None, + }; + let t3 = Instant::now(); + + let build_ms = (t1 - t0).as_secs_f64() * 1000.0; + let send_ms = (t2 - t1).as_secs_f64() * 1000.0; + let body_ms = (t3 - t2).as_secs_f64() * 1000.0; + let total_ms = (t3 - t0).as_secs_f64() * 1000.0; + + s.push((total_ms, build_ms, send_ms, body_ms)); + let _ = body_result; + } + s + }); + + // Compute distribution on total + let mut totals: Vec = samples.iter().map(|s| s.0).collect(); + totals.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let median = totals[100]; + + // Print spikes (> 3x median) with phase breakdown + for (i, (total, build, send, body)) in samples.iter().enumerate() { + if *total > median * 3.0 { + println!( + "{:>4} {:>5.1} ms {:>5.1} ms {:>5.1} ms {:>5.1} ms", + i, total, build, send, body + ); + } + } + + // Summary: mean of each phase for spikes vs non-spikes + let (spike_samples, normal_samples): (Vec<_>, Vec<_>) = samples + .iter() + .partition(|(t, _, _, _)| *t > median * 3.0); + + let phase_means = |samples: &[&(f64, f64, f64, f64)]| -> (f64, f64, f64, f64) { + let n = samples.len() as f64; + if n == 0.0 { return (0.0, 0.0, 0.0, 0.0); } + let total: f64 = samples.iter().map(|s| s.0).sum::() / n; + let build: f64 = samples.iter().map(|s| s.1).sum::() / n; + let send: f64 = samples.iter().map(|s| s.2).sum::() / n; + let body: f64 = samples.iter().map(|s| s.3).sum::() / n; + (total, build, send, body) + }; + + let spike_refs: Vec<&(f64, f64, f64, f64)> = spike_samples.iter().copied().collect(); + let normal_refs: Vec<&(f64, f64, f64, f64)> = normal_samples.iter().copied().collect(); + let (s_total, s_build, s_send, s_body) = phase_means(&spike_refs); + let (n_total, n_build, n_send, n_body) = phase_means(&normal_refs); + + println!(); + println!("Summary (mean ms):"); + println!( + " {:<8} {:>8} {:>8} {:>8} {:>8}", + "", "total", "build", "send", "body" + ); + println!( + " {:<8} {:>5.1} ms {:>5.1} ms {:>5.1} ms {:>5.1} ms (n={})", + "normal", n_total, n_build, n_send, n_body, normal_refs.len() + ); + println!( + " {:<8} {:>5.1} ms {:>5.1} ms {:>5.1} ms {:>5.1} ms (n={})", + "spike", s_total, s_build, s_send, s_body, spike_refs.len() + ); + println!(); + println!("Delta (spike - normal):"); + println!( + " build: {:+.1} ms, send: {:+.1} ms, body: {:+.1} ms", + s_build - n_build, + s_send - n_send, + s_body - n_body + ); +} + +/// Heartbeat probe: run a parallel task that ticks every 5ms and records +/// how long each tick actually takes. If the heartbeat stalls during a DoH +/// spike, it's a tokio scheduling issue (runtime can't poll tasks). If +/// heartbeat is fine while send() is stuck, it's internal to hyper/h2. +fn run_spike_heartbeat(rt: &tokio::runtime::Runtime) { + use std::sync::{Arc, Mutex}; + + println!("Spike heartbeat probe"); + println!("Running DoH queries + parallel 5ms heartbeat task\n"); + + let upstream = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); + let timeout = Duration::from_secs(10); + + // Warm up + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); + } + + // Shared vecs: (relative_ms_from_start, event_kind, latency_ms) + // event_kind: 0 = heartbeat, 1 = doh query + type EventLog = Vec<(f64, u8, f64)>; + let events: Arc> = Arc::new(Mutex::new(Vec::with_capacity(2000))); + let stop = Arc::new(std::sync::atomic::AtomicBool::new(false)); + + let test_start = Instant::now(); + + rt.block_on(async { + // Spawn heartbeat task + let hb_events = Arc::clone(&events); + let hb_stop = Arc::clone(&stop); + let hb_start = test_start; + let heartbeat = tokio::spawn(async move { + let mut next_tick = Instant::now(); + let target = Duration::from_millis(5); + while !hb_stop.load(std::sync::atomic::Ordering::Relaxed) { + next_tick += target; + // Sleep until the next scheduled tick + let now = Instant::now(); + if next_tick > now { + tokio::time::sleep(next_tick - now).await; + } + // Measure how much we overshot the scheduled tick + let actual = Instant::now(); + let lag_ms = if actual > next_tick { + (actual - next_tick).as_secs_f64() * 1000.0 + } else { + 0.0 + }; + let t = (actual - hb_start).as_secs_f64() * 1000.0; + if let Ok(mut e) = hb_events.lock() { + e.push((t, 0, lag_ms)); + } + } + }); + + // Run 200 DoH queries and record their timings + for i in 0..200 { + let domain = match i % 5 { + 0 => "example.com", + 1 => "google.com", + 2 => "github.com", + 3 => "rust-lang.org", + _ => "cloudflare.com", + }; + let wire = build_query_vec(domain); + let req_start = Instant::now(); + let _ = numa::forward::forward_query_raw(&wire, &upstream, timeout).await; + let elapsed = req_start.elapsed().as_secs_f64() * 1000.0; + let t = (req_start - test_start).as_secs_f64() * 1000.0; + if let Ok(mut e) = events.lock() { + e.push((t, 1, elapsed)); + } + } + + stop.store(true, std::sync::atomic::Ordering::Relaxed); + let _ = heartbeat.await; + }); + + let events = events.lock().unwrap(); + + // Separate heartbeats and doh events + let hb: Vec<(f64, f64)> = events + .iter() + .filter(|(_, k, _)| *k == 0) + .map(|(t, _, l)| (*t, *l)) + .collect(); + let doh: Vec<(f64, f64)> = events + .iter() + .filter(|(_, k, _)| *k == 1) + .map(|(t, _, l)| (*t, *l)) + .collect(); + + // Heartbeat stats + let mut hb_lags: Vec = hb.iter().map(|(_, l)| *l).collect(); + hb_lags.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let hb_n = hb_lags.len(); + let hb_median = hb_lags[hb_n / 2]; + let hb_p95 = hb_lags[(hb_n * 95) / 100]; + let hb_p99 = hb_lags[(hb_n * 99) / 100]; + let hb_max = hb_lags[hb_n - 1]; + + // DoH stats + let mut doh_latencies: Vec = doh.iter().map(|(_, l)| *l).collect(); + doh_latencies.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let doh_n = doh_latencies.len(); + let doh_median = doh_latencies[doh_n / 2]; + let doh_p95 = doh_latencies[(doh_n * 95) / 100]; + let doh_max = doh_latencies[doh_n - 1]; + + println!("Heartbeat lag (tick overshoot, {}ms target):", 5); + println!(" n: {}", hb_n); + println!(" median: {:.2} ms", hb_median); + println!(" p95: {:.2} ms", hb_p95); + println!(" p99: {:.2} ms", hb_p99); + println!(" max: {:.2} ms", hb_max); + println!(); + println!("DoH latency:"); + println!(" n: {}", doh_n); + println!(" median: {:.1} ms", doh_median); + println!(" p95: {:.1} ms", doh_p95); + println!(" max: {:.1} ms", doh_max); + println!(); + + // Find DoH spikes and check heartbeat activity DURING each spike + let doh_spike_threshold = doh_median * 3.0; + let mut spikes_with_hb_lag = 0; + let mut spikes_total = 0; + let mut max_hb_during_any_spike = 0.0_f64; + + println!( + "Correlation: during each DoH spike (>{:.1}ms), max heartbeat lag:", + doh_spike_threshold + ); + println!(" {:>6} {:>10} {:>18}", "doh_at", "doh_ms", "max_hb_lag_during"); + + for (doh_t, doh_ms) in &doh { + if *doh_ms > doh_spike_threshold { + spikes_total += 1; + // Find heartbeats that happened during this DoH query + let spike_start = *doh_t; + let spike_end = spike_start + *doh_ms; + let mut max_hb = 0.0_f64; + for (hb_t, hb_lag) in &hb { + if *hb_t >= spike_start && *hb_t <= spike_end + 20.0 { + if *hb_lag > max_hb { + max_hb = *hb_lag; + } + } + } + if max_hb > 5.0 { + spikes_with_hb_lag += 1; + } + max_hb_during_any_spike = max_hb_during_any_spike.max(max_hb); + println!( + " {:>5.0} ms {:>7.1} ms {:>14.2} ms", + doh_t, doh_ms, max_hb + ); + } + } + + println!(); + println!("Conclusion:"); + if spikes_total == 0 { + println!(" No DoH spikes in this run."); + } else { + let pct = (spikes_with_hb_lag as f64 / spikes_total as f64 * 100.0).round(); + println!( + " {}/{} spikes ({:.0}%) had concurrent heartbeat lag >5ms.", + spikes_with_hb_lag, spikes_total, pct + ); + println!(" Max heartbeat lag during any spike: {:.2}ms", max_hb_during_any_spike); + println!(); + if max_hb_during_any_spike > 20.0 { + println!(" → Heartbeat stalls during DoH spikes: tokio scheduling / OS thread issue."); + println!(" The runtime can't poll ANY task — likely QoS demotion, GC pause,"); + println!(" or the worker thread is blocked somewhere."); + } else { + println!(" → Heartbeat runs normally during DoH spikes: internal to hyper/h2."); + println!(" The runtime is fine, but send()'s await is stuck waiting for"); + println!(" the ClientTask to poll the dispatch channel."); + } + } +} + +/// Hedging benchmark: tests four configurations against Hickory. +/// Single: 1 client → Quad9 (baseline) +/// Hedge-same: hedge against same client/connection → Quad9 +/// Hedge-dual: hedge against 2 separate clients, both → Quad9 (same upstream, 2 HTTP/2 conns) +/// Hickory: Hickory resolver → Quad9 (reference) +fn run_hedge(rt: &tokio::runtime::Runtime) { + let hedge_delay = Duration::from_millis(10); + + println!("Hedging Benchmark (all paths → Quad9 only)"); + println!("Upstream: {}", DOH_UPSTREAM); + println!("Hedge delay: {:?}", hedge_delay); + println!("{} domains × {} rounds\n", DOMAINS.len(), ROUNDS); + + // Primary and secondary: two separate reqwest clients → same Quad9 URL. + // This gives two independent HTTP/2 connections, so dispatch spikes + // are uncorrelated (at most one stalls at a time). + let primary_same = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse primary"); + let primary_dual = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse primary_dual"); + let secondary_dual = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse secondary_dual"); + let timeout = Duration::from_secs(10); + + let resolver = rt.block_on(build_hickory_resolver()); + + // Warm up all paths (separate connections need their own TLS handshake) + println!("Warming up connections..."); + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_same, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_dual, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &secondary_dual, timeout)); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + let mut single_all = Vec::new(); + let mut hedge_same_all = Vec::new(); + let mut hedge_dual_all = Vec::new(); + let mut hickory_all = Vec::new(); + + println!( + "{:<24} {:>10} {:>10} {:>10} {:>10}", + "Domain", "Single", "Hedge-same", "Hedge-dual", "Hickory" + ); + println!("{}", "-".repeat(78)); + + for domain in DOMAINS { + let mut single_times = Vec::with_capacity(ROUNDS); + let mut hedge_same_times = Vec::with_capacity(ROUNDS); + let mut hedge_dual_times = Vec::with_capacity(ROUNDS); + let mut hickory_times = Vec::with_capacity(ROUNDS); + + for _ in 0..ROUNDS { + let wire = build_query_vec(domain); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_query_raw(&wire, &primary_same, timeout)); + single_times.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &wire, &primary_same, &primary_same, hedge_delay, timeout, + )); + hedge_same_times.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &wire, &primary_dual, &secondary_dual, hedge_delay, timeout, + )); + hedge_dual_times.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(query_hickory_doh(&resolver, domain)); + hickory_times.push(t.elapsed().as_secs_f64() * 1000.0); + } + + single_all.extend_from_slice(&single_times); + hedge_same_all.extend_from_slice(&hedge_same_times); + hedge_dual_all.extend_from_slice(&hedge_dual_times); + hickory_all.extend_from_slice(&hickory_times); + + println!( + "{:<24} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + domain, + mean(&single_times), + mean(&hedge_same_times), + mean(&hedge_dual_times), + mean(&hickory_times) + ); + } + + println!("{}", "-".repeat(78)); + + let stats = |all: &mut Vec| -> (f64, f64, f64, f64, f64) { + let m = mean(all); + let med = median(all); + let p95 = percentile(all, 95.0); + let p99 = percentile(all, 99.0); + let sd = stddev(all); + (m, med, p95, p99, sd) + }; + + let (s_m, s_med, s_p95, s_p99, s_sd) = stats(&mut single_all); + let (hs_m, hs_med, hs_p95, hs_p99, hs_sd) = stats(&mut hedge_same_all); + let (hd_m, hd_med, hd_p95, hd_p99, hd_sd) = stats(&mut hedge_dual_all); + let (k_m, k_med, k_p95, k_p99, k_sd) = stats(&mut hickory_all); + + println!(); + println!( + "{:<10} {:>10} {:>10} {:>10} {:>10}", + "", "Single", "Hedge-same", "Hedge-dual", "Hickory" + ); + println!( + "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + "mean", s_m, hs_m, hd_m, k_m + ); + println!( + "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + "median", s_med, hs_med, hd_med, k_med + ); + println!( + "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + "p95", s_p95, hs_p95, hd_p95, k_p95 + ); + println!( + "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + "p99", s_p99, hs_p99, hd_p99, k_p99 + ); + println!( + "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + "σ", s_sd, hs_sd, hd_sd, k_sd + ); + + println!(); + println!("Hedge-same improvement over single:"); + println!(" mean: {:+.0}%, p95: {:+.0}%, p99: {:+.0}%", + (hs_m - s_m) / s_m * 100.0, + (hs_p95 - s_p95) / s_p95 * 100.0, + (hs_p99 - s_p99) / s_p99 * 100.0); + println!("Hedge-dual improvement over single:"); + println!(" mean: {:+.0}%, p95: {:+.0}%, p99: {:+.0}%", + (hd_m - s_m) / s_m * 100.0, + (hd_p95 - s_p95) / s_p95 * 100.0, + (hd_p99 - s_p99) / s_p99 * 100.0); +} + +/// Run the hedging benchmark N times and aggregate samples across all runs. +/// Also reports per-run stats to show drift. +fn run_hedge_multi(rt: &tokio::runtime::Runtime, iterations: usize) { + let hedge_delay = Duration::from_millis(10); + + println!("Hedging Benchmark × {} iterations", iterations); + println!("Upstream: {}", DOH_UPSTREAM); + println!("Hedge delay: {:?}", hedge_delay); + println!("{} domains × {} rounds per iteration\n", DOMAINS.len(), ROUNDS); + + let primary_same = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let primary_dual = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let secondary_dual = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let timeout = Duration::from_secs(10); + + let resolver = rt.block_on(build_hickory_resolver()); + + // Warm up + println!("Warming up..."); + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_same, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_dual, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &secondary_dual, timeout)); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + // Accumulated samples across all iterations + let mut all_single = Vec::new(); + let mut all_hedge_same = Vec::new(); + let mut all_hedge_dual = Vec::new(); + let mut all_hickory = Vec::new(); + + // Per-iteration summary stats + let mut iter_stats: Vec<[(f64, f64, f64, f64, f64); 4]> = Vec::new(); + + for iter in 1..=iterations { + println!(" iteration {}/{}...", iter, iterations); + + let mut single = Vec::new(); + let mut hedge_same = Vec::new(); + let mut hedge_dual = Vec::new(); + let mut hickory = Vec::new(); + + for domain in DOMAINS { + for _ in 0..ROUNDS { + let wire = build_query_vec(domain); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_query_raw(&wire, &primary_same, timeout)); + single.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &wire, &primary_same, &primary_same, hedge_delay, timeout, + )); + hedge_same.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &wire, &primary_dual, &secondary_dual, hedge_delay, timeout, + )); + hedge_dual.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(query_hickory_doh(&resolver, domain)); + hickory.push(t.elapsed().as_secs_f64() * 1000.0); + } + } + + let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + iter_stats.push([ + stats(&mut single), + stats(&mut hedge_same), + stats(&mut hedge_dual), + stats(&mut hickory), + ]); + + all_single.extend_from_slice(&single); + all_hedge_same.extend_from_slice(&hedge_same); + all_hedge_dual.extend_from_slice(&hedge_dual); + all_hickory.extend_from_slice(&hickory); + } + + println!(); + println!("=== Per-iteration medians (drift check) ==="); + println!( + "{:<8} {:>10} {:>12} {:>12} {:>10}", + "iter", "Single", "Hedge-same", "Hedge-dual", "Hickory" + ); + for (i, s) in iter_stats.iter().enumerate() { + println!( + "{:<8} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + i + 1, + s[0].1, + s[1].1, + s[2].1, + s[3].1 + ); + } + + println!(); + println!("=== Per-iteration p99 (drift check) ==="); + println!( + "{:<8} {:>10} {:>12} {:>12} {:>10}", + "iter", "Single", "Hedge-same", "Hedge-dual", "Hickory" + ); + for (i, s) in iter_stats.iter().enumerate() { + println!( + "{:<8} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + i + 1, + s[0].3, + s[1].3, + s[2].3, + s[3].3 + ); + } + + let final_stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + let (s_m, s_med, s_p95, s_p99, s_sd) = final_stats(&mut all_single); + let (hs_m, hs_med, hs_p95, hs_p99, hs_sd) = final_stats(&mut all_hedge_same); + let (hd_m, hd_med, hd_p95, hd_p99, hd_sd) = final_stats(&mut all_hedge_dual); + let (k_m, k_med, k_p95, k_p99, k_sd) = final_stats(&mut all_hickory); + + println!(); + let total = iterations * DOMAINS.len() * ROUNDS; + println!("=== Aggregated across all {} samples per method ===", total); + println!(); + println!( + "{:<10} {:>10} {:>12} {:>12} {:>10}", + "", "Single", "Hedge-same", "Hedge-dual", "Hickory" + ); + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + "mean", s_m, hs_m, hd_m, k_m + ); + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + "median", s_med, hs_med, hd_med, k_med + ); + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + "p95", s_p95, hs_p95, hd_p95, k_p95 + ); + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + "p99", s_p99, hs_p99, hd_p99, k_p99 + ); + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + "σ", s_sd, hs_sd, hd_sd, k_sd + ); + + println!(); + println!("Hedge-same vs Single: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", + (hs_m - s_m) / s_m * 100.0, + (hs_p95 - s_p95) / s_p95 * 100.0, + (hs_p99 - s_p99) / s_p99 * 100.0); + println!("Hedge-dual vs Single: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", + (hd_m - s_m) / s_m * 100.0, + (hd_p95 - s_p95) / s_p95 * 100.0, + (hd_p99 - s_p99) / s_p99 * 100.0); + println!("Hedge-same vs Hickory: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", + (hs_m - k_m) / k_m * 100.0, + (hs_p95 - k_p95) / k_p95 * 100.0, + (hs_p99 - k_p99) / k_p99 * 100.0); +} + +/// Server-to-server benchmark: Numa vs dnscrypt-proxy vs Unbound. +/// All are full servers: UDP in, encrypted forwarding to Quad9. +/// Numa + dnscrypt: DoH (HTTPS). Unbound: DoT (TLS port 853). +fn run_vs_dnscrypt(rt: &tokio::runtime::Runtime, iterations: usize) { + const DNSCRYPT_ADDR: &str = "127.0.0.1:5455"; + const UNBOUND_ADDR: &str = "127.0.0.1:5456"; + let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); + let dnscrypt_addr: SocketAddr = DNSCRYPT_ADDR.parse().unwrap(); + let unbound_addr: SocketAddr = UNBOUND_ADDR.parse().unwrap(); + + println!("Server-to-Server: Numa vs dnscrypt-proxy vs Unbound"); + println!("Numa (DoH): {}", NUMA_BENCH); + println!("dnscrypt-proxy (DoH): {}", DNSCRYPT_ADDR); + println!("Unbound (DoT): {}", UNBOUND_ADDR); + println!("All forwarding to Quad9 over encrypted transport"); + println!("{} domains × {} rounds × {} iterations\n", + DOMAINS.len(), ROUNDS, iterations); + + // Verify all are up + let servers: Vec<(&str, SocketAddr)> = vec![ + ("Numa", numa_addr), + ("dnscrypt-proxy", dnscrypt_addr), + ("Unbound", unbound_addr), + ]; + for (name, addr) in &servers { + if rt.block_on(query_udp(*addr, "example.com")).is_none() { + eprintln!("{} not responding on {}", name, addr); + std::process::exit(1); + } + } + println!("All servers reachable.\n"); + + // Warm up + println!("Warming up..."); + for _ in 0..5 { + for (_, addr) in &servers { + let _ = rt.block_on(query_udp(*addr, "example.com")); + } + } + + let mut all_numa = Vec::new(); + let mut all_dnscrypt = Vec::new(); + let mut all_unbound = Vec::new(); + let mut iter_stats: Vec<[(f64, f64, f64, f64, f64); 3]> = Vec::new(); + + for iter in 1..=iterations { + println!(" iteration {}/{}...", iter, iterations); + + let mut numa = Vec::new(); + let mut dnscrypt = Vec::new(); + let mut unbound = Vec::new(); + + for domain in DOMAINS { + for round in 0..ROUNDS { + flush_cache(); + std::thread::sleep(Duration::from_millis(5)); + + // Rotate order: 3 servers, 3 possible orderings + let order = round % 3; + let mut measure = |addr: SocketAddr| -> f64 { + let t = Instant::now(); + let _ = rt.block_on(query_udp(addr, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }; + + match order { + 0 => { + numa.push(measure(numa_addr)); + dnscrypt.push(measure(dnscrypt_addr)); + unbound.push(measure(unbound_addr)); + } + 1 => { + dnscrypt.push(measure(dnscrypt_addr)); + unbound.push(measure(unbound_addr)); + numa.push(measure(numa_addr)); + } + _ => { + unbound.push(measure(unbound_addr)); + numa.push(measure(numa_addr)); + dnscrypt.push(measure(dnscrypt_addr)); + } + } + } + } + + let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + iter_stats.push([stats(&mut numa), stats(&mut dnscrypt), stats(&mut unbound)]); + + all_numa.extend_from_slice(&numa); + all_dnscrypt.extend_from_slice(&dnscrypt); + all_unbound.extend_from_slice(&unbound); + } + + println!(); + println!("=== Per-iteration medians ==="); + println!("{:<8} {:>10} {:>14} {:>10}", "iter", "Numa", "dnscrypt-proxy", "Unbound"); + for (i, s) in iter_stats.iter().enumerate() { + println!("{:<8} {:>7.1} ms {:>11.1} ms {:>7.1} ms", + i + 1, s[0].1, s[1].1, s[2].1); + } + + println!(); + println!("=== Per-iteration p99 ==="); + println!("{:<8} {:>10} {:>14} {:>10}", "iter", "Numa", "dnscrypt-proxy", "Unbound"); + for (i, s) in iter_stats.iter().enumerate() { + println!("{:<8} {:>7.1} ms {:>11.1} ms {:>7.1} ms", + i + 1, s[0].3, s[1].3, s[2].3); + } + + let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + let (n_m, n_med, n_p95, n_p99, n_sd) = stats(&mut all_numa); + let (d_m, d_med, d_p95, d_p99, d_sd) = stats(&mut all_dnscrypt); + let (u_m, u_med, u_p95, u_p99, u_sd) = stats(&mut all_unbound); + + println!(); + let total = iterations * DOMAINS.len() * ROUNDS; + println!("=== Aggregated ({} samples per method) ===", total); + println!(); + println!("{:<10} {:>10} {:>14} {:>10}", "", "Numa", "dnscrypt-proxy", "Unbound"); + println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "mean", n_m, d_m, u_m); + println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "median", n_med, d_med, u_med); + println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "p95", n_p95, d_p95, u_p95); + println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "p99", n_p99, d_p99, u_p99); + println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "σ", n_sd, d_sd, u_sd); + println!(); + + println!("Numa vs dnscrypt-proxy:"); + println!(" mean: {:+.0}%, median: {:+.0}%, p99: {:+.0}%", + (n_m - d_m) / d_m * 100.0, (n_med - d_med) / d_med * 100.0, (n_p99 - d_p99) / d_p99 * 100.0); + println!("Numa vs Unbound:"); + println!(" mean: {:+.0}%, median: {:+.0}%, p99: {:+.0}%", + (n_m - u_m) / u_m * 100.0, (n_med - u_med) / u_med * 100.0, (n_p99 - u_p99) / u_p99 * 100.0); +} + +/// Numa vs Unbound: both forward over plain UDP to Quad9, caching enabled. +/// Truly equal transport — no TLS, no HTTP/2, pure forwarding + cache. +fn run_vs_unbound(rt: &tokio::runtime::Runtime, iterations: usize) { + const UNBOUND_ADDR: &str = "127.0.0.1:5456"; + let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); + let unbound_addr: SocketAddr = UNBOUND_ADDR.parse().unwrap(); + + println!("Numa vs Unbound (both plain UDP forwarding to Quad9, caching enabled)"); + println!("Numa: {} → 9.9.9.9:53 UDP", NUMA_BENCH); + println!("Unbound: {} → 9.9.9.9:53 UDP", UNBOUND_ADDR); + println!("{} domains × {} rounds × {} iterations\n", + DOMAINS.len(), ROUNDS, iterations); + + if rt.block_on(query_udp(numa_addr, "example.com")).is_none() { + eprintln!("Numa not responding"); std::process::exit(1); + } + if rt.block_on(query_udp(unbound_addr, "example.com")).is_none() { + eprintln!("Unbound not responding"); std::process::exit(1); + } + println!("Both servers reachable.\n"); + + println!("Warming up..."); + for _ in 0..5 { + let _ = rt.block_on(query_udp(numa_addr, "example.com")); + let _ = rt.block_on(query_udp(unbound_addr, "example.com")); + } + + let mut all_numa = Vec::new(); + let mut all_unbound = Vec::new(); + let mut iter_stats: Vec<[(f64, f64, f64, f64, f64); 2]> = Vec::new(); + + for iter in 1..=iterations { + println!(" iteration {}/{}...", iter, iterations); + + let mut numa = Vec::new(); + let mut unbound = Vec::new(); + + for domain in DOMAINS { + for round in 0..ROUNDS { + // No cache flushing — both serve from cache after first hit + let mut measure = |addr: SocketAddr| -> f64 { + let t = Instant::now(); + let _ = rt.block_on(query_udp(addr, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }; + + if round % 2 == 0 { + numa.push(measure(numa_addr)); + unbound.push(measure(unbound_addr)); + } else { + unbound.push(measure(unbound_addr)); + numa.push(measure(numa_addr)); + } + } + } + + let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + iter_stats.push([stats(&mut numa), stats(&mut unbound)]); + + all_numa.extend_from_slice(&numa); + all_unbound.extend_from_slice(&unbound); + } + + println!(); + println!("=== Per-iteration medians ==="); + println!("{:<8} {:>10} {:>10}", "iter", "Numa", "Unbound"); + for (i, s) in iter_stats.iter().enumerate() { + println!("{:<8} {:>7.1} ms {:>7.1} ms", i + 1, s[0].1, s[1].1); + } + + println!(); + println!("=== Per-iteration p99 ==="); + println!("{:<8} {:>10} {:>10}", "iter", "Numa", "Unbound"); + for (i, s) in iter_stats.iter().enumerate() { + println!("{:<8} {:>7.1} ms {:>7.1} ms", i + 1, s[0].3, s[1].3); + } + + let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + let (n_m, n_med, n_p95, n_p99, n_sd) = stats(&mut all_numa); + let (u_m, u_med, u_p95, u_p99, u_sd) = stats(&mut all_unbound); + + println!(); + let total = iterations * DOMAINS.len() * ROUNDS; + println!("=== Aggregated ({} samples per method) ===", total); + println!(); + println!("{:<10} {:>10} {:>10}", "", "Numa", "Unbound"); + println!("{:<10} {:>7.1} ms {:>7.1} ms", "mean", n_m, u_m); + println!("{:<10} {:>7.1} ms {:>7.1} ms", "median", n_med, u_med); + println!("{:<10} {:>7.1} ms {:>7.1} ms", "p95", n_p95, u_p95); + println!("{:<10} {:>7.1} ms {:>7.1} ms", "p99", n_p99, u_p99); + println!("{:<10} {:>7.1} ms {:>7.1} ms", "σ", n_sd, u_sd); + println!(); + + println!("Numa vs Unbound:"); + println!(" mean: {:+.1} ms ({:+.0}%)", n_m - u_m, (n_m - u_m) / u_m * 100.0); + println!(" median: {:+.1} ms ({:+.0}%)", n_med - u_med, (n_med - u_med) / u_med * 100.0); + println!(" p95: {:+.1} ms ({:+.0}%)", n_p95 - u_p95, (n_p95 - u_p95) / u_p95 * 100.0); + println!(" p99: {:+.1} ms ({:+.0}%)", n_p99 - u_p99, (n_p99 - u_p99) / u_p99 * 100.0); +} + +/// Build a DNS query as a Vec for use with forward_query_raw. +fn build_query_vec(domain: &str) -> Vec { + let mut buf = vec![0u8; 512]; + let len = build_query(&mut buf, domain); + buf.truncate(len); + buf +} + +fn measure R, R>(_rt: &tokio::runtime::Runtime, f: F) -> f64 { + let start = Instant::now(); + f(); + start.elapsed().as_secs_f64() * 1000.0 +} + +fn mean(v: &[f64]) -> f64 { + v.iter().sum::() / v.len() as f64 +} + +fn stddev(v: &[f64]) -> f64 { + let m = mean(v); + let var = v.iter().map(|x| (x - m).powi(2)).sum::() / v.len() as f64; + var.sqrt() +} + +fn median(v: &mut [f64]) -> f64 { + v.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let n = v.len(); + if n % 2 == 0 { + (v[n / 2 - 1] + v[n / 2]) / 2.0 + } else { + v[n / 2] + } +} + +fn percentile(sorted: &[f64], p: f64) -> f64 { + let idx = (p / 100.0 * (sorted.len() - 1) as f64).round() as usize; + sorted[idx.min(sorted.len() - 1)] +} + +fn format_delta(delta: f64) -> String { + if delta > 0.0 { + format!("+{:.1}", delta) + } else { + format!("{:.1}", delta) + } +} + +/// Query a DNS server over UDP. +async fn query_udp(addr: SocketAddr, domain: &str) -> Option<()> { + use tokio::net::UdpSocket; + + let sock = UdpSocket::bind("0.0.0.0:0").await.ok()?; + let mut buf = vec![0u8; 512]; + let len = build_query(&mut buf, domain); + + sock.send_to(&buf[..len], addr).await.ok()?; + + let mut resp = vec![0u8; 4096]; + tokio::time::timeout(Duration::from_secs(10), sock.recv_from(&mut resp)) + .await + .ok()? + .ok()?; + + Some(()) +} + +/// Build a shared Hickory DoH resolver (reuses TLS connection across queries). +async fn build_hickory_resolver() -> hickory_resolver::TokioResolver { + use hickory_resolver::config::*; + + let ns = NameServerConfig { + socket_addr: "9.9.9.9:443".parse().unwrap(), + protocol: hickory_proto::xfer::Protocol::Https, + tls_dns_name: Some("dns.quad9.net".to_string()), + trust_negative_responses: true, + bind_addr: None, + http_endpoint: Some("/dns-query".to_string()), + }; + + let config = ResolverConfig::from_parts(None, vec![], NameServerConfigGroup::from(vec![ns])); + + let mut opts = ResolverOpts::default(); + opts.cache_size = 0; + opts.num_concurrent_reqs = 1; + opts.timeout = Duration::from_secs(10); + + hickory_resolver::TokioResolver::builder_with_config(config, Default::default()) + .with_options(opts) + .build() +} + +/// Query using the shared Hickory resolver. +async fn query_hickory_doh( + resolver: &hickory_resolver::TokioResolver, + domain: &str, +) -> Option<()> { + use hickory_resolver::proto::rr::RecordType; + let _ = resolver.lookup(domain, RecordType::A).await.ok()?; + Some(()) +} + +fn build_query(buf: &mut [u8], domain: &str) -> usize { + let mut pos = 0; + buf[pos..pos + 2].copy_from_slice(&0x1234u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 2].copy_from_slice(&0x0100u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 6].fill(0); + pos += 6; + + for label in domain.split('.') { + buf[pos] = label.len() as u8; + pos += 1; + buf[pos..pos + label.len()].copy_from_slice(label.as_bytes()); + pos += label.len(); + } + buf[pos] = 0; + pos += 1; + buf[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); + pos += 2; + pos +} + +fn flush_cache() { + let _ = std::process::Command::new("curl") + .args(["-s", "-X", "DELETE", &format!("http://127.0.0.1:{NUMA_API}/cache")]) + .output(); +} diff --git a/scripts/bench-recursive.sh b/scripts/bench-recursive.sh new file mode 100755 index 0000000..1a1ab71 --- /dev/null +++ b/scripts/bench-recursive.sh @@ -0,0 +1,115 @@ +#!/usr/bin/env bash +# Bench: Numa cold-cache recursive resolution vs dig (forwarded through system resolver) +# +# Measures cold-cache recursive resolution time for Numa. +# Flushes Numa's cache before each query to ensure cold-cache. +# Compares against dig querying a public recursive resolver (no cache advantage). +# +# Usage: ./scripts/bench-recursive.sh [numa_port] + +set -euo pipefail + +NUMA_ADDR="${NUMA_ADDR:-127.0.0.1}" +NUMA_PORT="${NUMA_PORT:-${1:-53}}" +API_PORT="${API_PORT:-5380}" +ROUNDS=3 + +DOMAINS=( + "example.com" + "rust-lang.org" + "kernel.org" + "signal.org" + "archlinux.org" + "openbsd.org" + "git-scm.com" + "sqlite.org" + "wireguard.com" + "mozilla.org" +) + +GREEN='\033[0;32m' +AMBER='\033[0;33m' +CYAN='\033[0;36m' +DIM='\033[0;90m' +BOLD='\033[1m' +RESET='\033[0m' + +echo -e "${CYAN}${BOLD}Recursive DNS Resolution Benchmark${RESET}" +echo -e "${DIM}Numa (cold cache, recursive from root) vs dig @1.1.1.1 (public resolver)${RESET}" +echo -e "${DIM}Rounds per domain: ${ROUNDS}${RESET}" +echo "" + +# Verify Numa is reachable +if ! dig @${NUMA_ADDR} -p ${NUMA_PORT} +short +time=3 +tries=1 example.com A &>/dev/null; then + echo -e "${AMBER}Numa not responding on ${NUMA_ADDR}:${NUMA_PORT}${RESET}" >&2 + exit 1 +fi + +# Verify we can flush cache +if ! curl -s -X DELETE "http://${NUMA_ADDR}:${API_PORT}/cache" &>/dev/null; then + echo -e "${AMBER}Cannot flush cache via API at ${NUMA_ADDR}:${API_PORT}${RESET}" >&2 + exit 1 +fi + +measure_ms() { + local start end + start=$(python3 -c 'import time; print(time.time())') + eval "$1" &>/dev/null + end=$(python3 -c 'import time; print(time.time())') + python3 -c "print(round(($end - $start) * 1000, 1))" +} + +printf "${BOLD}%-22s %10s %10s %8s${RESET}\n" "Domain" "Numa (ms)" "1.1.1.1" "Delta" +printf "%-22s %10s %10s %8s\n" "----------------------" "----------" "----------" "--------" + +numa_total=0 +dig_total=0 +count=0 + +for domain in "${DOMAINS[@]}"; do + numa_sum=0 + dig_sum=0 + + for ((r=1; r<=ROUNDS; r++)); do + # Flush Numa cache + curl -s -X DELETE "http://${NUMA_ADDR}:${API_PORT}/cache" &>/dev/null + sleep 0.05 + + # Measure Numa (recursive from root, cold cache) + ms=$(measure_ms "dig @${NUMA_ADDR} -p ${NUMA_PORT} +short +time=10 +tries=1 ${domain} A") + numa_sum=$(python3 -c "print(round($numa_sum + $ms, 1))") + + # Measure dig against 1.1.1.1 (Cloudflare — warm cache, but shows baseline) + ms=$(measure_ms "dig @1.1.1.1 +short +time=10 +tries=1 ${domain} A") + dig_sum=$(python3 -c "print(round($dig_sum + $ms, 1))") + done + + numa_avg=$(python3 -c "print(round($numa_sum / $ROUNDS, 1))") + dig_avg=$(python3 -c "print(round($dig_sum / $ROUNDS, 1))") + delta=$(python3 -c "d = round($numa_avg - $dig_avg, 1); print(f'+{d}' if d > 0 else str(d))") + + # Color the delta + delta_color="$GREEN" + if python3 -c "exit(0 if $numa_avg > $dig_avg * 1.5 else 1)" 2>/dev/null; then + delta_color="$AMBER" + fi + + printf "%-22s %8s ms %8s ms ${delta_color}%6s ms${RESET}\n" "$domain" "$numa_avg" "$dig_avg" "$delta" + + numa_total=$(python3 -c "print(round($numa_total + $numa_avg, 1))") + dig_total=$(python3 -c "print(round($dig_total + $dig_avg, 1))") + count=$((count + 1)) +done + +echo "" +numa_mean=$(python3 -c "print(round($numa_total / $count, 1))") +dig_mean=$(python3 -c "print(round($dig_total / $count, 1))") +delta_mean=$(python3 -c "d = round($numa_mean - $dig_mean, 1); print(f'+{d}' if d > 0 else str(d))") + +printf "${BOLD}%-22s %8s ms %8s ms %6s ms${RESET}\n" "AVERAGE" "$numa_mean" "$dig_mean" "$delta_mean" + +echo "" +echo -e "${DIM}Note: Numa resolves recursively from root hints (cold cache).${RESET}" +echo -e "${DIM}1.1.1.1 serves from Cloudflare's global cache (warm). The comparison${RESET}" +echo -e "${DIM}is intentionally unfair — it shows Numa's worst case vs the best case${RESET}" +echo -e "${DIM}of a global anycast resolver. Cached Numa queries resolve in <1ms.${RESET}" diff --git a/src/api.rs b/src/api.rs index a0bae58..e638fba 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1029,6 +1029,7 @@ mod tests { upstream_port: 53, lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), timeout: std::time::Duration::from_secs(3), + hedge_delay: std::time::Duration::ZERO, proxy_tld: "numa".to_string(), proxy_tld_suffix: ".numa".to_string(), lan_enabled: false, diff --git a/src/cache.rs b/src/cache.rs index 5bdde85..82795bc 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,9 +1,10 @@ use std::collections::HashMap; use std::time::{Duration, Instant}; +use crate::buffer::BytePacketBuffer; use crate::packet::DnsPacket; use crate::question::QueryType; -use crate::record::DnsRecord; +use crate::wire::WireMeta; #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] pub enum DnssecStatus { @@ -26,14 +27,16 @@ impl DnssecStatus { } struct CacheEntry { - packet: DnsPacket, + wire: Vec, + meta: WireMeta, inserted_at: Instant, ttl: Duration, dnssec_status: DnssecStatus, } -/// DNS cache using a two-level map (domain -> query_type -> entry) so that -/// lookups can borrow `&str` instead of allocating a `String` key. +const STALE_WINDOW: Duration = Duration::from_secs(3600); + +/// DNS cache with serve-stale (RFC 8767). Stores raw wire bytes. pub struct DnsCache { entries: HashMap>, entry_count: usize, @@ -53,6 +56,80 @@ impl DnsCache { } } + /// Look up cached wire bytes, patching ID and TTLs in the returned copy. + /// Implements serve-stale (RFC 8767): expired entries within STALE_WINDOW + /// are returned with TTL=1 and `stale=true` so callers can revalidate. + pub fn lookup_wire( + &self, + domain: &str, + qtype: QueryType, + new_id: u16, + ) -> Option<(Vec, DnssecStatus, bool)> { + let type_map = self.entries.get(domain)?; + let entry = type_map.get(&qtype)?; + + let elapsed = entry.inserted_at.elapsed(); + let (remaining, stale) = if elapsed < entry.ttl { + let secs = (entry.ttl - elapsed).as_secs() as u32; + (secs.max(1), false) + } else if elapsed < entry.ttl + STALE_WINDOW { + (1, true) + } else { + return None; + }; + + let mut wire = entry.wire.clone(); + crate::wire::patch_id(&mut wire, new_id); + crate::wire::patch_ttls(&mut wire, &entry.meta.ttl_offsets, remaining); + + Some((wire, entry.dnssec_status, stale)) + } + + pub fn insert_wire( + &mut self, + domain: &str, + qtype: QueryType, + wire: &[u8], + dnssec_status: DnssecStatus, + ) { + let meta = match crate::wire::scan_ttl_offsets(wire) { + Ok(m) => m, + Err(_) => return, // malformed wire, skip + }; + + if self.entry_count >= self.max_entries { + self.evict_expired(); + if self.entry_count >= self.max_entries { + return; + } + } + + let min_ttl = crate::wire::min_ttl_from_wire(wire, &meta) + .unwrap_or(self.min_ttl) + .clamp(self.min_ttl, self.max_ttl); + + let type_map = if let Some(existing) = self.entries.get_mut(domain) { + existing + } else { + self.entries.entry(domain.to_string()).or_default() + }; + + if !type_map.contains_key(&qtype) { + self.entry_count += 1; + } + + type_map.insert( + qtype, + CacheEntry { + wire: wire.to_vec(), + meta, + inserted_at: Instant::now(), + ttl: Duration::from_secs(min_ttl as u64), + dnssec_status, + }, + ); + } + /// Read-only lookup — expired entries are left in place (cleaned up on insert). pub fn lookup(&self, domain: &str, qtype: QueryType) -> Option { self.lookup_with_status(domain, qtype).map(|(pkt, _)| pkt) @@ -63,23 +140,28 @@ impl DnsCache { domain: &str, qtype: QueryType, ) -> Option<(DnsPacket, DnssecStatus)> { - let type_map = self.entries.get(domain)?; - let entry = type_map.get(&qtype)?; + let (wire, status, _stale) = self.lookup_wire(domain, qtype, 0)?; + let mut buf = BytePacketBuffer::from_bytes(&wire); + let pkt = DnsPacket::from_buffer(&mut buf).ok()?; + Some((pkt, status)) + } - let elapsed = entry.inserted_at.elapsed(); - if elapsed >= entry.ttl { - return None; + pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { + self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate); + } + + pub fn insert_with_status( + &mut self, + domain: &str, + qtype: QueryType, + packet: &DnsPacket, + dnssec_status: DnssecStatus, + ) { + let mut buf = BytePacketBuffer::new(); + if packet.write(&mut buf).is_err() { + return; } - - let remaining_secs = (entry.ttl - elapsed).as_secs() as u32; - let remaining = remaining_secs.max(1); - - let mut packet = entry.packet.clone(); - adjust_ttls(&mut packet.answers, remaining); - adjust_ttls(&mut packet.authorities, remaining); - adjust_ttls(&mut packet.resources, remaining); - - Some((packet, entry.dnssec_status)) + self.insert_wire(domain, qtype, buf.filled(), dnssec_status); } pub fn ttl_remaining(&self, domain: &str, qtype: QueryType) -> Option<(u32, u32)> { @@ -105,49 +187,6 @@ impl DnsCache { false } - pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { - self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate); - } - - pub fn insert_with_status( - &mut self, - domain: &str, - qtype: QueryType, - packet: &DnsPacket, - dnssec_status: DnssecStatus, - ) { - if self.entry_count >= self.max_entries { - self.evict_expired(); - if self.entry_count >= self.max_entries { - return; - } - } - - let min_ttl = extract_min_ttl(&packet.answers) - .unwrap_or(self.min_ttl) - .clamp(self.min_ttl, self.max_ttl); - - let type_map = if let Some(existing) = self.entries.get_mut(domain) { - existing - } else { - self.entries.entry(domain.to_string()).or_default() - }; - - if !type_map.contains_key(&qtype) { - self.entry_count += 1; - } - - type_map.insert( - qtype, - CacheEntry { - packet: packet.clone(), - inserted_at: Instant::now(), - ttl: Duration::from_secs(min_ttl as u64), - dnssec_status, - }, - ); - } - pub fn len(&self) -> usize { self.entry_count } @@ -179,7 +218,8 @@ impl DnsCache { + 1; total += type_map.capacity() * inner_slot; for entry in type_map.values() { - total += entry.packet.heap_bytes(); + total += entry.wire.capacity() + + entry.meta.ttl_offsets.capacity() * std::mem::size_of::(); } } total @@ -228,20 +268,11 @@ pub struct CacheInfo { pub ttl_remaining: u32, } -fn extract_min_ttl(records: &[DnsRecord]) -> Option { - records.iter().map(|r| r.ttl()).min() -} - -fn adjust_ttls(records: &mut [DnsRecord], new_ttl: u32) { - for record in records.iter_mut() { - record.set_ttl(new_ttl); - } -} - #[cfg(test)] mod tests { use super::*; use crate::packet::DnsPacket; + use crate::record::DnsRecord; #[test] fn heap_bytes_grows_with_entries() { diff --git a/src/config.rs b/src/config.rs index ae9f685..5f9db73 100644 --- a/src/config.rs +++ b/src/config.rs @@ -138,6 +138,8 @@ pub struct UpstreamConfig { pub fallback: Vec, #[serde(default = "default_timeout_ms")] pub timeout_ms: u64, + #[serde(default = "default_hedge_ms")] + pub hedge_ms: u64, #[serde(default = "default_root_hints")] pub root_hints: Vec, #[serde(default = "default_prime_tlds")] @@ -154,6 +156,7 @@ impl Default for UpstreamConfig { port: default_upstream_port(), fallback: Vec::new(), timeout_ms: default_timeout_ms(), + hedge_ms: default_hedge_ms(), root_hints: default_root_hints(), prime_tlds: default_prime_tlds(), srtt: default_srtt(), @@ -271,6 +274,9 @@ fn default_upstream_port() -> u16 { fn default_timeout_ms() -> u64 { 5000 } +fn default_hedge_ms() -> u64 { + 10 +} #[derive(Deserialize)] pub struct CacheConfig { diff --git a/src/ctx.rs b/src/ctx.rs index 3ef6a0a..2b26a06 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -16,7 +16,9 @@ use crate::blocklist::BlocklistStore; use crate::buffer::BytePacketBuffer; use crate::cache::{DnsCache, DnssecStatus}; use crate::config::{UpstreamMode, ZoneMap}; -use crate::forward::{forward_query, forward_with_failover, Upstream, UpstreamPool}; +use crate::forward::{ + forward_query_raw, forward_with_failover_raw, Upstream, UpstreamPool, +}; use crate::header::ResultCode; use crate::health::HealthMeta; use crate::lan::PeerStore; @@ -47,6 +49,7 @@ pub struct ServerCtx { pub upstream_port: u16, pub lan_ip: Mutex, pub timeout: Duration, + pub hedge_delay: Duration, pub proxy_tld: String, pub proxy_tld_suffix: String, // pre-computed ".{tld}" to avoid per-query allocation pub lan_enabled: bool, @@ -81,6 +84,7 @@ pub struct ServerCtx { /// (and logging parse errors) before calling this function. pub async fn resolve_query( query: DnsPacket, + raw_wire: &[u8], src_addr: SocketAddr, ctx: &ServerCtx, ) -> crate::Result { @@ -177,9 +181,8 @@ pub async fn resolve_query( // Conditional forwarding takes priority over recursive mode // (e.g. Tailscale .ts.net, VPC private zones) let upstream = Upstream::Udp(fwd_addr); - match forward_query(&query, &upstream, ctx.timeout).await { + match forward_and_cache(raw_wire, &upstream, ctx, &qname, qtype).await { Ok(resp) => { - ctx.cache.write().unwrap().insert(&qname, qtype, &resp); (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate) } Err(e) => { @@ -221,10 +224,19 @@ pub async fn resolve_query( (resp, path, DnssecStatus::Indeterminate) } else { let pool = ctx.upstream_pool.lock().unwrap().clone(); - match forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await { - Ok(resp) => { - ctx.cache.write().unwrap().insert(&qname, qtype, &resp); - (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate) + match forward_with_failover_raw(raw_wire, &pool, &ctx.srtt, ctx.timeout, ctx.hedge_delay).await { + Ok(resp_wire) => { + ctx.cache.write().unwrap().insert_wire( + &qname, qtype, &resp_wire, DnssecStatus::Indeterminate, + ); + let mut buf = BytePacketBuffer::from_bytes(&resp_wire); + match DnsPacket::from_buffer(&mut buf) { + Ok(resp) => (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate), + Err(e) => { + error!("{} | {:?} {} | PARSE ERROR | {}", src_addr, qtype, qname, e); + (DnsPacket::response_from(&query, ResultCode::SERVFAIL), QueryPath::UpstreamError, DnssecStatus::Indeterminate) + } + } } Err(e) => { error!( @@ -347,12 +359,29 @@ pub async fn resolve_query( Ok(resp_buffer) } -/// Handle a DNS query received over UDP. Thin wrapper around resolve_query. +async fn forward_and_cache( + wire: &[u8], + upstream: &Upstream, + ctx: &ServerCtx, + qname: &str, + qtype: QueryType, +) -> crate::Result { + let resp_wire = forward_query_raw(wire, upstream, ctx.timeout).await?; + ctx.cache + .write() + .unwrap() + .insert_wire(qname, qtype, &resp_wire, DnssecStatus::Indeterminate); + let mut buf = BytePacketBuffer::from_bytes(&resp_wire); + DnsPacket::from_buffer(&mut buf) +} + pub async fn handle_query( mut buffer: BytePacketBuffer, + raw_len: usize, src_addr: SocketAddr, ctx: &ServerCtx, ) -> crate::Result<()> { + let raw_wire = buffer.buf[..raw_len].to_vec(); let query = match DnsPacket::from_buffer(&mut buffer) { Ok(packet) => packet, Err(e) => { @@ -360,7 +389,7 @@ pub async fn handle_query( return Ok(()); } }; - match resolve_query(query, src_addr, ctx).await { + match resolve_query(query, &raw_wire, src_addr, ctx).await { Ok(resp_buffer) => { ctx.socket.send_to(resp_buffer.filled(), src_addr).await?; } diff --git a/src/doh.rs b/src/doh.rs index cf50b31..e31b6fe 100644 --- a/src/doh.rs +++ b/src/doh.rs @@ -82,7 +82,7 @@ async fn resolve_doh(dns_bytes: &[u8], src: SocketAddr, ctx: &ServerCtx) -> Resp let query_rd = query.header.recursion_desired; let questions = query.questions.clone(); - match resolve_query(query, src, ctx).await { + match resolve_query(query, dns_bytes, src, ctx).await { Ok(resp_buffer) => { let min_ttl = extract_min_ttl(resp_buffer.filled()); dns_response(resp_buffer.filled(), min_ttl) @@ -102,11 +102,10 @@ async fn resolve_doh(dns_bytes: &[u8], src: SocketAddr, ctx: &ServerCtx) -> Resp } fn extract_min_ttl(wire: &[u8]) -> u32 { - let mut buf = BytePacketBuffer::from_bytes(wire); - match DnsPacket::from_buffer(&mut buf) { - Ok(pkt) => pkt.answers.iter().map(|r| r.ttl()).min().unwrap_or(0), - Err(_) => 0, - } + crate::wire::scan_ttl_offsets(wire) + .ok() + .and_then(|meta| crate::wire::min_ttl_from_wire(wire, &meta)) + .unwrap_or(0) } fn dns_response(wire: &[u8], min_ttl: u32) -> Response { diff --git a/src/dot.rs b/src/dot.rs index 0d48fa2..4513f60 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -177,8 +177,7 @@ where break; }; - // Parse query up-front so we can echo its question section in SERVFAIL - // responses when resolve_query fails. + let raw_wire = buffer.buf[..msg_len].to_vec(); let query = match DnsPacket::from_buffer(&mut buffer) { Ok(q) => q, Err(e) => { @@ -200,7 +199,7 @@ where } }; - match resolve_query(query.clone(), remote_addr, ctx).await { + match resolve_query(query.clone(), &raw_wire, remote_addr, ctx).await { Ok(resp_buffer) => { if write_framed(&mut stream, resp_buffer.filled()) .await @@ -370,6 +369,7 @@ mod tests { upstream_port: 53, lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), timeout: Duration::from_millis(200), + hedge_delay: Duration::ZERO, proxy_tld: "numa".to_string(), proxy_tld_suffix: ".numa".to_string(), lan_enabled: false, diff --git a/src/forward.rs b/src/forward.rs index ea2f1e2..401ae1c 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -65,6 +65,13 @@ pub fn parse_upstream(s: &str, default_port: u16) -> Result { if s.starts_with("https://") { let client = reqwest::Client::builder() .use_rustls_tls() + .http2_initial_stream_window_size(65_535) + .http2_initial_connection_window_size(65_535) + .http2_keep_alive_interval(Duration::from_secs(15)) + .http2_keep_alive_while_idle(true) + .http2_keep_alive_timeout(Duration::from_secs(10)) + .pool_idle_timeout(Duration::from_secs(300)) + .pool_max_idle_per_host(1) .build() .unwrap_or_default(); return Ok(Upstream::Doh { @@ -325,13 +332,170 @@ async fn forward_doh( let mut send_buffer = BytePacketBuffer::new(); query.write(&mut send_buffer)?; + let resp_bytes = forward_doh_raw(send_buffer.filled(), url, client, timeout_duration).await?; + let mut recv_buffer = BytePacketBuffer::from_bytes(&resp_bytes); + DnsPacket::from_buffer(&mut recv_buffer) +} + +pub async fn forward_query_raw( + wire: &[u8], + upstream: &Upstream, + timeout_duration: Duration, +) -> Result> { + match upstream { + Upstream::Udp(addr) => forward_udp_raw(wire, *addr, timeout_duration).await, + Upstream::Doh { url, client } => forward_doh_raw(wire, url, client, timeout_duration).await, + } +} + +pub async fn forward_with_hedging_raw( + wire: &[u8], + primary: &Upstream, + secondary: &Upstream, + hedge_delay: Duration, + timeout_duration: Duration, +) -> Result> { + use tokio::time::sleep; + + let primary_fut = forward_query_raw(wire, primary, timeout_duration); + tokio::pin!(primary_fut); + + let delay = sleep(hedge_delay); + tokio::pin!(delay); + + // Phase 1: wait for either primary to return, or the hedge delay. + tokio::select! { + result = &mut primary_fut => return result, + _ = &mut delay => {} + } + + // Phase 2: hedge delay expired — fire secondary while still polling primary. + let secondary_fut = forward_query_raw(wire, secondary, timeout_duration); + tokio::pin!(secondary_fut); + + // First successful response wins. If one errors, wait for the other. + let mut primary_err: Option = None; + let mut secondary_err: Option = None; + + loop { + tokio::select! { + r = &mut primary_fut, if primary_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if let Some(se) = secondary_err.take() { + return Err(se); + } + primary_err = Some(e); + } + } + } + r = &mut secondary_fut, if secondary_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if let Some(pe) = primary_err.take() { + return Err(pe); + } + secondary_err = Some(e); + } + } + } + } + + match (primary_err, secondary_err) { + (Some(pe), Some(_)) => return Err(pe), + (pe, se) => { primary_err = pe; secondary_err = se; } + } + } +} + +pub async fn forward_with_failover_raw( + wire: &[u8], + pool: &UpstreamPool, + srtt: &RwLock, + timeout_duration: Duration, + hedge_delay: Duration, +) -> Result> { + let mut candidates: Vec<(usize, u64)> = pool + .primary + .iter() + .enumerate() + .map(|(i, u)| { + let rtt = match u { + Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()), + _ => 0, + }; + (i, rtt) + }) + .collect(); + candidates.sort_by_key(|&(_, rtt)| rtt); + + let all_upstreams: Vec<&Upstream> = candidates + .iter() + .map(|&(i, _)| &pool.primary[i]) + .chain(pool.fallback.iter()) + .collect(); + + let mut last_err: Option> = None; + + for upstream in &all_upstreams { + let start = Instant::now(); + let result = if !hedge_delay.is_zero() && matches!(upstream, Upstream::Doh { .. }) { + // Hedge against the same upstream: parallel h2 streams on same + // connection. Independent stream scheduling rescues dispatch spikes. + forward_with_hedging_raw(wire, upstream, upstream, hedge_delay, timeout_duration).await + } else { + forward_query_raw(wire, upstream, timeout_duration).await + }; + match result { + Ok(resp) => { + if let Upstream::Udp(addr) = upstream { + let rtt_ms = start.elapsed().as_millis() as u64; + srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false); + } + return Ok(resp); + } + Err(e) => { + if let Upstream::Udp(addr) = upstream { + srtt.write().unwrap().record_failure(addr.ip()); + } + log::debug!("upstream {} failed: {}", upstream, e); + last_err = Some(e); + } + } + } + + Err(last_err.unwrap_or_else(|| "no upstream configured".into())) +} + +async fn forward_udp_raw( + wire: &[u8], + upstream: SocketAddr, + timeout_duration: Duration, +) -> Result> { + let socket = UdpSocket::bind("0.0.0.0:0").await?; + socket.send_to(wire, upstream).await?; + + let mut recv_buf = vec![0u8; 4096]; + let (size, _) = timeout(timeout_duration, socket.recv_from(&mut recv_buf)).await??; + recv_buf.truncate(size); + Ok(recv_buf) +} + +async fn forward_doh_raw( + wire: &[u8], + url: &str, + client: &reqwest::Client, + timeout_duration: Duration, +) -> Result> { let resp = timeout( timeout_duration, client .post(url) .header("content-type", "application/dns-message") .header("accept", "application/dns-message") - .body(send_buffer.filled().to_vec()) + .body(wire.to_vec()) .send(), ) .await?? @@ -339,9 +503,25 @@ async fn forward_doh( let bytes = resp.bytes().await?; log::debug!("DoH response: {} bytes", bytes.len()); + Ok(bytes.to_vec()) +} - let mut recv_buffer = BytePacketBuffer::from_bytes(&bytes); - DnsPacket::from_buffer(&mut recv_buffer) +/// Send a lightweight keepalive query to a DoH upstream to prevent +/// the HTTP/2 + TLS connection from going idle and being torn down. +pub async fn keepalive_doh(upstream: &Upstream) { + if let Upstream::Doh { url, client } = upstream { + // Query for . NS — minimal, always succeeds, response is small + let wire: &[u8] = &[ + 0x00, 0x00, // ID + 0x01, 0x00, // flags: RD=1 + 0x00, 0x01, // QDCOUNT=1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // AN=0, NS=0, AR=0 + 0x00, // root name (.) + 0x00, 0x02, // type NS + 0x00, 0x01, // class IN + ]; + let _ = forward_doh_raw(wire, url, client, Duration::from_secs(5)).await; + } } #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index 4074020..92a0b00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,6 +26,7 @@ pub mod srtt; pub mod stats; pub mod system_dns; pub mod tls; +pub mod wire; pub type Error = Box; pub type Result = std::result::Result; diff --git a/src/main.rs b/src/main.rs index 7592186..0211a59 100644 --- a/src/main.rs +++ b/src/main.rs @@ -297,6 +297,7 @@ async fn main() -> numa::Result<()> { upstream_port: config.upstream.port, lan_ip: Mutex::new(numa::lan::detect_lan_ip().unwrap_or(std::net::Ipv4Addr::LOCALHOST)), timeout: Duration::from_millis(config.upstream.timeout_ms), + hedge_delay: Duration::from_millis(config.upstream.hedge_ms), proxy_tld_suffix: if config.proxy.tld.is_empty() { String::new() } else { @@ -511,6 +512,14 @@ async fn main() -> numa::Result<()> { }); } + // Spawn DoH connection keepalive — prevents idle TLS teardown + { + let keepalive_ctx = Arc::clone(&ctx); + tokio::spawn(async move { + doh_keepalive_loop(keepalive_ctx).await; + }); + } + // Spawn HTTP API server let api_ctx = Arc::clone(&ctx); let api_addr: SocketAddr = format!("{}:{}", config.server.api_bind_addr, api_port).parse()?; @@ -590,7 +599,7 @@ async fn main() -> numa::Result<()> { #[allow(clippy::infinite_loop)] loop { let mut buffer = BytePacketBuffer::new(); - let (_, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await { + let (len, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await { Ok(r) => r, Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => { // Windows delivers ICMP port-unreachable as ConnectionReset on UDP sockets @@ -598,10 +607,11 @@ async fn main() -> numa::Result<()> { } Err(e) => return Err(e.into()), }; + let raw_len = len; let ctx = Arc::clone(&ctx); tokio::spawn(async move { - if let Err(e) = handle_query(buffer, src_addr, &ctx).await { + if let Err(e) = handle_query(buffer, raw_len, src_addr, &ctx).await { error!("{} | HANDLER ERROR | {}", src_addr, e); } }); @@ -777,6 +787,18 @@ async fn warm_domain(ctx: &ServerCtx, domain: &str) { } } +async fn doh_keepalive_loop(ctx: Arc) { + let mut interval = tokio::time::interval(Duration::from_secs(25)); + interval.tick().await; // skip first immediate tick + loop { + interval.tick().await; + let pool = ctx.upstream_pool.lock().unwrap().clone(); + if let Some(upstream) = pool.preferred() { + numa::forward::keepalive_doh(upstream).await; + } + } +} + async fn cache_warm_loop(ctx: Arc, domains: Vec) { tokio::time::sleep(Duration::from_secs(2)).await; diff --git a/src/recursive.rs b/src/recursive.rs index 24d0367..2609f7f 100644 --- a/src/recursive.rs +++ b/src/recursive.rs @@ -202,23 +202,22 @@ pub(crate) fn resolve_iterative<'a>( let mut ns_idx = 0; for _ in 0..MAX_REFERRAL_DEPTH { - let ns_addr = match ns_addrs.get(ns_idx) { - Some(addr) => *addr, - None => return Err("no nameserver available".into()), - }; + if ns_idx >= ns_addrs.len() { + return Err("no nameserver available".into()); + } let (q_name, q_type) = minimize_query(qname, qtype, ¤t_zone); debug!( - "recursive: querying {} for {:?} {} (zone: {}, depth {})", - ns_addr, q_type, q_name, current_zone, referral_depth + "recursive: querying {} (+ hedge) for {:?} {} (zone: {}, depth {})", + ns_addrs[ns_idx], q_type, q_name, current_zone, referral_depth ); - let response = match send_query(q_name, q_type, ns_addr, srtt).await { + let response = match send_query_hedged(q_name, q_type, &ns_addrs[ns_idx..], srtt).await { Ok(r) => r, Err(e) => { - debug!("recursive: NS {} failed: {}", ns_addr, e); - ns_idx += 1; + debug!("recursive: NS query failed: {}", e); + ns_idx += 2; // both tried, skip past them continue; } }; @@ -228,6 +227,9 @@ pub(crate) fn resolve_iterative<'a>( { if let Some(zone) = referral_zone(&response) { current_zone = zone; + let mut cache_w = cache.write().unwrap(); + cache_ns_delegation(&mut cache_w, ¤t_zone, &response); + drop(cache_w); } let mut all_ns = extract_ns_from_records(&response.answers); if all_ns.is_empty() { @@ -296,6 +298,7 @@ pub(crate) fn resolve_iterative<'a>( { let mut cache_w = cache.write().unwrap(); + cache_ns_delegation(&mut cache_w, ¤t_zone, &response); cache_ds_from_authority(&mut cache_w, &response); } let mut new_ns_addrs = resolve_ns_addrs_from_glue(&response, &ns_names, cache); @@ -560,6 +563,23 @@ fn cache_ds_from_authority(cache: &mut DnsCache, response: &DnsPacket) { } } +/// Cache NS delegation records from a referral response so that +/// `find_closest_ns` can skip re-querying TLD servers on subsequent lookups. +fn cache_ns_delegation(cache: &mut DnsCache, zone: &str, response: &DnsPacket) { + let ns_records: Vec<_> = response + .authorities + .iter() + .filter(|r| matches!(r, DnsRecord::NS { .. })) + .cloned() + .collect(); + if ns_records.is_empty() { + return; + } + let mut pkt = make_glue_packet(); + pkt.answers = ns_records; + cache.insert(zone, QueryType::NS, &pkt); +} + fn make_glue_packet() -> DnsPacket { let mut pkt = DnsPacket::new(); pkt.header.response = true; @@ -587,6 +607,91 @@ async fn tcp_with_srtt( } } +/// Smart NS query: fire to two servers simultaneously when SRTT is unknown +/// (cold queries), or to the best server with SRTT-based hedge when known. +async fn send_query_hedged( + qname: &str, + qtype: QueryType, + servers: &[SocketAddr], + srtt: &RwLock, +) -> crate::Result { + if servers.is_empty() { + return Err("no nameserver available".into()); + } + if servers.len() == 1 { + return send_query(qname, qtype, servers[0], srtt).await; + } + + let primary = servers[0]; + let secondary = servers[1]; + let primary_known = srtt.read().unwrap().is_known(primary.ip()); + + if !primary_known { + // Cold: fire both simultaneously, first response wins + debug!( + "recursive: parallel query to {} and {} for {:?} {}", + primary, secondary, qtype, qname + ); + let fut_a = send_query(qname, qtype, primary, srtt); + let fut_b = send_query(qname, qtype, secondary, srtt); + tokio::pin!(fut_a); + tokio::pin!(fut_b); + + // First Ok wins. If one errors, wait for the other. + let mut a_done = false; + let mut b_done = false; + let mut a_err: Option = None; + let mut b_err: Option = None; + + loop { + tokio::select! { + r = &mut fut_a, if !a_done => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { a_done = true; a_err = Some(e); } + } + } + r = &mut fut_b, if !b_done => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { b_done = true; b_err = Some(e); } + } + } + } + match (a_err.take(), b_err.take()) { + (Some(e), Some(_)) => return Err(e), + (a, b) => { a_err = a; b_err = b; } + } + } + } else { + // Warm: send to best, hedge after SRTT × 3 if slow + let hedge_ms = srtt.read().unwrap().get(primary.ip()) * 3; + let hedge_delay = Duration::from_millis(hedge_ms.max(50)); + + let fut_a = send_query(qname, qtype, primary, srtt); + tokio::pin!(fut_a); + let delay = tokio::time::sleep(hedge_delay); + tokio::pin!(delay); + + tokio::select! { + r = &mut fut_a => return r, + _ = &mut delay => {} + } + + debug!( + "recursive: hedging {} -> {} after {}ms for {:?} {}", + primary, secondary, hedge_ms, qtype, qname + ); + let fut_b = send_query(qname, qtype, secondary, srtt); + tokio::pin!(fut_b); + + tokio::select! { + r = fut_a => r, + r = fut_b => r, + } + } +} + async fn send_query( qname: &str, qtype: QueryType, diff --git a/src/srtt.rs b/src/srtt.rs index f763a37..fe4df1e 100644 --- a/src/srtt.rs +++ b/src/srtt.rs @@ -45,6 +45,11 @@ impl SrttCache { } } + /// Whether we have observed RTT data for this IP. + pub fn is_known(&self, ip: IpAddr) -> bool { + self.entries.contains_key(&ip) + } + /// Apply time-based decay: each DECAY_AFTER_SECS period halves distance to INITIAL. fn decayed_srtt(entry: &SrttEntry) -> u64 { Self::decay_for_age(entry.srtt_ms, entry.updated_at.elapsed().as_secs()) diff --git a/src/wire.rs b/src/wire.rs new file mode 100644 index 0000000..6b68c3a --- /dev/null +++ b/src/wire.rs @@ -0,0 +1,1347 @@ +//! Wire-level DNS utilities: question extraction, TTL offset scanning, and patching. +//! +//! These operate directly on raw DNS wire bytes without full packet parsing, +//! enabling zero-copy forwarding and wire-level caching. + +use crate::question::QueryType; +use crate::Result; + +/// Metadata extracted from scanning a DNS response's wire bytes. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WireMeta { + /// Byte offsets of every TTL field in answer + authority + additional sections. + /// Each offset points to the first byte of a 4-byte big-endian TTL. + /// EDNS OPT pseudo-records are excluded (their "TTL" is flags, not a real TTL). + pub ttl_offsets: Vec, + /// How many of the offsets belong to the answer section (the first `answer_count` + /// entries). Used to extract min-TTL from answers only. + pub answer_count: usize, +} + +/// Extract the first question's (domain, query type) from raw DNS wire bytes. +/// +/// Reads only the 12-byte header + first question section. Returns the lowercased +/// domain name and query type without allocating a full `DnsPacket`. +pub fn extract_question(wire: &[u8]) -> Result<(String, QueryType)> { + if wire.len() < 12 { + return Err("wire too short for DNS header".into()); + } + let qdcount = u16::from_be_bytes([wire[4], wire[5]]); + if qdcount == 0 { + return Err("no questions in wire".into()); + } + + let mut pos = 12; + let mut domain = String::with_capacity(64); + read_wire_qname(wire, &mut pos, &mut domain)?; + + if pos + 4 > wire.len() { + return Err("wire truncated in question section".into()); + } + let qtype = u16::from_be_bytes([wire[pos], wire[pos + 1]]); + // skip QTYPE(2) + QCLASS(2) + + Ok((domain, QueryType::from_num(qtype))) +} + +/// Scan a DNS response's wire bytes and return metadata about TTL field locations. +/// +/// Walks the header, skips the question section, then for each resource record in +/// answer, authority, and additional sections, records the byte offset of the TTL +/// field. EDNS OPT records (type 41 with root name) are excluded. +pub fn scan_ttl_offsets(wire: &[u8]) -> Result { + if wire.len() < 12 { + return Err("wire too short for DNS header".into()); + } + + let qdcount = u16::from_be_bytes([wire[4], wire[5]]) as usize; + let ancount = u16::from_be_bytes([wire[6], wire[7]]) as usize; + let nscount = u16::from_be_bytes([wire[8], wire[9]]) as usize; + let arcount = u16::from_be_bytes([wire[10], wire[11]]) as usize; + + let mut pos = 12; + + // Skip question section + for _ in 0..qdcount { + skip_wire_name(wire, &mut pos)?; + if pos + 4 > wire.len() { + return Err("wire truncated in question section".into()); + } + pos += 4; // QTYPE(2) + QCLASS(2) + } + + let mut ttl_offsets = Vec::new(); + + // Process answer + authority + additional sections + let section_counts = [ancount, nscount, arcount]; + let mut answer_offset_count = 0; + + for (section_idx, &count) in section_counts.iter().enumerate() { + for _ in 0..count { + // Check if this is an OPT record: root name (0x00) + type 41 + let is_opt = pos < wire.len() + && wire[pos] == 0x00 + && pos + 3 <= wire.len() + && u16::from_be_bytes([wire[pos + 1], wire[pos + 2]]) == 41; + + // Skip name + skip_wire_name(wire, &mut pos)?; + + if pos + 10 > wire.len() { + return Err("wire truncated in resource record".into()); + } + + // TYPE(2) + CLASS(2) = 4 bytes before TTL + let ttl_offset = pos + 4; + + if !is_opt { + ttl_offsets.push(ttl_offset); + if section_idx == 0 { + answer_offset_count += 1; + } + } + + // Skip TYPE(2) + CLASS(2) + TTL(4) + RDLENGTH(2) = 10 bytes + let rdlength = u16::from_be_bytes([wire[pos + 8], wire[pos + 9]]) as usize; + pos += 10 + rdlength; + + if pos > wire.len() { + return Err("wire truncated in resource record RDATA".into()); + } + } + } + + Ok(WireMeta { + ttl_offsets, + answer_count: answer_offset_count, + }) +} + +/// Extract the minimum TTL from the answer section offsets of a wire response. +pub fn min_ttl_from_wire(wire: &[u8], meta: &WireMeta) -> Option { + meta.ttl_offsets + .iter() + .take(meta.answer_count) + .filter_map(|&off| { + if off + 4 <= wire.len() { + Some(u32::from_be_bytes([ + wire[off], + wire[off + 1], + wire[off + 2], + wire[off + 3], + ])) + } else { + None + } + }) + .min() +} + +/// Patch the transaction ID (bytes 0..2) in a DNS wire message. +pub fn patch_id(wire: &mut [u8], new_id: u16) { + let bytes = new_id.to_be_bytes(); + wire[0] = bytes[0]; + wire[1] = bytes[1]; +} + +/// Patch all TTL fields at the given offsets to `new_ttl`. +pub fn patch_ttls(wire: &mut [u8], offsets: &[usize], new_ttl: u32) { + let bytes = new_ttl.to_be_bytes(); + for &off in offsets { + wire[off] = bytes[0]; + wire[off + 1] = bytes[1]; + wire[off + 2] = bytes[2]; + wire[off + 3] = bytes[3]; + } +} + +/// Read a DNS name from wire bytes at `pos`, handling compression pointers. +/// Advances `pos` past the name as it appears at the current position +/// (compression pointer targets do NOT advance `pos`). +fn read_wire_qname(wire: &[u8], pos: &mut usize, out: &mut String) -> Result<()> { + let mut jumped = false; + let mut read_pos = *pos; + let mut jumps = 0; + let max_jumps = 20; + + loop { + if read_pos >= wire.len() { + return Err("wire truncated reading name".into()); + } + let len = wire[read_pos] as usize; + + // Compression pointer: top 2 bits set + if len & 0xC0 == 0xC0 { + if read_pos + 1 >= wire.len() { + return Err("wire truncated in compression pointer".into()); + } + if !jumped { + *pos = read_pos + 2; // advance past the pointer + } + let offset = ((len & 0x3F) << 8) | wire[read_pos + 1] as usize; + read_pos = offset; + jumped = true; + jumps += 1; + if jumps > max_jumps { + return Err("too many compression jumps".into()); + } + continue; + } + + if len == 0 { + if !jumped { + *pos = read_pos + 1; + } + break; + } + + if read_pos + 1 + len > wire.len() { + return Err("wire truncated in name label".into()); + } + + if !out.is_empty() { + out.push('.'); + } + for &b in &wire[read_pos + 1..read_pos + 1 + len] { + out.push(b.to_ascii_lowercase() as char); + } + read_pos += 1 + len; + } + + Ok(()) +} + +/// Skip a DNS name in wire bytes, advancing `pos` past it. +fn skip_wire_name(wire: &[u8], pos: &mut usize) -> Result<()> { + loop { + if *pos >= wire.len() { + return Err("wire truncated skipping name".into()); + } + let len = wire[*pos] as usize; + + if len & 0xC0 == 0xC0 { + *pos += 2; // compression pointer is 2 bytes + return Ok(()); + } + if len == 0 { + *pos += 1; + return Ok(()); + } + *pos += 1 + len; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::buffer::BytePacketBuffer; + use crate::cache::{DnsCache, DnssecStatus}; + use crate::header::ResultCode; + use crate::packet::{DnsPacket, EdnsOpt}; + use crate::question::DnsQuestion; + use crate::record::DnsRecord; + + // ── Helpers ────────────────────────────────────────────────────── + + /// Serialize a DnsPacket to wire bytes. + fn to_wire(pkt: &DnsPacket) -> Vec { + let mut buf = BytePacketBuffer::new(); + pkt.write(&mut buf).unwrap(); + buf.filled().to_vec() + } + + /// Build a minimal response with given answers. + fn response(id: u16, domain: &str, answers: Vec) -> DnsPacket { + let mut pkt = DnsPacket::new(); + pkt.header.id = id; + pkt.header.response = true; + pkt.header.recursion_desired = true; + pkt.header.recursion_available = true; + pkt.header.rescode = ResultCode::NOERROR; + pkt.questions + .push(DnsQuestion::new(domain.to_string(), QueryType::A)); + pkt.answers = answers; + pkt + } + + fn a_record(domain: &str, ip: &str, ttl: u32) -> DnsRecord { + DnsRecord::A { + domain: domain.into(), + addr: ip.parse().unwrap(), + ttl, + } + } + + fn aaaa_record(domain: &str, ip: &str, ttl: u32) -> DnsRecord { + DnsRecord::AAAA { + domain: domain.into(), + addr: ip.parse().unwrap(), + ttl, + } + } + + fn cname_record(domain: &str, host: &str, ttl: u32) -> DnsRecord { + DnsRecord::CNAME { + domain: domain.into(), + host: host.into(), + ttl, + } + } + + fn ns_record(domain: &str, host: &str, ttl: u32) -> DnsRecord { + DnsRecord::NS { + domain: domain.into(), + host: host.into(), + ttl, + } + } + + fn mx_record(domain: &str, host: &str, priority: u16, ttl: u32) -> DnsRecord { + DnsRecord::MX { + domain: domain.into(), + priority, + host: host.into(), + ttl, + } + } + + // ── A. TTL offset extraction ──────────────────────────────────── + + #[test] + fn scan_single_a_record() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 1); + assert_eq!(meta.answer_count, 1); + + let off = meta.ttl_offsets[0]; + let ttl = u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]); + assert_eq!(ttl, 300); + } + + #[test] + fn scan_multiple_a_records() { + let pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 600), + a_record("example.com", "9.10.11.12", 120), + ], + ); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 3); + assert_eq!(meta.answer_count, 3); + + let ttls: Vec = meta + .ttl_offsets + .iter() + .map(|&off| u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]])) + .collect(); + assert_eq!(ttls, vec![300, 600, 120]); + } + + #[test] + fn scan_mixed_sections() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.authorities + .push(ns_record("example.com", "ns1.example.com", 3600)); + pkt.authorities + .push(ns_record("example.com", "ns2.example.com", 3600)); + pkt.resources + .push(a_record("ns1.example.com", "10.0.0.1", 1800)); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 4); // 1 answer + 2 authority + 1 additional + assert_eq!(meta.answer_count, 1); + } + + #[test] + fn scan_cname_chain() { + let pkt = response( + 0x1234, + "www.example.com", + vec![ + cname_record("www.example.com", "example.com", 300), + a_record("example.com", "1.2.3.4", 600), + ], + ); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 2); + assert_eq!(meta.answer_count, 2); + + let ttls: Vec = meta + .ttl_offsets + .iter() + .map(|&off| u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]])) + .collect(); + assert_eq!(ttls, vec![300, 600]); + } + + #[test] + fn scan_compressed_names() { + // Build a packet with name compression (the serializer uses compression + // for repeated domain names). Two A records for the same domain will + // have the second name compressed as a pointer. + let pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 600), + ], + ); + let wire = to_wire(&pkt); + + // Verify compression is actually present (second name should be a pointer) + // The first answer's name is at some offset, and the second should use 0xC0xx + let meta = scan_ttl_offsets(&wire).unwrap(); + assert_eq!(meta.ttl_offsets.len(), 2); + + let ttls: Vec = meta + .ttl_offsets + .iter() + .map(|&off| u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]])) + .collect(); + assert_eq!(ttls, vec![300, 600]); + } + + #[test] + fn scan_edns_opt_excluded() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: false, + options: vec![], + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + // Only the A record's TTL, not the OPT pseudo-record's "TTL" + assert_eq!(meta.ttl_offsets.len(), 1); + assert_eq!(meta.answer_count, 1); + } + + #[test] + fn scan_rrsig_only_wire_ttl() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.answers.push(DnsRecord::RRSIG { + domain: "example.com".into(), + type_covered: 1, // A + algorithm: 13, + labels: 2, + original_ttl: 9999, // must NOT appear in offsets + expiration: 1700000000, + inception: 1690000000, + key_tag: 12345, + signer_name: "example.com".into(), + signature: vec![0x01, 0x02, 0x03, 0x04], + ttl: 300, + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + // 2 TTL offsets: A record + RRSIG wire TTL + assert_eq!(meta.ttl_offsets.len(), 2); + assert_eq!(meta.answer_count, 2); + + // Both wire TTLs should be 300, not 9999 + for &off in &meta.ttl_offsets { + let ttl = + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]); + assert_eq!(ttl, 300); + } + + // Verify that 9999 (original_ttl) exists somewhere in the wire but is NOT in offsets + let original_ttl_bytes = 9999u32.to_be_bytes(); + let found_at = wire + .windows(4) + .position(|w| w == original_ttl_bytes) + .expect("original_ttl should be in wire"); + assert!( + !meta.ttl_offsets.contains(&found_at), + "original_ttl offset must not be in ttl_offsets" + ); + } + + #[test] + fn scan_nsec_variable_rdata() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.authorities.push(DnsRecord::NSEC { + domain: "example.com".into(), + next_domain: "z.example.com".into(), + type_bitmap: vec![0x00, 0x06, 0x40, 0x01, 0x00, 0x00, 0x00, 0x03], + ttl: 1800, + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 2); // A + NSEC + assert_eq!(meta.answer_count, 1); + + let nsec_ttl_off = meta.ttl_offsets[1]; + let ttl = u32::from_be_bytes([ + wire[nsec_ttl_off], + wire[nsec_ttl_off + 1], + wire[nsec_ttl_off + 2], + wire[nsec_ttl_off + 3], + ]); + assert_eq!(ttl, 1800); + } + + #[test] + fn scan_empty_response() { + let pkt = response(0x1234, "nxdomain.example.com", vec![]); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert!(meta.ttl_offsets.is_empty()); + assert_eq!(meta.answer_count, 0); + } + + #[test] + fn scan_unknown_record_type() { + // Manually build a response with an unknown type (99) using raw wire bytes + let mut pkt = response(0x1234, "example.com", vec![]); + pkt.answers.push(DnsRecord::UNKNOWN { + domain: "example.com".into(), + qtype: 99, + data: vec![0xDE, 0xAD, 0xBE, 0xEF], + ttl: 500, + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 1); + let off = meta.ttl_offsets[0]; + let ttl = u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]); + assert_eq!(ttl, 500); + } + + #[test] + fn scan_truncated_wire_returns_error() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let wire = to_wire(&pkt); + // Truncate mid-record + let truncated = &wire[..wire.len() - 2]; + assert!(scan_ttl_offsets(truncated).is_err()); + } + + #[test] + fn scan_too_short_for_header() { + assert!(scan_ttl_offsets(&[0u8; 5]).is_err()); + } + + #[test] + fn scan_query_packet_no_offsets() { + let pkt = DnsPacket::query(0x1234, "example.com", QueryType::A); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + assert!(meta.ttl_offsets.is_empty()); + } + + // ── B. TTL patching ───────────────────────────────────────────── + + #[test] + fn patch_ttl_single() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, 120); + + let off = meta.ttl_offsets[0]; + assert_eq!( + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]), + 120 + ); + } + + #[test] + fn patch_ttl_multiple() { + let pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 600), + a_record("example.com", "9.10.11.12", 900), + ], + ); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, 42); + + for &off in &meta.ttl_offsets { + assert_eq!( + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]), + 42 + ); + } + } + + #[test] + fn patch_ttl_preserves_other_bytes() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let original = to_wire(&pkt); + let mut patched = original.clone(); + let meta = scan_ttl_offsets(&patched).unwrap(); + + patch_ttls(&mut patched, &meta.ttl_offsets, 120); + + // Every byte outside TTL offsets should be identical + for (i, (&orig, &patc)) in original.iter().zip(patched.iter()).enumerate() { + let in_ttl = meta + .ttl_offsets + .iter() + .any(|&off| i >= off && i < off + 4); + if !in_ttl { + assert_eq!( + orig, patc, + "byte {} changed (outside TTL): orig={:#04x}, patched={:#04x}", + i, orig, patc + ); + } + } + } + + #[test] + fn patch_ttl_zero() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, 0); + + let off = meta.ttl_offsets[0]; + assert_eq!(&wire[off..off + 4], &[0, 0, 0, 0]); + } + + #[test] + fn patch_ttl_max_u32() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, u32::MAX); + + let off = meta.ttl_offsets[0]; + assert_eq!(&wire[off..off + 4], &[0xFF, 0xFF, 0xFF, 0xFF]); + } + + #[test] + fn patch_ttl_edns_untouched() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: true, + options: vec![], + }); + let original = to_wire(&pkt); + let mut patched = original.clone(); + let meta = scan_ttl_offsets(&patched).unwrap(); + + patch_ttls(&mut patched, &meta.ttl_offsets, 42); + + // Only the A record's TTL bytes should differ; everything else + // (including the OPT "TTL" containing the DO bit) must be unchanged. + for (i, (&orig, &patc)) in original.iter().zip(patched.iter()).enumerate() { + let in_ttl = meta + .ttl_offsets + .iter() + .any(|&off| i >= off && i < off + 4); + if !in_ttl { + assert_eq!( + orig, patc, + "byte {} changed (outside TTL): orig={:#04x}, patched={:#04x}", + i, orig, patc + ); + } + } + } + + // ── C. ID patching ────────────────────────────────────────────── + + #[test] + fn patch_id_basic() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut wire = to_wire(&pkt); + + patch_id(&mut wire, 0xABCD); + assert_eq!(&wire[0..2], &[0xAB, 0xCD]); + } + + #[test] + fn patch_id_preserves_flags() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let original = to_wire(&pkt); + let mut patched = original.clone(); + + patch_id(&mut patched, 0x9999); + + // Bytes 2..12 (flags + counts) unchanged + assert_eq!(&original[2..12], &patched[2..12]); + } + + #[test] + fn patch_id_zero() { + let pkt = response(0xFFFF, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut wire = to_wire(&pkt); + + patch_id(&mut wire, 0x0000); + assert_eq!(&wire[0..2], &[0x00, 0x00]); + } + + // ── D. extract_question ───────────────────────────────────────── + + #[test] + fn extract_question_basic() { + let pkt = DnsPacket::query(0x1234, "Example.COM", QueryType::A); + let wire = to_wire(&pkt); + let (domain, qtype) = extract_question(&wire).unwrap(); + + assert_eq!(domain, "example.com"); // lowercased + assert_eq!(qtype, QueryType::A); + } + + #[test] + fn extract_question_aaaa() { + let pkt = DnsPacket::query(0x1234, "rust-lang.org", QueryType::AAAA); + let wire = to_wire(&pkt); + let (domain, qtype) = extract_question(&wire).unwrap(); + + assert_eq!(domain, "rust-lang.org"); + assert_eq!(qtype, QueryType::AAAA); + } + + #[test] + fn extract_question_too_short() { + assert!(extract_question(&[0u8; 5]).is_err()); + } + + #[test] + fn extract_question_no_questions() { + let mut wire = to_wire(&DnsPacket::query(0x1234, "example.com", QueryType::A)); + // Zero out QDCOUNT (bytes 4-5) + wire[4] = 0; + wire[5] = 0; + assert!(extract_question(&wire).is_err()); + } + + // ── E. min_ttl_from_wire ──────────────────────────────────────── + + #[test] + fn min_ttl_answers_only() { + let mut pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 60), + ], + ); + pkt.authorities + .push(ns_record("example.com", "ns1.example.com", 10)); // lower but in authority, not answer + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(min_ttl_from_wire(&wire, &meta), Some(60)); // from answers only + } + + #[test] + fn min_ttl_empty_answers() { + let pkt = response(0x1234, "example.com", vec![]); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + assert_eq!(min_ttl_from_wire(&wire, &meta), None); + } + + // ── F. Round-trip fidelity ────────────────────────────────────── + // + // These verify that wire bytes → scan → patch → parse produces the + // same semantic content as the original packet. They test the full + // integration path that the wire-level cache will use. + + #[test] + fn round_trip_simple_a() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + let mut patched = wire.clone(); + patch_id(&mut patched, 0xABCD); + patch_ttls(&mut patched, &meta.ttl_offsets, 120); + + // Parse the patched wire + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.header.id, 0xABCD); + assert_eq!(parsed.answers.len(), 1); + match &parsed.answers[0] { + DnsRecord::A { domain, addr, ttl } => { + assert_eq!(domain, "example.com"); + assert_eq!(*addr, "1.2.3.4".parse::().unwrap()); + assert_eq!(*ttl, 120); + } + other => panic!("expected A record, got {:?}", other), + } + } + + #[test] + fn round_trip_edns_survives() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: true, + options: vec![], + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 42); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + let edns = parsed.edns.as_ref().expect("EDNS should survive"); + assert_eq!(edns.udp_payload_size, 1232); + assert!(edns.do_bit); + } + + #[test] + fn round_trip_dnssec_full() { + let mut pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + DnsRecord::RRSIG { + domain: "example.com".into(), + type_covered: 1, + algorithm: 13, + labels: 2, + original_ttl: 300, + expiration: 1700000000, + inception: 1690000000, + key_tag: 12345, + signer_name: "example.com".into(), + signature: vec![1, 2, 3, 4, 5, 6, 7, 8], + ttl: 300, + }, + ], + ); + pkt.authorities.push(DnsRecord::NSEC { + domain: "example.com".into(), + next_domain: "z.example.com".into(), + type_bitmap: vec![0x00, 0x06, 0x40, 0x01, 0x00, 0x00, 0x00, 0x03], + ttl: 300, + }); + pkt.resources.push(DnsRecord::DNSKEY { + domain: "example.com".into(), + flags: 257, + protocol: 3, + algorithm: 13, + public_key: vec![10, 20, 30, 40], + ttl: 3600, + }); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: true, + options: vec![], + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + // 4 TTL offsets: A + RRSIG (answers) + NSEC (authority) + DNSKEY (additional) + // OPT excluded + assert_eq!(meta.ttl_offsets.len(), 4); + assert_eq!(meta.answer_count, 2); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 42); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.answers.len(), 2); + assert_eq!(parsed.authorities.len(), 1); + assert_eq!(parsed.resources.len(), 1); + assert!(parsed.edns.is_some()); + + // All TTLs should be 42 now + for ans in &parsed.answers { + assert_eq!(ans.ttl(), 42); + } + for auth in &parsed.authorities { + assert_eq!(auth.ttl(), 42); + } + for res in &parsed.resources { + assert_eq!(res.ttl(), 42); + } + + // RRSIG original_ttl must be preserved (it's inside RDATA, not a wire TTL) + match &parsed.answers[1] { + DnsRecord::RRSIG { original_ttl, .. } => assert_eq!(*original_ttl, 300), + other => panic!("expected RRSIG, got {:?}", other), + } + } + + #[test] + fn round_trip_nxdomain_soa() { + let mut pkt = DnsPacket::new(); + pkt.header.id = 0x5678; + pkt.header.response = true; + pkt.header.rescode = ResultCode::NXDOMAIN; + pkt.questions + .push(DnsQuestion::new("missing.example.com".into(), QueryType::A)); + // SOA in authority (we don't have a SOA variant, so use NS as proxy for offset testing) + pkt.authorities + .push(ns_record("example.com", "ns1.example.com", 900)); + + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 1); + assert_eq!(meta.answer_count, 0); // no answers, only authority + + let mut patched = wire.clone(); + patch_id(&mut patched, 0x9999); + patch_ttls(&mut patched, &meta.ttl_offsets, 60); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.header.id, 0x9999); + assert_eq!(parsed.header.rescode, ResultCode::NXDOMAIN); + assert_eq!(parsed.authorities[0].ttl(), 60); + } + + #[test] + fn round_trip_mx_record() { + let pkt = response( + 0x1234, + "example.com", + vec![mx_record("example.com", "mail.example.com", 10, 3600)], + ); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 100); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + match &parsed.answers[0] { + DnsRecord::MX { + domain, + priority, + host, + ttl, + } => { + assert_eq!(domain, "example.com"); + assert_eq!(*priority, 10); + assert_eq!(host, "mail.example.com"); + assert_eq!(*ttl, 100); + } + other => panic!("expected MX, got {:?}", other), + } + } + + #[test] + fn round_trip_many_records() { + let answers: Vec = (0..20) + .map(|i| a_record("example.com", &format!("10.0.0.{}", i), 300 + i * 10)) + .collect(); + let pkt = response(0x1234, "example.com", answers); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 20); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 1); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.answers.len(), 20); + for ans in &parsed.answers { + assert_eq!(ans.ttl(), 1); + } + } + + // ── G. Edge cases ─────────────────────────────────────────────── + + #[test] + fn scan_rejects_empty_wire() { + assert!(scan_ttl_offsets(&[]).is_err()); + } + + #[test] + fn extract_question_rejects_empty_wire() { + assert!(extract_question(&[]).is_err()); + } + + // ── H. Cache behavior tests ───────────────────────────────────── + // + // These test existing DnsCache behavior that must be preserved after + // the wire-level migration. They use the current parsed-packet API + // and serve as a regression suite. + + #[test] + fn cache_insert_lookup_hit() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + + let (result, status) = cache + .lookup_with_status("example.com", QueryType::A) + .expect("should hit"); + assert_eq!(result.answers.len(), 1); + assert_eq!(status, DnssecStatus::Indeterminate); + } + + #[test] + fn cache_lookup_adjusts_ttl() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + + let (result, _) = cache.lookup_with_status("example.com", QueryType::A).unwrap(); + // TTL should be <= 300 (at most original, reduced by elapsed time) + assert!(result.answers[0].ttl() <= 300); + assert!(result.answers[0].ttl() > 0); + } + + #[test] + fn cache_miss_wrong_domain() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + + assert!(cache + .lookup_with_status("other.com", QueryType::A) + .is_none()); + } + + #[test] + fn cache_miss_wrong_qtype() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + + assert!(cache + .lookup_with_status("example.com", QueryType::AAAA) + .is_none()); + } + + #[test] + fn cache_overwrite_no_double_count() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt1 = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt2 = response(0x5678, "example.com", vec![a_record("example.com", "5.6.7.8", 600)]); + + cache.insert("example.com", QueryType::A, &pkt1); + assert_eq!(cache.len(), 1); + + cache.insert("example.com", QueryType::A, &pkt2); + assert_eq!(cache.len(), 1); // no double count + + let (result, _) = cache.lookup_with_status("example.com", QueryType::A).unwrap(); + match &result.answers[0] { + DnsRecord::A { addr, .. } => { + assert_eq!(*addr, "5.6.7.8".parse::().unwrap()) + } + _ => panic!("expected A record"), + } + } + + #[test] + fn cache_ttl_clamped_min() { + let mut cache = DnsCache::new(100, 60, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 5)]); + cache.insert("example.com", QueryType::A, &pkt); + + let (remaining, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); + assert_eq!(total, 60); // clamped up from 5 + assert!(remaining <= 60); + } + + #[test] + fn cache_ttl_clamped_max() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 999999)]); + cache.insert("example.com", QueryType::A, &pkt); + + let (_, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); + assert_eq!(total, 3600); // clamped down from 999999 + } + + #[test] + fn cache_len_empty_clear() { + let mut cache = DnsCache::new(100, 1, 3600); + assert!(cache.is_empty()); + assert_eq!(cache.len(), 0); + + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + assert!(!cache.is_empty()); + assert_eq!(cache.len(), 1); + + cache.clear(); + assert!(cache.is_empty()); + assert_eq!(cache.len(), 0); + assert!(cache.lookup("example.com", QueryType::A).is_none()); + } + + #[test] + fn cache_remove_domain() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt_a = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt_aaaa = response( + 0x5678, + "example.com", + vec![aaaa_record("example.com", "::1", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt_a); + cache.insert("example.com", QueryType::AAAA, &pkt_aaaa); + assert_eq!(cache.len(), 2); + + cache.remove("example.com"); + assert_eq!(cache.len(), 0); + assert!(cache.lookup("example.com", QueryType::A).is_none()); + assert!(cache.lookup("example.com", QueryType::AAAA).is_none()); + } + + #[test] + fn cache_list_entries() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt_a = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt_b = response(0x5678, "test.org", vec![a_record("test.org", "5.6.7.8", 600)]); + cache.insert("example.com", QueryType::A, &pkt_a); + cache.insert("test.org", QueryType::A, &pkt_b); + + let list = cache.list(); + assert_eq!(list.len(), 2); + let domains: Vec<&str> = list.iter().map(|e| e.domain.as_str()).collect(); + assert!(domains.contains(&"example.com")); + assert!(domains.contains(&"test.org")); + } + + #[test] + fn cache_heap_bytes_grows() { + let mut cache = DnsCache::new(100, 1, 3600); + let empty = cache.heap_bytes(); + + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + assert!(cache.heap_bytes() > empty); + } + + #[test] + fn cache_needs_warm_behavior() { + let mut cache = DnsCache::new(100, 1, 3600); + + // Missing → needs warm + assert!(cache.needs_warm("example.com")); + + // Both A and AAAA cached → does not need warm + let pkt_a = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt_aaaa = response( + 0x5678, + "example.com", + vec![aaaa_record("example.com", "::1", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt_a); + cache.insert("example.com", QueryType::AAAA, &pkt_aaaa); + assert!(!cache.needs_warm("example.com")); + + // Only A cached → needs warm (AAAA missing) + cache.remove("example.com"); + cache.insert("example.com", QueryType::A, &pkt_a); + assert!(cache.needs_warm("example.com")); + } + + #[test] + fn cache_ttl_remaining_api() { + let mut cache = DnsCache::new(100, 60, 3600); + assert!(cache.ttl_remaining("missing.com", QueryType::A).is_none()); + + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + let (remaining, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); + assert_eq!(total, 300); + assert!(remaining > 0); + assert!(remaining <= 300); + } + + #[test] + fn cache_dnssec_status_preserved() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert_with_status("example.com", QueryType::A, &pkt, DnssecStatus::Secure); + + let (_, status) = cache + .lookup_with_status("example.com", QueryType::A) + .unwrap(); + assert_eq!(status, DnssecStatus::Secure); + } + + // ── I. Memory footprint baseline ────────────────────────────── + // + // Measures the current parsed-packet cache memory vs what wire-level + // storage would cost for the same entries. This is a baseline — after + // migration, re-run to verify improvement. + + #[test] + fn memory_footprint_baseline() { + let mut cache = DnsCache::new(1000, 1, 3600); + + // Simulate a realistic cache: 50 domains, mix of record types + let domains: Vec = (0..50).map(|i| format!("domain{}.example.com", i)).collect(); + + let mut total_wire_bytes = 0usize; + let mut total_wire_meta_bytes = 0usize; + + for (i, domain) in domains.iter().enumerate() { + // A record + let pkt_a = response( + i as u16, + domain, + vec![ + a_record(domain, &format!("10.0.{}.1", i % 256), 300), + a_record(domain, &format!("10.0.{}.2", i % 256), 300), + ], + ); + cache.insert(domain, QueryType::A, &pkt_a); + + let wire_a = to_wire(&pkt_a); + let meta_a = scan_ttl_offsets(&wire_a).unwrap(); + total_wire_bytes += wire_a.len(); + total_wire_meta_bytes += meta_a.ttl_offsets.len() * std::mem::size_of::(); + + // AAAA record for half of them + if i % 2 == 0 { + let pkt_aaaa = response( + (i + 1000) as u16, + domain, + vec![aaaa_record(domain, &format!("2001:db8::{:x}", i), 600)], + ); + cache.insert(domain, QueryType::AAAA, &pkt_aaaa); + + let wire_aaaa = to_wire(&pkt_aaaa); + let meta_aaaa = scan_ttl_offsets(&wire_aaaa).unwrap(); + total_wire_bytes += wire_aaaa.len(); + total_wire_meta_bytes += + meta_aaaa.ttl_offsets.len() * std::mem::size_of::(); + } + } + + // Compare only the variable per-entry data (what actually differs + // between parsed and wire storage). HashMap overhead, domain keys, + // Instant, Duration, DnssecStatus are identical in both approaches. + let mut parsed_data_bytes = 0usize; + // Re-insert and measure just packet.heap_bytes() per entry + { + let mut cache2 = DnsCache::new(1000, 1, 3600); + for (i, domain) in domains.iter().enumerate() { + let pkt_a = response( + i as u16, + domain, + vec![ + a_record(domain, &format!("10.0.{}.1", i % 256), 300), + a_record(domain, &format!("10.0.{}.2", i % 256), 300), + ], + ); + parsed_data_bytes += pkt_a.heap_bytes(); + cache2.insert(domain, QueryType::A, &pkt_a); + + if i % 2 == 0 { + let pkt_aaaa = response( + (i + 1000) as u16, + domain, + vec![aaaa_record(domain, &format!("2001:db8::{:x}", i), 600)], + ); + parsed_data_bytes += pkt_aaaa.heap_bytes(); + cache2.insert(domain, QueryType::AAAA, &pkt_aaaa); + } + } + } + + let wire_total = total_wire_bytes + total_wire_meta_bytes; + let entry_count = cache.len(); + + // Also measure the struct size difference per entry + let parsed_struct = std::mem::size_of::(); + let wire_struct = std::mem::size_of::>() + std::mem::size_of::>() + std::mem::size_of::(); // wire + offsets + answer_count + + println!(); + println!("=== Cache Memory Footprint Baseline ({} entries) ===", entry_count); + println!(); + println!("Variable data (heap, per-entry payload):"); + println!(" Parsed (packet.heap_bytes): {} bytes ({:.1}/entry)", parsed_data_bytes, parsed_data_bytes as f64 / entry_count as f64); + println!(" Wire (bytes + TTL offsets): {} bytes ({:.1}/entry)", wire_total, wire_total as f64 / entry_count as f64); + println!(" Ratio: {:.1}x smaller with wire", parsed_data_bytes as f64 / wire_total as f64); + println!(); + println!("Struct overhead (stack, per entry):"); + println!(" DnsPacket: {} bytes", parsed_struct); + println!(" Wire (Vec+Vec+usize): {} bytes", wire_struct); + println!(); + println!("Total per-entry (struct + avg heap):"); + let parsed_total_per = parsed_struct as f64 + parsed_data_bytes as f64 / entry_count as f64; + let wire_total_per = wire_struct as f64 + wire_total as f64 / entry_count as f64; + println!(" Parsed: {:.0} bytes", parsed_total_per); + println!(" Wire: {:.0} bytes", wire_total_per); + println!(" Ratio: {:.1}x smaller with wire", parsed_total_per / wire_total_per); + println!(); + + // Assertions + assert!( + wire_total < parsed_data_bytes, + "wire data ({wire_total}) should be smaller than parsed data ({parsed_data_bytes})" + ); + } + + #[test] + fn cache_max_entries_cap() { + let mut cache = DnsCache::new(2, 1, 3600); + for i in 0..3 { + let domain = format!("test{}.com", i); + let pkt = response( + i as u16, + &domain, + vec![a_record(&domain, &format!("1.2.3.{}", i), 3600)], + ); + cache.insert(&domain, QueryType::A, &pkt); + } + // Should not exceed max (third insert is silently dropped or evicts) + assert!(cache.len() <= 2); + } +} diff --git a/tests/integration.sh b/tests/integration.sh index 92da878..c70ec59 100755 --- a/tests/integration.sh +++ b/tests/integration.sh @@ -53,7 +53,17 @@ CONF echo "Starting Numa on :$PORT ($SUITE_NAME)..." RUST_LOG=info "$BINARY" "$CONFIG" > "$LOG" 2>&1 & NUMA_PID=$! - sleep 4 + sleep 2 + + # Wait for blocklist to load (if blocking is enabled in this suite) + if echo "$SUITE_CONFIG" | grep -q 'enabled = true'; then + for i in $(seq 1 20); do + LOADED=$(curl -sf http://127.0.0.1:$API_PORT/blocking/stats 2>/dev/null \ + | grep -o '"domains_loaded":[0-9]*' | cut -d: -f2) + if [ "${LOADED:-0}" -gt 0 ]; then break; fi + sleep 1 + done + fi if ! kill -0 "$NUMA_PID" 2>/dev/null; then echo "Failed to start Numa:" From 5d9a3a809b4bf7e85b3243efa69293cf2f0e399f Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 06:22:42 +0300 Subject: [PATCH 02/21] feat: DoT client, recursive optimization, bench refactor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add DoT forwarding client (tls://IP#hostname upstream config) - Recursive: cache NS delegations, serve-stale (RFC 8767), parallel NS queries on cold, no TCP fallback on individual UDP timeouts, 400ms NS/TCP timeout (down from 800/1500ms) - Reduce recursive p99 from 2367ms to 402ms (vs Unbound's 148ms) - Refactor benchmark suite: generic compare_two engine, delete one-off diagnostics (1969 → 750 lines) - Code cleanup: forward_query delegates to _raw, Option for tls_name, saturating_sub for ns_idx --- Cargo.lock | 2 +- benches/numa-bench.toml | 10 +- benches/recursive_compare.rs | 2060 ++++++++++++---------------------- src/forward.rs | 53 +- src/recursive.rs | 29 +- 5 files changed, 754 insertions(+), 1400 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eaba214..c0f7692 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1358,7 +1358,7 @@ dependencies = [ "tokio-rustls", "toml", "tower", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] diff --git a/benches/numa-bench.toml b/benches/numa-bench.toml index 0e058af..6124840 100644 --- a/benches/numa-bench.toml +++ b/benches/numa-bench.toml @@ -5,7 +5,8 @@ api_bind_addr = "127.0.0.1" data_dir = "/tmp/numa-bench" [upstream] -mode = "recursive" +mode = "forward" +address = ["https://9.9.9.9/dns-query"] timeout_ms = 10000 [cache] @@ -15,8 +16,13 @@ max_ttl = 3600 [blocking] enabled = false +[proxy] +port = 8080 +tls_port = 8443 + [dot] -enabled = false +enabled = true +port = 8530 [mobile] enabled = false diff --git a/benches/recursive_compare.rs b/benches/recursive_compare.rs index e35768c..12f3689 100644 --- a/benches/recursive_compare.rs +++ b/benches/recursive_compare.rs @@ -1,20 +1,18 @@ -//! DoH forwarding benchmark: Numa vs hickory-resolver. +//! DNS forwarding benchmark suite. //! -//! Both forward to the same DoH upstream (Quad9). -//! Measures end-to-end resolution time through each implementation. -//! -//! Fairness: -//! - Both reuse a single TLS connection (Numa via persistent server, -//! Hickory via a shared resolver instance with cache_size=0). -//! - Measurement order is alternated each round to cancel order bias. -//! - Numa cache is flushed before each query. -//! - 100 domains × 10 rounds for statistical confidence. +//! Modes: +//! (default) Numa server (UDP) vs Hickory library (DoH) — the original benchmark +//! --diag Hickory connection reuse diagnostic (20 queries) +//! --diag-clients Per-query reqwest vs Hickory timing (20 queries) +//! --direct Library-to-library: Numa forward_query_raw vs Hickory resolver.lookup +//! --hedge-5x Hedging: single vs hedge-same vs hedge-dual vs Hickory (5 iterations) +//! --vs-unbound Server-to-server: Numa vs Unbound (plain UDP, caching) +//! --vs-dot DoT server: Numa vs Unbound +//! --vs-doh-servers DoH server: Numa vs Unbound (DoT upstream) //! //! Setup: -//! 1. Start a bench Numa instance: -//! cargo run -- benches/numa-bench.toml -//! 2. Run: -//! cargo bench --bench recursive_compare +//! 1. Start a bench Numa instance: cargo run -- benches/numa-bench.toml +//! 2. Run: cargo bench --bench recursive_compare [-- --mode] use std::net::SocketAddr; use std::time::{Duration, Instant}; @@ -130,216 +128,585 @@ const DOMAINS: &[&str] = &[ const ROUNDS: usize = 10; fn main() { - let diag = std::env::args().any(|a| a == "--diag"); - let direct = std::env::args().any(|a| a == "--direct"); + let arg = |flag: &str| std::env::args().any(|a| a == flag); let rt = tokio::runtime::Runtime::new().unwrap(); - if diag { - run_diag(&rt); - return; + if arg("--diag") { + return run_diag(&rt); + } + if arg("--diag-clients") { + return run_diag_clients(&rt); + } + if arg("--direct") { + return run_direct(&rt); + } + if arg("--hedge-5x") { + return run_hedge_multi(&rt, 5); + } + if arg("--vs-unbound") { + return run_server_comparison(&rt, "Unbound", "127.0.0.1:5456", 5); + } + if arg("--vs-dnscrypt") { + return run_server_comparison(&rt, "dnscrypt-proxy", "127.0.0.1:5455", 5); + } + if arg("--vs-dot") { + return run_dot_comparison(&rt, 5); + } + if arg("--vs-doh-servers") { + return run_doh_comparison(&rt, 5); } - if direct { - run_direct(&rt); - return; + // Default: Numa server (UDP) vs Hickory library (DoH) + run_default(&rt); +} + +// ── Generic 2-way comparison engine ───────────────────────────── + +fn compare_two( + rt: &tokio::runtime::Runtime, + title: &str, + name_a: &str, + name_b: &str, + measure_a: &dyn Fn(&str) -> f64, + measure_b: &dyn Fn(&str) -> f64, + iterations: usize, +) { + let flush = std::env::args().any(|a| a == "--flush"); + println!("{}", title); + println!( + "{} domains × {} rounds × {} iterations\n", + DOMAINS.len(), + ROUNDS, + iterations + ); + + let mut all_a = Vec::new(); + let mut all_b = Vec::new(); + let mut iter_stats: Vec<[(f64, f64, f64, f64, f64); 2]> = Vec::new(); + + for iter in 1..=iterations { + println!(" iteration {}/{}...", iter, iterations); + let mut a = Vec::new(); + let mut b = Vec::new(); + + for domain in DOMAINS { + for round in 0..ROUNDS { + if flush { + flush_cache(); + std::thread::sleep(Duration::from_millis(5)); + } + if round % 2 == 0 { + a.push(measure_a(domain)); + b.push(measure_b(domain)); + } else { + b.push(measure_b(domain)); + a.push(measure_a(domain)); + } + } + } + + iter_stats.push([stats(&mut a), stats(&mut b)]); + all_a.extend_from_slice(&a); + all_b.extend_from_slice(&b); } - if std::env::args().any(|a| a == "--diag-clients") { - run_diag_clients(&rt); - return; + print_results( + name_a, + name_b, + &iter_stats, + &mut all_a, + &mut all_b, + iterations, + ); +} + +fn print_results( + name_a: &str, + name_b: &str, + iter_stats: &[[(f64, f64, f64, f64, f64); 2]], + all_a: &mut Vec, + all_b: &mut Vec, + iterations: usize, +) { + let w = name_a.len().max(name_b.len()).max(6); + + println!("\n=== Per-iteration medians ==="); + println!("{:<8} {:>w$} {:>w$}", "iter", name_a, name_b, w = w + 3); + for (i, s) in iter_stats.iter().enumerate() { + println!( + "{:<8} {:>w$.1} ms {:>w$.1} ms", + i + 1, + s[0].1, + s[1].1, + w = w + ); } - if std::env::args().any(|a| a == "--spike-trace") { - run_spike_trace(&rt); - return; + println!("\n=== Per-iteration p99 ==="); + println!("{:<8} {:>w$} {:>w$}", "iter", name_a, name_b, w = w + 3); + for (i, s) in iter_stats.iter().enumerate() { + println!( + "{:<8} {:>w$.1} ms {:>w$.1} ms", + i + 1, + s[0].3, + s[1].3, + w = w + ); } - if std::env::args().any(|a| a == "--spike-phases") { - run_spike_phases(&rt); - return; - } + let (a_m, a_med, a_p95, a_p99, a_sd) = stats(all_a); + let (b_m, b_med, b_p95, b_p99, b_sd) = stats(all_b); - if std::env::args().any(|a| a == "--spike-heartbeat") { - run_spike_heartbeat(&rt); - return; - } + let total = iterations * DOMAINS.len() * ROUNDS; + println!("\n=== Aggregated ({} samples per method) ===\n", total); + println!("{:<10} {:>w$} {:>w$}", "", name_a, name_b, w = w + 3); + println!("{:<10} {:>w$.1} ms {:>w$.1} ms", "mean", a_m, b_m, w = w); + println!( + "{:<10} {:>w$.1} ms {:>w$.1} ms", + "median", + a_med, + b_med, + w = w + ); + println!( + "{:<10} {:>w$.1} ms {:>w$.1} ms", + "p95", + a_p95, + b_p95, + w = w + ); + println!( + "{:<10} {:>w$.1} ms {:>w$.1} ms", + "p99", + a_p99, + b_p99, + w = w + ); + println!("{:<10} {:>w$.1} ms {:>w$.1} ms", "σ", a_sd, b_sd, w = w); - if std::env::args().any(|a| a == "--hedge") { - run_hedge(&rt); - return; - } + let pct = |a: f64, b: f64| { + if b.abs() > 0.001 { + (a - b) / b * 100.0 + } else { + 0.0 + } + }; + println!("\n{} vs {}:", name_a, name_b); + println!(" mean: {:+.1} ms ({:+.0}%)", a_m - b_m, pct(a_m, b_m)); + println!( + " median: {:+.1} ms ({:+.0}%)", + a_med - b_med, + pct(a_med, b_med) + ); + println!( + " p99: {:+.1} ms ({:+.0}%)", + a_p99 - b_p99, + pct(a_p99, b_p99) + ); +} - if std::env::args().any(|a| a == "--hedge-5x") { - run_hedge_multi(&rt, 5); - return; - } - - if std::env::args().any(|a| a == "--vs-dnscrypt") { - run_vs_dnscrypt(&rt, 5); - return; - } - - if std::env::args().any(|a| a == "--vs-unbound") { - run_vs_unbound(&rt, 5); - return; - } +// ── Modes ─────────────────────────────────────────────────────── +/// Default: Numa server (UDP) vs Hickory library (DoH), cache flushed. +fn run_default(rt: &tokio::runtime::Runtime) { let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); - - println!("DoH Forwarding Benchmark: Numa vs hickory-resolver"); - println!("Both forwarding to {DOH_UPSTREAM}"); - println!("{} domains × {ROUNDS} rounds", DOMAINS.len()); - println!(); - - // Verify bench Numa is reachable if rt.block_on(query_udp(numa_addr, "example.com")).is_none() { eprintln!("Bench Numa not responding on {numa_addr}"); - eprintln!(); - eprintln!("Start it with:"); - eprintln!(" cargo run -- benches/numa-bench.toml"); + eprintln!("Start with: cargo run -- benches/numa-bench.toml"); std::process::exit(1); } - // Build a shared Hickory resolver (reuses TLS connection, like Numa does) let resolver = rt.block_on(build_hickory_resolver()); - // Warm up both paths (TLS handshake, connection establishment) - println!("Warming up connections..."); + println!("Warming up..."); for _ in 0..3 { rt.block_on(query_udp(numa_addr, "example.com")); rt.block_on(query_hickory_doh(&resolver, "example.com")); } flush_cache(); - println!( - "{:<30} {:>10} {:>10} {:>10} {:>8} {:>8}", - "Domain", "Numa (ms)", "Hickory", "Delta", "σ Numa", "σ Hick" - ); - println!("{}", "-".repeat(92)); - - let mut numa_all = Vec::new(); - let mut hickory_all = Vec::new(); - let mut per_domain: Vec<(&str, f64, f64, f64, f64, f64)> = Vec::new(); - - for domain in DOMAINS { - let mut numa_times = Vec::with_capacity(ROUNDS); - let mut hickory_times = Vec::with_capacity(ROUNDS); - - for round in 0..ROUNDS { + compare_two( + rt, + &format!("DoH Forwarding: Numa server vs Hickory library\nBoth → {DOH_UPSTREAM}"), + "Numa", + "Hickory", + &|domain| { flush_cache(); std::thread::sleep(Duration::from_millis(10)); + let t = Instant::now(); + let _ = rt.block_on(query_udp(numa_addr, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }, + &|domain| { + let t = Instant::now(); + let _ = rt.block_on(query_hickory_doh(&resolver, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }, + 1, + ); +} - // Alternate measurement order each round to cancel systematic bias - if round % 2 == 0 { - // Numa first - let t = measure(&rt, || rt.block_on(query_udp(numa_addr, domain))); - numa_times.push(t); - let t = measure(&rt, || rt.block_on(query_hickory_doh(&resolver, domain))); - hickory_times.push(t); - } else { - // Hickory first - let t = measure(&rt, || rt.block_on(query_hickory_doh(&resolver, domain))); - hickory_times.push(t); - flush_cache(); - std::thread::sleep(Duration::from_millis(10)); - let t = measure(&rt, || rt.block_on(query_udp(numa_addr, domain))); - numa_times.push(t); +/// Library-to-library: Numa forward_query_raw vs Hickory resolver.lookup. +fn run_direct(rt: &tokio::runtime::Runtime) { + let upstream = numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let resolver = rt.block_on(build_hickory_resolver()); + let timeout = Duration::from_secs(10); + + println!("Warming up..."); + for _ in 0..3 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + compare_two( + rt, + &format!("Direct DoH: Numa forward_query_raw vs Hickory resolver.lookup\nBoth → {DOH_UPSTREAM}, no server pipeline"), + "Numa", "Hickory", + &|domain| { + let w = build_query_vec(domain); + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); + t.elapsed().as_secs_f64() * 1000.0 + }, + &|domain| { + let t = Instant::now(); + let _ = rt.block_on(query_hickory_doh(&resolver, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }, + 5, + ); +} + +/// Server-to-server: Numa vs another server, both on plain UDP. +fn run_server_comparison( + rt: &tokio::runtime::Runtime, + other_name: &str, + other_addr: &str, + iterations: usize, +) { + let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); + let other: SocketAddr = other_addr.parse().unwrap(); + + for (name, addr) in [("Numa", numa_addr), (other_name, other)] { + if rt.block_on(query_udp(addr, "example.com")).is_none() { + eprintln!("{name} not responding on {addr}"); + std::process::exit(1); + } + } + + println!("Warming up..."); + for _ in 0..5 { + let _ = rt.block_on(query_udp(numa_addr, "example.com")); + let _ = rt.block_on(query_udp(other, "example.com")); + } + + compare_two( + rt, + &format!("Server-to-Server: Numa vs {other_name} (UDP, caching)"), + "Numa", + other_name, + &|domain| { + let t = Instant::now(); + let _ = rt.block_on(query_udp(numa_addr, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }, + &|domain| { + let t = Instant::now(); + let _ = rt.block_on(query_udp(other, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }, + iterations, + ); +} + +/// DoT server comparison: Numa vs Unbound. +fn run_dot_comparison(rt: &tokio::runtime::Runtime, iterations: usize) { + const NUMA_DOT: &str = "127.0.0.1:8530"; + const UNBOUND_DOT: &str = "127.0.0.1:8531"; + + let _ = rustls::crypto::ring::default_provider().install_default(); + let tls_config = build_insecure_tls_config(); + + for (name, addr) in [("Numa", NUMA_DOT), ("Unbound", UNBOUND_DOT)] { + match rt.block_on(query_dot_once(addr, "example.com", &tls_config)) { + Ok(_) => println!("{name} DoT: OK"), + Err(e) => { + eprintln!("{name} DoT not responding on {addr}: {e}"); + std::process::exit(1); + } + } + } + + println!("Warming up..."); + for _ in 0..3 { + let _ = rt.block_on(query_dot_once(NUMA_DOT, "example.com", &tls_config)); + let _ = rt.block_on(query_dot_once(UNBOUND_DOT, "example.com", &tls_config)); + } + + compare_two( + rt, + "DoT Server: Numa vs Unbound (both DoT→clients, forwarding to Quad9)", + "Numa", + "Unbound", + &|domain| { + let t = Instant::now(); + let _ = rt.block_on(query_dot_once(NUMA_DOT, domain, &tls_config)); + t.elapsed().as_secs_f64() * 1000.0 + }, + &|domain| { + let t = Instant::now(); + let _ = rt.block_on(query_dot_once(UNBOUND_DOT, domain, &tls_config)); + t.elapsed().as_secs_f64() * 1000.0 + }, + iterations, + ); +} + +/// DoH server comparison: Numa vs Unbound (both DoH→clients, DoT upstream). +fn run_doh_comparison(rt: &tokio::runtime::Runtime, iterations: usize) { + const NUMA_DOH: &str = "https://127.0.0.1:8443/dns-query"; + const UNBOUND_DOH: &str = "https://127.0.0.1:8445/dns-query"; + + let client = reqwest::Client::builder() + .use_rustls_tls() + .danger_accept_invalid_certs(true) + .http2_initial_stream_window_size(65_535) + .http2_initial_connection_window_size(65_535) + .pool_idle_timeout(Duration::from_secs(300)) + .build() + .unwrap(); + + for (name, url, host) in [ + ("Numa", NUMA_DOH, Some("numa.numa")), + ("Unbound", UNBOUND_DOH, None), + ] { + let w = build_query_vec("example.com"); + match rt.block_on(query_doh_server(&client, url, &w, host)) { + Ok(_) => println!("{name} DoH: OK"), + Err(e) => { + eprintln!("{name} DoH not responding: {e}"); + std::process::exit(1); + } + } + } + + println!("Warming up..."); + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(query_doh_server(&client, NUMA_DOH, &w, Some("numa.numa"))); + let _ = rt.block_on(query_doh_server(&client, UNBOUND_DOH, &w, None)); + } + + compare_two( + rt, + "DoH Server: Numa vs Unbound (both DoH→clients, DoT upstream)", + "Numa", + "Unbound", + &|domain| { + let w = build_query_vec(domain); + let t = Instant::now(); + let _ = rt.block_on(query_doh_server(&client, NUMA_DOH, &w, Some("numa.numa"))); + t.elapsed().as_secs_f64() * 1000.0 + }, + &|domain| { + let w = build_query_vec(domain); + let t = Instant::now(); + let _ = rt.block_on(query_doh_server(&client, UNBOUND_DOH, &w, None)); + t.elapsed().as_secs_f64() * 1000.0 + }, + iterations, + ); +} + +/// Hedging: single vs hedge-same vs hedge-dual vs Hickory. +/// This is the one mode that compares 4 contenders, not 2. +fn run_hedge_multi(rt: &tokio::runtime::Runtime, iterations: usize) { + let hedge_delay = Duration::from_millis(10); + let timeout = Duration::from_secs(10); + + println!("Hedging Benchmark × {iterations} iterations"); + println!("Upstream: {DOH_UPSTREAM}"); + println!("Hedge delay: {hedge_delay:?}"); + println!( + "{} domains × {ROUNDS} rounds per iteration\n", + DOMAINS.len() + ); + + let primary = numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let primary_dual = numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let secondary_dual = numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let resolver = rt.block_on(build_hickory_resolver()); + + println!("Warming up..."); + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_dual, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw( + &w, + &secondary_dual, + timeout, + )); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + let labels = ["Single", "Hedge-same", "Hedge-dual", "Hickory"]; + let mut all: [Vec; 4] = [vec![], vec![], vec![], vec![]]; + let mut iter_medians: Vec<[f64; 4]> = vec![]; + let mut iter_p99s: Vec<[f64; 4]> = vec![]; + + for iter in 1..=iterations { + println!(" iteration {iter}/{iterations}..."); + let mut samples: [Vec; 4] = [vec![], vec![], vec![], vec![]]; + + for domain in DOMAINS { + for _ in 0..ROUNDS { + let w = build_query_vec(domain); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary, timeout)); + samples[0].push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &w, + &primary, + &primary, + hedge_delay, + timeout, + )); + samples[1].push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &w, + &primary_dual, + &secondary_dual, + hedge_delay, + timeout, + )); + samples[2].push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(query_hickory_doh(&resolver, domain)); + samples[3].push(t.elapsed().as_secs_f64() * 1000.0); } } - let numa_avg = mean(&numa_times); - let hickory_avg = mean(&hickory_times); - let numa_sd = stddev(&numa_times); - let hickory_sd = stddev(&hickory_times); - let delta = numa_avg - hickory_avg; + let s: Vec<_> = samples.iter_mut().map(|v| stats(v)).collect(); + iter_medians.push([s[0].1, s[1].1, s[2].1, s[3].1]); + iter_p99s.push([s[0].3, s[1].3, s[2].3, s[3].3]); + for (i, v) in samples.iter().enumerate() { + all[i].extend_from_slice(v); + } + } - numa_all.extend_from_slice(&numa_times); - hickory_all.extend_from_slice(&hickory_times); - per_domain.push((domain, numa_avg, hickory_avg, delta, numa_sd, hickory_sd)); - - let delta_str = format_delta(delta); + println!("\n=== Per-iteration medians ==="); + println!( + "{:<8} {:>10} {:>12} {:>12} {:>10}", + "iter", labels[0], labels[1], labels[2], labels[3] + ); + for (i, m) in iter_medians.iter().enumerate() { println!( - "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", - domain, numa_avg, hickory_avg, delta_str, numa_sd, hickory_sd + "{:<8} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + i + 1, + m[0], + m[1], + m[2], + m[3] ); } - println!("{}", "-".repeat(92)); - - let numa_mean = mean(&numa_all); - let hickory_mean = mean(&hickory_all); - let delta_mean = numa_mean - hickory_mean; - + println!("\n=== Per-iteration p99 ==="); println!( - "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", - "OVERALL MEAN", - numa_mean, - hickory_mean, - format_delta(delta_mean), - stddev(&numa_all), - stddev(&hickory_all), + "{:<8} {:>10} {:>12} {:>12} {:>10}", + "iter", labels[0], labels[1], labels[2], labels[3] ); - - // Median - let numa_med = median(&mut numa_all); - let hickory_med = median(&mut hickory_all); - println!( - "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", - "MEDIAN", - numa_med, - hickory_med, - format_delta(numa_med - hickory_med), - ); - - // P95 - let numa_p95 = percentile(&numa_all, 95.0); - let hickory_p95 = percentile(&hickory_all, 95.0); - println!( - "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", - "P95", - numa_p95, - hickory_p95, - format_delta(numa_p95 - hickory_p95), - ); - - println!(); - let total_queries = DOMAINS.len() * ROUNDS; - if numa_mean < hickory_mean { - let pct = ((hickory_mean - numa_mean) / hickory_mean * 100.0).round(); - println!("Numa is ~{pct}% faster (mean over {total_queries} queries)."); - } else if hickory_mean < numa_mean { - let pct = ((numa_mean - hickory_mean) / numa_mean * 100.0).round(); - println!("Hickory is ~{pct}% faster (mean over {total_queries} queries)."); - } else { - println!("Both are equal (mean over {total_queries} queries)."); + for (i, p) in iter_p99s.iter().enumerate() { + println!( + "{:<8} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + i + 1, + p[0], + p[1], + p[2], + p[3] + ); } - println!(); - println!("Methodology:"); - println!(" - Both forward to {DOH_UPSTREAM} over a reused TLS connection."); - println!(" - Numa cache flushed before each query. Hickory cache disabled."); - println!(" - Measurement order alternates each round to cancel order bias."); - println!(" - {} domains × {ROUNDS} rounds = {total_queries} queries per resolver.", DOMAINS.len()); + let s: Vec<_> = all + .iter_mut() + .map(|v| { + let (m, med, p95, p99, sd) = stats(v); + [m, med, p95, p99, sd] + }) + .collect(); + let total = iterations * DOMAINS.len() * ROUNDS; + println!("\n=== Aggregated ({total} samples per method) ===\n"); + println!( + "{:<10} {:>10} {:>12} {:>12} {:>10}", + "", labels[0], labels[1], labels[2], labels[3] + ); + for (row, idx) in [("mean", 0), ("median", 1), ("p95", 2), ("p99", 3), ("σ", 4)] { + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + row, s[0][idx], s[1][idx], s[2][idx], s[3][idx] + ); + } + + let pct = |a: f64, b: f64| { + if b.abs() > 0.001 { + (a - b) / b * 100.0 + } else { + 0.0 + } + }; + println!( + "\nHedge-same vs Single: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", + pct(s[1][0], s[0][0]), + pct(s[1][2], s[0][2]), + pct(s[1][3], s[0][3]) + ); + println!( + "Hedge-same vs Hickory: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", + pct(s[1][0], s[3][0]), + pct(s[1][2], s[3][2]), + pct(s[1][3], s[3][3]) + ); } +// ── Diagnostics (small, kept for debugging) ───────────────────── + fn run_diag(rt: &tokio::runtime::Runtime) { - println!("Hickory connection reuse diagnostic"); - println!("20 sequential queries to {DOH_UPSTREAM} via one shared resolver"); - println!("If conn is reused: query 1 slow (TLS handshake), rest fast.\n"); + println!("Hickory connection reuse diagnostic\n20 queries to {DOH_UPSTREAM}\n"); let resolver = rt.block_on(build_hickory_resolver()); - let domains = [ - "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", - "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", - "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", - "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", + "example.com", + "rust-lang.org", + "kernel.org", + "google.com", + "github.com", + "example.com", + "rust-lang.org", + "kernel.org", + "google.com", + "github.com", + "example.com", + "rust-lang.org", + "kernel.org", + "google.com", + "github.com", + "example.com", + "rust-lang.org", + "kernel.org", + "google.com", + "github.com", ]; println!("{:>3} {:<20} {:>10}", "#", "Domain", "Time (ms)"); println!("{}", "-".repeat(40)); - for (i, domain) in domains.iter().enumerate() { use hickory_resolver::proto::rr::RecordType; let start = Instant::now(); @@ -347,143 +714,31 @@ fn run_diag(rt: &tokio::runtime::Runtime) { let ms = start.elapsed().as_secs_f64() * 1000.0; match &result { Ok(lookup) => { - let first = lookup.iter().next().map(|r| format!("{r}")).unwrap_or_default(); - println!("{:>3} {:<20} {:>7.1} ms OK {}", i + 1, domain, ms, first); - } - Err(e) => { - println!("{:>3} {:<20} {:>7.1} ms ERR {}", i + 1, domain, ms, e); + let first = lookup + .iter() + .next() + .map(|r| format!("{r}")) + .unwrap_or_default(); + println!( + "{:>3} {:<20} {:>7.1} ms OK {}", + i + 1, + domain, + ms, + first + ); } + Err(e) => println!("{:>3} {:<20} {:>7.1} ms ERR {}", i + 1, domain, ms, e), } } } -/// Library-to-library comparison: Numa's forward_query_raw vs Hickory's resolver.lookup(). -/// No UDP, no server pipeline — just the DoH forwarding call. -fn run_direct(rt: &tokio::runtime::Runtime) { - println!("Direct DoH Forwarding: Numa forward_query_raw vs Hickory resolver.lookup()"); - println!("Both forwarding to {DOH_UPSTREAM} — no UDP, no server pipeline"); - println!("{} domains × {ROUNDS} rounds", DOMAINS.len()); - println!(); - - // Build Numa's upstream (shared reqwest client, reuses HTTP/2 connection) - let numa_upstream = - numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); - let timeout = Duration::from_secs(10); - - // Build Hickory's resolver (shared, reuses HTTP/2 connection) - let resolver = rt.block_on(build_hickory_resolver()); - - // Warm up both - println!("Warming up connections..."); - for _ in 0..3 { - let wire = build_query_vec("example.com"); - let _ = rt.block_on(numa::forward::forward_query_raw(&wire, &numa_upstream, timeout)); - let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); - } - - println!( - "{:<30} {:>10} {:>10} {:>10} {:>8} {:>8}", - "Domain", "Numa (ms)", "Hickory", "Delta", "σ Numa", "σ Hick" - ); - println!("{}", "-".repeat(92)); - - let mut numa_all = Vec::new(); - let mut hickory_all = Vec::new(); - - for domain in DOMAINS { - let mut numa_times = Vec::with_capacity(ROUNDS); - let mut hickory_times = Vec::with_capacity(ROUNDS); - - for round in 0..ROUNDS { - let wire = build_query_vec(domain); - - if round % 2 == 0 { - let w = wire.clone(); - let t = measure(rt, || { - rt.block_on(numa::forward::forward_query_raw(&w, &numa_upstream, timeout)) - }); - numa_times.push(t); - let t = measure(rt, || rt.block_on(query_hickory_doh(&resolver, domain))); - hickory_times.push(t); - } else { - let t = measure(rt, || rt.block_on(query_hickory_doh(&resolver, domain))); - hickory_times.push(t); - let w = wire.clone(); - let t = measure(rt, || { - rt.block_on(numa::forward::forward_query_raw(&w, &numa_upstream, timeout)) - }); - numa_times.push(t); - } - } - - let numa_avg = mean(&numa_times); - let hickory_avg = mean(&hickory_times); - let numa_sd = stddev(&numa_times); - let hickory_sd = stddev(&hickory_times); - let delta = numa_avg - hickory_avg; - - numa_all.extend_from_slice(&numa_times); - hickory_all.extend_from_slice(&hickory_times); - - println!( - "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", - domain, numa_avg, hickory_avg, format_delta(delta), numa_sd, hickory_sd - ); - } - - println!("{}", "-".repeat(92)); - let numa_mean = mean(&numa_all); - let hickory_mean = mean(&hickory_all); - println!( - "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", - "OVERALL MEAN", numa_mean, hickory_mean, format_delta(numa_mean - hickory_mean), - stddev(&numa_all), stddev(&hickory_all), - ); - let numa_med = median(&mut numa_all); - let hickory_med = median(&mut hickory_all); - println!( - "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", - "MEDIAN", numa_med, hickory_med, format_delta(numa_med - hickory_med), - ); - let numa_p95 = percentile(&numa_all, 95.0); - let hickory_p95 = percentile(&hickory_all, 95.0); - println!( - "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", - "P95", numa_p95, hickory_p95, format_delta(numa_p95 - hickory_p95), - ); - - println!(); - let total_queries = DOMAINS.len() * ROUNDS; - if numa_mean < hickory_mean { - let pct = ((hickory_mean - numa_mean) / hickory_mean * 100.0).round(); - println!("Numa is ~{pct}% faster (mean over {total_queries} queries)."); - } else if hickory_mean < numa_mean { - let pct = ((numa_mean - hickory_mean) / numa_mean * 100.0).round(); - println!("Hickory is ~{pct}% faster (mean over {total_queries} queries)."); - } else { - println!("Both are equal (mean over {total_queries} queries)."); - } - - println!(); - println!("Methodology:"); - println!(" - Both forward to {DOH_UPSTREAM} over a reused TLS/HTTP2 connection."); - println!(" - No UDP, no server pipeline, no cache — pure DoH forwarding."); - println!(" - Numa: forward_query_raw (reqwest). Hickory: resolver.lookup (h2)."); - println!(" - {} domains × {ROUNDS} rounds = {total_queries} queries per implementation.", DOMAINS.len()); -} - -/// Per-query timing diagnostic: 20 queries each through reqwest and Hickory. -/// Shows whether reqwest has connection reuse issues or per-request overhead. fn run_diag_clients(rt: &tokio::runtime::Runtime) { - println!("Client diagnostic: reqwest vs Hickory per-query timing"); - println!("20 queries each to {DOH_UPSTREAM}\n"); + println!("Client diagnostic: reqwest vs Hickory (20 queries to {DOH_UPSTREAM})\n"); - let upstream = - numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); + let upstream = numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); let resolver = rt.block_on(build_hickory_resolver()); let timeout = Duration::from_secs(10); - // Warm both for _ in 0..3 { let w = build_query_vec("example.com"); let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); @@ -491,18 +746,35 @@ fn run_diag_clients(rt: &tokio::runtime::Runtime) { } let domains = [ - "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", - "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", - "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", - "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", + "example.com", + "google.com", + "github.com", + "rust-lang.org", + "cloudflare.com", + "example.com", + "google.com", + "github.com", + "rust-lang.org", + "cloudflare.com", + "example.com", + "google.com", + "github.com", + "rust-lang.org", + "cloudflare.com", + "example.com", + "google.com", + "github.com", + "rust-lang.org", + "cloudflare.com", ]; - println!("{:>3} {:<20} {:>12} {:>12}", "#", "Domain", "reqwest", "Hickory"); + println!( + "{:>3} {:<20} {:>12} {:>12}", + "#", "Domain", "reqwest", "Hickory" + ); println!("{}", "-".repeat(55)); - for (i, domain) in domains.iter().enumerate() { let wire = build_query_vec(domain); - let start = Instant::now(); let r_result = rt.block_on(numa::forward::forward_query_raw(&wire, &upstream, timeout)); let r_ms = start.elapsed().as_secs_f64() * 1000.0; @@ -515,1076 +787,104 @@ fn run_diag_clients(rt: &tokio::runtime::Runtime) { println!( "{:>3} {:<20} {:>7.1} ms {} {:>7.1} ms {}", - i + 1, domain, r_ms, r_ok, h_ms, h_ok - ); - } -} - -/// Spike trace: fire 200 sequential queries through reqwest and log every one -/// with a timestamp. Analyze the distribution and find spike clusters. -fn run_spike_trace(rt: &tokio::runtime::Runtime) { - println!("Spike trace: 200 sequential reqwest DoH queries"); - println!("Target: {DOH_UPSTREAM}\n"); - - let upstream = - numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); - let timeout = Duration::from_secs(10); - - // Warm - for _ in 0..5 { - let w = build_query_vec("example.com"); - let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); - } - - // Run the entire 200-query loop inside ONE block_on to eliminate - // per-query runtime re-entry overhead. - let samples: Vec<(u128, f64)> = rt.block_on(async { - let test_start = Instant::now(); - let mut s = Vec::with_capacity(200); - for i in 0..200 { - let domain = match i % 5 { - 0 => "example.com", - 1 => "google.com", - 2 => "github.com", - 3 => "rust-lang.org", - _ => "cloudflare.com", - }; - let wire = build_query_vec(domain); - let req_start = Instant::now(); - let t_from_start_us = test_start.elapsed().as_micros(); - let _ = numa::forward::forward_query_raw(&wire, &upstream, timeout).await; - let ms = req_start.elapsed().as_secs_f64() * 1000.0; - s.push((t_from_start_us, ms)); - } - s - }); - - // Compute stats - let mut sorted_times: Vec = samples.iter().map(|(_, t)| *t).collect(); - sorted_times.sort_by(|a, b| a.partial_cmp(b).unwrap()); - let n = sorted_times.len(); - let median = sorted_times[n / 2]; - let p90 = sorted_times[(n * 90) / 100]; - let p95 = sorted_times[(n * 95) / 100]; - let p99 = sorted_times[(n * 99) / 100]; - let max = sorted_times[n - 1]; - let mean: f64 = sorted_times.iter().sum::() / n as f64; - - println!("Distribution (n={}):", n); - println!(" mean: {:.1} ms", mean); - println!(" median: {:.1} ms", median); - println!(" p90: {:.1} ms", p90); - println!(" p95: {:.1} ms", p95); - println!(" p99: {:.1} ms", p99); - println!(" max: {:.1} ms", max); - println!(); - - // Define spike threshold as 3x median - let spike_threshold = median * 3.0; - let spikes: Vec<(usize, u128, f64)> = samples - .iter() - .enumerate() - .filter(|(_, (_, t))| *t > spike_threshold) - .map(|(i, (ts, t))| (i, *ts, *t)) - .collect(); - - println!("Spikes (> {:.1}ms, which is 3x median):", spike_threshold); - println!(" count: {}", spikes.len()); - if spikes.is_empty() { - return; - } - - // Inter-spike gaps (time between spikes) - let mut gaps_ms: Vec = Vec::new(); - for w in spikes.windows(2) { - let gap_us = w[1].1 - w[0].1; - gaps_ms.push(gap_us as f64 / 1000.0); - } - - println!(); - println!(" {:>4} {:>12} {:>10} {:>12}", "idx", "at (ms)", "latency", "gap from prev"); - for (i, ((idx, ts, latency), gap)) in spikes.iter().zip( - std::iter::once(&0.0).chain(gaps_ms.iter()) - ).enumerate() { - let _ = i; - let gap_str = if *gap > 0.0 { - format!("{:.0} ms", gap) - } else { - "-".to_string() - }; - println!(" {:>4} {:>9.1} {:>6.1} ms {:>12}", idx, *ts as f64 / 1000.0, latency, gap_str); - } - - if !gaps_ms.is_empty() { - let gap_mean: f64 = gaps_ms.iter().sum::() / gaps_ms.len() as f64; - let mut gap_sorted = gaps_ms.clone(); - gap_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); - let gap_median = gap_sorted[gap_sorted.len() / 2]; - println!(); - println!(" Inter-spike gap: mean={:.0}ms, median={:.0}ms", gap_mean, gap_median); - } -} - -/// Spike phases: time each step of the reqwest DoH call to find which phase -/// is slow during a spike. Reports (build+send, send->resp headers, body read). -fn run_spike_phases(rt: &tokio::runtime::Runtime) { - println!("Spike phases: timing each phase of reqwest DoH call"); - println!("Target: {DOH_UPSTREAM}\n"); - - // Build the same tuned client our forward_doh uses - let client = reqwest::Client::builder() - .use_rustls_tls() - .http2_initial_stream_window_size(65_535) - .http2_initial_connection_window_size(65_535) - .http2_keep_alive_interval(Duration::from_secs(15)) - .http2_keep_alive_while_idle(true) - .http2_keep_alive_timeout(Duration::from_secs(10)) - .pool_idle_timeout(Duration::from_secs(300)) - .pool_max_idle_per_host(1) - .build() - .unwrap(); - - // Warm up - for _ in 0..5 { - let wire = build_query_vec("example.com"); - let _ = rt.block_on(async { - client - .post(DOH_UPSTREAM) - .header("content-type", "application/dns-message") - .header("accept", "application/dns-message") - .body(wire) - .send() - .await - .ok()? - .bytes() - .await - .ok() - }); - } - - println!("{:>4} {:>8} {:>8} {:>8} {:>8}", "idx", "total", "build", "send", "body"); - println!("{}", "-".repeat(50)); - - let samples: Vec<(f64, f64, f64, f64)> = rt.block_on(async { - let mut s = Vec::with_capacity(200); - for i in 0..200 { - let domain = match i % 5 { - 0 => "example.com", - 1 => "google.com", - 2 => "github.com", - 3 => "rust-lang.org", - _ => "cloudflare.com", - }; - let wire = build_query_vec(domain); - - let t0 = Instant::now(); - // Phase 1: build the request - let req = client - .post(DOH_UPSTREAM) - .header("content-type", "application/dns-message") - .header("accept", "application/dns-message") - .body(wire); - let t1 = Instant::now(); - // Phase 2: send() — this is the dispatch channel + round trip to headers - let resp_result = req.send().await; - let t2 = Instant::now(); - // Phase 3: read body - let body_result = match resp_result { - Ok(r) => r.bytes().await.ok().map(|b| b.len()), - Err(_) => None, - }; - let t3 = Instant::now(); - - let build_ms = (t1 - t0).as_secs_f64() * 1000.0; - let send_ms = (t2 - t1).as_secs_f64() * 1000.0; - let body_ms = (t3 - t2).as_secs_f64() * 1000.0; - let total_ms = (t3 - t0).as_secs_f64() * 1000.0; - - s.push((total_ms, build_ms, send_ms, body_ms)); - let _ = body_result; - } - s - }); - - // Compute distribution on total - let mut totals: Vec = samples.iter().map(|s| s.0).collect(); - totals.sort_by(|a, b| a.partial_cmp(b).unwrap()); - let median = totals[100]; - - // Print spikes (> 3x median) with phase breakdown - for (i, (total, build, send, body)) in samples.iter().enumerate() { - if *total > median * 3.0 { - println!( - "{:>4} {:>5.1} ms {:>5.1} ms {:>5.1} ms {:>5.1} ms", - i, total, build, send, body - ); - } - } - - // Summary: mean of each phase for spikes vs non-spikes - let (spike_samples, normal_samples): (Vec<_>, Vec<_>) = samples - .iter() - .partition(|(t, _, _, _)| *t > median * 3.0); - - let phase_means = |samples: &[&(f64, f64, f64, f64)]| -> (f64, f64, f64, f64) { - let n = samples.len() as f64; - if n == 0.0 { return (0.0, 0.0, 0.0, 0.0); } - let total: f64 = samples.iter().map(|s| s.0).sum::() / n; - let build: f64 = samples.iter().map(|s| s.1).sum::() / n; - let send: f64 = samples.iter().map(|s| s.2).sum::() / n; - let body: f64 = samples.iter().map(|s| s.3).sum::() / n; - (total, build, send, body) - }; - - let spike_refs: Vec<&(f64, f64, f64, f64)> = spike_samples.iter().copied().collect(); - let normal_refs: Vec<&(f64, f64, f64, f64)> = normal_samples.iter().copied().collect(); - let (s_total, s_build, s_send, s_body) = phase_means(&spike_refs); - let (n_total, n_build, n_send, n_body) = phase_means(&normal_refs); - - println!(); - println!("Summary (mean ms):"); - println!( - " {:<8} {:>8} {:>8} {:>8} {:>8}", - "", "total", "build", "send", "body" - ); - println!( - " {:<8} {:>5.1} ms {:>5.1} ms {:>5.1} ms {:>5.1} ms (n={})", - "normal", n_total, n_build, n_send, n_body, normal_refs.len() - ); - println!( - " {:<8} {:>5.1} ms {:>5.1} ms {:>5.1} ms {:>5.1} ms (n={})", - "spike", s_total, s_build, s_send, s_body, spike_refs.len() - ); - println!(); - println!("Delta (spike - normal):"); - println!( - " build: {:+.1} ms, send: {:+.1} ms, body: {:+.1} ms", - s_build - n_build, - s_send - n_send, - s_body - n_body - ); -} - -/// Heartbeat probe: run a parallel task that ticks every 5ms and records -/// how long each tick actually takes. If the heartbeat stalls during a DoH -/// spike, it's a tokio scheduling issue (runtime can't poll tasks). If -/// heartbeat is fine while send() is stuck, it's internal to hyper/h2. -fn run_spike_heartbeat(rt: &tokio::runtime::Runtime) { - use std::sync::{Arc, Mutex}; - - println!("Spike heartbeat probe"); - println!("Running DoH queries + parallel 5ms heartbeat task\n"); - - let upstream = - numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); - let timeout = Duration::from_secs(10); - - // Warm up - for _ in 0..5 { - let w = build_query_vec("example.com"); - let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); - } - - // Shared vecs: (relative_ms_from_start, event_kind, latency_ms) - // event_kind: 0 = heartbeat, 1 = doh query - type EventLog = Vec<(f64, u8, f64)>; - let events: Arc> = Arc::new(Mutex::new(Vec::with_capacity(2000))); - let stop = Arc::new(std::sync::atomic::AtomicBool::new(false)); - - let test_start = Instant::now(); - - rt.block_on(async { - // Spawn heartbeat task - let hb_events = Arc::clone(&events); - let hb_stop = Arc::clone(&stop); - let hb_start = test_start; - let heartbeat = tokio::spawn(async move { - let mut next_tick = Instant::now(); - let target = Duration::from_millis(5); - while !hb_stop.load(std::sync::atomic::Ordering::Relaxed) { - next_tick += target; - // Sleep until the next scheduled tick - let now = Instant::now(); - if next_tick > now { - tokio::time::sleep(next_tick - now).await; - } - // Measure how much we overshot the scheduled tick - let actual = Instant::now(); - let lag_ms = if actual > next_tick { - (actual - next_tick).as_secs_f64() * 1000.0 - } else { - 0.0 - }; - let t = (actual - hb_start).as_secs_f64() * 1000.0; - if let Ok(mut e) = hb_events.lock() { - e.push((t, 0, lag_ms)); - } - } - }); - - // Run 200 DoH queries and record their timings - for i in 0..200 { - let domain = match i % 5 { - 0 => "example.com", - 1 => "google.com", - 2 => "github.com", - 3 => "rust-lang.org", - _ => "cloudflare.com", - }; - let wire = build_query_vec(domain); - let req_start = Instant::now(); - let _ = numa::forward::forward_query_raw(&wire, &upstream, timeout).await; - let elapsed = req_start.elapsed().as_secs_f64() * 1000.0; - let t = (req_start - test_start).as_secs_f64() * 1000.0; - if let Ok(mut e) = events.lock() { - e.push((t, 1, elapsed)); - } - } - - stop.store(true, std::sync::atomic::Ordering::Relaxed); - let _ = heartbeat.await; - }); - - let events = events.lock().unwrap(); - - // Separate heartbeats and doh events - let hb: Vec<(f64, f64)> = events - .iter() - .filter(|(_, k, _)| *k == 0) - .map(|(t, _, l)| (*t, *l)) - .collect(); - let doh: Vec<(f64, f64)> = events - .iter() - .filter(|(_, k, _)| *k == 1) - .map(|(t, _, l)| (*t, *l)) - .collect(); - - // Heartbeat stats - let mut hb_lags: Vec = hb.iter().map(|(_, l)| *l).collect(); - hb_lags.sort_by(|a, b| a.partial_cmp(b).unwrap()); - let hb_n = hb_lags.len(); - let hb_median = hb_lags[hb_n / 2]; - let hb_p95 = hb_lags[(hb_n * 95) / 100]; - let hb_p99 = hb_lags[(hb_n * 99) / 100]; - let hb_max = hb_lags[hb_n - 1]; - - // DoH stats - let mut doh_latencies: Vec = doh.iter().map(|(_, l)| *l).collect(); - doh_latencies.sort_by(|a, b| a.partial_cmp(b).unwrap()); - let doh_n = doh_latencies.len(); - let doh_median = doh_latencies[doh_n / 2]; - let doh_p95 = doh_latencies[(doh_n * 95) / 100]; - let doh_max = doh_latencies[doh_n - 1]; - - println!("Heartbeat lag (tick overshoot, {}ms target):", 5); - println!(" n: {}", hb_n); - println!(" median: {:.2} ms", hb_median); - println!(" p95: {:.2} ms", hb_p95); - println!(" p99: {:.2} ms", hb_p99); - println!(" max: {:.2} ms", hb_max); - println!(); - println!("DoH latency:"); - println!(" n: {}", doh_n); - println!(" median: {:.1} ms", doh_median); - println!(" p95: {:.1} ms", doh_p95); - println!(" max: {:.1} ms", doh_max); - println!(); - - // Find DoH spikes and check heartbeat activity DURING each spike - let doh_spike_threshold = doh_median * 3.0; - let mut spikes_with_hb_lag = 0; - let mut spikes_total = 0; - let mut max_hb_during_any_spike = 0.0_f64; - - println!( - "Correlation: during each DoH spike (>{:.1}ms), max heartbeat lag:", - doh_spike_threshold - ); - println!(" {:>6} {:>10} {:>18}", "doh_at", "doh_ms", "max_hb_lag_during"); - - for (doh_t, doh_ms) in &doh { - if *doh_ms > doh_spike_threshold { - spikes_total += 1; - // Find heartbeats that happened during this DoH query - let spike_start = *doh_t; - let spike_end = spike_start + *doh_ms; - let mut max_hb = 0.0_f64; - for (hb_t, hb_lag) in &hb { - if *hb_t >= spike_start && *hb_t <= spike_end + 20.0 { - if *hb_lag > max_hb { - max_hb = *hb_lag; - } - } - } - if max_hb > 5.0 { - spikes_with_hb_lag += 1; - } - max_hb_during_any_spike = max_hb_during_any_spike.max(max_hb); - println!( - " {:>5.0} ms {:>7.1} ms {:>14.2} ms", - doh_t, doh_ms, max_hb - ); - } - } - - println!(); - println!("Conclusion:"); - if spikes_total == 0 { - println!(" No DoH spikes in this run."); - } else { - let pct = (spikes_with_hb_lag as f64 / spikes_total as f64 * 100.0).round(); - println!( - " {}/{} spikes ({:.0}%) had concurrent heartbeat lag >5ms.", - spikes_with_hb_lag, spikes_total, pct - ); - println!(" Max heartbeat lag during any spike: {:.2}ms", max_hb_during_any_spike); - println!(); - if max_hb_during_any_spike > 20.0 { - println!(" → Heartbeat stalls during DoH spikes: tokio scheduling / OS thread issue."); - println!(" The runtime can't poll ANY task — likely QoS demotion, GC pause,"); - println!(" or the worker thread is blocked somewhere."); - } else { - println!(" → Heartbeat runs normally during DoH spikes: internal to hyper/h2."); - println!(" The runtime is fine, but send()'s await is stuck waiting for"); - println!(" the ClientTask to poll the dispatch channel."); - } - } -} - -/// Hedging benchmark: tests four configurations against Hickory. -/// Single: 1 client → Quad9 (baseline) -/// Hedge-same: hedge against same client/connection → Quad9 -/// Hedge-dual: hedge against 2 separate clients, both → Quad9 (same upstream, 2 HTTP/2 conns) -/// Hickory: Hickory resolver → Quad9 (reference) -fn run_hedge(rt: &tokio::runtime::Runtime) { - let hedge_delay = Duration::from_millis(10); - - println!("Hedging Benchmark (all paths → Quad9 only)"); - println!("Upstream: {}", DOH_UPSTREAM); - println!("Hedge delay: {:?}", hedge_delay); - println!("{} domains × {} rounds\n", DOMAINS.len(), ROUNDS); - - // Primary and secondary: two separate reqwest clients → same Quad9 URL. - // This gives two independent HTTP/2 connections, so dispatch spikes - // are uncorrelated (at most one stalls at a time). - let primary_same = - numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse primary"); - let primary_dual = - numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse primary_dual"); - let secondary_dual = - numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse secondary_dual"); - let timeout = Duration::from_secs(10); - - let resolver = rt.block_on(build_hickory_resolver()); - - // Warm up all paths (separate connections need their own TLS handshake) - println!("Warming up connections..."); - for _ in 0..5 { - let w = build_query_vec("example.com"); - let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_same, timeout)); - let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_dual, timeout)); - let _ = rt.block_on(numa::forward::forward_query_raw(&w, &secondary_dual, timeout)); - let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); - } - - let mut single_all = Vec::new(); - let mut hedge_same_all = Vec::new(); - let mut hedge_dual_all = Vec::new(); - let mut hickory_all = Vec::new(); - - println!( - "{:<24} {:>10} {:>10} {:>10} {:>10}", - "Domain", "Single", "Hedge-same", "Hedge-dual", "Hickory" - ); - println!("{}", "-".repeat(78)); - - for domain in DOMAINS { - let mut single_times = Vec::with_capacity(ROUNDS); - let mut hedge_same_times = Vec::with_capacity(ROUNDS); - let mut hedge_dual_times = Vec::with_capacity(ROUNDS); - let mut hickory_times = Vec::with_capacity(ROUNDS); - - for _ in 0..ROUNDS { - let wire = build_query_vec(domain); - - let t = Instant::now(); - let _ = rt.block_on(numa::forward::forward_query_raw(&wire, &primary_same, timeout)); - single_times.push(t.elapsed().as_secs_f64() * 1000.0); - - let t = Instant::now(); - let _ = rt.block_on(numa::forward::forward_with_hedging_raw( - &wire, &primary_same, &primary_same, hedge_delay, timeout, - )); - hedge_same_times.push(t.elapsed().as_secs_f64() * 1000.0); - - let t = Instant::now(); - let _ = rt.block_on(numa::forward::forward_with_hedging_raw( - &wire, &primary_dual, &secondary_dual, hedge_delay, timeout, - )); - hedge_dual_times.push(t.elapsed().as_secs_f64() * 1000.0); - - let t = Instant::now(); - let _ = rt.block_on(query_hickory_doh(&resolver, domain)); - hickory_times.push(t.elapsed().as_secs_f64() * 1000.0); - } - - single_all.extend_from_slice(&single_times); - hedge_same_all.extend_from_slice(&hedge_same_times); - hedge_dual_all.extend_from_slice(&hedge_dual_times); - hickory_all.extend_from_slice(&hickory_times); - - println!( - "{:<24} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + i + 1, domain, - mean(&single_times), - mean(&hedge_same_times), - mean(&hedge_dual_times), - mean(&hickory_times) + r_ms, + r_ok, + h_ms, + h_ok ); } - - println!("{}", "-".repeat(78)); - - let stats = |all: &mut Vec| -> (f64, f64, f64, f64, f64) { - let m = mean(all); - let med = median(all); - let p95 = percentile(all, 95.0); - let p99 = percentile(all, 99.0); - let sd = stddev(all); - (m, med, p95, p99, sd) - }; - - let (s_m, s_med, s_p95, s_p99, s_sd) = stats(&mut single_all); - let (hs_m, hs_med, hs_p95, hs_p99, hs_sd) = stats(&mut hedge_same_all); - let (hd_m, hd_med, hd_p95, hd_p99, hd_sd) = stats(&mut hedge_dual_all); - let (k_m, k_med, k_p95, k_p99, k_sd) = stats(&mut hickory_all); - - println!(); - println!( - "{:<10} {:>10} {:>10} {:>10} {:>10}", - "", "Single", "Hedge-same", "Hedge-dual", "Hickory" - ); - println!( - "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", - "mean", s_m, hs_m, hd_m, k_m - ); - println!( - "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", - "median", s_med, hs_med, hd_med, k_med - ); - println!( - "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", - "p95", s_p95, hs_p95, hd_p95, k_p95 - ); - println!( - "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", - "p99", s_p99, hs_p99, hd_p99, k_p99 - ); - println!( - "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", - "σ", s_sd, hs_sd, hd_sd, k_sd - ); - - println!(); - println!("Hedge-same improvement over single:"); - println!(" mean: {:+.0}%, p95: {:+.0}%, p99: {:+.0}%", - (hs_m - s_m) / s_m * 100.0, - (hs_p95 - s_p95) / s_p95 * 100.0, - (hs_p99 - s_p99) / s_p99 * 100.0); - println!("Hedge-dual improvement over single:"); - println!(" mean: {:+.0}%, p95: {:+.0}%, p99: {:+.0}%", - (hd_m - s_m) / s_m * 100.0, - (hd_p95 - s_p95) / s_p95 * 100.0, - (hd_p99 - s_p99) / s_p99 * 100.0); } -/// Run the hedging benchmark N times and aggregate samples across all runs. -/// Also reports per-run stats to show drift. -fn run_hedge_multi(rt: &tokio::runtime::Runtime, iterations: usize) { - let hedge_delay = Duration::from_millis(10); +// ── Stats helpers ─────────────────────────────────────────────── - println!("Hedging Benchmark × {} iterations", iterations); - println!("Upstream: {}", DOH_UPSTREAM); - println!("Hedge delay: {:?}", hedge_delay); - println!("{} domains × {} rounds per iteration\n", DOMAINS.len(), ROUNDS); - - let primary_same = - numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); - let primary_dual = - numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); - let secondary_dual = - numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); - let timeout = Duration::from_secs(10); - - let resolver = rt.block_on(build_hickory_resolver()); - - // Warm up - println!("Warming up..."); - for _ in 0..5 { - let w = build_query_vec("example.com"); - let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_same, timeout)); - let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_dual, timeout)); - let _ = rt.block_on(numa::forward::forward_query_raw(&w, &secondary_dual, timeout)); - let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); +fn stats(v: &mut [f64]) -> (f64, f64, f64, f64, f64) { + if v.is_empty() { + return (0.0, 0.0, 0.0, 0.0, 0.0); } - - // Accumulated samples across all iterations - let mut all_single = Vec::new(); - let mut all_hedge_same = Vec::new(); - let mut all_hedge_dual = Vec::new(); - let mut all_hickory = Vec::new(); - - // Per-iteration summary stats - let mut iter_stats: Vec<[(f64, f64, f64, f64, f64); 4]> = Vec::new(); - - for iter in 1..=iterations { - println!(" iteration {}/{}...", iter, iterations); - - let mut single = Vec::new(); - let mut hedge_same = Vec::new(); - let mut hedge_dual = Vec::new(); - let mut hickory = Vec::new(); - - for domain in DOMAINS { - for _ in 0..ROUNDS { - let wire = build_query_vec(domain); - - let t = Instant::now(); - let _ = rt.block_on(numa::forward::forward_query_raw(&wire, &primary_same, timeout)); - single.push(t.elapsed().as_secs_f64() * 1000.0); - - let t = Instant::now(); - let _ = rt.block_on(numa::forward::forward_with_hedging_raw( - &wire, &primary_same, &primary_same, hedge_delay, timeout, - )); - hedge_same.push(t.elapsed().as_secs_f64() * 1000.0); - - let t = Instant::now(); - let _ = rt.block_on(numa::forward::forward_with_hedging_raw( - &wire, &primary_dual, &secondary_dual, hedge_delay, timeout, - )); - hedge_dual.push(t.elapsed().as_secs_f64() * 1000.0); - - let t = Instant::now(); - let _ = rt.block_on(query_hickory_doh(&resolver, domain)); - hickory.push(t.elapsed().as_secs_f64() * 1000.0); - } - } - - let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { - (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) - }; - iter_stats.push([ - stats(&mut single), - stats(&mut hedge_same), - stats(&mut hedge_dual), - stats(&mut hickory), - ]); - - all_single.extend_from_slice(&single); - all_hedge_same.extend_from_slice(&hedge_same); - all_hedge_dual.extend_from_slice(&hedge_dual); - all_hickory.extend_from_slice(&hickory); - } - - println!(); - println!("=== Per-iteration medians (drift check) ==="); - println!( - "{:<8} {:>10} {:>12} {:>12} {:>10}", - "iter", "Single", "Hedge-same", "Hedge-dual", "Hickory" - ); - for (i, s) in iter_stats.iter().enumerate() { - println!( - "{:<8} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", - i + 1, - s[0].1, - s[1].1, - s[2].1, - s[3].1 - ); - } - - println!(); - println!("=== Per-iteration p99 (drift check) ==="); - println!( - "{:<8} {:>10} {:>12} {:>12} {:>10}", - "iter", "Single", "Hedge-same", "Hedge-dual", "Hickory" - ); - for (i, s) in iter_stats.iter().enumerate() { - println!( - "{:<8} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", - i + 1, - s[0].3, - s[1].3, - s[2].3, - s[3].3 - ); - } - - let final_stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { - (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) - }; - let (s_m, s_med, s_p95, s_p99, s_sd) = final_stats(&mut all_single); - let (hs_m, hs_med, hs_p95, hs_p99, hs_sd) = final_stats(&mut all_hedge_same); - let (hd_m, hd_med, hd_p95, hd_p99, hd_sd) = final_stats(&mut all_hedge_dual); - let (k_m, k_med, k_p95, k_p99, k_sd) = final_stats(&mut all_hickory); - - println!(); - let total = iterations * DOMAINS.len() * ROUNDS; - println!("=== Aggregated across all {} samples per method ===", total); - println!(); - println!( - "{:<10} {:>10} {:>12} {:>12} {:>10}", - "", "Single", "Hedge-same", "Hedge-dual", "Hickory" - ); - println!( - "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", - "mean", s_m, hs_m, hd_m, k_m - ); - println!( - "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", - "median", s_med, hs_med, hd_med, k_med - ); - println!( - "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", - "p95", s_p95, hs_p95, hd_p95, k_p95 - ); - println!( - "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", - "p99", s_p99, hs_p99, hd_p99, k_p99 - ); - println!( - "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", - "σ", s_sd, hs_sd, hd_sd, k_sd - ); - - println!(); - println!("Hedge-same vs Single: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", - (hs_m - s_m) / s_m * 100.0, - (hs_p95 - s_p95) / s_p95 * 100.0, - (hs_p99 - s_p99) / s_p99 * 100.0); - println!("Hedge-dual vs Single: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", - (hd_m - s_m) / s_m * 100.0, - (hd_p95 - s_p95) / s_p95 * 100.0, - (hd_p99 - s_p99) / s_p99 * 100.0); - println!("Hedge-same vs Hickory: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", - (hs_m - k_m) / k_m * 100.0, - (hs_p95 - k_p95) / k_p95 * 100.0, - (hs_p99 - k_p99) / k_p99 * 100.0); -} - -/// Server-to-server benchmark: Numa vs dnscrypt-proxy vs Unbound. -/// All are full servers: UDP in, encrypted forwarding to Quad9. -/// Numa + dnscrypt: DoH (HTTPS). Unbound: DoT (TLS port 853). -fn run_vs_dnscrypt(rt: &tokio::runtime::Runtime, iterations: usize) { - const DNSCRYPT_ADDR: &str = "127.0.0.1:5455"; - const UNBOUND_ADDR: &str = "127.0.0.1:5456"; - let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); - let dnscrypt_addr: SocketAddr = DNSCRYPT_ADDR.parse().unwrap(); - let unbound_addr: SocketAddr = UNBOUND_ADDR.parse().unwrap(); - - println!("Server-to-Server: Numa vs dnscrypt-proxy vs Unbound"); - println!("Numa (DoH): {}", NUMA_BENCH); - println!("dnscrypt-proxy (DoH): {}", DNSCRYPT_ADDR); - println!("Unbound (DoT): {}", UNBOUND_ADDR); - println!("All forwarding to Quad9 over encrypted transport"); - println!("{} domains × {} rounds × {} iterations\n", - DOMAINS.len(), ROUNDS, iterations); - - // Verify all are up - let servers: Vec<(&str, SocketAddr)> = vec![ - ("Numa", numa_addr), - ("dnscrypt-proxy", dnscrypt_addr), - ("Unbound", unbound_addr), - ]; - for (name, addr) in &servers { - if rt.block_on(query_udp(*addr, "example.com")).is_none() { - eprintln!("{} not responding on {}", name, addr); - std::process::exit(1); - } - } - println!("All servers reachable.\n"); - - // Warm up - println!("Warming up..."); - for _ in 0..5 { - for (_, addr) in &servers { - let _ = rt.block_on(query_udp(*addr, "example.com")); - } - } - - let mut all_numa = Vec::new(); - let mut all_dnscrypt = Vec::new(); - let mut all_unbound = Vec::new(); - let mut iter_stats: Vec<[(f64, f64, f64, f64, f64); 3]> = Vec::new(); - - for iter in 1..=iterations { - println!(" iteration {}/{}...", iter, iterations); - - let mut numa = Vec::new(); - let mut dnscrypt = Vec::new(); - let mut unbound = Vec::new(); - - for domain in DOMAINS { - for round in 0..ROUNDS { - flush_cache(); - std::thread::sleep(Duration::from_millis(5)); - - // Rotate order: 3 servers, 3 possible orderings - let order = round % 3; - let mut measure = |addr: SocketAddr| -> f64 { - let t = Instant::now(); - let _ = rt.block_on(query_udp(addr, domain)); - t.elapsed().as_secs_f64() * 1000.0 - }; - - match order { - 0 => { - numa.push(measure(numa_addr)); - dnscrypt.push(measure(dnscrypt_addr)); - unbound.push(measure(unbound_addr)); - } - 1 => { - dnscrypt.push(measure(dnscrypt_addr)); - unbound.push(measure(unbound_addr)); - numa.push(measure(numa_addr)); - } - _ => { - unbound.push(measure(unbound_addr)); - numa.push(measure(numa_addr)); - dnscrypt.push(measure(dnscrypt_addr)); - } - } - } - } - - let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { - (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) - }; - iter_stats.push([stats(&mut numa), stats(&mut dnscrypt), stats(&mut unbound)]); - - all_numa.extend_from_slice(&numa); - all_dnscrypt.extend_from_slice(&dnscrypt); - all_unbound.extend_from_slice(&unbound); - } - - println!(); - println!("=== Per-iteration medians ==="); - println!("{:<8} {:>10} {:>14} {:>10}", "iter", "Numa", "dnscrypt-proxy", "Unbound"); - for (i, s) in iter_stats.iter().enumerate() { - println!("{:<8} {:>7.1} ms {:>11.1} ms {:>7.1} ms", - i + 1, s[0].1, s[1].1, s[2].1); - } - - println!(); - println!("=== Per-iteration p99 ==="); - println!("{:<8} {:>10} {:>14} {:>10}", "iter", "Numa", "dnscrypt-proxy", "Unbound"); - for (i, s) in iter_stats.iter().enumerate() { - println!("{:<8} {:>7.1} ms {:>11.1} ms {:>7.1} ms", - i + 1, s[0].3, s[1].3, s[2].3); - } - - let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { - (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) - }; - let (n_m, n_med, n_p95, n_p99, n_sd) = stats(&mut all_numa); - let (d_m, d_med, d_p95, d_p99, d_sd) = stats(&mut all_dnscrypt); - let (u_m, u_med, u_p95, u_p99, u_sd) = stats(&mut all_unbound); - - println!(); - let total = iterations * DOMAINS.len() * ROUNDS; - println!("=== Aggregated ({} samples per method) ===", total); - println!(); - println!("{:<10} {:>10} {:>14} {:>10}", "", "Numa", "dnscrypt-proxy", "Unbound"); - println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "mean", n_m, d_m, u_m); - println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "median", n_med, d_med, u_med); - println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "p95", n_p95, d_p95, u_p95); - println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "p99", n_p99, d_p99, u_p99); - println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "σ", n_sd, d_sd, u_sd); - println!(); - - println!("Numa vs dnscrypt-proxy:"); - println!(" mean: {:+.0}%, median: {:+.0}%, p99: {:+.0}%", - (n_m - d_m) / d_m * 100.0, (n_med - d_med) / d_med * 100.0, (n_p99 - d_p99) / d_p99 * 100.0); - println!("Numa vs Unbound:"); - println!(" mean: {:+.0}%, median: {:+.0}%, p99: {:+.0}%", - (n_m - u_m) / u_m * 100.0, (n_med - u_med) / u_med * 100.0, (n_p99 - u_p99) / u_p99 * 100.0); -} - -/// Numa vs Unbound: both forward over plain UDP to Quad9, caching enabled. -/// Truly equal transport — no TLS, no HTTP/2, pure forwarding + cache. -fn run_vs_unbound(rt: &tokio::runtime::Runtime, iterations: usize) { - const UNBOUND_ADDR: &str = "127.0.0.1:5456"; - let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); - let unbound_addr: SocketAddr = UNBOUND_ADDR.parse().unwrap(); - - println!("Numa vs Unbound (both plain UDP forwarding to Quad9, caching enabled)"); - println!("Numa: {} → 9.9.9.9:53 UDP", NUMA_BENCH); - println!("Unbound: {} → 9.9.9.9:53 UDP", UNBOUND_ADDR); - println!("{} domains × {} rounds × {} iterations\n", - DOMAINS.len(), ROUNDS, iterations); - - if rt.block_on(query_udp(numa_addr, "example.com")).is_none() { - eprintln!("Numa not responding"); std::process::exit(1); - } - if rt.block_on(query_udp(unbound_addr, "example.com")).is_none() { - eprintln!("Unbound not responding"); std::process::exit(1); - } - println!("Both servers reachable.\n"); - - println!("Warming up..."); - for _ in 0..5 { - let _ = rt.block_on(query_udp(numa_addr, "example.com")); - let _ = rt.block_on(query_udp(unbound_addr, "example.com")); - } - - let mut all_numa = Vec::new(); - let mut all_unbound = Vec::new(); - let mut iter_stats: Vec<[(f64, f64, f64, f64, f64); 2]> = Vec::new(); - - for iter in 1..=iterations { - println!(" iteration {}/{}...", iter, iterations); - - let mut numa = Vec::new(); - let mut unbound = Vec::new(); - - for domain in DOMAINS { - for round in 0..ROUNDS { - // No cache flushing — both serve from cache after first hit - let mut measure = |addr: SocketAddr| -> f64 { - let t = Instant::now(); - let _ = rt.block_on(query_udp(addr, domain)); - t.elapsed().as_secs_f64() * 1000.0 - }; - - if round % 2 == 0 { - numa.push(measure(numa_addr)); - unbound.push(measure(unbound_addr)); - } else { - unbound.push(measure(unbound_addr)); - numa.push(measure(numa_addr)); - } - } - } - - let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { - (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) - }; - iter_stats.push([stats(&mut numa), stats(&mut unbound)]); - - all_numa.extend_from_slice(&numa); - all_unbound.extend_from_slice(&unbound); - } - - println!(); - println!("=== Per-iteration medians ==="); - println!("{:<8} {:>10} {:>10}", "iter", "Numa", "Unbound"); - for (i, s) in iter_stats.iter().enumerate() { - println!("{:<8} {:>7.1} ms {:>7.1} ms", i + 1, s[0].1, s[1].1); - } - - println!(); - println!("=== Per-iteration p99 ==="); - println!("{:<8} {:>10} {:>10}", "iter", "Numa", "Unbound"); - for (i, s) in iter_stats.iter().enumerate() { - println!("{:<8} {:>7.1} ms {:>7.1} ms", i + 1, s[0].3, s[1].3); - } - - let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { - (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) - }; - let (n_m, n_med, n_p95, n_p99, n_sd) = stats(&mut all_numa); - let (u_m, u_med, u_p95, u_p99, u_sd) = stats(&mut all_unbound); - - println!(); - let total = iterations * DOMAINS.len() * ROUNDS; - println!("=== Aggregated ({} samples per method) ===", total); - println!(); - println!("{:<10} {:>10} {:>10}", "", "Numa", "Unbound"); - println!("{:<10} {:>7.1} ms {:>7.1} ms", "mean", n_m, u_m); - println!("{:<10} {:>7.1} ms {:>7.1} ms", "median", n_med, u_med); - println!("{:<10} {:>7.1} ms {:>7.1} ms", "p95", n_p95, u_p95); - println!("{:<10} {:>7.1} ms {:>7.1} ms", "p99", n_p99, u_p99); - println!("{:<10} {:>7.1} ms {:>7.1} ms", "σ", n_sd, u_sd); - println!(); - - println!("Numa vs Unbound:"); - println!(" mean: {:+.1} ms ({:+.0}%)", n_m - u_m, (n_m - u_m) / u_m * 100.0); - println!(" median: {:+.1} ms ({:+.0}%)", n_med - u_med, (n_med - u_med) / u_med * 100.0); - println!(" p95: {:+.1} ms ({:+.0}%)", n_p95 - u_p95, (n_p95 - u_p95) / u_p95 * 100.0); - println!(" p99: {:+.1} ms ({:+.0}%)", n_p99 - u_p99, (n_p99 - u_p99) / u_p99 * 100.0); -} - -/// Build a DNS query as a Vec for use with forward_query_raw. -fn build_query_vec(domain: &str) -> Vec { - let mut buf = vec![0u8; 512]; - let len = build_query(&mut buf, domain); - buf.truncate(len); - buf -} - -fn measure R, R>(_rt: &tokio::runtime::Runtime, f: F) -> f64 { - let start = Instant::now(); - f(); - start.elapsed().as_secs_f64() * 1000.0 -} - -fn mean(v: &[f64]) -> f64 { - v.iter().sum::() / v.len() as f64 -} - -fn stddev(v: &[f64]) -> f64 { - let m = mean(v); - let var = v.iter().map(|x| (x - m).powi(2)).sum::() / v.len() as f64; - var.sqrt() -} - -fn median(v: &mut [f64]) -> f64 { + let mean = v.iter().sum::() / v.len() as f64; v.sort_by(|a, b| a.partial_cmp(b).unwrap()); let n = v.len(); - if n % 2 == 0 { + let median = if n % 2 == 0 { (v[n / 2 - 1] + v[n / 2]) / 2.0 } else { v[n / 2] - } + }; + let p95 = v[((n as f64 * 0.95).round() as usize).min(n - 1)]; + let p99 = v[((n as f64 * 0.99).round() as usize).min(n - 1)]; + let var = v.iter().map(|x| (x - mean).powi(2)).sum::() / n as f64; + (mean, median, p95, p99, var.sqrt()) } -fn percentile(sorted: &[f64], p: f64) -> f64 { - let idx = (p / 100.0 * (sorted.len() - 1) as f64).round() as usize; - sorted[idx.min(sorted.len() - 1)] -} +// ── Query helpers ─────────────────────────────────────────────── -fn format_delta(delta: f64) -> String { - if delta > 0.0 { - format!("+{:.1}", delta) - } else { - format!("{:.1}", delta) - } -} - -/// Query a DNS server over UDP. async fn query_udp(addr: SocketAddr, domain: &str) -> Option<()> { use tokio::net::UdpSocket; - let sock = UdpSocket::bind("0.0.0.0:0").await.ok()?; let mut buf = vec![0u8; 512]; let len = build_query(&mut buf, domain); - sock.send_to(&buf[..len], addr).await.ok()?; - let mut resp = vec![0u8; 4096]; tokio::time::timeout(Duration::from_secs(10), sock.recv_from(&mut resp)) .await .ok()? .ok()?; - Some(()) } -/// Build a shared Hickory DoH resolver (reuses TLS connection across queries). +async fn query_dot_once( + addr: &str, + domain: &str, + tls_config: &std::sync::Arc, +) -> Result<(), Box> { + use rustls::pki_types::ServerName; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpStream; + use tokio_rustls::TlsConnector; + + let connector = TlsConnector::from(tls_config.clone()); + let stream = TcpStream::connect(addr).await?; + let server_name = ServerName::try_from("localhost")?; + let mut tls = connector.connect(server_name, stream).await?; + + let mut buf = vec![0u8; 512]; + let len = build_query(&mut buf, domain); + let msg = &buf[..len]; + + let mut out = Vec::with_capacity(2 + msg.len()); + out.extend_from_slice(&(msg.len() as u16).to_be_bytes()); + out.extend_from_slice(msg); + tls.write_all(&out).await?; + + let mut len_buf = [0u8; 2]; + tls.read_exact(&mut len_buf).await?; + let resp_len = u16::from_be_bytes(len_buf) as usize; + let mut resp = vec![0u8; resp_len]; + tls.read_exact(&mut resp).await?; + Ok(()) +} + +async fn query_doh_server( + client: &reqwest::Client, + url: &str, + wire: &[u8], + host: Option<&str>, +) -> Result, Box> { + let mut req = client + .post(url) + .header("content-type", "application/dns-message") + .header("accept", "application/dns-message") + .body(wire.to_vec()); + if let Some(h) = host { + req = req.header("host", h); + } + let resp = req.send().await?.error_for_status()?; + Ok(resp.bytes().await?.to_vec()) +} + async fn build_hickory_resolver() -> hickory_resolver::TokioResolver { use hickory_resolver::config::*; - let ns = NameServerConfig { socket_addr: "9.9.9.9:443".parse().unwrap(), protocol: hickory_proto::xfer::Protocol::Https, @@ -1593,29 +893,79 @@ async fn build_hickory_resolver() -> hickory_resolver::TokioResolver { bind_addr: None, http_endpoint: Some("/dns-query".to_string()), }; - let config = ResolverConfig::from_parts(None, vec![], NameServerConfigGroup::from(vec![ns])); - let mut opts = ResolverOpts::default(); opts.cache_size = 0; opts.num_concurrent_reqs = 1; opts.timeout = Duration::from_secs(10); - hickory_resolver::TokioResolver::builder_with_config(config, Default::default()) .with_options(opts) .build() } -/// Query using the shared Hickory resolver. -async fn query_hickory_doh( - resolver: &hickory_resolver::TokioResolver, - domain: &str, -) -> Option<()> { +async fn query_hickory_doh(resolver: &hickory_resolver::TokioResolver, domain: &str) -> Option<()> { use hickory_resolver::proto::rr::RecordType; let _ = resolver.lookup(domain, RecordType::A).await.ok()?; Some(()) } +fn build_insecure_tls_config() -> std::sync::Arc { + use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; + use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; + use rustls::DigitallySignedStruct; + + #[derive(Debug)] + struct NoVerify; + impl ServerCertVerifier for NoVerify { + fn verify_server_cert( + &self, + _: &CertificateDer<'_>, + _: &[CertificateDer<'_>], + _: &ServerName<'_>, + _: &[u8], + _: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + fn verify_tls12_signature( + &self, + _: &[u8], + _: &CertificateDer<'_>, + _: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + fn verify_tls13_signature( + &self, + _: &[u8], + _: &CertificateDer<'_>, + _: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + fn supported_verify_schemes(&self) -> Vec { + rustls::crypto::ring::default_provider() + .signature_verification_algorithms + .supported_schemes() + } + } + std::sync::Arc::new( + rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(std::sync::Arc::new(NoVerify)) + .with_no_client_auth(), + ) +} + +// ── Wire helpers ──────────────────────────────────────────────── + +fn build_query_vec(domain: &str) -> Vec { + let mut buf = vec![0u8; 512]; + let len = build_query(&mut buf, domain); + buf.truncate(len); + buf +} + fn build_query(buf: &mut [u8], domain: &str) -> usize { let mut pos = 0; buf[pos..pos + 2].copy_from_slice(&0x1234u16.to_be_bytes()); @@ -1626,7 +976,6 @@ fn build_query(buf: &mut [u8], domain: &str) -> usize { pos += 2; buf[pos..pos + 6].fill(0); pos += 6; - for label in domain.split('.') { buf[pos] = label.len() as u8; pos += 1; @@ -1644,6 +993,11 @@ fn build_query(buf: &mut [u8], domain: &str) -> usize { fn flush_cache() { let _ = std::process::Command::new("curl") - .args(["-s", "-X", "DELETE", &format!("http://127.0.0.1:{NUMA_API}/cache")]) + .args([ + "-s", + "-X", + "DELETE", + &format!("http://127.0.0.1:{NUMA_API}/cache"), + ]) .output(); } diff --git a/src/forward.rs b/src/forward.rs index 401ae1c..6afb7e5 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -214,15 +214,11 @@ pub async fn forward_query( upstream: &Upstream, timeout_duration: Duration, ) -> Result { - match upstream { - Upstream::Udp(addr) => forward_udp(query, *addr, timeout_duration).await, - Upstream::Doh { url, client } => forward_doh(query, url, client, timeout_duration).await, - Upstream::Dot { - addr, - tls_name, - connector, - } => forward_dot(query, *addr, tls_name, connector, timeout_duration).await, - } + let mut send_buffer = BytePacketBuffer::new(); + query.write(&mut send_buffer)?; + let data = forward_query_raw(send_buffer.filled(), upstream, timeout_duration).await?; + let mut recv_buffer = BytePacketBuffer::from_bytes(&data); + DnsPacket::from_buffer(&mut recv_buffer) } pub(crate) async fn forward_udp( @@ -284,13 +280,13 @@ pub(crate) async fn forward_tcp( DnsPacket::from_buffer(&mut recv_buffer) } -async fn forward_dot( - query: &DnsPacket, +async fn forward_dot_raw( + wire: &[u8], addr: SocketAddr, tls_name: &Option, connector: &tokio_rustls::TlsConnector, timeout_duration: Duration, -) -> Result { +) -> Result> { use rustls::pki_types::ServerName; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; @@ -303,10 +299,6 @@ async fn forward_dot( let tcp = timeout(timeout_duration, TcpStream::connect(addr)).await??; let mut tls = timeout(timeout_duration, connector.connect(server_name, tcp)).await??; - let mut send_buffer = BytePacketBuffer::new(); - query.write(&mut send_buffer)?; - let wire = send_buffer.filled(); - let mut outbuf = Vec::with_capacity(2 + wire.len()); outbuf.extend_from_slice(&(wire.len() as u16).to_be_bytes()); outbuf.extend_from_slice(wire); @@ -319,22 +311,7 @@ async fn forward_dot( let mut data = vec![0u8; resp_len]; timeout(timeout_duration, tls.read_exact(&mut data)).await??; - let mut recv_buffer = BytePacketBuffer::from_bytes(&data); - DnsPacket::from_buffer(&mut recv_buffer) -} - -async fn forward_doh( - query: &DnsPacket, - url: &str, - client: &reqwest::Client, - timeout_duration: Duration, -) -> Result { - let mut send_buffer = BytePacketBuffer::new(); - query.write(&mut send_buffer)?; - - let resp_bytes = forward_doh_raw(send_buffer.filled(), url, client, timeout_duration).await?; - let mut recv_buffer = BytePacketBuffer::from_bytes(&resp_bytes); - DnsPacket::from_buffer(&mut recv_buffer) + Ok(data) } pub async fn forward_query_raw( @@ -345,6 +322,11 @@ pub async fn forward_query_raw( match upstream { Upstream::Udp(addr) => forward_udp_raw(wire, *addr, timeout_duration).await, Upstream::Doh { url, client } => forward_doh_raw(wire, url, client, timeout_duration).await, + Upstream::Dot { + addr, + tls_name, + connector, + } => forward_dot_raw(wire, *addr, tls_name, connector, timeout_duration).await, } } @@ -405,7 +387,10 @@ pub async fn forward_with_hedging_raw( match (primary_err, secondary_err) { (Some(pe), Some(_)) => return Err(pe), - (pe, se) => { primary_err = pe; secondary_err = se; } + (pe, se) => { + primary_err = pe; + secondary_err = se; + } } } } @@ -516,7 +501,7 @@ pub async fn keepalive_doh(upstream: &Upstream) { 0x01, 0x00, // flags: RD=1 0x00, 0x01, // QDCOUNT=1 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // AN=0, NS=0, AR=0 - 0x00, // root name (.) + 0x00, // root name (.) 0x00, 0x02, // type NS 0x00, 0x01, // class IN ]; diff --git a/src/recursive.rs b/src/recursive.rs index 2609f7f..190a57a 100644 --- a/src/recursive.rs +++ b/src/recursive.rs @@ -15,8 +15,8 @@ use crate::srtt::SrttCache; const MAX_REFERRAL_DEPTH: u8 = 10; const MAX_CNAME_DEPTH: u8 = 8; -const NS_QUERY_TIMEOUT: Duration = Duration::from_millis(800); -const TCP_TIMEOUT: Duration = Duration::from_millis(1500); +const NS_QUERY_TIMEOUT: Duration = Duration::from_millis(400); +const TCP_TIMEOUT: Duration = Duration::from_millis(400); const UDP_FAIL_THRESHOLD: u8 = 3; static QUERY_ID: AtomicU16 = AtomicU16::new(1); @@ -213,11 +213,13 @@ pub(crate) fn resolve_iterative<'a>( ns_addrs[ns_idx], q_type, q_name, current_zone, referral_depth ); - let response = match send_query_hedged(q_name, q_type, &ns_addrs[ns_idx..], srtt).await { + let response = match send_query_hedged(q_name, q_type, &ns_addrs[ns_idx..], srtt).await + { Ok(r) => r, Err(e) => { debug!("recursive: NS query failed: {}", e); - ns_idx += 2; // both tried, skip past them + let remaining = ns_addrs.len().saturating_sub(ns_idx); + ns_idx += remaining.min(2); continue; } }; @@ -660,7 +662,10 @@ async fn send_query_hedged( } match (a_err.take(), b_err.take()) { (Some(e), Some(_)) => return Err(e), - (a, b) => { a_err = a; b_err = b; } + (a, b) => { + a_err = a; + b_err = b; + } } } } else { @@ -739,9 +744,13 @@ async fn send_query( "send_query: {} consecutive UDP failures — switching to TCP-first", fails ); + // Now that UDP is disabled, retry this query via TCP + return tcp_with_srtt(&query, server, srtt, start).await; } - debug!("send_query: UDP failed for {}: {}, trying TCP", server, e); - tcp_with_srtt(&query, server, srtt, start).await + // UDP works in general (priming succeeded) but this server timed out. + // Don't waste another 400ms on TCP — the server is unreachable. + srtt.write().unwrap().record_failure(server.ip()); + Err(e) } } } @@ -1021,10 +1030,10 @@ mod tests { } /// TCP-only server returns authoritative answer directly. - /// Verifies: UDP fails → TCP fallback → resolves. + /// Verifies: when UDP is disabled, TCP-first resolves. #[tokio::test] async fn tcp_fallback_resolves_when_udp_blocked() { - UDP_DISABLED.store(false, Ordering::Relaxed); + UDP_DISABLED.store(true, Ordering::Relaxed); UDP_FAILURES.store(0, Ordering::Release); let server_addr = spawn_tcp_dns_server(|query| { @@ -1107,7 +1116,7 @@ mod tests { #[tokio::test] async fn tcp_fallback_handles_nxdomain() { - UDP_DISABLED.store(false, Ordering::Relaxed); + UDP_DISABLED.store(true, Ordering::Relaxed); UDP_FAILURES.store(0, Ordering::Release); let server_addr = spawn_tcp_dns_server(|query| { From c1b651aa636acf9fe582e5809f43e913ce182f88 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 06:25:42 +0300 Subject: [PATCH 03/21] chore: remove obsolete bash benchmark script --- scripts/bench-recursive.sh | 115 ------------------------------------- 1 file changed, 115 deletions(-) delete mode 100755 scripts/bench-recursive.sh diff --git a/scripts/bench-recursive.sh b/scripts/bench-recursive.sh deleted file mode 100755 index 1a1ab71..0000000 --- a/scripts/bench-recursive.sh +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env bash -# Bench: Numa cold-cache recursive resolution vs dig (forwarded through system resolver) -# -# Measures cold-cache recursive resolution time for Numa. -# Flushes Numa's cache before each query to ensure cold-cache. -# Compares against dig querying a public recursive resolver (no cache advantage). -# -# Usage: ./scripts/bench-recursive.sh [numa_port] - -set -euo pipefail - -NUMA_ADDR="${NUMA_ADDR:-127.0.0.1}" -NUMA_PORT="${NUMA_PORT:-${1:-53}}" -API_PORT="${API_PORT:-5380}" -ROUNDS=3 - -DOMAINS=( - "example.com" - "rust-lang.org" - "kernel.org" - "signal.org" - "archlinux.org" - "openbsd.org" - "git-scm.com" - "sqlite.org" - "wireguard.com" - "mozilla.org" -) - -GREEN='\033[0;32m' -AMBER='\033[0;33m' -CYAN='\033[0;36m' -DIM='\033[0;90m' -BOLD='\033[1m' -RESET='\033[0m' - -echo -e "${CYAN}${BOLD}Recursive DNS Resolution Benchmark${RESET}" -echo -e "${DIM}Numa (cold cache, recursive from root) vs dig @1.1.1.1 (public resolver)${RESET}" -echo -e "${DIM}Rounds per domain: ${ROUNDS}${RESET}" -echo "" - -# Verify Numa is reachable -if ! dig @${NUMA_ADDR} -p ${NUMA_PORT} +short +time=3 +tries=1 example.com A &>/dev/null; then - echo -e "${AMBER}Numa not responding on ${NUMA_ADDR}:${NUMA_PORT}${RESET}" >&2 - exit 1 -fi - -# Verify we can flush cache -if ! curl -s -X DELETE "http://${NUMA_ADDR}:${API_PORT}/cache" &>/dev/null; then - echo -e "${AMBER}Cannot flush cache via API at ${NUMA_ADDR}:${API_PORT}${RESET}" >&2 - exit 1 -fi - -measure_ms() { - local start end - start=$(python3 -c 'import time; print(time.time())') - eval "$1" &>/dev/null - end=$(python3 -c 'import time; print(time.time())') - python3 -c "print(round(($end - $start) * 1000, 1))" -} - -printf "${BOLD}%-22s %10s %10s %8s${RESET}\n" "Domain" "Numa (ms)" "1.1.1.1" "Delta" -printf "%-22s %10s %10s %8s\n" "----------------------" "----------" "----------" "--------" - -numa_total=0 -dig_total=0 -count=0 - -for domain in "${DOMAINS[@]}"; do - numa_sum=0 - dig_sum=0 - - for ((r=1; r<=ROUNDS; r++)); do - # Flush Numa cache - curl -s -X DELETE "http://${NUMA_ADDR}:${API_PORT}/cache" &>/dev/null - sleep 0.05 - - # Measure Numa (recursive from root, cold cache) - ms=$(measure_ms "dig @${NUMA_ADDR} -p ${NUMA_PORT} +short +time=10 +tries=1 ${domain} A") - numa_sum=$(python3 -c "print(round($numa_sum + $ms, 1))") - - # Measure dig against 1.1.1.1 (Cloudflare — warm cache, but shows baseline) - ms=$(measure_ms "dig @1.1.1.1 +short +time=10 +tries=1 ${domain} A") - dig_sum=$(python3 -c "print(round($dig_sum + $ms, 1))") - done - - numa_avg=$(python3 -c "print(round($numa_sum / $ROUNDS, 1))") - dig_avg=$(python3 -c "print(round($dig_sum / $ROUNDS, 1))") - delta=$(python3 -c "d = round($numa_avg - $dig_avg, 1); print(f'+{d}' if d > 0 else str(d))") - - # Color the delta - delta_color="$GREEN" - if python3 -c "exit(0 if $numa_avg > $dig_avg * 1.5 else 1)" 2>/dev/null; then - delta_color="$AMBER" - fi - - printf "%-22s %8s ms %8s ms ${delta_color}%6s ms${RESET}\n" "$domain" "$numa_avg" "$dig_avg" "$delta" - - numa_total=$(python3 -c "print(round($numa_total + $numa_avg, 1))") - dig_total=$(python3 -c "print(round($dig_total + $dig_avg, 1))") - count=$((count + 1)) -done - -echo "" -numa_mean=$(python3 -c "print(round($numa_total / $count, 1))") -dig_mean=$(python3 -c "print(round($dig_total / $count, 1))") -delta_mean=$(python3 -c "d = round($numa_mean - $dig_mean, 1); print(f'+{d}' if d > 0 else str(d))") - -printf "${BOLD}%-22s %8s ms %8s ms %6s ms${RESET}\n" "AVERAGE" "$numa_mean" "$dig_mean" "$delta_mean" - -echo "" -echo -e "${DIM}Note: Numa resolves recursively from root hints (cold cache).${RESET}" -echo -e "${DIM}1.1.1.1 serves from Cloudflare's global cache (warm). The comparison${RESET}" -echo -e "${DIM}is intentionally unfair — it shows Numa's worst case vs the best case${RESET}" -echo -e "${DIM}of a global anycast resolver. Cached Numa queries resolve in <1ms.${RESET}" From 72b540a44aadbb34867e2d12ab9c9630015b4b44 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 06:27:38 +0300 Subject: [PATCH 04/21] feat: wire-level cache, serve-stale, raw wire passthrough - Cache stores raw DNS wire bytes + TTL offsets (2.4x memory reduction) - Serve-stale (RFC 8767): expired entries returned with TTL=1 for 1hr - handle_query captures raw_len from recv_from for zero-copy forwarding - resolve_query accepts raw wire bytes, forwards without re-serializing - wire.rs: TTL offset scanner, ID/TTL patching, question extraction - 52 wire tests + 16 cache regression tests --- src/ctx.rs | 34 +++++-- src/wire.rs | 270 ++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 231 insertions(+), 73 deletions(-) diff --git a/src/ctx.rs b/src/ctx.rs index 2b26a06..46316f2 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -16,9 +16,7 @@ use crate::blocklist::BlocklistStore; use crate::buffer::BytePacketBuffer; use crate::cache::{DnsCache, DnssecStatus}; use crate::config::{UpstreamMode, ZoneMap}; -use crate::forward::{ - forward_query_raw, forward_with_failover_raw, Upstream, UpstreamPool, -}; +use crate::forward::{forward_query_raw, forward_with_failover_raw, Upstream, UpstreamPool}; use crate::header::ResultCode; use crate::health::HealthMeta; use crate::lan::PeerStore; @@ -182,9 +180,7 @@ pub async fn resolve_query( // (e.g. Tailscale .ts.net, VPC private zones) let upstream = Upstream::Udp(fwd_addr); match forward_and_cache(raw_wire, &upstream, ctx, &qname, qtype).await { - Ok(resp) => { - (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate) - } + Ok(resp) => (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate), Err(e) => { error!( "{} | {:?} {} | FORWARD ERROR | {}", @@ -224,17 +220,35 @@ pub async fn resolve_query( (resp, path, DnssecStatus::Indeterminate) } else { let pool = ctx.upstream_pool.lock().unwrap().clone(); - match forward_with_failover_raw(raw_wire, &pool, &ctx.srtt, ctx.timeout, ctx.hedge_delay).await { + match forward_with_failover_raw( + raw_wire, + &pool, + &ctx.srtt, + ctx.timeout, + ctx.hedge_delay, + ) + .await + { Ok(resp_wire) => { ctx.cache.write().unwrap().insert_wire( - &qname, qtype, &resp_wire, DnssecStatus::Indeterminate, + &qname, + qtype, + &resp_wire, + DnssecStatus::Indeterminate, ); let mut buf = BytePacketBuffer::from_bytes(&resp_wire); match DnsPacket::from_buffer(&mut buf) { Ok(resp) => (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate), Err(e) => { - error!("{} | {:?} {} | PARSE ERROR | {}", src_addr, qtype, qname, e); - (DnsPacket::response_from(&query, ResultCode::SERVFAIL), QueryPath::UpstreamError, DnssecStatus::Indeterminate) + error!( + "{} | {:?} {} | PARSE ERROR | {}", + src_addr, qtype, qname, e + ); + ( + DnsPacket::response_from(&query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + DnssecStatus::Indeterminate, + ) } } } diff --git a/src/wire.rs b/src/wire.rs index 6b68c3a..a93fe27 100644 --- a/src/wire.rs +++ b/src/wire.rs @@ -309,7 +309,11 @@ mod tests { #[test] fn scan_single_a_record() { - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let wire = to_wire(&pkt); let meta = scan_ttl_offsets(&wire).unwrap(); @@ -341,15 +345,20 @@ mod tests { let ttls: Vec = meta .ttl_offsets .iter() - .map(|&off| u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]])) + .map(|&off| { + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]) + }) .collect(); assert_eq!(ttls, vec![300, 600, 120]); } #[test] fn scan_mixed_sections() { - let mut pkt = - response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); pkt.authorities .push(ns_record("example.com", "ns1.example.com", 3600)); pkt.authorities @@ -382,7 +391,9 @@ mod tests { let ttls: Vec = meta .ttl_offsets .iter() - .map(|&off| u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]])) + .map(|&off| { + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]) + }) .collect(); assert_eq!(ttls, vec![300, 600]); } @@ -410,15 +421,20 @@ mod tests { let ttls: Vec = meta .ttl_offsets .iter() - .map(|&off| u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]])) + .map(|&off| { + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]) + }) .collect(); assert_eq!(ttls, vec![300, 600]); } #[test] fn scan_edns_opt_excluded() { - let mut pkt = - response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); pkt.edns = Some(EdnsOpt { udp_payload_size: 1232, extended_rcode: 0, @@ -436,8 +452,11 @@ mod tests { #[test] fn scan_rrsig_only_wire_ttl() { - let mut pkt = - response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); pkt.answers.push(DnsRecord::RRSIG { domain: "example.com".into(), type_covered: 1, // A @@ -460,8 +479,7 @@ mod tests { // Both wire TTLs should be 300, not 9999 for &off in &meta.ttl_offsets { - let ttl = - u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]); + let ttl = u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]); assert_eq!(ttl, 300); } @@ -479,8 +497,11 @@ mod tests { #[test] fn scan_nsec_variable_rdata() { - let mut pkt = - response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); pkt.authorities.push(DnsRecord::NSEC { domain: "example.com".into(), next_domain: "z.example.com".into(), @@ -534,7 +555,11 @@ mod tests { #[test] fn scan_truncated_wire_returns_error() { - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let wire = to_wire(&pkt); // Truncate mid-record let truncated = &wire[..wire.len() - 2]; @@ -558,7 +583,11 @@ mod tests { #[test] fn patch_ttl_single() { - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let mut wire = to_wire(&pkt); let meta = scan_ttl_offsets(&wire).unwrap(); @@ -597,7 +626,11 @@ mod tests { #[test] fn patch_ttl_preserves_other_bytes() { - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let original = to_wire(&pkt); let mut patched = original.clone(); let meta = scan_ttl_offsets(&patched).unwrap(); @@ -606,10 +639,7 @@ mod tests { // Every byte outside TTL offsets should be identical for (i, (&orig, &patc)) in original.iter().zip(patched.iter()).enumerate() { - let in_ttl = meta - .ttl_offsets - .iter() - .any(|&off| i >= off && i < off + 4); + let in_ttl = meta.ttl_offsets.iter().any(|&off| i >= off && i < off + 4); if !in_ttl { assert_eq!( orig, patc, @@ -622,7 +652,11 @@ mod tests { #[test] fn patch_ttl_zero() { - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let mut wire = to_wire(&pkt); let meta = scan_ttl_offsets(&wire).unwrap(); @@ -634,7 +668,11 @@ mod tests { #[test] fn patch_ttl_max_u32() { - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let mut wire = to_wire(&pkt); let meta = scan_ttl_offsets(&wire).unwrap(); @@ -646,8 +684,11 @@ mod tests { #[test] fn patch_ttl_edns_untouched() { - let mut pkt = - response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); pkt.edns = Some(EdnsOpt { udp_payload_size: 1232, extended_rcode: 0, @@ -664,10 +705,7 @@ mod tests { // Only the A record's TTL bytes should differ; everything else // (including the OPT "TTL" containing the DO bit) must be unchanged. for (i, (&orig, &patc)) in original.iter().zip(patched.iter()).enumerate() { - let in_ttl = meta - .ttl_offsets - .iter() - .any(|&off| i >= off && i < off + 4); + let in_ttl = meta.ttl_offsets.iter().any(|&off| i >= off && i < off + 4); if !in_ttl { assert_eq!( orig, patc, @@ -682,7 +720,11 @@ mod tests { #[test] fn patch_id_basic() { - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let mut wire = to_wire(&pkt); patch_id(&mut wire, 0xABCD); @@ -691,7 +733,11 @@ mod tests { #[test] fn patch_id_preserves_flags() { - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let original = to_wire(&pkt); let mut patched = original.clone(); @@ -703,7 +749,11 @@ mod tests { #[test] fn patch_id_zero() { - let pkt = response(0xFFFF, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0xFFFF, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let mut wire = to_wire(&pkt); patch_id(&mut wire, 0x0000); @@ -782,7 +832,11 @@ mod tests { #[test] fn round_trip_simple_a() { - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let wire = to_wire(&pkt); let meta = scan_ttl_offsets(&wire).unwrap(); @@ -808,8 +862,11 @@ mod tests { #[test] fn round_trip_edns_survives() { - let mut pkt = - response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); pkt.edns = Some(EdnsOpt { udp_payload_size: 1232, extended_rcode: 0, @@ -1017,7 +1074,11 @@ mod tests { #[test] fn cache_insert_lookup_hit() { let mut cache = DnsCache::new(100, 1, 3600); - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); cache.insert("example.com", QueryType::A, &pkt); let (result, status) = cache @@ -1030,10 +1091,16 @@ mod tests { #[test] fn cache_lookup_adjusts_ttl() { let mut cache = DnsCache::new(100, 1, 3600); - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); cache.insert("example.com", QueryType::A, &pkt); - let (result, _) = cache.lookup_with_status("example.com", QueryType::A).unwrap(); + let (result, _) = cache + .lookup_with_status("example.com", QueryType::A) + .unwrap(); // TTL should be <= 300 (at most original, reduced by elapsed time) assert!(result.answers[0].ttl() <= 300); assert!(result.answers[0].ttl() > 0); @@ -1042,7 +1109,11 @@ mod tests { #[test] fn cache_miss_wrong_domain() { let mut cache = DnsCache::new(100, 1, 3600); - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); cache.insert("example.com", QueryType::A, &pkt); assert!(cache @@ -1053,7 +1124,11 @@ mod tests { #[test] fn cache_miss_wrong_qtype() { let mut cache = DnsCache::new(100, 1, 3600); - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); cache.insert("example.com", QueryType::A, &pkt); assert!(cache @@ -1064,8 +1139,16 @@ mod tests { #[test] fn cache_overwrite_no_double_count() { let mut cache = DnsCache::new(100, 1, 3600); - let pkt1 = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); - let pkt2 = response(0x5678, "example.com", vec![a_record("example.com", "5.6.7.8", 600)]); + let pkt1 = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let pkt2 = response( + 0x5678, + "example.com", + vec![a_record("example.com", "5.6.7.8", 600)], + ); cache.insert("example.com", QueryType::A, &pkt1); assert_eq!(cache.len(), 1); @@ -1073,7 +1156,9 @@ mod tests { cache.insert("example.com", QueryType::A, &pkt2); assert_eq!(cache.len(), 1); // no double count - let (result, _) = cache.lookup_with_status("example.com", QueryType::A).unwrap(); + let (result, _) = cache + .lookup_with_status("example.com", QueryType::A) + .unwrap(); match &result.answers[0] { DnsRecord::A { addr, .. } => { assert_eq!(*addr, "5.6.7.8".parse::().unwrap()) @@ -1085,7 +1170,11 @@ mod tests { #[test] fn cache_ttl_clamped_min() { let mut cache = DnsCache::new(100, 60, 3600); - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 5)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 5)], + ); cache.insert("example.com", QueryType::A, &pkt); let (remaining, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); @@ -1096,8 +1185,11 @@ mod tests { #[test] fn cache_ttl_clamped_max() { let mut cache = DnsCache::new(100, 1, 3600); - let pkt = - response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 999999)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 999999)], + ); cache.insert("example.com", QueryType::A, &pkt); let (_, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); @@ -1110,7 +1202,11 @@ mod tests { assert!(cache.is_empty()); assert_eq!(cache.len(), 0); - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); cache.insert("example.com", QueryType::A, &pkt); assert!(!cache.is_empty()); assert_eq!(cache.len(), 1); @@ -1124,7 +1220,11 @@ mod tests { #[test] fn cache_remove_domain() { let mut cache = DnsCache::new(100, 1, 3600); - let pkt_a = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt_a = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let pkt_aaaa = response( 0x5678, "example.com", @@ -1143,8 +1243,16 @@ mod tests { #[test] fn cache_list_entries() { let mut cache = DnsCache::new(100, 1, 3600); - let pkt_a = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); - let pkt_b = response(0x5678, "test.org", vec![a_record("test.org", "5.6.7.8", 600)]); + let pkt_a = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let pkt_b = response( + 0x5678, + "test.org", + vec![a_record("test.org", "5.6.7.8", 600)], + ); cache.insert("example.com", QueryType::A, &pkt_a); cache.insert("test.org", QueryType::A, &pkt_b); @@ -1160,7 +1268,11 @@ mod tests { let mut cache = DnsCache::new(100, 1, 3600); let empty = cache.heap_bytes(); - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); cache.insert("example.com", QueryType::A, &pkt); assert!(cache.heap_bytes() > empty); } @@ -1173,7 +1285,11 @@ mod tests { assert!(cache.needs_warm("example.com")); // Both A and AAAA cached → does not need warm - let pkt_a = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt_a = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); let pkt_aaaa = response( 0x5678, "example.com", @@ -1194,7 +1310,11 @@ mod tests { let mut cache = DnsCache::new(100, 60, 3600); assert!(cache.ttl_remaining("missing.com", QueryType::A).is_none()); - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); cache.insert("example.com", QueryType::A, &pkt); let (remaining, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); assert_eq!(total, 300); @@ -1205,7 +1325,11 @@ mod tests { #[test] fn cache_dnssec_status_preserved() { let mut cache = DnsCache::new(100, 1, 3600); - let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); cache.insert_with_status("example.com", QueryType::A, &pkt, DnssecStatus::Secure); let (_, status) = cache @@ -1225,7 +1349,9 @@ mod tests { let mut cache = DnsCache::new(1000, 1, 3600); // Simulate a realistic cache: 50 domains, mix of record types - let domains: Vec = (0..50).map(|i| format!("domain{}.example.com", i)).collect(); + let domains: Vec = (0..50) + .map(|i| format!("domain{}.example.com", i)) + .collect(); let mut total_wire_bytes = 0usize; let mut total_wire_meta_bytes = 0usize; @@ -1259,8 +1385,7 @@ mod tests { let wire_aaaa = to_wire(&pkt_aaaa); let meta_aaaa = scan_ttl_offsets(&wire_aaaa).unwrap(); total_wire_bytes += wire_aaaa.len(); - total_wire_meta_bytes += - meta_aaaa.ttl_offsets.len() * std::mem::size_of::(); + total_wire_meta_bytes += meta_aaaa.ttl_offsets.len() * std::mem::size_of::(); } } @@ -1300,15 +1425,31 @@ mod tests { // Also measure the struct size difference per entry let parsed_struct = std::mem::size_of::(); - let wire_struct = std::mem::size_of::>() + std::mem::size_of::>() + std::mem::size_of::(); // wire + offsets + answer_count + let wire_struct = std::mem::size_of::>() + + std::mem::size_of::>() + + std::mem::size_of::(); // wire + offsets + answer_count println!(); - println!("=== Cache Memory Footprint Baseline ({} entries) ===", entry_count); + println!( + "=== Cache Memory Footprint Baseline ({} entries) ===", + entry_count + ); println!(); println!("Variable data (heap, per-entry payload):"); - println!(" Parsed (packet.heap_bytes): {} bytes ({:.1}/entry)", parsed_data_bytes, parsed_data_bytes as f64 / entry_count as f64); - println!(" Wire (bytes + TTL offsets): {} bytes ({:.1}/entry)", wire_total, wire_total as f64 / entry_count as f64); - println!(" Ratio: {:.1}x smaller with wire", parsed_data_bytes as f64 / wire_total as f64); + println!( + " Parsed (packet.heap_bytes): {} bytes ({:.1}/entry)", + parsed_data_bytes, + parsed_data_bytes as f64 / entry_count as f64 + ); + println!( + " Wire (bytes + TTL offsets): {} bytes ({:.1}/entry)", + wire_total, + wire_total as f64 / entry_count as f64 + ); + println!( + " Ratio: {:.1}x smaller with wire", + parsed_data_bytes as f64 / wire_total as f64 + ); println!(); println!("Struct overhead (stack, per entry):"); println!(" DnsPacket: {} bytes", parsed_struct); @@ -1319,7 +1460,10 @@ mod tests { let wire_total_per = wire_struct as f64 + wire_total as f64 / entry_count as f64; println!(" Parsed: {:.0} bytes", parsed_total_per); println!(" Wire: {:.0} bytes", wire_total_per); - println!(" Ratio: {:.1}x smaller with wire", parsed_total_per / wire_total_per); + println!( + " Ratio: {:.1}x smaller with wire", + parsed_total_per / wire_total_per + ); println!(); // Assertions From 17a1a6ddba351d8b5ec529ef5ef242e57bcb56ec Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 06:42:59 +0300 Subject: [PATCH 05/21] refactor: remove forward_with_failover duplication, fix warm-branch hedge bug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove forward_with_failover (parsed): warm_domain now uses _raw + insert_wire - forward_udp delegates to forward_udp_raw (single UDP socket implementation) - forward_query uses unified _raw path for all protocols - Fix send_query_hedged warm branch: bare select! dropped secondary on primary error instead of waiting for it — now drains both futures like the cold branch - Remove pointless raw_len = len rename --- src/forward.rs | 85 +++++++++--------------------------------------- src/main.rs | 52 +++++++++++++++++------------ src/recursive.rs | 27 +++++++++++++-- 3 files changed, 71 insertions(+), 93 deletions(-) diff --git a/src/forward.rs b/src/forward.rs index 6afb7e5..ebbe777 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -157,58 +157,6 @@ impl UpstreamPool { } } -pub async fn forward_with_failover( - query: &DnsPacket, - pool: &UpstreamPool, - srtt: &RwLock, - timeout_duration: Duration, -) -> Result { - // Build candidate list: primary (sorted by SRTT for UDP) then fallback - let mut candidates: Vec<(usize, u64)> = pool - .primary - .iter() - .enumerate() - .map(|(i, u)| { - let rtt = match u { - Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()), - _ => 0, // DoH: keep config order (stable sort preserves it) - }; - (i, rtt) - }) - .collect(); - candidates.sort_by_key(|&(_, rtt)| rtt); - - let all_upstreams: Vec<&Upstream> = candidates - .iter() - .map(|&(i, _)| &pool.primary[i]) - .chain(pool.fallback.iter()) - .collect(); - - let mut last_err: Option> = None; - - for upstream in &all_upstreams { - let start = Instant::now(); - match forward_query(query, upstream, timeout_duration).await { - Ok(resp) => { - if let Upstream::Udp(addr) = upstream { - let rtt_ms = start.elapsed().as_millis() as u64; - srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false); - } - return Ok(resp); - } - Err(e) => { - if let Upstream::Udp(addr) = upstream { - srtt.write().unwrap().record_failure(addr.ip()); - } - log::debug!("upstream {} failed: {}", upstream, e); - last_err = Some(e); - } - } - } - - Err(last_err.unwrap_or_else(|| "no upstream configured".into())) -} - pub async fn forward_query( query: &DnsPacket, upstream: &Upstream, @@ -226,24 +174,14 @@ pub(crate) async fn forward_udp( upstream: SocketAddr, timeout_duration: Duration, ) -> Result { - let socket = UdpSocket::bind("0.0.0.0:0").await?; - let mut send_buffer = BytePacketBuffer::new(); query.write(&mut send_buffer)?; - socket.send_to(send_buffer.filled(), upstream).await?; - - let mut recv_buffer = BytePacketBuffer::new(); - let (size, _) = timeout(timeout_duration, socket.recv_from(&mut recv_buffer.buf)).await??; - - if size == recv_buffer.buf.len() { - log::debug!( - "upstream response truncated ({} bytes, buffer {})", - size, - recv_buffer.buf.len() - ); + let data = forward_udp_raw(send_buffer.filled(), upstream, timeout_duration).await?; + if data.len() >= 4096 { + log::debug!("upstream response may be truncated ({} bytes)", data.len()); } - + let mut recv_buffer = BytePacketBuffer::from_bytes(&data); DnsPacket::from_buffer(&mut recv_buffer) } @@ -721,10 +659,19 @@ mod tests { ); let srtt = RwLock::new(SrttCache::new(true)); - let result = forward_with_failover(&query, &pool, &srtt, Duration::from_millis(500)) - .await - .expect("should fail over to second upstream"); + let wire = to_wire(&query); + let resp_wire = forward_with_failover_raw( + &wire, + &pool, + &srtt, + Duration::from_millis(500), + Duration::ZERO, + ) + .await + .expect("should fail over to second upstream"); + let mut buf = BytePacketBuffer::from_bytes(&resp_wire); + let result = DnsPacket::from_buffer(&mut buf).unwrap(); assert_eq!(result.header.id, 0xABCD); assert_eq!(result.answers.len(), 1); } diff --git a/src/main.rs b/src/main.rs index 0211a59..68e4794 100644 --- a/src/main.rs +++ b/src/main.rs @@ -607,11 +607,9 @@ async fn main() -> numa::Result<()> { } Err(e) => return Err(e.into()), }; - let raw_len = len; - let ctx = Arc::clone(&ctx); tokio::spawn(async move { - if let Err(e) = handle_query(buffer, raw_len, src_addr, &ctx).await { + if let Err(e) = handle_query(buffer, len, src_addr, &ctx).await { error!("{} | HANDLER ERROR | {}", src_addr, e); } }); @@ -762,27 +760,39 @@ async fn warm_domain(ctx: &ServerCtx, domain: &str) { use numa::question::QueryType; for qtype in [QueryType::A, QueryType::AAAA] { - let query = numa::packet::DnsPacket::query(0, domain, qtype); - let result = if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { - numa::recursive::resolve_recursive( - domain, - qtype, - &ctx.cache, - &query, - &ctx.root_hints, - &ctx.srtt, + if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { + let query = numa::packet::DnsPacket::query(0, domain, qtype); + match numa::recursive::resolve_recursive( + domain, qtype, &ctx.cache, &query, &ctx.root_hints, &ctx.srtt, ) .await - } else { - let pool = ctx.upstream_pool.lock().unwrap().clone(); - numa::forward::forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await - }; - match result { - Ok(resp) => { - ctx.cache.write().unwrap().insert(domain, qtype, &resp); - log::debug!("cache warm: {} {:?}", domain, qtype); + { + Ok(resp) => { + ctx.cache.write().unwrap().insert(domain, qtype, &resp); + log::debug!("cache warm: {} {:?}", domain, qtype); + } + Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), + } + } else { + let query = numa::packet::DnsPacket::query(0, domain, qtype); + let mut buf = numa::buffer::BytePacketBuffer::new(); + if query.write(&mut buf).is_err() { + continue; + } + let pool = ctx.upstream_pool.lock().unwrap().clone(); + match numa::forward::forward_with_failover_raw( + buf.filled(), &pool, &ctx.srtt, ctx.timeout, ctx.hedge_delay, + ) + .await + { + Ok(wire) => { + ctx.cache.write().unwrap().insert_wire( + domain, qtype, &wire, numa::cache::DnssecStatus::Indeterminate, + ); + log::debug!("cache warm: {} {:?}", domain, qtype); + } + Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), } - Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), } } } diff --git a/src/recursive.rs b/src/recursive.rs index 190a57a..70f35c0 100644 --- a/src/recursive.rs +++ b/src/recursive.rs @@ -690,9 +690,30 @@ async fn send_query_hedged( let fut_b = send_query(qname, qtype, secondary, srtt); tokio::pin!(fut_b); - tokio::select! { - r = fut_a => r, - r = fut_b => r, + // First Ok wins; if one errors, wait for the other. + let mut a_err: Option = None; + let mut b_err: Option = None; + loop { + tokio::select! { + r = &mut fut_a, if a_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if b_err.is_some() { return Err(e); } + a_err = Some(e); + } + } + } + r = &mut fut_b, if b_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if let Some(ae) = a_err.take() { return Err(ae); } + b_err = Some(e); + } + } + } + } } } } From f705f8c49fc89d2919ed5f39d95239318ca7814d Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 06:45:10 +0300 Subject: [PATCH 06/21] fix: bump TCP_TIMEOUT to 800ms to fix flaky CI test --- src/recursive.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/recursive.rs b/src/recursive.rs index 70f35c0..0910421 100644 --- a/src/recursive.rs +++ b/src/recursive.rs @@ -16,7 +16,7 @@ use crate::srtt::SrttCache; const MAX_REFERRAL_DEPTH: u8 = 10; const MAX_CNAME_DEPTH: u8 = 8; const NS_QUERY_TIMEOUT: Duration = Duration::from_millis(400); -const TCP_TIMEOUT: Duration = Duration::from_millis(400); +const TCP_TIMEOUT: Duration = Duration::from_millis(800); const UDP_FAIL_THRESHOLD: u8 = 3; static QUERY_ID: AtomicU16 = AtomicU16::new(1); From 700cca9cb616aeecf5d28c52a099f2f134b318ac Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 11:09:20 +0300 Subject: [PATCH 07/21] style: rustfmt warm_domain --- src/main.rs | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/main.rs b/src/main.rs index 68e4794..ebc16cc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -763,7 +763,12 @@ async fn warm_domain(ctx: &ServerCtx, domain: &str) { if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { let query = numa::packet::DnsPacket::query(0, domain, qtype); match numa::recursive::resolve_recursive( - domain, qtype, &ctx.cache, &query, &ctx.root_hints, &ctx.srtt, + domain, + qtype, + &ctx.cache, + &query, + &ctx.root_hints, + &ctx.srtt, ) .await { @@ -781,13 +786,20 @@ async fn warm_domain(ctx: &ServerCtx, domain: &str) { } let pool = ctx.upstream_pool.lock().unwrap().clone(); match numa::forward::forward_with_failover_raw( - buf.filled(), &pool, &ctx.srtt, ctx.timeout, ctx.hedge_delay, + buf.filled(), + &pool, + &ctx.srtt, + ctx.timeout, + ctx.hedge_delay, ) .await { Ok(wire) => { ctx.cache.write().unwrap().insert_wire( - domain, qtype, &wire, numa::cache::DnssecStatus::Indeterminate, + domain, + qtype, + &wire, + numa::cache::DnssecStatus::Indeterminate, ); log::debug!("cache warm: {} {:?}", domain, qtype); } From 67b472fea787227faa99c19a6bab6f24fd981d29 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 11:47:48 +0300 Subject: [PATCH 08/21] fix: serialize tests that share global UDP_DISABLED state The tcp_only_iterative_resolution, tcp_fallback_resolves_when_udp_blocked, tcp_fallback_handles_nxdomain, and udp_auto_disable_resets tests all mutate global UDP_DISABLED / UDP_FAILURES atomics. Under cargo test parallelism, udp_auto_disable_resets would reset the flag mid-flight causing other tests to attempt UDP against TCP-only mock servers and time out. Fix: static Mutex serializes tests that depend on global UDP state. Also: tcp_only_iterative_resolution now calls forward_tcp directly, removing its dependence on the flag entirely. --- src/recursive.rs | 54 ++++++++++++++++++++---------------------------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/src/recursive.rs b/src/recursive.rs index 0910421..53397d2 100644 --- a/src/recursive.rs +++ b/src/recursive.rs @@ -813,6 +813,10 @@ mod tests { use super::*; use std::net::{Ipv4Addr, Ipv6Addr}; + /// Tests that mutate the global UDP_DISABLED / UDP_FAILURES flags must hold + /// this lock to avoid racing with each other under `cargo test` parallelism. + static UDP_STATE_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + #[test] fn extract_ns_from_authority() { let mut pkt = DnsPacket::new(); @@ -1054,6 +1058,7 @@ mod tests { /// Verifies: when UDP is disabled, TCP-first resolves. #[tokio::test] async fn tcp_fallback_resolves_when_udp_blocked() { + let _guard = UDP_STATE_LOCK.lock().unwrap(); UDP_DISABLED.store(true, Ordering::Relaxed); UDP_FAILURES.store(0, Ordering::Release); @@ -1085,49 +1090,32 @@ mod tests { } } - /// Full iterative resolution through TCP-only mock: root referral → authoritative answer. - /// The mock plays both roles (returns referral for NS queries, answer for A queries). + /// TCP round-trip through mock: query → authoritative answer via forward_tcp. + /// Uses forward_tcp directly to avoid dependence on the global UDP_DISABLED flag + /// which is shared across concurrent tests. #[tokio::test] async fn tcp_only_iterative_resolution() { - UDP_DISABLED.store(true, Ordering::Release); // Skip UDP entirely for speed - let server_addr = spawn_tcp_dns_server(|query| { let q = match query.questions.first() { Some(q) => q, None => return DnsPacket::response_from(query, ResultCode::SERVFAIL), }; - if q.qtype == QueryType::NS || q.name == "com" { - // Return referral — NS points back to ourselves (same IP, port 53 in glue - // won't work, but cache will have our address from root_hints) - let mut resp = DnsPacket::new(); - resp.header.id = query.header.id; - resp.header.response = true; - resp.header.rescode = ResultCode::NOERROR; - resp.questions = query.questions.clone(); - resp.authorities.push(DnsRecord::NS { - domain: "com".into(), - host: "ns1.com".into(), - ttl: 3600, - }); - resp - } else { - // Return authoritative answer - let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR); - resp.header.authoritative_answer = true; - resp.answers.push(DnsRecord::A { - domain: q.name.clone(), - addr: Ipv4Addr::new(10, 0, 0, 42), - ttl: 300, - }); - resp - } + let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR); + resp.header.authoritative_answer = true; + resp.answers.push(DnsRecord::A { + domain: q.name.clone(), + addr: Ipv4Addr::new(10, 0, 0, 42), + ttl: 300, + }); + resp }) .await; - let srtt = RwLock::new(SrttCache::new(true)); - let result = send_query("hello.example.com", QueryType::A, server_addr, &srtt).await; - let resp = result.expect("TCP-only send_query should work"); + let query = DnsPacket::query(0x1234, "hello.example.com", QueryType::A); + let resp = crate::forward::forward_tcp(&query, server_addr, TCP_TIMEOUT) + .await + .expect("TCP query should work"); assert_eq!(resp.header.rescode, ResultCode::NOERROR); match &resp.answers[0] { DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::new(10, 0, 0, 42)), @@ -1137,6 +1125,7 @@ mod tests { #[tokio::test] async fn tcp_fallback_handles_nxdomain() { + let _guard = UDP_STATE_LOCK.lock().unwrap(); UDP_DISABLED.store(true, Ordering::Relaxed); UDP_FAILURES.store(0, Ordering::Release); @@ -1169,6 +1158,7 @@ mod tests { #[tokio::test] async fn udp_auto_disable_resets() { + let _guard = UDP_STATE_LOCK.lock().unwrap(); UDP_DISABLED.store(true, Ordering::Release); UDP_FAILURES.store(5, Ordering::Relaxed); From 85cff052a4e4efd513b8ef4eb8f4a3b4dcc923a3 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 12:34:20 +0300 Subject: [PATCH 09/21] fix: restore TCP_TIMEOUT to 400ms (test race was the real issue) --- src/recursive.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/recursive.rs b/src/recursive.rs index 53397d2..a4dff08 100644 --- a/src/recursive.rs +++ b/src/recursive.rs @@ -16,7 +16,7 @@ use crate::srtt::SrttCache; const MAX_REFERRAL_DEPTH: u8 = 10; const MAX_CNAME_DEPTH: u8 = 8; const NS_QUERY_TIMEOUT: Duration = Duration::from_millis(400); -const TCP_TIMEOUT: Duration = Duration::from_millis(800); +const TCP_TIMEOUT: Duration = Duration::from_millis(400); const UDP_FAIL_THRESHOLD: u8 = 3; static QUERY_ID: AtomicU16 = AtomicU16::new(1); From 628ed00074dd423b51c71e46211fecc1f17f1bfb Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 13:08:37 +0300 Subject: [PATCH 10/21] refactor: extract cache_and_parse, remove dead truncation log, restore TCP_TIMEOUT to 400ms --- src/ctx.rs | 53 ++++++++++++++++++++++++-------------------------- src/forward.rs | 4 ---- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/src/ctx.rs b/src/ctx.rs index 46316f2..e1d2d95 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -229,29 +229,17 @@ pub async fn resolve_query( ) .await { - Ok(resp_wire) => { - ctx.cache.write().unwrap().insert_wire( - &qname, - qtype, - &resp_wire, - DnssecStatus::Indeterminate, - ); - let mut buf = BytePacketBuffer::from_bytes(&resp_wire); - match DnsPacket::from_buffer(&mut buf) { - Ok(resp) => (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate), - Err(e) => { - error!( - "{} | {:?} {} | PARSE ERROR | {}", - src_addr, qtype, qname, e - ); - ( - DnsPacket::response_from(&query, ResultCode::SERVFAIL), - QueryPath::UpstreamError, - DnssecStatus::Indeterminate, - ) - } + Ok(resp_wire) => match cache_and_parse(ctx, &qname, qtype, &resp_wire) { + Ok(resp) => (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate), + Err(e) => { + error!("{} | {:?} {} | PARSE ERROR | {}", src_addr, qtype, qname, e); + ( + DnsPacket::response_from(&query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + DnssecStatus::Indeterminate, + ) } - } + }, Err(e) => { error!( "{} | {:?} {} | UPSTREAM ERROR | {}", @@ -373,6 +361,20 @@ pub async fn resolve_query( Ok(resp_buffer) } +fn cache_and_parse( + ctx: &ServerCtx, + qname: &str, + qtype: QueryType, + resp_wire: &[u8], +) -> crate::Result { + ctx.cache + .write() + .unwrap() + .insert_wire(qname, qtype, resp_wire, DnssecStatus::Indeterminate); + let mut buf = BytePacketBuffer::from_bytes(resp_wire); + DnsPacket::from_buffer(&mut buf) +} + async fn forward_and_cache( wire: &[u8], upstream: &Upstream, @@ -381,12 +383,7 @@ async fn forward_and_cache( qtype: QueryType, ) -> crate::Result { let resp_wire = forward_query_raw(wire, upstream, ctx.timeout).await?; - ctx.cache - .write() - .unwrap() - .insert_wire(qname, qtype, &resp_wire, DnssecStatus::Indeterminate); - let mut buf = BytePacketBuffer::from_bytes(&resp_wire); - DnsPacket::from_buffer(&mut buf) + cache_and_parse(ctx, qname, qtype, &resp_wire) } pub async fn handle_query( diff --git a/src/forward.rs b/src/forward.rs index ebbe777..839ac81 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -176,11 +176,7 @@ pub(crate) async fn forward_udp( ) -> Result { let mut send_buffer = BytePacketBuffer::new(); query.write(&mut send_buffer)?; - let data = forward_udp_raw(send_buffer.filled(), upstream, timeout_duration).await?; - if data.len() >= 4096 { - log::debug!("upstream response may be truncated ({} bytes)", data.len()); - } let mut recv_buffer = BytePacketBuffer::from_bytes(&data); DnsPacket::from_buffer(&mut recv_buffer) } From 15058aea83c4f171e5f7a8160351b87b6b06d9e3 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 18:35:40 +0300 Subject: [PATCH 11/21] bench: add --vs-nextdns, --vs-unbound-cold modes with mode validation - --vs-nextdns: Numa local cache vs NextDNS cloud (45.90.28.0) - --vs-unbound-cold: unique random subdomains, no record cache hits - check_numa_mode validates forward/recursive mode before running - numa-bench-recursive.toml config for cold benchmarks --- benches/numa-bench-recursive.toml | 30 ++++++++++++++ benches/recursive_compare.rs | 66 ++++++++++++++++++++++++++++--- 2 files changed, 91 insertions(+), 5 deletions(-) create mode 100644 benches/numa-bench-recursive.toml diff --git a/benches/numa-bench-recursive.toml b/benches/numa-bench-recursive.toml new file mode 100644 index 0000000..055d75a --- /dev/null +++ b/benches/numa-bench-recursive.toml @@ -0,0 +1,30 @@ +[server] +bind_addr = "127.0.0.1:5454" +api_port = 5381 +api_bind_addr = "127.0.0.1" +data_dir = "/tmp/numa-bench" + +[upstream] +mode = "recursive" +timeout_ms = 10000 + +[cache] +min_ttl = 60 +max_ttl = 3600 + +[blocking] +enabled = false + +[proxy] +port = 8080 +tls_port = 8443 + +[dot] +enabled = true +port = 8530 + +[mobile] +enabled = false + +[lan] +enabled = false diff --git a/benches/recursive_compare.rs b/benches/recursive_compare.rs index 12f3689..dcff2c5 100644 --- a/benches/recursive_compare.rs +++ b/benches/recursive_compare.rs @@ -7,6 +7,8 @@ //! --direct Library-to-library: Numa forward_query_raw vs Hickory resolver.lookup //! --hedge-5x Hedging: single vs hedge-same vs hedge-dual vs Hickory (5 iterations) //! --vs-unbound Server-to-server: Numa vs Unbound (plain UDP, caching) +//! --vs-unbound-cold Cold: Numa vs Unbound (unique subdomains, no cache hits) +//! --vs-nextdns Server-to-cloud: Numa (local cache) vs NextDNS (remote, 45.90.28.0) //! --vs-dot DoT server: Numa vs Unbound //! --vs-doh-servers DoH server: Numa vs Unbound (DoT upstream) //! @@ -145,10 +147,20 @@ fn main() { return run_hedge_multi(&rt, 5); } if arg("--vs-unbound") { - return run_server_comparison(&rt, "Unbound", "127.0.0.1:5456", 5); + check_numa_mode(&rt, "forward"); + return run_server_comparison(&rt, "Unbound", "127.0.0.1:5456", 5, false); + } + if arg("--vs-unbound-cold") { + check_numa_mode(&rt, "recursive"); + return run_server_comparison(&rt, "Unbound", "127.0.0.1:5456", 5, true); } if arg("--vs-dnscrypt") { - return run_server_comparison(&rt, "dnscrypt-proxy", "127.0.0.1:5455", 5); + check_numa_mode(&rt, "forward"); + return run_server_comparison(&rt, "dnscrypt-proxy", "127.0.0.1:5455", 5, false); + } + if arg("--vs-nextdns") { + check_numa_mode(&rt, "forward"); + return run_server_comparison(&rt, "NextDNS", "45.90.28.0:53", 5, false); } if arg("--vs-dot") { return run_dot_comparison(&rt, 5); @@ -380,12 +392,18 @@ fn run_direct(rt: &tokio::runtime::Runtime) { } /// Server-to-server: Numa vs another server, both on plain UDP. +/// When `cold` is true, each query uses a unique random subdomain so neither +/// server can answer from its record cache (NS delegation caching still applies). fn run_server_comparison( rt: &tokio::runtime::Runtime, other_name: &str, other_addr: &str, iterations: usize, + cold: bool, ) { + use std::sync::atomic::{AtomicU64, Ordering}; + static COUNTER: AtomicU64 = AtomicU64::new(0); + let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); let other: SocketAddr = other_addr.parse().unwrap(); @@ -402,19 +420,35 @@ fn run_server_comparison( let _ = rt.block_on(query_udp(other, "example.com")); } + let tag = if cold { + "cold, unique subdomains" + } else { + "caching" + }; + compare_two( rt, - &format!("Server-to-Server: Numa vs {other_name} (UDP, caching)"), + &format!("Server-to-Server: Numa vs {other_name} (UDP, {tag})"), "Numa", other_name, &|domain| { + let d = if cold { + format!("c{}.{}", COUNTER.fetch_add(1, Ordering::Relaxed), domain) + } else { + domain.to_string() + }; let t = Instant::now(); - let _ = rt.block_on(query_udp(numa_addr, domain)); + let _ = rt.block_on(query_udp(numa_addr, &d)); t.elapsed().as_secs_f64() * 1000.0 }, &|domain| { + let d = if cold { + format!("c{}.{}", COUNTER.fetch_add(1, Ordering::Relaxed), domain) + } else { + domain.to_string() + }; let t = Instant::now(); - let _ = rt.block_on(query_udp(other, domain)); + let _ = rt.block_on(query_udp(other, &d)); t.elapsed().as_secs_f64() * 1000.0 }, iterations, @@ -991,6 +1025,28 @@ fn build_query(buf: &mut [u8], domain: &str) -> usize { pos } +fn check_numa_mode(rt: &tokio::runtime::Runtime, expected: &str) { + let url = format!("http://127.0.0.1:{NUMA_API}/stats"); + let resp = match rt.block_on(async { reqwest::get(&url).await?.text().await }) { + Ok(body) => body, + Err(_) => { + eprintln!("Bench Numa not responding on {NUMA_BENCH}"); + eprintln!("Start with: cargo run -- benches/numa-bench.toml"); + std::process::exit(1); + } + }; + let config = if expected == "recursive" { + "benches/numa-bench-recursive.toml" + } else { + "benches/numa-bench.toml" + }; + if !resp.contains(&format!("\"mode\":\"{expected}\"")) { + eprintln!("This benchmark requires Numa in {expected} mode."); + eprintln!("Restart with: cargo run -- {config}"); + std::process::exit(1); + } +} + fn flush_cache() { let _ = std::process::Command::new("curl") .args([ From 05d5a5145f09765f84d714a4964b71a7a28ab34b Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 18:46:03 +0300 Subject: [PATCH 12/21] refactor: remove unused extract_question and read_wire_qname from wire.rs --- src/wire.rs | 130 ++-------------------------------------------------- 1 file changed, 3 insertions(+), 127 deletions(-) diff --git a/src/wire.rs b/src/wire.rs index a93fe27..8d299ce 100644 --- a/src/wire.rs +++ b/src/wire.rs @@ -3,7 +3,6 @@ //! These operate directly on raw DNS wire bytes without full packet parsing, //! enabling zero-copy forwarding and wire-level caching. -use crate::question::QueryType; use crate::Result; /// Metadata extracted from scanning a DNS response's wire bytes. @@ -18,32 +17,6 @@ pub struct WireMeta { pub answer_count: usize, } -/// Extract the first question's (domain, query type) from raw DNS wire bytes. -/// -/// Reads only the 12-byte header + first question section. Returns the lowercased -/// domain name and query type without allocating a full `DnsPacket`. -pub fn extract_question(wire: &[u8]) -> Result<(String, QueryType)> { - if wire.len() < 12 { - return Err("wire too short for DNS header".into()); - } - let qdcount = u16::from_be_bytes([wire[4], wire[5]]); - if qdcount == 0 { - return Err("no questions in wire".into()); - } - - let mut pos = 12; - let mut domain = String::with_capacity(64); - read_wire_qname(wire, &mut pos, &mut domain)?; - - if pos + 4 > wire.len() { - return Err("wire truncated in question section".into()); - } - let qtype = u16::from_be_bytes([wire[pos], wire[pos + 1]]); - // skip QTYPE(2) + QCLASS(2) - - Ok((domain, QueryType::from_num(qtype))) -} - /// Scan a DNS response's wire bytes and return metadata about TTL field locations. /// /// Walks the header, skips the question section, then for each resource record in @@ -155,62 +128,6 @@ pub fn patch_ttls(wire: &mut [u8], offsets: &[usize], new_ttl: u32) { } } -/// Read a DNS name from wire bytes at `pos`, handling compression pointers. -/// Advances `pos` past the name as it appears at the current position -/// (compression pointer targets do NOT advance `pos`). -fn read_wire_qname(wire: &[u8], pos: &mut usize, out: &mut String) -> Result<()> { - let mut jumped = false; - let mut read_pos = *pos; - let mut jumps = 0; - let max_jumps = 20; - - loop { - if read_pos >= wire.len() { - return Err("wire truncated reading name".into()); - } - let len = wire[read_pos] as usize; - - // Compression pointer: top 2 bits set - if len & 0xC0 == 0xC0 { - if read_pos + 1 >= wire.len() { - return Err("wire truncated in compression pointer".into()); - } - if !jumped { - *pos = read_pos + 2; // advance past the pointer - } - let offset = ((len & 0x3F) << 8) | wire[read_pos + 1] as usize; - read_pos = offset; - jumped = true; - jumps += 1; - if jumps > max_jumps { - return Err("too many compression jumps".into()); - } - continue; - } - - if len == 0 { - if !jumped { - *pos = read_pos + 1; - } - break; - } - - if read_pos + 1 + len > wire.len() { - return Err("wire truncated in name label".into()); - } - - if !out.is_empty() { - out.push('.'); - } - for &b in &wire[read_pos + 1..read_pos + 1 + len] { - out.push(b.to_ascii_lowercase() as char); - } - read_pos += 1 + len; - } - - Ok(()) -} - /// Skip a DNS name in wire bytes, advancing `pos` past it. fn skip_wire_name(wire: &[u8], pos: &mut usize) -> Result<()> { loop { @@ -238,7 +155,7 @@ mod tests { use crate::cache::{DnsCache, DnssecStatus}; use crate::header::ResultCode; use crate::packet::{DnsPacket, EdnsOpt}; - use crate::question::DnsQuestion; + use crate::question::{DnsQuestion, QueryType}; use crate::record::DnsRecord; // ── Helpers ────────────────────────────────────────────────────── @@ -760,43 +677,7 @@ mod tests { assert_eq!(&wire[0..2], &[0x00, 0x00]); } - // ── D. extract_question ───────────────────────────────────────── - - #[test] - fn extract_question_basic() { - let pkt = DnsPacket::query(0x1234, "Example.COM", QueryType::A); - let wire = to_wire(&pkt); - let (domain, qtype) = extract_question(&wire).unwrap(); - - assert_eq!(domain, "example.com"); // lowercased - assert_eq!(qtype, QueryType::A); - } - - #[test] - fn extract_question_aaaa() { - let pkt = DnsPacket::query(0x1234, "rust-lang.org", QueryType::AAAA); - let wire = to_wire(&pkt); - let (domain, qtype) = extract_question(&wire).unwrap(); - - assert_eq!(domain, "rust-lang.org"); - assert_eq!(qtype, QueryType::AAAA); - } - - #[test] - fn extract_question_too_short() { - assert!(extract_question(&[0u8; 5]).is_err()); - } - - #[test] - fn extract_question_no_questions() { - let mut wire = to_wire(&DnsPacket::query(0x1234, "example.com", QueryType::A)); - // Zero out QDCOUNT (bytes 4-5) - wire[4] = 0; - wire[5] = 0; - assert!(extract_question(&wire).is_err()); - } - - // ── E. min_ttl_from_wire ──────────────────────────────────────── + // ── D. min_ttl_from_wire ──────────────────────────────────────── #[test] fn min_ttl_answers_only() { @@ -1060,12 +941,7 @@ mod tests { assert!(scan_ttl_offsets(&[]).is_err()); } - #[test] - fn extract_question_rejects_empty_wire() { - assert!(extract_question(&[]).is_err()); - } - - // ── H. Cache behavior tests ───────────────────────────────────── + // ── G. Cache behavior tests ───────────────────────────────────── // // These test existing DnsCache behavior that must be preserved after // the wire-level migration. They use the current parsed-packet API From 043a7e1ba5da32c291709d785f86d1fa668e5994 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 19:23:28 +0300 Subject: [PATCH 13/21] feat: raise cache default to 100K entries, evict stalest instead of dropping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 10K cap was too conservative — the blocklist alone holds 400K domains. At ~100 bytes per wire entry, 100K entries is ~10MB. When the cache is full and evict_expired doesn't free enough slots, evict_stalest removes the entry with the least remaining TTL instead of silently discarding the new insert. --- src/cache.rs | 30 +++++++++++++++++++++++++++++- src/config.rs | 2 +- src/wire.rs | 17 ++++++++++++----- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index 82795bc..42cea5f 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -100,7 +100,7 @@ impl DnsCache { if self.entry_count >= self.max_entries { self.evict_expired(); if self.entry_count >= self.max_entries { - return; + self.evict_stalest(); } } @@ -260,6 +260,34 @@ impl DnsCache { }); self.entry_count -= count; } + + /// Evict the single entry closest to (or furthest past) expiry. + fn evict_stalest(&mut self) { + let mut worst: Option<(String, QueryType, Duration)> = None; + for (domain, type_map) in &self.entries { + for (qtype, entry) in type_map { + let age = entry.inserted_at.elapsed(); + let remaining = entry.ttl.saturating_sub(age); + match &worst { + None => worst = Some((domain.clone(), *qtype, remaining)), + Some((_, _, w)) if remaining < *w => { + worst = Some((domain.clone(), *qtype, remaining)); + } + _ => {} + } + } + } + if let Some((domain, qtype, _)) = worst { + if let Some(type_map) = self.entries.get_mut(&domain) { + if type_map.remove(&qtype).is_some() { + self.entry_count -= 1; + } + if type_map.is_empty() { + self.entries.remove(&domain); + } + } + } + } } pub struct CacheInfo { diff --git a/src/config.rs b/src/config.rs index 5f9db73..237f3bd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -302,7 +302,7 @@ impl Default for CacheConfig { } fn default_max_entries() -> usize { - 10000 + 100_000 } fn default_min_ttl() -> u32 { 60 diff --git a/src/wire.rs b/src/wire.rs index 8d299ce..6e2c213 100644 --- a/src/wire.rs +++ b/src/wire.rs @@ -1350,18 +1350,25 @@ mod tests { } #[test] - fn cache_max_entries_cap() { + fn cache_max_entries_evicts_stalest() { let mut cache = DnsCache::new(2, 1, 3600); - for i in 0..3 { + // Insert with decreasing TTL so test0.com is stalest + for (i, ttl) in [(0, 60), (1, 3600)] { let domain = format!("test{}.com", i); let pkt = response( i as u16, &domain, - vec![a_record(&domain, &format!("1.2.3.{}", i), 3600)], + vec![a_record(&domain, &format!("1.2.3.{}", i), ttl)], ); cache.insert(&domain, QueryType::A, &pkt); } - // Should not exceed max (third insert is silently dropped or evicts) - assert!(cache.len() <= 2); + assert_eq!(cache.len(), 2); + + // Third insert should evict test0.com (lowest remaining TTL) + let pkt = response(2, "test2.com", vec![a_record("test2.com", "1.2.3.2", 3600)]); + cache.insert("test2.com", QueryType::A, &pkt); + assert_eq!(cache.len(), 2); + assert!(cache.lookup("test0.com", QueryType::A).is_none()); // evicted + assert!(cache.lookup("test2.com", QueryType::A).is_some()); // inserted } } From 571ce2f0133c974517a51f87b4aa754065cb1d14 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 19:42:56 +0300 Subject: [PATCH 14/21] feat: background refresh on stale cache hit (RFC 8767 revalidation) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a cached entry is expired but within the 1-hour stale window, serve it immediately with TTL=1 AND spawn a background re-resolve. The next query gets a fresh entry instead of another stale serve. Without this, stale entries were served repeatedly for up to an hour with no refresh — effectively ignoring TTL. --- src/cache.rs | 9 +++++---- src/ctx.rs | 53 ++++++++++++++++++++++++++++++++++++++++++++++++---- src/doh.rs | 6 +++++- src/dot.rs | 7 +++++-- 4 files changed, 64 insertions(+), 11 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index 42cea5f..5f62cc8 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -132,18 +132,19 @@ impl DnsCache { /// Read-only lookup — expired entries are left in place (cleaned up on insert). pub fn lookup(&self, domain: &str, qtype: QueryType) -> Option { - self.lookup_with_status(domain, qtype).map(|(pkt, _)| pkt) + self.lookup_with_status(domain, qtype) + .map(|(pkt, _, _)| pkt) } pub fn lookup_with_status( &self, domain: &str, qtype: QueryType, - ) -> Option<(DnsPacket, DnssecStatus)> { - let (wire, status, _stale) = self.lookup_wire(domain, qtype, 0)?; + ) -> Option<(DnsPacket, DnssecStatus, bool)> { + let (wire, status, stale) = self.lookup_wire(domain, qtype, 0)?; let mut buf = BytePacketBuffer::from_bytes(&wire); let pkt = DnsPacket::from_buffer(&mut buf).ok()?; - Some((pkt, status)) + Some((pkt, status, stale)) } pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { diff --git a/src/ctx.rs b/src/ctx.rs index e1d2d95..c1f28f2 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::net::SocketAddr; use std::path::PathBuf; -use std::sync::{Mutex, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use std::time::{Duration, Instant, SystemTime}; use arc_swap::ArcSwap; @@ -84,7 +84,7 @@ pub async fn resolve_query( query: DnsPacket, raw_wire: &[u8], src_addr: SocketAddr, - ctx: &ServerCtx, + ctx: &Arc, ) -> crate::Result { let start = Instant::now(); @@ -166,7 +166,12 @@ pub async fn resolve_query( (resp, QueryPath::Blocked, DnssecStatus::Indeterminate) } else { let cached = ctx.cache.read().unwrap().lookup_with_status(&qname, qtype); - if let Some((cached, cached_dnssec)) = cached { + if let Some((cached, cached_dnssec, stale)) = cached { + if stale { + let ctx = Arc::clone(ctx); + let qname = qname.clone(); + tokio::spawn(async move { warm_stale(&ctx, &qname, qtype).await }); + } let mut resp = cached; resp.header.id = query.header.id; if cached_dnssec == DnssecStatus::Secure { @@ -375,6 +380,46 @@ fn cache_and_parse( DnsPacket::from_buffer(&mut buf) } +/// Background refresh for a stale cache entry (RFC 8767 revalidation). +async fn warm_stale(ctx: &ServerCtx, qname: &str, qtype: QueryType) { + let query = DnsPacket::query(0, qname, qtype); + if ctx.upstream_mode == UpstreamMode::Recursive { + if let Ok(resp) = crate::recursive::resolve_recursive( + qname, + qtype, + &ctx.cache, + &query, + &ctx.root_hints, + &ctx.srtt, + ) + .await + { + ctx.cache.write().unwrap().insert(qname, qtype, &resp); + } + } else { + let mut buf = BytePacketBuffer::new(); + if query.write(&mut buf).is_ok() { + let pool = ctx.upstream_pool.lock().unwrap().clone(); + if let Ok(wire) = forward_with_failover_raw( + buf.filled(), + &pool, + &ctx.srtt, + ctx.timeout, + ctx.hedge_delay, + ) + .await + { + ctx.cache.write().unwrap().insert_wire( + qname, + qtype, + &wire, + DnssecStatus::Indeterminate, + ); + } + } + } +} + async fn forward_and_cache( wire: &[u8], upstream: &Upstream, @@ -390,7 +435,7 @@ pub async fn handle_query( mut buffer: BytePacketBuffer, raw_len: usize, src_addr: SocketAddr, - ctx: &ServerCtx, + ctx: &Arc, ) -> crate::Result<()> { let raw_wire = buffer.buf[..raw_len].to_vec(); let query = match DnsPacket::from_buffer(&mut buffer) { diff --git a/src/doh.rs b/src/doh.rs index e31b6fe..bc4ba95 100644 --- a/src/doh.rs +++ b/src/doh.rs @@ -60,7 +60,11 @@ fn is_doh_host(host: Option<&str>, tld: &str) -> bool { } } -async fn resolve_doh(dns_bytes: &[u8], src: SocketAddr, ctx: &ServerCtx) -> Response { +async fn resolve_doh( + dns_bytes: &[u8], + src: SocketAddr, + ctx: &std::sync::Arc, +) -> Response { let mut buffer = BytePacketBuffer::from_bytes(dns_bytes); let query = match DnsPacket::from_buffer(&mut buffer) { Ok(q) => q, diff --git a/src/dot.rs b/src/dot.rs index 4513f60..be22375 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -153,8 +153,11 @@ async fn accept_loop(listener: TcpListener, acceptor: TlsAcceptor, ctx: Arc(mut stream: S, remote_addr: SocketAddr, ctx: &ServerCtx) -where +async fn handle_dot_connection( + mut stream: S, + remote_addr: SocketAddr, + ctx: &std::sync::Arc, +) where S: AsyncReadExt + AsyncWriteExt + Unpin, { loop { From 8ef95383a21c4e2267a9fddc9ccf30861241d6a5 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 19:46:14 +0300 Subject: [PATCH 15/21] feat: prefetch at <10% TTL remaining, add stale behavior tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Entries with <10% TTL remaining are now marked stale on lookup, triggering a background refresh before they expire. Combined with the serve-stale + background refresh from the previous commit, this means entries are proactively refreshed — matching Unbound's prefetch behavior. --- src/cache.rs | 3 ++- src/wire.rs | 55 ++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index 5f62cc8..fb5889b 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -71,7 +71,8 @@ impl DnsCache { let elapsed = entry.inserted_at.elapsed(); let (remaining, stale) = if elapsed < entry.ttl { let secs = (entry.ttl - elapsed).as_secs() as u32; - (secs.max(1), false) + let near_expiry = elapsed * 10 >= entry.ttl * 9; // <10% TTL remaining + (secs.max(1), near_expiry) } else if elapsed < entry.ttl + STALE_WINDOW { (1, true) } else { diff --git a/src/wire.rs b/src/wire.rs index 6e2c213..aa419f2 100644 --- a/src/wire.rs +++ b/src/wire.rs @@ -957,7 +957,7 @@ mod tests { ); cache.insert("example.com", QueryType::A, &pkt); - let (result, status) = cache + let (result, status, _) = cache .lookup_with_status("example.com", QueryType::A) .expect("should hit"); assert_eq!(result.answers.len(), 1); @@ -974,7 +974,7 @@ mod tests { ); cache.insert("example.com", QueryType::A, &pkt); - let (result, _) = cache + let (result, _, _) = cache .lookup_with_status("example.com", QueryType::A) .unwrap(); // TTL should be <= 300 (at most original, reduced by elapsed time) @@ -1032,7 +1032,7 @@ mod tests { cache.insert("example.com", QueryType::A, &pkt2); assert_eq!(cache.len(), 1); // no double count - let (result, _) = cache + let (result, _, _) = cache .lookup_with_status("example.com", QueryType::A) .unwrap(); match &result.answers[0] { @@ -1208,7 +1208,7 @@ mod tests { ); cache.insert_with_status("example.com", QueryType::A, &pkt, DnssecStatus::Secure); - let (_, status) = cache + let (_, status, _) = cache .lookup_with_status("example.com", QueryType::A) .unwrap(); assert_eq!(status, DnssecStatus::Secure); @@ -1371,4 +1371,51 @@ mod tests { assert!(cache.lookup("test0.com", QueryType::A).is_none()); // evicted assert!(cache.lookup("test2.com", QueryType::A).is_some()); // inserted } + + #[test] + fn lookup_wire_signals_stale_when_expired() { + let mut cache = DnsCache::new(100, 1, 1); // max_ttl=1s so entry expires fast + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 1)], // 1s TTL, clamped to min=1 + ); + cache.insert("example.com", QueryType::A, &pkt); + + // Fresh: not stale + let (_, _, stale) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert!(!stale); + + // Wait for expiry + std::thread::sleep(std::time::Duration::from_millis(1100)); + + // Expired but within stale window: stale=true + let (_, _, stale) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert!(stale); + } + + #[test] + fn lookup_wire_signals_prefetch_near_expiry() { + let mut cache = DnsCache::new(100, 10, 10); // min_ttl=10, max_ttl=10 → entry gets 10s TTL + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 10)], + ); + cache.insert("example.com", QueryType::A, &pkt); + + // Fresh (>10% remaining): not stale + let (_, _, stale) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert!(!stale); + + // Wait until <10% remaining (>9s elapsed of 10s TTL) + std::thread::sleep(std::time::Duration::from_millis(9100)); + + // Still valid but near expiry: stale=true (triggers prefetch) + let result = cache.lookup_wire("example.com", QueryType::A, 0); + if let Some((_, _, stale)) = result { + assert!(stale, "entry at <10% TTL should signal stale for prefetch"); + } + // (entry may have fully expired on slow CI, so we don't assert Some) + } } From 3c49b0e65d643b0c05aa86d5e26c690ff5bf7cb7 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 19:49:23 +0300 Subject: [PATCH 16/21] fix: deduplicate background refresh with per-domain guard Multiple stale queries for the same domain now spawn only one background refresh. A HashSet<(String, QueryType)> on ServerCtx tracks in-flight refreshes; subsequent stale hits for the same key skip the spawn. --- src/api.rs | 1 + src/ctx.rs | 16 ++++++++++++---- src/dot.rs | 1 + src/main.rs | 1 + 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/api.rs b/src/api.rs index e638fba..9aa3f60 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1012,6 +1012,7 @@ mod tests { socket, zone_map: std::collections::HashMap::new(), cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)), + refreshing: Mutex::new(std::collections::HashSet::new()), stats: Mutex::new(crate::stats::ServerStats::new()), overrides: RwLock::new(crate::override_store::OverrideStore::new()), blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()), diff --git a/src/ctx.rs b/src/ctx.rs index c1f28f2..8632a28 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::path::PathBuf; use std::sync::{Arc, Mutex, RwLock}; @@ -35,6 +35,8 @@ pub struct ServerCtx { pub zone_map: ZoneMap, /// std::sync::RwLock (not tokio) — locks must never be held across .await points. pub cache: RwLock, + /// Domains currently being refreshed in the background (dedup guard). + pub refreshing: Mutex>, pub stats: Mutex, pub overrides: RwLock, pub blocklist: RwLock, @@ -168,9 +170,15 @@ pub async fn resolve_query( let cached = ctx.cache.read().unwrap().lookup_with_status(&qname, qtype); if let Some((cached, cached_dnssec, stale)) = cached { if stale { - let ctx = Arc::clone(ctx); - let qname = qname.clone(); - tokio::spawn(async move { warm_stale(&ctx, &qname, qtype).await }); + let key = (qname.clone(), qtype); + let already = !ctx.refreshing.lock().unwrap().insert(key.clone()); + if !already { + let ctx = Arc::clone(ctx); + tokio::spawn(async move { + warm_stale(&ctx, &key.0, key.1).await; + ctx.refreshing.lock().unwrap().remove(&key); + }); + } } let mut resp = cached; resp.header.id = query.header.id; diff --git a/src/dot.rs b/src/dot.rs index be22375..0216dbf 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -357,6 +357,7 @@ mod tests { m }, cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)), + refreshing: Mutex::new(std::collections::HashSet::new()), stats: Mutex::new(crate::stats::ServerStats::new()), overrides: RwLock::new(crate::override_store::OverrideStore::new()), blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()), diff --git a/src/main.rs b/src/main.rs index ebc16cc..9aa3f17 100644 --- a/src/main.rs +++ b/src/main.rs @@ -285,6 +285,7 @@ async fn main() -> numa::Result<()> { config.cache.min_ttl, config.cache.max_ttl, )), + refreshing: Mutex::new(std::collections::HashSet::new()), stats: Mutex::new(ServerStats::new()), overrides: RwLock::new(OverrideStore::new()), blocklist: RwLock::new(blocklist), From 6d9ee14ea6333c510e1972625fa9667a505e4996 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 19:56:42 +0300 Subject: [PATCH 17/21] refactor: unify warm_stale/warm_domain, remove raw_wire alloc, add Freshness enum MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract refresh_entry in ctx.rs — warm_domain in main.rs now delegates to it instead of duplicating the resolve+cache logic (~40 lines removed) - Eliminate unconditional .to_vec() of raw wire on every UDP/DoT query — pass &buffer.buf[..len] directly (zero-cost for cache hits) - Replace bare bool stale flag with Freshness enum (Fresh/NearExpiry/Stale) making the three states self-documenting at every call site --- src/cache.rs | 38 +++++++++++++++++++++++++++--------- src/ctx.rs | 14 +++++++------- src/dot.rs | 3 +-- src/main.rs | 54 +++++----------------------------------------------- src/wire.rs | 29 ++++++++++++---------------- 5 files changed, 54 insertions(+), 84 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index fb5889b..18fdc19 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -6,6 +6,22 @@ use crate::packet::DnsPacket; use crate::question::QueryType; use crate::wire::WireMeta; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Freshness { + /// Within TTL, no action needed. + Fresh, + /// Within TTL but <10% remaining — trigger background prefetch. + NearExpiry, + /// Past TTL but within stale window — serve with TTL=1, trigger background refresh. + Stale, +} + +impl Freshness { + pub fn needs_refresh(self) -> bool { + matches!(self, Freshness::NearExpiry | Freshness::Stale) + } +} + #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] pub enum DnssecStatus { Secure, @@ -64,17 +80,21 @@ impl DnsCache { domain: &str, qtype: QueryType, new_id: u16, - ) -> Option<(Vec, DnssecStatus, bool)> { + ) -> Option<(Vec, DnssecStatus, Freshness)> { let type_map = self.entries.get(domain)?; let entry = type_map.get(&qtype)?; let elapsed = entry.inserted_at.elapsed(); - let (remaining, stale) = if elapsed < entry.ttl { + let (remaining, freshness) = if elapsed < entry.ttl { let secs = (entry.ttl - elapsed).as_secs() as u32; - let near_expiry = elapsed * 10 >= entry.ttl * 9; // <10% TTL remaining - (secs.max(1), near_expiry) + let f = if elapsed * 10 >= entry.ttl * 9 { + Freshness::NearExpiry + } else { + Freshness::Fresh + }; + (secs.max(1), f) } else if elapsed < entry.ttl + STALE_WINDOW { - (1, true) + (1, Freshness::Stale) } else { return None; }; @@ -83,7 +103,7 @@ impl DnsCache { crate::wire::patch_id(&mut wire, new_id); crate::wire::patch_ttls(&mut wire, &entry.meta.ttl_offsets, remaining); - Some((wire, entry.dnssec_status, stale)) + Some((wire, entry.dnssec_status, freshness)) } pub fn insert_wire( @@ -141,11 +161,11 @@ impl DnsCache { &self, domain: &str, qtype: QueryType, - ) -> Option<(DnsPacket, DnssecStatus, bool)> { - let (wire, status, stale) = self.lookup_wire(domain, qtype, 0)?; + ) -> Option<(DnsPacket, DnssecStatus, Freshness)> { + let (wire, status, freshness) = self.lookup_wire(domain, qtype, 0)?; let mut buf = BytePacketBuffer::from_bytes(&wire); let pkt = DnsPacket::from_buffer(&mut buf).ok()?; - Some((pkt, status, stale)) + Some((pkt, status, freshness)) } pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { diff --git a/src/ctx.rs b/src/ctx.rs index 8632a28..e97a7ea 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -168,14 +168,14 @@ pub async fn resolve_query( (resp, QueryPath::Blocked, DnssecStatus::Indeterminate) } else { let cached = ctx.cache.read().unwrap().lookup_with_status(&qname, qtype); - if let Some((cached, cached_dnssec, stale)) = cached { - if stale { + if let Some((cached, cached_dnssec, freshness)) = cached { + if freshness.needs_refresh() { let key = (qname.clone(), qtype); let already = !ctx.refreshing.lock().unwrap().insert(key.clone()); if !already { let ctx = Arc::clone(ctx); tokio::spawn(async move { - warm_stale(&ctx, &key.0, key.1).await; + refresh_entry(&ctx, &key.0, key.1).await; ctx.refreshing.lock().unwrap().remove(&key); }); } @@ -388,8 +388,9 @@ fn cache_and_parse( DnsPacket::from_buffer(&mut buf) } -/// Background refresh for a stale cache entry (RFC 8767 revalidation). -async fn warm_stale(ctx: &ServerCtx, qname: &str, qtype: QueryType) { +/// Re-resolve a single (domain, qtype) and update the cache. +/// Used for both stale-entry refresh and proactive cache warming. +pub async fn refresh_entry(ctx: &ServerCtx, qname: &str, qtype: QueryType) { let query = DnsPacket::query(0, qname, qtype); if ctx.upstream_mode == UpstreamMode::Recursive { if let Ok(resp) = crate::recursive::resolve_recursive( @@ -445,7 +446,6 @@ pub async fn handle_query( src_addr: SocketAddr, ctx: &Arc, ) -> crate::Result<()> { - let raw_wire = buffer.buf[..raw_len].to_vec(); let query = match DnsPacket::from_buffer(&mut buffer) { Ok(packet) => packet, Err(e) => { @@ -453,7 +453,7 @@ pub async fn handle_query( return Ok(()); } }; - match resolve_query(query, &raw_wire, src_addr, ctx).await { + match resolve_query(query, &buffer.buf[..raw_len], src_addr, ctx).await { Ok(resp_buffer) => { ctx.socket.send_to(resp_buffer.filled(), src_addr).await?; } diff --git a/src/dot.rs b/src/dot.rs index 0216dbf..d4eeb95 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -180,7 +180,6 @@ async fn handle_dot_connection( break; }; - let raw_wire = buffer.buf[..msg_len].to_vec(); let query = match DnsPacket::from_buffer(&mut buffer) { Ok(q) => q, Err(e) => { @@ -202,7 +201,7 @@ async fn handle_dot_connection( } }; - match resolve_query(query.clone(), &raw_wire, remote_addr, ctx).await { + match resolve_query(query.clone(), &buffer.buf[..msg_len], remote_addr, ctx).await { Ok(resp_buffer) => { if write_framed(&mut stream, resp_buffer.filled()) .await diff --git a/src/main.rs b/src/main.rs index 9aa3f17..1ec7791 100644 --- a/src/main.rs +++ b/src/main.rs @@ -758,55 +758,11 @@ async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) { } async fn warm_domain(ctx: &ServerCtx, domain: &str) { - use numa::question::QueryType; - - for qtype in [QueryType::A, QueryType::AAAA] { - if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { - let query = numa::packet::DnsPacket::query(0, domain, qtype); - match numa::recursive::resolve_recursive( - domain, - qtype, - &ctx.cache, - &query, - &ctx.root_hints, - &ctx.srtt, - ) - .await - { - Ok(resp) => { - ctx.cache.write().unwrap().insert(domain, qtype, &resp); - log::debug!("cache warm: {} {:?}", domain, qtype); - } - Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), - } - } else { - let query = numa::packet::DnsPacket::query(0, domain, qtype); - let mut buf = numa::buffer::BytePacketBuffer::new(); - if query.write(&mut buf).is_err() { - continue; - } - let pool = ctx.upstream_pool.lock().unwrap().clone(); - match numa::forward::forward_with_failover_raw( - buf.filled(), - &pool, - &ctx.srtt, - ctx.timeout, - ctx.hedge_delay, - ) - .await - { - Ok(wire) => { - ctx.cache.write().unwrap().insert_wire( - domain, - qtype, - &wire, - numa::cache::DnssecStatus::Indeterminate, - ); - log::debug!("cache warm: {} {:?}", domain, qtype); - } - Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), - } - } + for qtype in [ + numa::question::QueryType::A, + numa::question::QueryType::AAAA, + ] { + numa::ctx::refresh_entry(ctx, domain, qtype).await; } } diff --git a/src/wire.rs b/src/wire.rs index aa419f2..3ee2ab3 100644 --- a/src/wire.rs +++ b/src/wire.rs @@ -1374,29 +1374,28 @@ mod tests { #[test] fn lookup_wire_signals_stale_when_expired() { + use crate::cache::Freshness; let mut cache = DnsCache::new(100, 1, 1); // max_ttl=1s so entry expires fast let pkt = response( 0x1234, "example.com", - vec![a_record("example.com", "1.2.3.4", 1)], // 1s TTL, clamped to min=1 + vec![a_record("example.com", "1.2.3.4", 1)], ); cache.insert("example.com", QueryType::A, &pkt); - // Fresh: not stale - let (_, _, stale) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); - assert!(!stale); + let (_, _, f) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert_eq!(f, Freshness::Fresh); - // Wait for expiry std::thread::sleep(std::time::Duration::from_millis(1100)); - // Expired but within stale window: stale=true - let (_, _, stale) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); - assert!(stale); + let (_, _, f) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert_eq!(f, Freshness::Stale); } #[test] fn lookup_wire_signals_prefetch_near_expiry() { - let mut cache = DnsCache::new(100, 10, 10); // min_ttl=10, max_ttl=10 → entry gets 10s TTL + use crate::cache::Freshness; + let mut cache = DnsCache::new(100, 10, 10); let pkt = response( 0x1234, "example.com", @@ -1404,18 +1403,14 @@ mod tests { ); cache.insert("example.com", QueryType::A, &pkt); - // Fresh (>10% remaining): not stale - let (_, _, stale) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); - assert!(!stale); + let (_, _, f) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert_eq!(f, Freshness::Fresh); - // Wait until <10% remaining (>9s elapsed of 10s TTL) std::thread::sleep(std::time::Duration::from_millis(9100)); - // Still valid but near expiry: stale=true (triggers prefetch) let result = cache.lookup_wire("example.com", QueryType::A, 0); - if let Some((_, _, stale)) = result { - assert!(stale, "entry at <10% TTL should signal stale for prefetch"); + if let Some((_, _, f)) = result { + assert_eq!(f, Freshness::NearExpiry); } - // (entry may have fully expired on slow CI, so we don't assert Some) } } From 51848919858053895887afad7510eee7b7d71c24 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 20:43:46 +0300 Subject: [PATCH 18/21] fix: cold benchmark cache-busting with PID prefix and flush Re-runs of --vs-unbound-cold were hitting stale cache entries from prior runs. The static COUNTER reset to 0 each process, generating the same c0.example.com subdomains. With the 1-hour stale window, entries from 10 minutes ago served as stale hits. Fix: prefix with PID (r{pid}-c{n}.domain) and flush Numa's cache before cold benchmarks. --- benches/recursive_compare.rs | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/benches/recursive_compare.rs b/benches/recursive_compare.rs index dcff2c5..8f3b079 100644 --- a/benches/recursive_compare.rs +++ b/benches/recursive_compare.rs @@ -403,6 +403,8 @@ fn run_server_comparison( ) { use std::sync::atomic::{AtomicU64, Ordering}; static COUNTER: AtomicU64 = AtomicU64::new(0); + // Unique prefix per process so re-runs don't hit stale cache entries + let run_id = std::process::id(); let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); let other: SocketAddr = other_addr.parse().unwrap(); @@ -414,6 +416,10 @@ fn run_server_comparison( } } + if cold { + flush_cache(); // flush Numa's record cache + } + println!("Warming up..."); for _ in 0..5 { let _ = rt.block_on(query_udp(numa_addr, "example.com")); @@ -433,7 +439,12 @@ fn run_server_comparison( other_name, &|domain| { let d = if cold { - format!("c{}.{}", COUNTER.fetch_add(1, Ordering::Relaxed), domain) + format!( + "r{}-c{}.{}", + run_id, + COUNTER.fetch_add(1, Ordering::Relaxed), + domain + ) } else { domain.to_string() }; @@ -443,7 +454,12 @@ fn run_server_comparison( }, &|domain| { let d = if cold { - format!("c{}.{}", COUNTER.fetch_add(1, Ordering::Relaxed), domain) + format!( + "r{}-c{}.{}", + run_id, + COUNTER.fetch_add(1, Ordering::Relaxed), + domain + ) } else { domain.to_string() }; From 50828c411a5545ff115ab863c1d5258feae4998b Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 20:54:27 +0300 Subject: [PATCH 19/21] fix: cold benchmark uses 1 round per domain for genuine cold measurements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With ROUNDS=10, only the first query per domain was truly cold — the other 9 hit cached NS delegations at <1ms, diluting the median to 0.4ms. Now cold mode uses 1 round so every sample is a real cold resolve. Also extracted compare_two_rounds to support per-mode rounds. --- benches/recursive_compare.rs | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/benches/recursive_compare.rs b/benches/recursive_compare.rs index 8f3b079..f1a59d2 100644 --- a/benches/recursive_compare.rs +++ b/benches/recursive_compare.rs @@ -183,13 +183,28 @@ fn compare_two( measure_a: &dyn Fn(&str) -> f64, measure_b: &dyn Fn(&str) -> f64, iterations: usize, +) { + compare_two_rounds( + rt, title, name_a, name_b, measure_a, measure_b, iterations, ROUNDS, + ); +} + +fn compare_two_rounds( + rt: &tokio::runtime::Runtime, + title: &str, + name_a: &str, + name_b: &str, + measure_a: &dyn Fn(&str) -> f64, + measure_b: &dyn Fn(&str) -> f64, + iterations: usize, + rounds: usize, ) { let flush = std::env::args().any(|a| a == "--flush"); println!("{}", title); println!( "{} domains × {} rounds × {} iterations\n", DOMAINS.len(), - ROUNDS, + rounds, iterations ); @@ -203,7 +218,7 @@ fn compare_two( let mut b = Vec::new(); for domain in DOMAINS { - for round in 0..ROUNDS { + for round in 0..rounds { if flush { flush_cache(); std::thread::sleep(Duration::from_millis(5)); @@ -230,6 +245,7 @@ fn compare_two( &mut all_a, &mut all_b, iterations, + rounds, ); } @@ -240,6 +256,7 @@ fn print_results( all_a: &mut Vec, all_b: &mut Vec, iterations: usize, + rounds: usize, ) { let w = name_a.len().max(name_b.len()).max(6); @@ -270,7 +287,7 @@ fn print_results( let (a_m, a_med, a_p95, a_p99, a_sd) = stats(all_a); let (b_m, b_med, b_p95, b_p99, b_sd) = stats(all_b); - let total = iterations * DOMAINS.len() * ROUNDS; + let total = iterations * DOMAINS.len() * rounds; println!("\n=== Aggregated ({} samples per method) ===\n", total); println!("{:<10} {:>w$} {:>w$}", "", name_a, name_b, w = w + 3); println!("{:<10} {:>w$.1} ms {:>w$.1} ms", "mean", a_m, b_m, w = w); @@ -432,7 +449,9 @@ fn run_server_comparison( "caching" }; - compare_two( + let rounds = if cold { 1 } else { ROUNDS }; + + compare_two_rounds( rt, &format!("Server-to-Server: Numa vs {other_name} (UDP, {tag})"), "Numa", @@ -468,6 +487,7 @@ fn run_server_comparison( t.elapsed().as_secs_f64() * 1000.0 }, iterations, + rounds, ); } From 02e1449a4544e251be0e74336fb93cda7f3c920e Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 21:34:47 +0300 Subject: [PATCH 20/21] feat: enable request hedging for all upstream protocols Hedging was DoH-only (hyper dispatch spike mitigation). Now applies to UDP (rescues packet loss) and DoT (rescues TLS handshake stalls) too. Same-upstream hedging: fires a second independent request after hedge_ms delay. First response wins. Disable with hedge_ms = 0. --- src/forward.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/forward.rs b/src/forward.rs index 839ac81..e13e360 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -360,9 +360,11 @@ pub async fn forward_with_failover_raw( for upstream in &all_upstreams { let start = Instant::now(); - let result = if !hedge_delay.is_zero() && matches!(upstream, Upstream::Doh { .. }) { - // Hedge against the same upstream: parallel h2 streams on same - // connection. Independent stream scheduling rescues dispatch spikes. + let result = if !hedge_delay.is_zero() { + // Hedge against the same upstream: independent h2 streams (DoH), + // independent UDP packets (plain DNS), or independent TLS + // connections (DoT). Rescues packet loss, dispatch spikes, and + // TLS handshake stalls. forward_with_hedging_raw(wire, upstream, upstream, hedge_delay, timeout_duration).await } else { forward_query_raw(wire, upstream, timeout_duration).await From 8085c1068773ccb3a11ad82a61f7523910ea4b87 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 21:37:59 +0300 Subject: [PATCH 21/21] docs: document hedge_ms, tls:// upstream, update max_entries default in numa.toml --- numa.toml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/numa.toml b/numa.toml index 3b716e8..1ea3341 100644 --- a/numa.toml +++ b/numa.toml @@ -15,9 +15,15 @@ api_port = 5380 # address = "9.9.9.9" # single upstream (plain UDP) # address = ["192.168.1.1", "9.9.9.9:5353"] # multiple upstreams — SRTT picks fastest # address = "https://dns.quad9.net/dns-query" # DNS-over-HTTPS (encrypted) +# address = "tls://9.9.9.9#dns.quad9.net" # DNS-over-TLS (encrypted, port 853) # fallback = ["8.8.8.8", "1.1.1.1"] # tried only when all primaries fail # port = 53 # default port for addresses without :port # timeout_ms = 3000 +# hedge_ms = 10 # request hedging delay (ms). After this delay +# # without a response, fires a parallel request +# # to the same upstream. Rescues packet loss (UDP), +# # dispatch spikes (DoH), TLS stalls (DoT). +# # Set to 0 to disable. Default: 10 # root_hints = [ # only used in recursive mode # "198.41.0.4", # a.root-servers.net (Verisign) # "199.9.14.201", # b.root-servers.net (USC-ISI) @@ -60,7 +66,7 @@ api_port = 5380 # allowlist = ["example.com"] # domains to never block [cache] -max_entries = 10000 +max_entries = 100000 min_ttl = 60 max_ttl = 86400 # warm = ["google.com", "github.com"] # resolve at startup, refresh before TTL expiry