From 7efac85836bacd483e174e5509fad03bac3f548f Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 04:20:18 +0300 Subject: [PATCH] feat: wire-level forwarding, cache, request hedging, and DoH keepalive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wire-level forwarding path skips DnsPacket parse/serialize on the hot path. Cache stores raw wire bytes with pre-scanned TTL offsets — patches ID + TTLs in-place on lookup instead of cloning parsed packets. Request hedging (Dean & Barroso "Tail at Scale") fires a second parallel request after a configurable delay (default 10ms) when the primary upstream stalls. DoH keepalive loop prevents idle HTTP/2 + TLS connection teardown. Recursive resolver now hedges across multiple NS addresses and caches NS delegation records to skip TLD re-queries. Integration test harness polls /blocking/stats instead of fixed sleep, eliminating the blocklist-download race condition. --- Cargo.lock | 458 +++++++++- Cargo.toml | 6 + benches/numa-bench.toml | 25 + benches/recursive_compare.rs | 1649 ++++++++++++++++++++++++++++++++++ scripts/bench-recursive.sh | 115 +++ src/api.rs | 1 + src/cache.rs | 177 ++-- src/config.rs | 6 + src/ctx.rs | 47 +- src/doh.rs | 11 +- src/dot.rs | 6 +- src/forward.rs | 186 +++- src/lib.rs | 1 + src/main.rs | 26 +- src/recursive.rs | 123 ++- src/srtt.rs | 5 + src/wire.rs | 1347 +++++++++++++++++++++++++++ tests/integration.sh | 12 +- 18 files changed, 4091 insertions(+), 110 deletions(-) create mode 100644 benches/numa-bench.toml create mode 100644 benches/recursive_compare.rs create mode 100755 scripts/bench-recursive.sh create mode 100644 src/wire.rs diff --git a/Cargo.lock b/Cargo.lock index c7cd38b..eaba214 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -82,6 +82,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + [[package]] name = "arc-swap" version = "1.9.0" @@ -142,6 +148,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -410,6 +427,21 @@ dependencies = [ "itertools", ] +[[package]] +name = "critical-section" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -493,6 +525,18 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "enum-as-inner" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "env_filter" version = "1.0.1" @@ -554,6 +598,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -679,11 +729,24 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "r-efi", + "r-efi 5.3.0", "wasip2", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + [[package]] name = "h2" version = "0.4.13" @@ -714,12 +777,82 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + [[package]] name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hickory-proto" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8a6fe56c0038198998a6f217ca4e7ef3a5e51f46163bd6dd60b5c71ca6c6502" +dependencies = [ + "async-trait", + "bytes", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-channel", + "futures-io", + "futures-util", + "h2", + "http", + "idna", + "ipnet", + "once_cell", + "rand", + "ring", + "rustls", + "thiserror", + "tinyvec", + "tokio", + "tokio-rustls", + "tracing", + "url", + "webpki-roots 0.26.11", +] + +[[package]] +name = "hickory-resolver" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc62a9a99b0bfb44d2ab95a7208ac952d31060efc16241c87eaf36406fecf87a" +dependencies = [ + "cfg-if", + "futures-util", + "hickory-proto", + "ipconfig", + "moka", + "once_cell", + "parking_lot", + "rand", + "resolv-conf", + "rustls", + "smallvec", + "thiserror", + "tokio", + "tokio-rustls", + "tracing", + "webpki-roots 0.26.11", +] + [[package]] name = "http" version = "1.4.0" @@ -802,7 +935,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] @@ -909,6 +1042,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "idna" version = "1.1.0" @@ -937,7 +1076,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "ipconfig" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d40460c0ce33d6ce4b0630ad68ff63d6661961c48b6dba35e5a4d81cfb48222" +dependencies = [ + "socket2", + "widestring", + "windows-registry", + "windows-result", + "windows-sys 0.61.2", ] [[package]] @@ -1029,6 +1183,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "libc" version = "0.2.183" @@ -1041,6 +1201,15 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.29" @@ -1098,6 +1267,23 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "moka" +version = "0.12.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046" +dependencies = [ + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "parking_lot", + "portable-atomic", + "smallvec", + "tagptr", + "uuid", +] + [[package]] name = "nom" version = "7.1.3" @@ -1151,6 +1337,8 @@ dependencies = [ "criterion", "env_logger", "futures", + "hickory-proto", + "hickory-resolver", "http", "http-body-util", "hyper", @@ -1187,6 +1375,10 @@ name = "once_cell" version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +dependencies = [ + "critical-section", + "portable-atomic", +] [[package]] name = "once_cell_polyfill" @@ -1210,6 +1402,29 @@ dependencies = [ "winapi", ] +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "pem" version = "3.0.6" @@ -1305,6 +1520,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.106" @@ -1390,6 +1615,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + [[package]] name = "rand" version = "0.9.2" @@ -1453,6 +1684,15 @@ dependencies = [ "yasna", ] +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.12.3" @@ -1518,9 +1758,15 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots", + "webpki-roots 1.0.6", ] +[[package]] +name = "resolv-conf" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e061d1b48cb8d38042de4ae0a7a6401009d6143dc80d2e2d6f31f0bdd6470c7" + [[package]] name = "ring" version = "0.17.14" @@ -1618,6 +1864,18 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + [[package]] name = "serde" version = "1.0.228" @@ -1780,6 +2038,12 @@ dependencies = [ "syn", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "thiserror" version = "2.0.18" @@ -2038,6 +2302,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "untrusted" version = "0.9.0" @@ -2068,6 +2338,17 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ac8b6f42ead25368cf5b098aeb3dc8a1a2c05a3eee8a9a1a68c640edbfc79d9" +dependencies = [ + "getrandom 0.4.2", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "walkdir" version = "2.5.0" @@ -2102,6 +2383,15 @@ dependencies = [ "wit-bindgen", ] +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasm-bindgen" version = "0.2.115" @@ -2157,6 +2447,40 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + [[package]] name = "web-sys" version = "0.3.92" @@ -2177,6 +2501,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.6", +] + [[package]] name = "webpki-roots" version = "1.0.6" @@ -2186,6 +2519,12 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "widestring" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72069c3113ab32ab29e5584db3c6ec55d416895e60715417b5b883a357c3e471" + [[package]] name = "winapi" version = "0.3.9" @@ -2223,6 +2562,35 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -2390,6 +2758,88 @@ name = "wit-bindgen" version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] [[package]] name = "writeable" diff --git a/Cargo.toml b/Cargo.toml index c5d5e1d..d7f6f9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,8 @@ webpki-roots = "1" criterion = { version = "0.8", features = ["html_reports"] } tower = { version = "0.5", features = ["util"] } http = "1" +hickory-resolver = { version = "0.25", features = ["https-ring", "webpki-roots"] } +hickory-proto = "0.25" [[bench]] name = "hot_path" @@ -49,3 +51,7 @@ harness = false [[bench]] name = "dnssec" harness = false + +[[bench]] +name = "recursive_compare" +harness = false diff --git a/benches/numa-bench.toml b/benches/numa-bench.toml new file mode 100644 index 0000000..0e058af --- /dev/null +++ b/benches/numa-bench.toml @@ -0,0 +1,25 @@ +[server] +bind_addr = "127.0.0.1:5454" +api_port = 5381 +api_bind_addr = "127.0.0.1" +data_dir = "/tmp/numa-bench" + +[upstream] +mode = "recursive" +timeout_ms = 10000 + +[cache] +min_ttl = 60 +max_ttl = 3600 + +[blocking] +enabled = false + +[dot] +enabled = false + +[mobile] +enabled = false + +[lan] +enabled = false diff --git a/benches/recursive_compare.rs b/benches/recursive_compare.rs new file mode 100644 index 0000000..e35768c --- /dev/null +++ b/benches/recursive_compare.rs @@ -0,0 +1,1649 @@ +//! DoH forwarding benchmark: Numa vs hickory-resolver. +//! +//! Both forward to the same DoH upstream (Quad9). +//! Measures end-to-end resolution time through each implementation. +//! +//! Fairness: +//! - Both reuse a single TLS connection (Numa via persistent server, +//! Hickory via a shared resolver instance with cache_size=0). +//! - Measurement order is alternated each round to cancel order bias. +//! - Numa cache is flushed before each query. +//! - 100 domains × 10 rounds for statistical confidence. +//! +//! Setup: +//! 1. Start a bench Numa instance: +//! cargo run -- benches/numa-bench.toml +//! 2. Run: +//! cargo bench --bench recursive_compare + +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +const DOH_UPSTREAM: &str = "https://9.9.9.9/dns-query"; +const NUMA_BENCH: &str = "127.0.0.1:5454"; +const NUMA_API: u16 = 5381; + +const DOMAINS: &[&str] = &[ + "example.com", + "rust-lang.org", + "kernel.org", + "signal.org", + "archlinux.org", + "openbsd.org", + "git-scm.com", + "sqlite.org", + "wireguard.com", + "mozilla.org", + "cloudflare.com", + "google.com", + "github.com", + "stackoverflow.com", + "wikipedia.org", + "reddit.com", + "amazon.com", + "apple.com", + "microsoft.com", + "facebook.com", + "twitter.com", + "linkedin.com", + "netflix.com", + "spotify.com", + "discord.com", + "twitch.tv", + "youtube.com", + "instagram.com", + "whatsapp.com", + "telegram.org", + "debian.org", + "ubuntu.com", + "fedoraproject.org", + "nixos.org", + "gentoo.org", + "freebsd.org", + "netbsd.org", + "dragonflybsd.org", + "illumos.org", + "haiku-os.org", + "python.org", + "golang.org", + "nodejs.org", + "ruby-lang.org", + "php.net", + "swift.org", + "kotlinlang.org", + "scala-lang.org", + "haskell.org", + "elixir-lang.org", + "erlang.org", + "clojure.org", + "julialang.org", + "ziglang.org", + "nim-lang.org", + "dlang.org", + "vlang.io", + "crystal-lang.org", + "racket-lang.org", + "ocaml.org", + "crates.io", + "npmjs.com", + "pypi.org", + "rubygems.org", + "packagist.org", + "nuget.org", + "maven.apache.org", + "hex.pm", + "hackage.haskell.org", + "pkg.go.dev", + "docker.com", + "kubernetes.io", + "prometheus.io", + "grafana.com", + "elastic.co", + "datadog.com", + "sentry.io", + "pagerduty.com", + "atlassian.com", + "jetbrains.com", + "gitlab.com", + "bitbucket.org", + "sourcehut.org", + "codeberg.org", + "launchpad.net", + "savannah.gnu.org", + "letsencrypt.org", + "eff.org", + "torproject.org", + "privacyguides.org", + "matrix.org", + "element.io", + "jitsi.org", + "nextcloud.com", + "syncthing.net", + "tailscale.com", + "mullvad.net", + "proton.me", + "duckduckgo.com", + "brave.com", + "vivaldi.com", +]; + +const ROUNDS: usize = 10; + +fn main() { + let diag = std::env::args().any(|a| a == "--diag"); + let direct = std::env::args().any(|a| a == "--direct"); + + let rt = tokio::runtime::Runtime::new().unwrap(); + + if diag { + run_diag(&rt); + return; + } + + if direct { + run_direct(&rt); + return; + } + + if std::env::args().any(|a| a == "--diag-clients") { + run_diag_clients(&rt); + return; + } + + if std::env::args().any(|a| a == "--spike-trace") { + run_spike_trace(&rt); + return; + } + + if std::env::args().any(|a| a == "--spike-phases") { + run_spike_phases(&rt); + return; + } + + if std::env::args().any(|a| a == "--spike-heartbeat") { + run_spike_heartbeat(&rt); + return; + } + + if std::env::args().any(|a| a == "--hedge") { + run_hedge(&rt); + return; + } + + if std::env::args().any(|a| a == "--hedge-5x") { + run_hedge_multi(&rt, 5); + return; + } + + if std::env::args().any(|a| a == "--vs-dnscrypt") { + run_vs_dnscrypt(&rt, 5); + return; + } + + if std::env::args().any(|a| a == "--vs-unbound") { + run_vs_unbound(&rt, 5); + return; + } + + let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); + + println!("DoH Forwarding Benchmark: Numa vs hickory-resolver"); + println!("Both forwarding to {DOH_UPSTREAM}"); + println!("{} domains × {ROUNDS} rounds", DOMAINS.len()); + println!(); + + // Verify bench Numa is reachable + if rt.block_on(query_udp(numa_addr, "example.com")).is_none() { + eprintln!("Bench Numa not responding on {numa_addr}"); + eprintln!(); + eprintln!("Start it with:"); + eprintln!(" cargo run -- benches/numa-bench.toml"); + std::process::exit(1); + } + + // Build a shared Hickory resolver (reuses TLS connection, like Numa does) + let resolver = rt.block_on(build_hickory_resolver()); + + // Warm up both paths (TLS handshake, connection establishment) + println!("Warming up connections..."); + for _ in 0..3 { + rt.block_on(query_udp(numa_addr, "example.com")); + rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + flush_cache(); + + println!( + "{:<30} {:>10} {:>10} {:>10} {:>8} {:>8}", + "Domain", "Numa (ms)", "Hickory", "Delta", "σ Numa", "σ Hick" + ); + println!("{}", "-".repeat(92)); + + let mut numa_all = Vec::new(); + let mut hickory_all = Vec::new(); + let mut per_domain: Vec<(&str, f64, f64, f64, f64, f64)> = Vec::new(); + + for domain in DOMAINS { + let mut numa_times = Vec::with_capacity(ROUNDS); + let mut hickory_times = Vec::with_capacity(ROUNDS); + + for round in 0..ROUNDS { + flush_cache(); + std::thread::sleep(Duration::from_millis(10)); + + // Alternate measurement order each round to cancel systematic bias + if round % 2 == 0 { + // Numa first + let t = measure(&rt, || rt.block_on(query_udp(numa_addr, domain))); + numa_times.push(t); + let t = measure(&rt, || rt.block_on(query_hickory_doh(&resolver, domain))); + hickory_times.push(t); + } else { + // Hickory first + let t = measure(&rt, || rt.block_on(query_hickory_doh(&resolver, domain))); + hickory_times.push(t); + flush_cache(); + std::thread::sleep(Duration::from_millis(10)); + let t = measure(&rt, || rt.block_on(query_udp(numa_addr, domain))); + numa_times.push(t); + } + } + + let numa_avg = mean(&numa_times); + let hickory_avg = mean(&hickory_times); + let numa_sd = stddev(&numa_times); + let hickory_sd = stddev(&hickory_times); + let delta = numa_avg - hickory_avg; + + numa_all.extend_from_slice(&numa_times); + hickory_all.extend_from_slice(&hickory_times); + per_domain.push((domain, numa_avg, hickory_avg, delta, numa_sd, hickory_sd)); + + let delta_str = format_delta(delta); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", + domain, numa_avg, hickory_avg, delta_str, numa_sd, hickory_sd + ); + } + + println!("{}", "-".repeat(92)); + + let numa_mean = mean(&numa_all); + let hickory_mean = mean(&hickory_all); + let delta_mean = numa_mean - hickory_mean; + + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", + "OVERALL MEAN", + numa_mean, + hickory_mean, + format_delta(delta_mean), + stddev(&numa_all), + stddev(&hickory_all), + ); + + // Median + let numa_med = median(&mut numa_all); + let hickory_med = median(&mut hickory_all); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", + "MEDIAN", + numa_med, + hickory_med, + format_delta(numa_med - hickory_med), + ); + + // P95 + let numa_p95 = percentile(&numa_all, 95.0); + let hickory_p95 = percentile(&hickory_all, 95.0); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", + "P95", + numa_p95, + hickory_p95, + format_delta(numa_p95 - hickory_p95), + ); + + println!(); + let total_queries = DOMAINS.len() * ROUNDS; + if numa_mean < hickory_mean { + let pct = ((hickory_mean - numa_mean) / hickory_mean * 100.0).round(); + println!("Numa is ~{pct}% faster (mean over {total_queries} queries)."); + } else if hickory_mean < numa_mean { + let pct = ((numa_mean - hickory_mean) / numa_mean * 100.0).round(); + println!("Hickory is ~{pct}% faster (mean over {total_queries} queries)."); + } else { + println!("Both are equal (mean over {total_queries} queries)."); + } + + println!(); + println!("Methodology:"); + println!(" - Both forward to {DOH_UPSTREAM} over a reused TLS connection."); + println!(" - Numa cache flushed before each query. Hickory cache disabled."); + println!(" - Measurement order alternates each round to cancel order bias."); + println!(" - {} domains × {ROUNDS} rounds = {total_queries} queries per resolver.", DOMAINS.len()); +} + +fn run_diag(rt: &tokio::runtime::Runtime) { + println!("Hickory connection reuse diagnostic"); + println!("20 sequential queries to {DOH_UPSTREAM} via one shared resolver"); + println!("If conn is reused: query 1 slow (TLS handshake), rest fast.\n"); + + let resolver = rt.block_on(build_hickory_resolver()); + + let domains = [ + "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", + "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", + "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", + "example.com", "rust-lang.org", "kernel.org", "google.com", "github.com", + ]; + + println!("{:>3} {:<20} {:>10}", "#", "Domain", "Time (ms)"); + println!("{}", "-".repeat(40)); + + for (i, domain) in domains.iter().enumerate() { + use hickory_resolver::proto::rr::RecordType; + let start = Instant::now(); + let result = rt.block_on(resolver.lookup(*domain, RecordType::A)); + let ms = start.elapsed().as_secs_f64() * 1000.0; + match &result { + Ok(lookup) => { + let first = lookup.iter().next().map(|r| format!("{r}")).unwrap_or_default(); + println!("{:>3} {:<20} {:>7.1} ms OK {}", i + 1, domain, ms, first); + } + Err(e) => { + println!("{:>3} {:<20} {:>7.1} ms ERR {}", i + 1, domain, ms, e); + } + } + } +} + +/// Library-to-library comparison: Numa's forward_query_raw vs Hickory's resolver.lookup(). +/// No UDP, no server pipeline — just the DoH forwarding call. +fn run_direct(rt: &tokio::runtime::Runtime) { + println!("Direct DoH Forwarding: Numa forward_query_raw vs Hickory resolver.lookup()"); + println!("Both forwarding to {DOH_UPSTREAM} — no UDP, no server pipeline"); + println!("{} domains × {ROUNDS} rounds", DOMAINS.len()); + println!(); + + // Build Numa's upstream (shared reqwest client, reuses HTTP/2 connection) + let numa_upstream = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); + let timeout = Duration::from_secs(10); + + // Build Hickory's resolver (shared, reuses HTTP/2 connection) + let resolver = rt.block_on(build_hickory_resolver()); + + // Warm up both + println!("Warming up connections..."); + for _ in 0..3 { + let wire = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&wire, &numa_upstream, timeout)); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + println!( + "{:<30} {:>10} {:>10} {:>10} {:>8} {:>8}", + "Domain", "Numa (ms)", "Hickory", "Delta", "σ Numa", "σ Hick" + ); + println!("{}", "-".repeat(92)); + + let mut numa_all = Vec::new(); + let mut hickory_all = Vec::new(); + + for domain in DOMAINS { + let mut numa_times = Vec::with_capacity(ROUNDS); + let mut hickory_times = Vec::with_capacity(ROUNDS); + + for round in 0..ROUNDS { + let wire = build_query_vec(domain); + + if round % 2 == 0 { + let w = wire.clone(); + let t = measure(rt, || { + rt.block_on(numa::forward::forward_query_raw(&w, &numa_upstream, timeout)) + }); + numa_times.push(t); + let t = measure(rt, || rt.block_on(query_hickory_doh(&resolver, domain))); + hickory_times.push(t); + } else { + let t = measure(rt, || rt.block_on(query_hickory_doh(&resolver, domain))); + hickory_times.push(t); + let w = wire.clone(); + let t = measure(rt, || { + rt.block_on(numa::forward::forward_query_raw(&w, &numa_upstream, timeout)) + }); + numa_times.push(t); + } + } + + let numa_avg = mean(&numa_times); + let hickory_avg = mean(&hickory_times); + let numa_sd = stddev(&numa_times); + let hickory_sd = stddev(&hickory_times); + let delta = numa_avg - hickory_avg; + + numa_all.extend_from_slice(&numa_times); + hickory_all.extend_from_slice(&hickory_times); + + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", + domain, numa_avg, hickory_avg, format_delta(delta), numa_sd, hickory_sd + ); + } + + println!("{}", "-".repeat(92)); + let numa_mean = mean(&numa_all); + let hickory_mean = mean(&hickory_all); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms {:>5.1}ms {:>5.1}ms", + "OVERALL MEAN", numa_mean, hickory_mean, format_delta(numa_mean - hickory_mean), + stddev(&numa_all), stddev(&hickory_all), + ); + let numa_med = median(&mut numa_all); + let hickory_med = median(&mut hickory_all); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", + "MEDIAN", numa_med, hickory_med, format_delta(numa_med - hickory_med), + ); + let numa_p95 = percentile(&numa_all, 95.0); + let hickory_p95 = percentile(&hickory_all, 95.0); + println!( + "{:<30} {:>7.1} ms {:>7.1} ms {:>7} ms", + "P95", numa_p95, hickory_p95, format_delta(numa_p95 - hickory_p95), + ); + + println!(); + let total_queries = DOMAINS.len() * ROUNDS; + if numa_mean < hickory_mean { + let pct = ((hickory_mean - numa_mean) / hickory_mean * 100.0).round(); + println!("Numa is ~{pct}% faster (mean over {total_queries} queries)."); + } else if hickory_mean < numa_mean { + let pct = ((numa_mean - hickory_mean) / numa_mean * 100.0).round(); + println!("Hickory is ~{pct}% faster (mean over {total_queries} queries)."); + } else { + println!("Both are equal (mean over {total_queries} queries)."); + } + + println!(); + println!("Methodology:"); + println!(" - Both forward to {DOH_UPSTREAM} over a reused TLS/HTTP2 connection."); + println!(" - No UDP, no server pipeline, no cache — pure DoH forwarding."); + println!(" - Numa: forward_query_raw (reqwest). Hickory: resolver.lookup (h2)."); + println!(" - {} domains × {ROUNDS} rounds = {total_queries} queries per implementation.", DOMAINS.len()); +} + +/// Per-query timing diagnostic: 20 queries each through reqwest and Hickory. +/// Shows whether reqwest has connection reuse issues or per-request overhead. +fn run_diag_clients(rt: &tokio::runtime::Runtime) { + println!("Client diagnostic: reqwest vs Hickory per-query timing"); + println!("20 queries each to {DOH_UPSTREAM}\n"); + + let upstream = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); + let resolver = rt.block_on(build_hickory_resolver()); + let timeout = Duration::from_secs(10); + + // Warm both + for _ in 0..3 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + let domains = [ + "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", + "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", + "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", + "example.com", "google.com", "github.com", "rust-lang.org", "cloudflare.com", + ]; + + println!("{:>3} {:<20} {:>12} {:>12}", "#", "Domain", "reqwest", "Hickory"); + println!("{}", "-".repeat(55)); + + for (i, domain) in domains.iter().enumerate() { + let wire = build_query_vec(domain); + + let start = Instant::now(); + let r_result = rt.block_on(numa::forward::forward_query_raw(&wire, &upstream, timeout)); + let r_ms = start.elapsed().as_secs_f64() * 1000.0; + let r_ok = if r_result.is_ok() { "OK" } else { "FAIL" }; + + let start = Instant::now(); + let h_result = rt.block_on(query_hickory_doh(&resolver, domain)); + let h_ms = start.elapsed().as_secs_f64() * 1000.0; + let h_ok = if h_result.is_some() { "OK" } else { "FAIL" }; + + println!( + "{:>3} {:<20} {:>7.1} ms {} {:>7.1} ms {}", + i + 1, domain, r_ms, r_ok, h_ms, h_ok + ); + } +} + +/// Spike trace: fire 200 sequential queries through reqwest and log every one +/// with a timestamp. Analyze the distribution and find spike clusters. +fn run_spike_trace(rt: &tokio::runtime::Runtime) { + println!("Spike trace: 200 sequential reqwest DoH queries"); + println!("Target: {DOH_UPSTREAM}\n"); + + let upstream = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); + let timeout = Duration::from_secs(10); + + // Warm + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); + } + + // Run the entire 200-query loop inside ONE block_on to eliminate + // per-query runtime re-entry overhead. + let samples: Vec<(u128, f64)> = rt.block_on(async { + let test_start = Instant::now(); + let mut s = Vec::with_capacity(200); + for i in 0..200 { + let domain = match i % 5 { + 0 => "example.com", + 1 => "google.com", + 2 => "github.com", + 3 => "rust-lang.org", + _ => "cloudflare.com", + }; + let wire = build_query_vec(domain); + let req_start = Instant::now(); + let t_from_start_us = test_start.elapsed().as_micros(); + let _ = numa::forward::forward_query_raw(&wire, &upstream, timeout).await; + let ms = req_start.elapsed().as_secs_f64() * 1000.0; + s.push((t_from_start_us, ms)); + } + s + }); + + // Compute stats + let mut sorted_times: Vec = samples.iter().map(|(_, t)| *t).collect(); + sorted_times.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let n = sorted_times.len(); + let median = sorted_times[n / 2]; + let p90 = sorted_times[(n * 90) / 100]; + let p95 = sorted_times[(n * 95) / 100]; + let p99 = sorted_times[(n * 99) / 100]; + let max = sorted_times[n - 1]; + let mean: f64 = sorted_times.iter().sum::() / n as f64; + + println!("Distribution (n={}):", n); + println!(" mean: {:.1} ms", mean); + println!(" median: {:.1} ms", median); + println!(" p90: {:.1} ms", p90); + println!(" p95: {:.1} ms", p95); + println!(" p99: {:.1} ms", p99); + println!(" max: {:.1} ms", max); + println!(); + + // Define spike threshold as 3x median + let spike_threshold = median * 3.0; + let spikes: Vec<(usize, u128, f64)> = samples + .iter() + .enumerate() + .filter(|(_, (_, t))| *t > spike_threshold) + .map(|(i, (ts, t))| (i, *ts, *t)) + .collect(); + + println!("Spikes (> {:.1}ms, which is 3x median):", spike_threshold); + println!(" count: {}", spikes.len()); + if spikes.is_empty() { + return; + } + + // Inter-spike gaps (time between spikes) + let mut gaps_ms: Vec = Vec::new(); + for w in spikes.windows(2) { + let gap_us = w[1].1 - w[0].1; + gaps_ms.push(gap_us as f64 / 1000.0); + } + + println!(); + println!(" {:>4} {:>12} {:>10} {:>12}", "idx", "at (ms)", "latency", "gap from prev"); + for (i, ((idx, ts, latency), gap)) in spikes.iter().zip( + std::iter::once(&0.0).chain(gaps_ms.iter()) + ).enumerate() { + let _ = i; + let gap_str = if *gap > 0.0 { + format!("{:.0} ms", gap) + } else { + "-".to_string() + }; + println!(" {:>4} {:>9.1} {:>6.1} ms {:>12}", idx, *ts as f64 / 1000.0, latency, gap_str); + } + + if !gaps_ms.is_empty() { + let gap_mean: f64 = gaps_ms.iter().sum::() / gaps_ms.len() as f64; + let mut gap_sorted = gaps_ms.clone(); + gap_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let gap_median = gap_sorted[gap_sorted.len() / 2]; + println!(); + println!(" Inter-spike gap: mean={:.0}ms, median={:.0}ms", gap_mean, gap_median); + } +} + +/// Spike phases: time each step of the reqwest DoH call to find which phase +/// is slow during a spike. Reports (build+send, send->resp headers, body read). +fn run_spike_phases(rt: &tokio::runtime::Runtime) { + println!("Spike phases: timing each phase of reqwest DoH call"); + println!("Target: {DOH_UPSTREAM}\n"); + + // Build the same tuned client our forward_doh uses + let client = reqwest::Client::builder() + .use_rustls_tls() + .http2_initial_stream_window_size(65_535) + .http2_initial_connection_window_size(65_535) + .http2_keep_alive_interval(Duration::from_secs(15)) + .http2_keep_alive_while_idle(true) + .http2_keep_alive_timeout(Duration::from_secs(10)) + .pool_idle_timeout(Duration::from_secs(300)) + .pool_max_idle_per_host(1) + .build() + .unwrap(); + + // Warm up + for _ in 0..5 { + let wire = build_query_vec("example.com"); + let _ = rt.block_on(async { + client + .post(DOH_UPSTREAM) + .header("content-type", "application/dns-message") + .header("accept", "application/dns-message") + .body(wire) + .send() + .await + .ok()? + .bytes() + .await + .ok() + }); + } + + println!("{:>4} {:>8} {:>8} {:>8} {:>8}", "idx", "total", "build", "send", "body"); + println!("{}", "-".repeat(50)); + + let samples: Vec<(f64, f64, f64, f64)> = rt.block_on(async { + let mut s = Vec::with_capacity(200); + for i in 0..200 { + let domain = match i % 5 { + 0 => "example.com", + 1 => "google.com", + 2 => "github.com", + 3 => "rust-lang.org", + _ => "cloudflare.com", + }; + let wire = build_query_vec(domain); + + let t0 = Instant::now(); + // Phase 1: build the request + let req = client + .post(DOH_UPSTREAM) + .header("content-type", "application/dns-message") + .header("accept", "application/dns-message") + .body(wire); + let t1 = Instant::now(); + // Phase 2: send() — this is the dispatch channel + round trip to headers + let resp_result = req.send().await; + let t2 = Instant::now(); + // Phase 3: read body + let body_result = match resp_result { + Ok(r) => r.bytes().await.ok().map(|b| b.len()), + Err(_) => None, + }; + let t3 = Instant::now(); + + let build_ms = (t1 - t0).as_secs_f64() * 1000.0; + let send_ms = (t2 - t1).as_secs_f64() * 1000.0; + let body_ms = (t3 - t2).as_secs_f64() * 1000.0; + let total_ms = (t3 - t0).as_secs_f64() * 1000.0; + + s.push((total_ms, build_ms, send_ms, body_ms)); + let _ = body_result; + } + s + }); + + // Compute distribution on total + let mut totals: Vec = samples.iter().map(|s| s.0).collect(); + totals.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let median = totals[100]; + + // Print spikes (> 3x median) with phase breakdown + for (i, (total, build, send, body)) in samples.iter().enumerate() { + if *total > median * 3.0 { + println!( + "{:>4} {:>5.1} ms {:>5.1} ms {:>5.1} ms {:>5.1} ms", + i, total, build, send, body + ); + } + } + + // Summary: mean of each phase for spikes vs non-spikes + let (spike_samples, normal_samples): (Vec<_>, Vec<_>) = samples + .iter() + .partition(|(t, _, _, _)| *t > median * 3.0); + + let phase_means = |samples: &[&(f64, f64, f64, f64)]| -> (f64, f64, f64, f64) { + let n = samples.len() as f64; + if n == 0.0 { return (0.0, 0.0, 0.0, 0.0); } + let total: f64 = samples.iter().map(|s| s.0).sum::() / n; + let build: f64 = samples.iter().map(|s| s.1).sum::() / n; + let send: f64 = samples.iter().map(|s| s.2).sum::() / n; + let body: f64 = samples.iter().map(|s| s.3).sum::() / n; + (total, build, send, body) + }; + + let spike_refs: Vec<&(f64, f64, f64, f64)> = spike_samples.iter().copied().collect(); + let normal_refs: Vec<&(f64, f64, f64, f64)> = normal_samples.iter().copied().collect(); + let (s_total, s_build, s_send, s_body) = phase_means(&spike_refs); + let (n_total, n_build, n_send, n_body) = phase_means(&normal_refs); + + println!(); + println!("Summary (mean ms):"); + println!( + " {:<8} {:>8} {:>8} {:>8} {:>8}", + "", "total", "build", "send", "body" + ); + println!( + " {:<8} {:>5.1} ms {:>5.1} ms {:>5.1} ms {:>5.1} ms (n={})", + "normal", n_total, n_build, n_send, n_body, normal_refs.len() + ); + println!( + " {:<8} {:>5.1} ms {:>5.1} ms {:>5.1} ms {:>5.1} ms (n={})", + "spike", s_total, s_build, s_send, s_body, spike_refs.len() + ); + println!(); + println!("Delta (spike - normal):"); + println!( + " build: {:+.1} ms, send: {:+.1} ms, body: {:+.1} ms", + s_build - n_build, + s_send - n_send, + s_body - n_body + ); +} + +/// Heartbeat probe: run a parallel task that ticks every 5ms and records +/// how long each tick actually takes. If the heartbeat stalls during a DoH +/// spike, it's a tokio scheduling issue (runtime can't poll tasks). If +/// heartbeat is fine while send() is stuck, it's internal to hyper/h2. +fn run_spike_heartbeat(rt: &tokio::runtime::Runtime) { + use std::sync::{Arc, Mutex}; + + println!("Spike heartbeat probe"); + println!("Running DoH queries + parallel 5ms heartbeat task\n"); + + let upstream = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse upstream"); + let timeout = Duration::from_secs(10); + + // Warm up + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); + } + + // Shared vecs: (relative_ms_from_start, event_kind, latency_ms) + // event_kind: 0 = heartbeat, 1 = doh query + type EventLog = Vec<(f64, u8, f64)>; + let events: Arc> = Arc::new(Mutex::new(Vec::with_capacity(2000))); + let stop = Arc::new(std::sync::atomic::AtomicBool::new(false)); + + let test_start = Instant::now(); + + rt.block_on(async { + // Spawn heartbeat task + let hb_events = Arc::clone(&events); + let hb_stop = Arc::clone(&stop); + let hb_start = test_start; + let heartbeat = tokio::spawn(async move { + let mut next_tick = Instant::now(); + let target = Duration::from_millis(5); + while !hb_stop.load(std::sync::atomic::Ordering::Relaxed) { + next_tick += target; + // Sleep until the next scheduled tick + let now = Instant::now(); + if next_tick > now { + tokio::time::sleep(next_tick - now).await; + } + // Measure how much we overshot the scheduled tick + let actual = Instant::now(); + let lag_ms = if actual > next_tick { + (actual - next_tick).as_secs_f64() * 1000.0 + } else { + 0.0 + }; + let t = (actual - hb_start).as_secs_f64() * 1000.0; + if let Ok(mut e) = hb_events.lock() { + e.push((t, 0, lag_ms)); + } + } + }); + + // Run 200 DoH queries and record their timings + for i in 0..200 { + let domain = match i % 5 { + 0 => "example.com", + 1 => "google.com", + 2 => "github.com", + 3 => "rust-lang.org", + _ => "cloudflare.com", + }; + let wire = build_query_vec(domain); + let req_start = Instant::now(); + let _ = numa::forward::forward_query_raw(&wire, &upstream, timeout).await; + let elapsed = req_start.elapsed().as_secs_f64() * 1000.0; + let t = (req_start - test_start).as_secs_f64() * 1000.0; + if let Ok(mut e) = events.lock() { + e.push((t, 1, elapsed)); + } + } + + stop.store(true, std::sync::atomic::Ordering::Relaxed); + let _ = heartbeat.await; + }); + + let events = events.lock().unwrap(); + + // Separate heartbeats and doh events + let hb: Vec<(f64, f64)> = events + .iter() + .filter(|(_, k, _)| *k == 0) + .map(|(t, _, l)| (*t, *l)) + .collect(); + let doh: Vec<(f64, f64)> = events + .iter() + .filter(|(_, k, _)| *k == 1) + .map(|(t, _, l)| (*t, *l)) + .collect(); + + // Heartbeat stats + let mut hb_lags: Vec = hb.iter().map(|(_, l)| *l).collect(); + hb_lags.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let hb_n = hb_lags.len(); + let hb_median = hb_lags[hb_n / 2]; + let hb_p95 = hb_lags[(hb_n * 95) / 100]; + let hb_p99 = hb_lags[(hb_n * 99) / 100]; + let hb_max = hb_lags[hb_n - 1]; + + // DoH stats + let mut doh_latencies: Vec = doh.iter().map(|(_, l)| *l).collect(); + doh_latencies.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let doh_n = doh_latencies.len(); + let doh_median = doh_latencies[doh_n / 2]; + let doh_p95 = doh_latencies[(doh_n * 95) / 100]; + let doh_max = doh_latencies[doh_n - 1]; + + println!("Heartbeat lag (tick overshoot, {}ms target):", 5); + println!(" n: {}", hb_n); + println!(" median: {:.2} ms", hb_median); + println!(" p95: {:.2} ms", hb_p95); + println!(" p99: {:.2} ms", hb_p99); + println!(" max: {:.2} ms", hb_max); + println!(); + println!("DoH latency:"); + println!(" n: {}", doh_n); + println!(" median: {:.1} ms", doh_median); + println!(" p95: {:.1} ms", doh_p95); + println!(" max: {:.1} ms", doh_max); + println!(); + + // Find DoH spikes and check heartbeat activity DURING each spike + let doh_spike_threshold = doh_median * 3.0; + let mut spikes_with_hb_lag = 0; + let mut spikes_total = 0; + let mut max_hb_during_any_spike = 0.0_f64; + + println!( + "Correlation: during each DoH spike (>{:.1}ms), max heartbeat lag:", + doh_spike_threshold + ); + println!(" {:>6} {:>10} {:>18}", "doh_at", "doh_ms", "max_hb_lag_during"); + + for (doh_t, doh_ms) in &doh { + if *doh_ms > doh_spike_threshold { + spikes_total += 1; + // Find heartbeats that happened during this DoH query + let spike_start = *doh_t; + let spike_end = spike_start + *doh_ms; + let mut max_hb = 0.0_f64; + for (hb_t, hb_lag) in &hb { + if *hb_t >= spike_start && *hb_t <= spike_end + 20.0 { + if *hb_lag > max_hb { + max_hb = *hb_lag; + } + } + } + if max_hb > 5.0 { + spikes_with_hb_lag += 1; + } + max_hb_during_any_spike = max_hb_during_any_spike.max(max_hb); + println!( + " {:>5.0} ms {:>7.1} ms {:>14.2} ms", + doh_t, doh_ms, max_hb + ); + } + } + + println!(); + println!("Conclusion:"); + if spikes_total == 0 { + println!(" No DoH spikes in this run."); + } else { + let pct = (spikes_with_hb_lag as f64 / spikes_total as f64 * 100.0).round(); + println!( + " {}/{} spikes ({:.0}%) had concurrent heartbeat lag >5ms.", + spikes_with_hb_lag, spikes_total, pct + ); + println!(" Max heartbeat lag during any spike: {:.2}ms", max_hb_during_any_spike); + println!(); + if max_hb_during_any_spike > 20.0 { + println!(" → Heartbeat stalls during DoH spikes: tokio scheduling / OS thread issue."); + println!(" The runtime can't poll ANY task — likely QoS demotion, GC pause,"); + println!(" or the worker thread is blocked somewhere."); + } else { + println!(" → Heartbeat runs normally during DoH spikes: internal to hyper/h2."); + println!(" The runtime is fine, but send()'s await is stuck waiting for"); + println!(" the ClientTask to poll the dispatch channel."); + } + } +} + +/// Hedging benchmark: tests four configurations against Hickory. +/// Single: 1 client → Quad9 (baseline) +/// Hedge-same: hedge against same client/connection → Quad9 +/// Hedge-dual: hedge against 2 separate clients, both → Quad9 (same upstream, 2 HTTP/2 conns) +/// Hickory: Hickory resolver → Quad9 (reference) +fn run_hedge(rt: &tokio::runtime::Runtime) { + let hedge_delay = Duration::from_millis(10); + + println!("Hedging Benchmark (all paths → Quad9 only)"); + println!("Upstream: {}", DOH_UPSTREAM); + println!("Hedge delay: {:?}", hedge_delay); + println!("{} domains × {} rounds\n", DOMAINS.len(), ROUNDS); + + // Primary and secondary: two separate reqwest clients → same Quad9 URL. + // This gives two independent HTTP/2 connections, so dispatch spikes + // are uncorrelated (at most one stalls at a time). + let primary_same = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse primary"); + let primary_dual = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse primary_dual"); + let secondary_dual = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse secondary_dual"); + let timeout = Duration::from_secs(10); + + let resolver = rt.block_on(build_hickory_resolver()); + + // Warm up all paths (separate connections need their own TLS handshake) + println!("Warming up connections..."); + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_same, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_dual, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &secondary_dual, timeout)); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + let mut single_all = Vec::new(); + let mut hedge_same_all = Vec::new(); + let mut hedge_dual_all = Vec::new(); + let mut hickory_all = Vec::new(); + + println!( + "{:<24} {:>10} {:>10} {:>10} {:>10}", + "Domain", "Single", "Hedge-same", "Hedge-dual", "Hickory" + ); + println!("{}", "-".repeat(78)); + + for domain in DOMAINS { + let mut single_times = Vec::with_capacity(ROUNDS); + let mut hedge_same_times = Vec::with_capacity(ROUNDS); + let mut hedge_dual_times = Vec::with_capacity(ROUNDS); + let mut hickory_times = Vec::with_capacity(ROUNDS); + + for _ in 0..ROUNDS { + let wire = build_query_vec(domain); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_query_raw(&wire, &primary_same, timeout)); + single_times.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &wire, &primary_same, &primary_same, hedge_delay, timeout, + )); + hedge_same_times.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &wire, &primary_dual, &secondary_dual, hedge_delay, timeout, + )); + hedge_dual_times.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(query_hickory_doh(&resolver, domain)); + hickory_times.push(t.elapsed().as_secs_f64() * 1000.0); + } + + single_all.extend_from_slice(&single_times); + hedge_same_all.extend_from_slice(&hedge_same_times); + hedge_dual_all.extend_from_slice(&hedge_dual_times); + hickory_all.extend_from_slice(&hickory_times); + + println!( + "{:<24} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + domain, + mean(&single_times), + mean(&hedge_same_times), + mean(&hedge_dual_times), + mean(&hickory_times) + ); + } + + println!("{}", "-".repeat(78)); + + let stats = |all: &mut Vec| -> (f64, f64, f64, f64, f64) { + let m = mean(all); + let med = median(all); + let p95 = percentile(all, 95.0); + let p99 = percentile(all, 99.0); + let sd = stddev(all); + (m, med, p95, p99, sd) + }; + + let (s_m, s_med, s_p95, s_p99, s_sd) = stats(&mut single_all); + let (hs_m, hs_med, hs_p95, hs_p99, hs_sd) = stats(&mut hedge_same_all); + let (hd_m, hd_med, hd_p95, hd_p99, hd_sd) = stats(&mut hedge_dual_all); + let (k_m, k_med, k_p95, k_p99, k_sd) = stats(&mut hickory_all); + + println!(); + println!( + "{:<10} {:>10} {:>10} {:>10} {:>10}", + "", "Single", "Hedge-same", "Hedge-dual", "Hickory" + ); + println!( + "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + "mean", s_m, hs_m, hd_m, k_m + ); + println!( + "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + "median", s_med, hs_med, hd_med, k_med + ); + println!( + "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + "p95", s_p95, hs_p95, hd_p95, k_p95 + ); + println!( + "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + "p99", s_p99, hs_p99, hd_p99, k_p99 + ); + println!( + "{:<10} {:>7.1} ms {:>7.1} ms {:>7.1} ms {:>7.1} ms", + "σ", s_sd, hs_sd, hd_sd, k_sd + ); + + println!(); + println!("Hedge-same improvement over single:"); + println!(" mean: {:+.0}%, p95: {:+.0}%, p99: {:+.0}%", + (hs_m - s_m) / s_m * 100.0, + (hs_p95 - s_p95) / s_p95 * 100.0, + (hs_p99 - s_p99) / s_p99 * 100.0); + println!("Hedge-dual improvement over single:"); + println!(" mean: {:+.0}%, p95: {:+.0}%, p99: {:+.0}%", + (hd_m - s_m) / s_m * 100.0, + (hd_p95 - s_p95) / s_p95 * 100.0, + (hd_p99 - s_p99) / s_p99 * 100.0); +} + +/// Run the hedging benchmark N times and aggregate samples across all runs. +/// Also reports per-run stats to show drift. +fn run_hedge_multi(rt: &tokio::runtime::Runtime, iterations: usize) { + let hedge_delay = Duration::from_millis(10); + + println!("Hedging Benchmark × {} iterations", iterations); + println!("Upstream: {}", DOH_UPSTREAM); + println!("Hedge delay: {:?}", hedge_delay); + println!("{} domains × {} rounds per iteration\n", DOMAINS.len(), ROUNDS); + + let primary_same = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let primary_dual = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let secondary_dual = + numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let timeout = Duration::from_secs(10); + + let resolver = rt.block_on(build_hickory_resolver()); + + // Warm up + println!("Warming up..."); + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_same, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_dual, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &secondary_dual, timeout)); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + // Accumulated samples across all iterations + let mut all_single = Vec::new(); + let mut all_hedge_same = Vec::new(); + let mut all_hedge_dual = Vec::new(); + let mut all_hickory = Vec::new(); + + // Per-iteration summary stats + let mut iter_stats: Vec<[(f64, f64, f64, f64, f64); 4]> = Vec::new(); + + for iter in 1..=iterations { + println!(" iteration {}/{}...", iter, iterations); + + let mut single = Vec::new(); + let mut hedge_same = Vec::new(); + let mut hedge_dual = Vec::new(); + let mut hickory = Vec::new(); + + for domain in DOMAINS { + for _ in 0..ROUNDS { + let wire = build_query_vec(domain); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_query_raw(&wire, &primary_same, timeout)); + single.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &wire, &primary_same, &primary_same, hedge_delay, timeout, + )); + hedge_same.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &wire, &primary_dual, &secondary_dual, hedge_delay, timeout, + )); + hedge_dual.push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(query_hickory_doh(&resolver, domain)); + hickory.push(t.elapsed().as_secs_f64() * 1000.0); + } + } + + let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + iter_stats.push([ + stats(&mut single), + stats(&mut hedge_same), + stats(&mut hedge_dual), + stats(&mut hickory), + ]); + + all_single.extend_from_slice(&single); + all_hedge_same.extend_from_slice(&hedge_same); + all_hedge_dual.extend_from_slice(&hedge_dual); + all_hickory.extend_from_slice(&hickory); + } + + println!(); + println!("=== Per-iteration medians (drift check) ==="); + println!( + "{:<8} {:>10} {:>12} {:>12} {:>10}", + "iter", "Single", "Hedge-same", "Hedge-dual", "Hickory" + ); + for (i, s) in iter_stats.iter().enumerate() { + println!( + "{:<8} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + i + 1, + s[0].1, + s[1].1, + s[2].1, + s[3].1 + ); + } + + println!(); + println!("=== Per-iteration p99 (drift check) ==="); + println!( + "{:<8} {:>10} {:>12} {:>12} {:>10}", + "iter", "Single", "Hedge-same", "Hedge-dual", "Hickory" + ); + for (i, s) in iter_stats.iter().enumerate() { + println!( + "{:<8} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + i + 1, + s[0].3, + s[1].3, + s[2].3, + s[3].3 + ); + } + + let final_stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + let (s_m, s_med, s_p95, s_p99, s_sd) = final_stats(&mut all_single); + let (hs_m, hs_med, hs_p95, hs_p99, hs_sd) = final_stats(&mut all_hedge_same); + let (hd_m, hd_med, hd_p95, hd_p99, hd_sd) = final_stats(&mut all_hedge_dual); + let (k_m, k_med, k_p95, k_p99, k_sd) = final_stats(&mut all_hickory); + + println!(); + let total = iterations * DOMAINS.len() * ROUNDS; + println!("=== Aggregated across all {} samples per method ===", total); + println!(); + println!( + "{:<10} {:>10} {:>12} {:>12} {:>10}", + "", "Single", "Hedge-same", "Hedge-dual", "Hickory" + ); + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + "mean", s_m, hs_m, hd_m, k_m + ); + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + "median", s_med, hs_med, hd_med, k_med + ); + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + "p95", s_p95, hs_p95, hd_p95, k_p95 + ); + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + "p99", s_p99, hs_p99, hd_p99, k_p99 + ); + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + "σ", s_sd, hs_sd, hd_sd, k_sd + ); + + println!(); + println!("Hedge-same vs Single: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", + (hs_m - s_m) / s_m * 100.0, + (hs_p95 - s_p95) / s_p95 * 100.0, + (hs_p99 - s_p99) / s_p99 * 100.0); + println!("Hedge-dual vs Single: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", + (hd_m - s_m) / s_m * 100.0, + (hd_p95 - s_p95) / s_p95 * 100.0, + (hd_p99 - s_p99) / s_p99 * 100.0); + println!("Hedge-same vs Hickory: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", + (hs_m - k_m) / k_m * 100.0, + (hs_p95 - k_p95) / k_p95 * 100.0, + (hs_p99 - k_p99) / k_p99 * 100.0); +} + +/// Server-to-server benchmark: Numa vs dnscrypt-proxy vs Unbound. +/// All are full servers: UDP in, encrypted forwarding to Quad9. +/// Numa + dnscrypt: DoH (HTTPS). Unbound: DoT (TLS port 853). +fn run_vs_dnscrypt(rt: &tokio::runtime::Runtime, iterations: usize) { + const DNSCRYPT_ADDR: &str = "127.0.0.1:5455"; + const UNBOUND_ADDR: &str = "127.0.0.1:5456"; + let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); + let dnscrypt_addr: SocketAddr = DNSCRYPT_ADDR.parse().unwrap(); + let unbound_addr: SocketAddr = UNBOUND_ADDR.parse().unwrap(); + + println!("Server-to-Server: Numa vs dnscrypt-proxy vs Unbound"); + println!("Numa (DoH): {}", NUMA_BENCH); + println!("dnscrypt-proxy (DoH): {}", DNSCRYPT_ADDR); + println!("Unbound (DoT): {}", UNBOUND_ADDR); + println!("All forwarding to Quad9 over encrypted transport"); + println!("{} domains × {} rounds × {} iterations\n", + DOMAINS.len(), ROUNDS, iterations); + + // Verify all are up + let servers: Vec<(&str, SocketAddr)> = vec![ + ("Numa", numa_addr), + ("dnscrypt-proxy", dnscrypt_addr), + ("Unbound", unbound_addr), + ]; + for (name, addr) in &servers { + if rt.block_on(query_udp(*addr, "example.com")).is_none() { + eprintln!("{} not responding on {}", name, addr); + std::process::exit(1); + } + } + println!("All servers reachable.\n"); + + // Warm up + println!("Warming up..."); + for _ in 0..5 { + for (_, addr) in &servers { + let _ = rt.block_on(query_udp(*addr, "example.com")); + } + } + + let mut all_numa = Vec::new(); + let mut all_dnscrypt = Vec::new(); + let mut all_unbound = Vec::new(); + let mut iter_stats: Vec<[(f64, f64, f64, f64, f64); 3]> = Vec::new(); + + for iter in 1..=iterations { + println!(" iteration {}/{}...", iter, iterations); + + let mut numa = Vec::new(); + let mut dnscrypt = Vec::new(); + let mut unbound = Vec::new(); + + for domain in DOMAINS { + for round in 0..ROUNDS { + flush_cache(); + std::thread::sleep(Duration::from_millis(5)); + + // Rotate order: 3 servers, 3 possible orderings + let order = round % 3; + let mut measure = |addr: SocketAddr| -> f64 { + let t = Instant::now(); + let _ = rt.block_on(query_udp(addr, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }; + + match order { + 0 => { + numa.push(measure(numa_addr)); + dnscrypt.push(measure(dnscrypt_addr)); + unbound.push(measure(unbound_addr)); + } + 1 => { + dnscrypt.push(measure(dnscrypt_addr)); + unbound.push(measure(unbound_addr)); + numa.push(measure(numa_addr)); + } + _ => { + unbound.push(measure(unbound_addr)); + numa.push(measure(numa_addr)); + dnscrypt.push(measure(dnscrypt_addr)); + } + } + } + } + + let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + iter_stats.push([stats(&mut numa), stats(&mut dnscrypt), stats(&mut unbound)]); + + all_numa.extend_from_slice(&numa); + all_dnscrypt.extend_from_slice(&dnscrypt); + all_unbound.extend_from_slice(&unbound); + } + + println!(); + println!("=== Per-iteration medians ==="); + println!("{:<8} {:>10} {:>14} {:>10}", "iter", "Numa", "dnscrypt-proxy", "Unbound"); + for (i, s) in iter_stats.iter().enumerate() { + println!("{:<8} {:>7.1} ms {:>11.1} ms {:>7.1} ms", + i + 1, s[0].1, s[1].1, s[2].1); + } + + println!(); + println!("=== Per-iteration p99 ==="); + println!("{:<8} {:>10} {:>14} {:>10}", "iter", "Numa", "dnscrypt-proxy", "Unbound"); + for (i, s) in iter_stats.iter().enumerate() { + println!("{:<8} {:>7.1} ms {:>11.1} ms {:>7.1} ms", + i + 1, s[0].3, s[1].3, s[2].3); + } + + let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + let (n_m, n_med, n_p95, n_p99, n_sd) = stats(&mut all_numa); + let (d_m, d_med, d_p95, d_p99, d_sd) = stats(&mut all_dnscrypt); + let (u_m, u_med, u_p95, u_p99, u_sd) = stats(&mut all_unbound); + + println!(); + let total = iterations * DOMAINS.len() * ROUNDS; + println!("=== Aggregated ({} samples per method) ===", total); + println!(); + println!("{:<10} {:>10} {:>14} {:>10}", "", "Numa", "dnscrypt-proxy", "Unbound"); + println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "mean", n_m, d_m, u_m); + println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "median", n_med, d_med, u_med); + println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "p95", n_p95, d_p95, u_p95); + println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "p99", n_p99, d_p99, u_p99); + println!("{:<10} {:>7.1} ms {:>11.1} ms {:>7.1} ms", "σ", n_sd, d_sd, u_sd); + println!(); + + println!("Numa vs dnscrypt-proxy:"); + println!(" mean: {:+.0}%, median: {:+.0}%, p99: {:+.0}%", + (n_m - d_m) / d_m * 100.0, (n_med - d_med) / d_med * 100.0, (n_p99 - d_p99) / d_p99 * 100.0); + println!("Numa vs Unbound:"); + println!(" mean: {:+.0}%, median: {:+.0}%, p99: {:+.0}%", + (n_m - u_m) / u_m * 100.0, (n_med - u_med) / u_med * 100.0, (n_p99 - u_p99) / u_p99 * 100.0); +} + +/// Numa vs Unbound: both forward over plain UDP to Quad9, caching enabled. +/// Truly equal transport — no TLS, no HTTP/2, pure forwarding + cache. +fn run_vs_unbound(rt: &tokio::runtime::Runtime, iterations: usize) { + const UNBOUND_ADDR: &str = "127.0.0.1:5456"; + let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); + let unbound_addr: SocketAddr = UNBOUND_ADDR.parse().unwrap(); + + println!("Numa vs Unbound (both plain UDP forwarding to Quad9, caching enabled)"); + println!("Numa: {} → 9.9.9.9:53 UDP", NUMA_BENCH); + println!("Unbound: {} → 9.9.9.9:53 UDP", UNBOUND_ADDR); + println!("{} domains × {} rounds × {} iterations\n", + DOMAINS.len(), ROUNDS, iterations); + + if rt.block_on(query_udp(numa_addr, "example.com")).is_none() { + eprintln!("Numa not responding"); std::process::exit(1); + } + if rt.block_on(query_udp(unbound_addr, "example.com")).is_none() { + eprintln!("Unbound not responding"); std::process::exit(1); + } + println!("Both servers reachable.\n"); + + println!("Warming up..."); + for _ in 0..5 { + let _ = rt.block_on(query_udp(numa_addr, "example.com")); + let _ = rt.block_on(query_udp(unbound_addr, "example.com")); + } + + let mut all_numa = Vec::new(); + let mut all_unbound = Vec::new(); + let mut iter_stats: Vec<[(f64, f64, f64, f64, f64); 2]> = Vec::new(); + + for iter in 1..=iterations { + println!(" iteration {}/{}...", iter, iterations); + + let mut numa = Vec::new(); + let mut unbound = Vec::new(); + + for domain in DOMAINS { + for round in 0..ROUNDS { + // No cache flushing — both serve from cache after first hit + let mut measure = |addr: SocketAddr| -> f64 { + let t = Instant::now(); + let _ = rt.block_on(query_udp(addr, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }; + + if round % 2 == 0 { + numa.push(measure(numa_addr)); + unbound.push(measure(unbound_addr)); + } else { + unbound.push(measure(unbound_addr)); + numa.push(measure(numa_addr)); + } + } + } + + let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + iter_stats.push([stats(&mut numa), stats(&mut unbound)]); + + all_numa.extend_from_slice(&numa); + all_unbound.extend_from_slice(&unbound); + } + + println!(); + println!("=== Per-iteration medians ==="); + println!("{:<8} {:>10} {:>10}", "iter", "Numa", "Unbound"); + for (i, s) in iter_stats.iter().enumerate() { + println!("{:<8} {:>7.1} ms {:>7.1} ms", i + 1, s[0].1, s[1].1); + } + + println!(); + println!("=== Per-iteration p99 ==="); + println!("{:<8} {:>10} {:>10}", "iter", "Numa", "Unbound"); + for (i, s) in iter_stats.iter().enumerate() { + println!("{:<8} {:>7.1} ms {:>7.1} ms", i + 1, s[0].3, s[1].3); + } + + let stats = |v: &mut Vec| -> (f64, f64, f64, f64, f64) { + (mean(v), median(v), percentile(v, 95.0), percentile(v, 99.0), stddev(v)) + }; + let (n_m, n_med, n_p95, n_p99, n_sd) = stats(&mut all_numa); + let (u_m, u_med, u_p95, u_p99, u_sd) = stats(&mut all_unbound); + + println!(); + let total = iterations * DOMAINS.len() * ROUNDS; + println!("=== Aggregated ({} samples per method) ===", total); + println!(); + println!("{:<10} {:>10} {:>10}", "", "Numa", "Unbound"); + println!("{:<10} {:>7.1} ms {:>7.1} ms", "mean", n_m, u_m); + println!("{:<10} {:>7.1} ms {:>7.1} ms", "median", n_med, u_med); + println!("{:<10} {:>7.1} ms {:>7.1} ms", "p95", n_p95, u_p95); + println!("{:<10} {:>7.1} ms {:>7.1} ms", "p99", n_p99, u_p99); + println!("{:<10} {:>7.1} ms {:>7.1} ms", "σ", n_sd, u_sd); + println!(); + + println!("Numa vs Unbound:"); + println!(" mean: {:+.1} ms ({:+.0}%)", n_m - u_m, (n_m - u_m) / u_m * 100.0); + println!(" median: {:+.1} ms ({:+.0}%)", n_med - u_med, (n_med - u_med) / u_med * 100.0); + println!(" p95: {:+.1} ms ({:+.0}%)", n_p95 - u_p95, (n_p95 - u_p95) / u_p95 * 100.0); + println!(" p99: {:+.1} ms ({:+.0}%)", n_p99 - u_p99, (n_p99 - u_p99) / u_p99 * 100.0); +} + +/// Build a DNS query as a Vec for use with forward_query_raw. +fn build_query_vec(domain: &str) -> Vec { + let mut buf = vec![0u8; 512]; + let len = build_query(&mut buf, domain); + buf.truncate(len); + buf +} + +fn measure R, R>(_rt: &tokio::runtime::Runtime, f: F) -> f64 { + let start = Instant::now(); + f(); + start.elapsed().as_secs_f64() * 1000.0 +} + +fn mean(v: &[f64]) -> f64 { + v.iter().sum::() / v.len() as f64 +} + +fn stddev(v: &[f64]) -> f64 { + let m = mean(v); + let var = v.iter().map(|x| (x - m).powi(2)).sum::() / v.len() as f64; + var.sqrt() +} + +fn median(v: &mut [f64]) -> f64 { + v.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let n = v.len(); + if n % 2 == 0 { + (v[n / 2 - 1] + v[n / 2]) / 2.0 + } else { + v[n / 2] + } +} + +fn percentile(sorted: &[f64], p: f64) -> f64 { + let idx = (p / 100.0 * (sorted.len() - 1) as f64).round() as usize; + sorted[idx.min(sorted.len() - 1)] +} + +fn format_delta(delta: f64) -> String { + if delta > 0.0 { + format!("+{:.1}", delta) + } else { + format!("{:.1}", delta) + } +} + +/// Query a DNS server over UDP. +async fn query_udp(addr: SocketAddr, domain: &str) -> Option<()> { + use tokio::net::UdpSocket; + + let sock = UdpSocket::bind("0.0.0.0:0").await.ok()?; + let mut buf = vec![0u8; 512]; + let len = build_query(&mut buf, domain); + + sock.send_to(&buf[..len], addr).await.ok()?; + + let mut resp = vec![0u8; 4096]; + tokio::time::timeout(Duration::from_secs(10), sock.recv_from(&mut resp)) + .await + .ok()? + .ok()?; + + Some(()) +} + +/// Build a shared Hickory DoH resolver (reuses TLS connection across queries). +async fn build_hickory_resolver() -> hickory_resolver::TokioResolver { + use hickory_resolver::config::*; + + let ns = NameServerConfig { + socket_addr: "9.9.9.9:443".parse().unwrap(), + protocol: hickory_proto::xfer::Protocol::Https, + tls_dns_name: Some("dns.quad9.net".to_string()), + trust_negative_responses: true, + bind_addr: None, + http_endpoint: Some("/dns-query".to_string()), + }; + + let config = ResolverConfig::from_parts(None, vec![], NameServerConfigGroup::from(vec![ns])); + + let mut opts = ResolverOpts::default(); + opts.cache_size = 0; + opts.num_concurrent_reqs = 1; + opts.timeout = Duration::from_secs(10); + + hickory_resolver::TokioResolver::builder_with_config(config, Default::default()) + .with_options(opts) + .build() +} + +/// Query using the shared Hickory resolver. +async fn query_hickory_doh( + resolver: &hickory_resolver::TokioResolver, + domain: &str, +) -> Option<()> { + use hickory_resolver::proto::rr::RecordType; + let _ = resolver.lookup(domain, RecordType::A).await.ok()?; + Some(()) +} + +fn build_query(buf: &mut [u8], domain: &str) -> usize { + let mut pos = 0; + buf[pos..pos + 2].copy_from_slice(&0x1234u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 2].copy_from_slice(&0x0100u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 6].fill(0); + pos += 6; + + for label in domain.split('.') { + buf[pos] = label.len() as u8; + pos += 1; + buf[pos..pos + label.len()].copy_from_slice(label.as_bytes()); + pos += label.len(); + } + buf[pos] = 0; + pos += 1; + buf[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); + pos += 2; + pos +} + +fn flush_cache() { + let _ = std::process::Command::new("curl") + .args(["-s", "-X", "DELETE", &format!("http://127.0.0.1:{NUMA_API}/cache")]) + .output(); +} diff --git a/scripts/bench-recursive.sh b/scripts/bench-recursive.sh new file mode 100755 index 0000000..1a1ab71 --- /dev/null +++ b/scripts/bench-recursive.sh @@ -0,0 +1,115 @@ +#!/usr/bin/env bash +# Bench: Numa cold-cache recursive resolution vs dig (forwarded through system resolver) +# +# Measures cold-cache recursive resolution time for Numa. +# Flushes Numa's cache before each query to ensure cold-cache. +# Compares against dig querying a public recursive resolver (no cache advantage). +# +# Usage: ./scripts/bench-recursive.sh [numa_port] + +set -euo pipefail + +NUMA_ADDR="${NUMA_ADDR:-127.0.0.1}" +NUMA_PORT="${NUMA_PORT:-${1:-53}}" +API_PORT="${API_PORT:-5380}" +ROUNDS=3 + +DOMAINS=( + "example.com" + "rust-lang.org" + "kernel.org" + "signal.org" + "archlinux.org" + "openbsd.org" + "git-scm.com" + "sqlite.org" + "wireguard.com" + "mozilla.org" +) + +GREEN='\033[0;32m' +AMBER='\033[0;33m' +CYAN='\033[0;36m' +DIM='\033[0;90m' +BOLD='\033[1m' +RESET='\033[0m' + +echo -e "${CYAN}${BOLD}Recursive DNS Resolution Benchmark${RESET}" +echo -e "${DIM}Numa (cold cache, recursive from root) vs dig @1.1.1.1 (public resolver)${RESET}" +echo -e "${DIM}Rounds per domain: ${ROUNDS}${RESET}" +echo "" + +# Verify Numa is reachable +if ! dig @${NUMA_ADDR} -p ${NUMA_PORT} +short +time=3 +tries=1 example.com A &>/dev/null; then + echo -e "${AMBER}Numa not responding on ${NUMA_ADDR}:${NUMA_PORT}${RESET}" >&2 + exit 1 +fi + +# Verify we can flush cache +if ! curl -s -X DELETE "http://${NUMA_ADDR}:${API_PORT}/cache" &>/dev/null; then + echo -e "${AMBER}Cannot flush cache via API at ${NUMA_ADDR}:${API_PORT}${RESET}" >&2 + exit 1 +fi + +measure_ms() { + local start end + start=$(python3 -c 'import time; print(time.time())') + eval "$1" &>/dev/null + end=$(python3 -c 'import time; print(time.time())') + python3 -c "print(round(($end - $start) * 1000, 1))" +} + +printf "${BOLD}%-22s %10s %10s %8s${RESET}\n" "Domain" "Numa (ms)" "1.1.1.1" "Delta" +printf "%-22s %10s %10s %8s\n" "----------------------" "----------" "----------" "--------" + +numa_total=0 +dig_total=0 +count=0 + +for domain in "${DOMAINS[@]}"; do + numa_sum=0 + dig_sum=0 + + for ((r=1; r<=ROUNDS; r++)); do + # Flush Numa cache + curl -s -X DELETE "http://${NUMA_ADDR}:${API_PORT}/cache" &>/dev/null + sleep 0.05 + + # Measure Numa (recursive from root, cold cache) + ms=$(measure_ms "dig @${NUMA_ADDR} -p ${NUMA_PORT} +short +time=10 +tries=1 ${domain} A") + numa_sum=$(python3 -c "print(round($numa_sum + $ms, 1))") + + # Measure dig against 1.1.1.1 (Cloudflare — warm cache, but shows baseline) + ms=$(measure_ms "dig @1.1.1.1 +short +time=10 +tries=1 ${domain} A") + dig_sum=$(python3 -c "print(round($dig_sum + $ms, 1))") + done + + numa_avg=$(python3 -c "print(round($numa_sum / $ROUNDS, 1))") + dig_avg=$(python3 -c "print(round($dig_sum / $ROUNDS, 1))") + delta=$(python3 -c "d = round($numa_avg - $dig_avg, 1); print(f'+{d}' if d > 0 else str(d))") + + # Color the delta + delta_color="$GREEN" + if python3 -c "exit(0 if $numa_avg > $dig_avg * 1.5 else 1)" 2>/dev/null; then + delta_color="$AMBER" + fi + + printf "%-22s %8s ms %8s ms ${delta_color}%6s ms${RESET}\n" "$domain" "$numa_avg" "$dig_avg" "$delta" + + numa_total=$(python3 -c "print(round($numa_total + $numa_avg, 1))") + dig_total=$(python3 -c "print(round($dig_total + $dig_avg, 1))") + count=$((count + 1)) +done + +echo "" +numa_mean=$(python3 -c "print(round($numa_total / $count, 1))") +dig_mean=$(python3 -c "print(round($dig_total / $count, 1))") +delta_mean=$(python3 -c "d = round($numa_mean - $dig_mean, 1); print(f'+{d}' if d > 0 else str(d))") + +printf "${BOLD}%-22s %8s ms %8s ms %6s ms${RESET}\n" "AVERAGE" "$numa_mean" "$dig_mean" "$delta_mean" + +echo "" +echo -e "${DIM}Note: Numa resolves recursively from root hints (cold cache).${RESET}" +echo -e "${DIM}1.1.1.1 serves from Cloudflare's global cache (warm). The comparison${RESET}" +echo -e "${DIM}is intentionally unfair — it shows Numa's worst case vs the best case${RESET}" +echo -e "${DIM}of a global anycast resolver. Cached Numa queries resolve in <1ms.${RESET}" diff --git a/src/api.rs b/src/api.rs index a0bae58..e638fba 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1029,6 +1029,7 @@ mod tests { upstream_port: 53, lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), timeout: std::time::Duration::from_secs(3), + hedge_delay: std::time::Duration::ZERO, proxy_tld: "numa".to_string(), proxy_tld_suffix: ".numa".to_string(), lan_enabled: false, diff --git a/src/cache.rs b/src/cache.rs index 5bdde85..82795bc 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,9 +1,10 @@ use std::collections::HashMap; use std::time::{Duration, Instant}; +use crate::buffer::BytePacketBuffer; use crate::packet::DnsPacket; use crate::question::QueryType; -use crate::record::DnsRecord; +use crate::wire::WireMeta; #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] pub enum DnssecStatus { @@ -26,14 +27,16 @@ impl DnssecStatus { } struct CacheEntry { - packet: DnsPacket, + wire: Vec, + meta: WireMeta, inserted_at: Instant, ttl: Duration, dnssec_status: DnssecStatus, } -/// DNS cache using a two-level map (domain -> query_type -> entry) so that -/// lookups can borrow `&str` instead of allocating a `String` key. +const STALE_WINDOW: Duration = Duration::from_secs(3600); + +/// DNS cache with serve-stale (RFC 8767). Stores raw wire bytes. pub struct DnsCache { entries: HashMap>, entry_count: usize, @@ -53,6 +56,80 @@ impl DnsCache { } } + /// Look up cached wire bytes, patching ID and TTLs in the returned copy. + /// Implements serve-stale (RFC 8767): expired entries within STALE_WINDOW + /// are returned with TTL=1 and `stale=true` so callers can revalidate. + pub fn lookup_wire( + &self, + domain: &str, + qtype: QueryType, + new_id: u16, + ) -> Option<(Vec, DnssecStatus, bool)> { + let type_map = self.entries.get(domain)?; + let entry = type_map.get(&qtype)?; + + let elapsed = entry.inserted_at.elapsed(); + let (remaining, stale) = if elapsed < entry.ttl { + let secs = (entry.ttl - elapsed).as_secs() as u32; + (secs.max(1), false) + } else if elapsed < entry.ttl + STALE_WINDOW { + (1, true) + } else { + return None; + }; + + let mut wire = entry.wire.clone(); + crate::wire::patch_id(&mut wire, new_id); + crate::wire::patch_ttls(&mut wire, &entry.meta.ttl_offsets, remaining); + + Some((wire, entry.dnssec_status, stale)) + } + + pub fn insert_wire( + &mut self, + domain: &str, + qtype: QueryType, + wire: &[u8], + dnssec_status: DnssecStatus, + ) { + let meta = match crate::wire::scan_ttl_offsets(wire) { + Ok(m) => m, + Err(_) => return, // malformed wire, skip + }; + + if self.entry_count >= self.max_entries { + self.evict_expired(); + if self.entry_count >= self.max_entries { + return; + } + } + + let min_ttl = crate::wire::min_ttl_from_wire(wire, &meta) + .unwrap_or(self.min_ttl) + .clamp(self.min_ttl, self.max_ttl); + + let type_map = if let Some(existing) = self.entries.get_mut(domain) { + existing + } else { + self.entries.entry(domain.to_string()).or_default() + }; + + if !type_map.contains_key(&qtype) { + self.entry_count += 1; + } + + type_map.insert( + qtype, + CacheEntry { + wire: wire.to_vec(), + meta, + inserted_at: Instant::now(), + ttl: Duration::from_secs(min_ttl as u64), + dnssec_status, + }, + ); + } + /// Read-only lookup — expired entries are left in place (cleaned up on insert). pub fn lookup(&self, domain: &str, qtype: QueryType) -> Option { self.lookup_with_status(domain, qtype).map(|(pkt, _)| pkt) @@ -63,23 +140,28 @@ impl DnsCache { domain: &str, qtype: QueryType, ) -> Option<(DnsPacket, DnssecStatus)> { - let type_map = self.entries.get(domain)?; - let entry = type_map.get(&qtype)?; + let (wire, status, _stale) = self.lookup_wire(domain, qtype, 0)?; + let mut buf = BytePacketBuffer::from_bytes(&wire); + let pkt = DnsPacket::from_buffer(&mut buf).ok()?; + Some((pkt, status)) + } - let elapsed = entry.inserted_at.elapsed(); - if elapsed >= entry.ttl { - return None; + pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { + self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate); + } + + pub fn insert_with_status( + &mut self, + domain: &str, + qtype: QueryType, + packet: &DnsPacket, + dnssec_status: DnssecStatus, + ) { + let mut buf = BytePacketBuffer::new(); + if packet.write(&mut buf).is_err() { + return; } - - let remaining_secs = (entry.ttl - elapsed).as_secs() as u32; - let remaining = remaining_secs.max(1); - - let mut packet = entry.packet.clone(); - adjust_ttls(&mut packet.answers, remaining); - adjust_ttls(&mut packet.authorities, remaining); - adjust_ttls(&mut packet.resources, remaining); - - Some((packet, entry.dnssec_status)) + self.insert_wire(domain, qtype, buf.filled(), dnssec_status); } pub fn ttl_remaining(&self, domain: &str, qtype: QueryType) -> Option<(u32, u32)> { @@ -105,49 +187,6 @@ impl DnsCache { false } - pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { - self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate); - } - - pub fn insert_with_status( - &mut self, - domain: &str, - qtype: QueryType, - packet: &DnsPacket, - dnssec_status: DnssecStatus, - ) { - if self.entry_count >= self.max_entries { - self.evict_expired(); - if self.entry_count >= self.max_entries { - return; - } - } - - let min_ttl = extract_min_ttl(&packet.answers) - .unwrap_or(self.min_ttl) - .clamp(self.min_ttl, self.max_ttl); - - let type_map = if let Some(existing) = self.entries.get_mut(domain) { - existing - } else { - self.entries.entry(domain.to_string()).or_default() - }; - - if !type_map.contains_key(&qtype) { - self.entry_count += 1; - } - - type_map.insert( - qtype, - CacheEntry { - packet: packet.clone(), - inserted_at: Instant::now(), - ttl: Duration::from_secs(min_ttl as u64), - dnssec_status, - }, - ); - } - pub fn len(&self) -> usize { self.entry_count } @@ -179,7 +218,8 @@ impl DnsCache { + 1; total += type_map.capacity() * inner_slot; for entry in type_map.values() { - total += entry.packet.heap_bytes(); + total += entry.wire.capacity() + + entry.meta.ttl_offsets.capacity() * std::mem::size_of::(); } } total @@ -228,20 +268,11 @@ pub struct CacheInfo { pub ttl_remaining: u32, } -fn extract_min_ttl(records: &[DnsRecord]) -> Option { - records.iter().map(|r| r.ttl()).min() -} - -fn adjust_ttls(records: &mut [DnsRecord], new_ttl: u32) { - for record in records.iter_mut() { - record.set_ttl(new_ttl); - } -} - #[cfg(test)] mod tests { use super::*; use crate::packet::DnsPacket; + use crate::record::DnsRecord; #[test] fn heap_bytes_grows_with_entries() { diff --git a/src/config.rs b/src/config.rs index ae9f685..5f9db73 100644 --- a/src/config.rs +++ b/src/config.rs @@ -138,6 +138,8 @@ pub struct UpstreamConfig { pub fallback: Vec, #[serde(default = "default_timeout_ms")] pub timeout_ms: u64, + #[serde(default = "default_hedge_ms")] + pub hedge_ms: u64, #[serde(default = "default_root_hints")] pub root_hints: Vec, #[serde(default = "default_prime_tlds")] @@ -154,6 +156,7 @@ impl Default for UpstreamConfig { port: default_upstream_port(), fallback: Vec::new(), timeout_ms: default_timeout_ms(), + hedge_ms: default_hedge_ms(), root_hints: default_root_hints(), prime_tlds: default_prime_tlds(), srtt: default_srtt(), @@ -271,6 +274,9 @@ fn default_upstream_port() -> u16 { fn default_timeout_ms() -> u64 { 5000 } +fn default_hedge_ms() -> u64 { + 10 +} #[derive(Deserialize)] pub struct CacheConfig { diff --git a/src/ctx.rs b/src/ctx.rs index 3ef6a0a..2b26a06 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -16,7 +16,9 @@ use crate::blocklist::BlocklistStore; use crate::buffer::BytePacketBuffer; use crate::cache::{DnsCache, DnssecStatus}; use crate::config::{UpstreamMode, ZoneMap}; -use crate::forward::{forward_query, forward_with_failover, Upstream, UpstreamPool}; +use crate::forward::{ + forward_query_raw, forward_with_failover_raw, Upstream, UpstreamPool, +}; use crate::header::ResultCode; use crate::health::HealthMeta; use crate::lan::PeerStore; @@ -47,6 +49,7 @@ pub struct ServerCtx { pub upstream_port: u16, pub lan_ip: Mutex, pub timeout: Duration, + pub hedge_delay: Duration, pub proxy_tld: String, pub proxy_tld_suffix: String, // pre-computed ".{tld}" to avoid per-query allocation pub lan_enabled: bool, @@ -81,6 +84,7 @@ pub struct ServerCtx { /// (and logging parse errors) before calling this function. pub async fn resolve_query( query: DnsPacket, + raw_wire: &[u8], src_addr: SocketAddr, ctx: &ServerCtx, ) -> crate::Result { @@ -177,9 +181,8 @@ pub async fn resolve_query( // Conditional forwarding takes priority over recursive mode // (e.g. Tailscale .ts.net, VPC private zones) let upstream = Upstream::Udp(fwd_addr); - match forward_query(&query, &upstream, ctx.timeout).await { + match forward_and_cache(raw_wire, &upstream, ctx, &qname, qtype).await { Ok(resp) => { - ctx.cache.write().unwrap().insert(&qname, qtype, &resp); (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate) } Err(e) => { @@ -221,10 +224,19 @@ pub async fn resolve_query( (resp, path, DnssecStatus::Indeterminate) } else { let pool = ctx.upstream_pool.lock().unwrap().clone(); - match forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await { - Ok(resp) => { - ctx.cache.write().unwrap().insert(&qname, qtype, &resp); - (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate) + match forward_with_failover_raw(raw_wire, &pool, &ctx.srtt, ctx.timeout, ctx.hedge_delay).await { + Ok(resp_wire) => { + ctx.cache.write().unwrap().insert_wire( + &qname, qtype, &resp_wire, DnssecStatus::Indeterminate, + ); + let mut buf = BytePacketBuffer::from_bytes(&resp_wire); + match DnsPacket::from_buffer(&mut buf) { + Ok(resp) => (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate), + Err(e) => { + error!("{} | {:?} {} | PARSE ERROR | {}", src_addr, qtype, qname, e); + (DnsPacket::response_from(&query, ResultCode::SERVFAIL), QueryPath::UpstreamError, DnssecStatus::Indeterminate) + } + } } Err(e) => { error!( @@ -347,12 +359,29 @@ pub async fn resolve_query( Ok(resp_buffer) } -/// Handle a DNS query received over UDP. Thin wrapper around resolve_query. +async fn forward_and_cache( + wire: &[u8], + upstream: &Upstream, + ctx: &ServerCtx, + qname: &str, + qtype: QueryType, +) -> crate::Result { + let resp_wire = forward_query_raw(wire, upstream, ctx.timeout).await?; + ctx.cache + .write() + .unwrap() + .insert_wire(qname, qtype, &resp_wire, DnssecStatus::Indeterminate); + let mut buf = BytePacketBuffer::from_bytes(&resp_wire); + DnsPacket::from_buffer(&mut buf) +} + pub async fn handle_query( mut buffer: BytePacketBuffer, + raw_len: usize, src_addr: SocketAddr, ctx: &ServerCtx, ) -> crate::Result<()> { + let raw_wire = buffer.buf[..raw_len].to_vec(); let query = match DnsPacket::from_buffer(&mut buffer) { Ok(packet) => packet, Err(e) => { @@ -360,7 +389,7 @@ pub async fn handle_query( return Ok(()); } }; - match resolve_query(query, src_addr, ctx).await { + match resolve_query(query, &raw_wire, src_addr, ctx).await { Ok(resp_buffer) => { ctx.socket.send_to(resp_buffer.filled(), src_addr).await?; } diff --git a/src/doh.rs b/src/doh.rs index cf50b31..e31b6fe 100644 --- a/src/doh.rs +++ b/src/doh.rs @@ -82,7 +82,7 @@ async fn resolve_doh(dns_bytes: &[u8], src: SocketAddr, ctx: &ServerCtx) -> Resp let query_rd = query.header.recursion_desired; let questions = query.questions.clone(); - match resolve_query(query, src, ctx).await { + match resolve_query(query, dns_bytes, src, ctx).await { Ok(resp_buffer) => { let min_ttl = extract_min_ttl(resp_buffer.filled()); dns_response(resp_buffer.filled(), min_ttl) @@ -102,11 +102,10 @@ async fn resolve_doh(dns_bytes: &[u8], src: SocketAddr, ctx: &ServerCtx) -> Resp } fn extract_min_ttl(wire: &[u8]) -> u32 { - let mut buf = BytePacketBuffer::from_bytes(wire); - match DnsPacket::from_buffer(&mut buf) { - Ok(pkt) => pkt.answers.iter().map(|r| r.ttl()).min().unwrap_or(0), - Err(_) => 0, - } + crate::wire::scan_ttl_offsets(wire) + .ok() + .and_then(|meta| crate::wire::min_ttl_from_wire(wire, &meta)) + .unwrap_or(0) } fn dns_response(wire: &[u8], min_ttl: u32) -> Response { diff --git a/src/dot.rs b/src/dot.rs index 0d48fa2..4513f60 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -177,8 +177,7 @@ where break; }; - // Parse query up-front so we can echo its question section in SERVFAIL - // responses when resolve_query fails. + let raw_wire = buffer.buf[..msg_len].to_vec(); let query = match DnsPacket::from_buffer(&mut buffer) { Ok(q) => q, Err(e) => { @@ -200,7 +199,7 @@ where } }; - match resolve_query(query.clone(), remote_addr, ctx).await { + match resolve_query(query.clone(), &raw_wire, remote_addr, ctx).await { Ok(resp_buffer) => { if write_framed(&mut stream, resp_buffer.filled()) .await @@ -370,6 +369,7 @@ mod tests { upstream_port: 53, lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), timeout: Duration::from_millis(200), + hedge_delay: Duration::ZERO, proxy_tld: "numa".to_string(), proxy_tld_suffix: ".numa".to_string(), lan_enabled: false, diff --git a/src/forward.rs b/src/forward.rs index ea2f1e2..401ae1c 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -65,6 +65,13 @@ pub fn parse_upstream(s: &str, default_port: u16) -> Result { if s.starts_with("https://") { let client = reqwest::Client::builder() .use_rustls_tls() + .http2_initial_stream_window_size(65_535) + .http2_initial_connection_window_size(65_535) + .http2_keep_alive_interval(Duration::from_secs(15)) + .http2_keep_alive_while_idle(true) + .http2_keep_alive_timeout(Duration::from_secs(10)) + .pool_idle_timeout(Duration::from_secs(300)) + .pool_max_idle_per_host(1) .build() .unwrap_or_default(); return Ok(Upstream::Doh { @@ -325,13 +332,170 @@ async fn forward_doh( let mut send_buffer = BytePacketBuffer::new(); query.write(&mut send_buffer)?; + let resp_bytes = forward_doh_raw(send_buffer.filled(), url, client, timeout_duration).await?; + let mut recv_buffer = BytePacketBuffer::from_bytes(&resp_bytes); + DnsPacket::from_buffer(&mut recv_buffer) +} + +pub async fn forward_query_raw( + wire: &[u8], + upstream: &Upstream, + timeout_duration: Duration, +) -> Result> { + match upstream { + Upstream::Udp(addr) => forward_udp_raw(wire, *addr, timeout_duration).await, + Upstream::Doh { url, client } => forward_doh_raw(wire, url, client, timeout_duration).await, + } +} + +pub async fn forward_with_hedging_raw( + wire: &[u8], + primary: &Upstream, + secondary: &Upstream, + hedge_delay: Duration, + timeout_duration: Duration, +) -> Result> { + use tokio::time::sleep; + + let primary_fut = forward_query_raw(wire, primary, timeout_duration); + tokio::pin!(primary_fut); + + let delay = sleep(hedge_delay); + tokio::pin!(delay); + + // Phase 1: wait for either primary to return, or the hedge delay. + tokio::select! { + result = &mut primary_fut => return result, + _ = &mut delay => {} + } + + // Phase 2: hedge delay expired — fire secondary while still polling primary. + let secondary_fut = forward_query_raw(wire, secondary, timeout_duration); + tokio::pin!(secondary_fut); + + // First successful response wins. If one errors, wait for the other. + let mut primary_err: Option = None; + let mut secondary_err: Option = None; + + loop { + tokio::select! { + r = &mut primary_fut, if primary_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if let Some(se) = secondary_err.take() { + return Err(se); + } + primary_err = Some(e); + } + } + } + r = &mut secondary_fut, if secondary_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if let Some(pe) = primary_err.take() { + return Err(pe); + } + secondary_err = Some(e); + } + } + } + } + + match (primary_err, secondary_err) { + (Some(pe), Some(_)) => return Err(pe), + (pe, se) => { primary_err = pe; secondary_err = se; } + } + } +} + +pub async fn forward_with_failover_raw( + wire: &[u8], + pool: &UpstreamPool, + srtt: &RwLock, + timeout_duration: Duration, + hedge_delay: Duration, +) -> Result> { + let mut candidates: Vec<(usize, u64)> = pool + .primary + .iter() + .enumerate() + .map(|(i, u)| { + let rtt = match u { + Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()), + _ => 0, + }; + (i, rtt) + }) + .collect(); + candidates.sort_by_key(|&(_, rtt)| rtt); + + let all_upstreams: Vec<&Upstream> = candidates + .iter() + .map(|&(i, _)| &pool.primary[i]) + .chain(pool.fallback.iter()) + .collect(); + + let mut last_err: Option> = None; + + for upstream in &all_upstreams { + let start = Instant::now(); + let result = if !hedge_delay.is_zero() && matches!(upstream, Upstream::Doh { .. }) { + // Hedge against the same upstream: parallel h2 streams on same + // connection. Independent stream scheduling rescues dispatch spikes. + forward_with_hedging_raw(wire, upstream, upstream, hedge_delay, timeout_duration).await + } else { + forward_query_raw(wire, upstream, timeout_duration).await + }; + match result { + Ok(resp) => { + if let Upstream::Udp(addr) = upstream { + let rtt_ms = start.elapsed().as_millis() as u64; + srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false); + } + return Ok(resp); + } + Err(e) => { + if let Upstream::Udp(addr) = upstream { + srtt.write().unwrap().record_failure(addr.ip()); + } + log::debug!("upstream {} failed: {}", upstream, e); + last_err = Some(e); + } + } + } + + Err(last_err.unwrap_or_else(|| "no upstream configured".into())) +} + +async fn forward_udp_raw( + wire: &[u8], + upstream: SocketAddr, + timeout_duration: Duration, +) -> Result> { + let socket = UdpSocket::bind("0.0.0.0:0").await?; + socket.send_to(wire, upstream).await?; + + let mut recv_buf = vec![0u8; 4096]; + let (size, _) = timeout(timeout_duration, socket.recv_from(&mut recv_buf)).await??; + recv_buf.truncate(size); + Ok(recv_buf) +} + +async fn forward_doh_raw( + wire: &[u8], + url: &str, + client: &reqwest::Client, + timeout_duration: Duration, +) -> Result> { let resp = timeout( timeout_duration, client .post(url) .header("content-type", "application/dns-message") .header("accept", "application/dns-message") - .body(send_buffer.filled().to_vec()) + .body(wire.to_vec()) .send(), ) .await?? @@ -339,9 +503,25 @@ async fn forward_doh( let bytes = resp.bytes().await?; log::debug!("DoH response: {} bytes", bytes.len()); + Ok(bytes.to_vec()) +} - let mut recv_buffer = BytePacketBuffer::from_bytes(&bytes); - DnsPacket::from_buffer(&mut recv_buffer) +/// Send a lightweight keepalive query to a DoH upstream to prevent +/// the HTTP/2 + TLS connection from going idle and being torn down. +pub async fn keepalive_doh(upstream: &Upstream) { + if let Upstream::Doh { url, client } = upstream { + // Query for . NS — minimal, always succeeds, response is small + let wire: &[u8] = &[ + 0x00, 0x00, // ID + 0x01, 0x00, // flags: RD=1 + 0x00, 0x01, // QDCOUNT=1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // AN=0, NS=0, AR=0 + 0x00, // root name (.) + 0x00, 0x02, // type NS + 0x00, 0x01, // class IN + ]; + let _ = forward_doh_raw(wire, url, client, Duration::from_secs(5)).await; + } } #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index 4074020..92a0b00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,6 +26,7 @@ pub mod srtt; pub mod stats; pub mod system_dns; pub mod tls; +pub mod wire; pub type Error = Box; pub type Result = std::result::Result; diff --git a/src/main.rs b/src/main.rs index 7592186..0211a59 100644 --- a/src/main.rs +++ b/src/main.rs @@ -297,6 +297,7 @@ async fn main() -> numa::Result<()> { upstream_port: config.upstream.port, lan_ip: Mutex::new(numa::lan::detect_lan_ip().unwrap_or(std::net::Ipv4Addr::LOCALHOST)), timeout: Duration::from_millis(config.upstream.timeout_ms), + hedge_delay: Duration::from_millis(config.upstream.hedge_ms), proxy_tld_suffix: if config.proxy.tld.is_empty() { String::new() } else { @@ -511,6 +512,14 @@ async fn main() -> numa::Result<()> { }); } + // Spawn DoH connection keepalive — prevents idle TLS teardown + { + let keepalive_ctx = Arc::clone(&ctx); + tokio::spawn(async move { + doh_keepalive_loop(keepalive_ctx).await; + }); + } + // Spawn HTTP API server let api_ctx = Arc::clone(&ctx); let api_addr: SocketAddr = format!("{}:{}", config.server.api_bind_addr, api_port).parse()?; @@ -590,7 +599,7 @@ async fn main() -> numa::Result<()> { #[allow(clippy::infinite_loop)] loop { let mut buffer = BytePacketBuffer::new(); - let (_, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await { + let (len, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await { Ok(r) => r, Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => { // Windows delivers ICMP port-unreachable as ConnectionReset on UDP sockets @@ -598,10 +607,11 @@ async fn main() -> numa::Result<()> { } Err(e) => return Err(e.into()), }; + let raw_len = len; let ctx = Arc::clone(&ctx); tokio::spawn(async move { - if let Err(e) = handle_query(buffer, src_addr, &ctx).await { + if let Err(e) = handle_query(buffer, raw_len, src_addr, &ctx).await { error!("{} | HANDLER ERROR | {}", src_addr, e); } }); @@ -777,6 +787,18 @@ async fn warm_domain(ctx: &ServerCtx, domain: &str) { } } +async fn doh_keepalive_loop(ctx: Arc) { + let mut interval = tokio::time::interval(Duration::from_secs(25)); + interval.tick().await; // skip first immediate tick + loop { + interval.tick().await; + let pool = ctx.upstream_pool.lock().unwrap().clone(); + if let Some(upstream) = pool.preferred() { + numa::forward::keepalive_doh(upstream).await; + } + } +} + async fn cache_warm_loop(ctx: Arc, domains: Vec) { tokio::time::sleep(Duration::from_secs(2)).await; diff --git a/src/recursive.rs b/src/recursive.rs index 24d0367..2609f7f 100644 --- a/src/recursive.rs +++ b/src/recursive.rs @@ -202,23 +202,22 @@ pub(crate) fn resolve_iterative<'a>( let mut ns_idx = 0; for _ in 0..MAX_REFERRAL_DEPTH { - let ns_addr = match ns_addrs.get(ns_idx) { - Some(addr) => *addr, - None => return Err("no nameserver available".into()), - }; + if ns_idx >= ns_addrs.len() { + return Err("no nameserver available".into()); + } let (q_name, q_type) = minimize_query(qname, qtype, ¤t_zone); debug!( - "recursive: querying {} for {:?} {} (zone: {}, depth {})", - ns_addr, q_type, q_name, current_zone, referral_depth + "recursive: querying {} (+ hedge) for {:?} {} (zone: {}, depth {})", + ns_addrs[ns_idx], q_type, q_name, current_zone, referral_depth ); - let response = match send_query(q_name, q_type, ns_addr, srtt).await { + let response = match send_query_hedged(q_name, q_type, &ns_addrs[ns_idx..], srtt).await { Ok(r) => r, Err(e) => { - debug!("recursive: NS {} failed: {}", ns_addr, e); - ns_idx += 1; + debug!("recursive: NS query failed: {}", e); + ns_idx += 2; // both tried, skip past them continue; } }; @@ -228,6 +227,9 @@ pub(crate) fn resolve_iterative<'a>( { if let Some(zone) = referral_zone(&response) { current_zone = zone; + let mut cache_w = cache.write().unwrap(); + cache_ns_delegation(&mut cache_w, ¤t_zone, &response); + drop(cache_w); } let mut all_ns = extract_ns_from_records(&response.answers); if all_ns.is_empty() { @@ -296,6 +298,7 @@ pub(crate) fn resolve_iterative<'a>( { let mut cache_w = cache.write().unwrap(); + cache_ns_delegation(&mut cache_w, ¤t_zone, &response); cache_ds_from_authority(&mut cache_w, &response); } let mut new_ns_addrs = resolve_ns_addrs_from_glue(&response, &ns_names, cache); @@ -560,6 +563,23 @@ fn cache_ds_from_authority(cache: &mut DnsCache, response: &DnsPacket) { } } +/// Cache NS delegation records from a referral response so that +/// `find_closest_ns` can skip re-querying TLD servers on subsequent lookups. +fn cache_ns_delegation(cache: &mut DnsCache, zone: &str, response: &DnsPacket) { + let ns_records: Vec<_> = response + .authorities + .iter() + .filter(|r| matches!(r, DnsRecord::NS { .. })) + .cloned() + .collect(); + if ns_records.is_empty() { + return; + } + let mut pkt = make_glue_packet(); + pkt.answers = ns_records; + cache.insert(zone, QueryType::NS, &pkt); +} + fn make_glue_packet() -> DnsPacket { let mut pkt = DnsPacket::new(); pkt.header.response = true; @@ -587,6 +607,91 @@ async fn tcp_with_srtt( } } +/// Smart NS query: fire to two servers simultaneously when SRTT is unknown +/// (cold queries), or to the best server with SRTT-based hedge when known. +async fn send_query_hedged( + qname: &str, + qtype: QueryType, + servers: &[SocketAddr], + srtt: &RwLock, +) -> crate::Result { + if servers.is_empty() { + return Err("no nameserver available".into()); + } + if servers.len() == 1 { + return send_query(qname, qtype, servers[0], srtt).await; + } + + let primary = servers[0]; + let secondary = servers[1]; + let primary_known = srtt.read().unwrap().is_known(primary.ip()); + + if !primary_known { + // Cold: fire both simultaneously, first response wins + debug!( + "recursive: parallel query to {} and {} for {:?} {}", + primary, secondary, qtype, qname + ); + let fut_a = send_query(qname, qtype, primary, srtt); + let fut_b = send_query(qname, qtype, secondary, srtt); + tokio::pin!(fut_a); + tokio::pin!(fut_b); + + // First Ok wins. If one errors, wait for the other. + let mut a_done = false; + let mut b_done = false; + let mut a_err: Option = None; + let mut b_err: Option = None; + + loop { + tokio::select! { + r = &mut fut_a, if !a_done => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { a_done = true; a_err = Some(e); } + } + } + r = &mut fut_b, if !b_done => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { b_done = true; b_err = Some(e); } + } + } + } + match (a_err.take(), b_err.take()) { + (Some(e), Some(_)) => return Err(e), + (a, b) => { a_err = a; b_err = b; } + } + } + } else { + // Warm: send to best, hedge after SRTT × 3 if slow + let hedge_ms = srtt.read().unwrap().get(primary.ip()) * 3; + let hedge_delay = Duration::from_millis(hedge_ms.max(50)); + + let fut_a = send_query(qname, qtype, primary, srtt); + tokio::pin!(fut_a); + let delay = tokio::time::sleep(hedge_delay); + tokio::pin!(delay); + + tokio::select! { + r = &mut fut_a => return r, + _ = &mut delay => {} + } + + debug!( + "recursive: hedging {} -> {} after {}ms for {:?} {}", + primary, secondary, hedge_ms, qtype, qname + ); + let fut_b = send_query(qname, qtype, secondary, srtt); + tokio::pin!(fut_b); + + tokio::select! { + r = fut_a => r, + r = fut_b => r, + } + } +} + async fn send_query( qname: &str, qtype: QueryType, diff --git a/src/srtt.rs b/src/srtt.rs index f763a37..fe4df1e 100644 --- a/src/srtt.rs +++ b/src/srtt.rs @@ -45,6 +45,11 @@ impl SrttCache { } } + /// Whether we have observed RTT data for this IP. + pub fn is_known(&self, ip: IpAddr) -> bool { + self.entries.contains_key(&ip) + } + /// Apply time-based decay: each DECAY_AFTER_SECS period halves distance to INITIAL. fn decayed_srtt(entry: &SrttEntry) -> u64 { Self::decay_for_age(entry.srtt_ms, entry.updated_at.elapsed().as_secs()) diff --git a/src/wire.rs b/src/wire.rs new file mode 100644 index 0000000..6b68c3a --- /dev/null +++ b/src/wire.rs @@ -0,0 +1,1347 @@ +//! Wire-level DNS utilities: question extraction, TTL offset scanning, and patching. +//! +//! These operate directly on raw DNS wire bytes without full packet parsing, +//! enabling zero-copy forwarding and wire-level caching. + +use crate::question::QueryType; +use crate::Result; + +/// Metadata extracted from scanning a DNS response's wire bytes. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WireMeta { + /// Byte offsets of every TTL field in answer + authority + additional sections. + /// Each offset points to the first byte of a 4-byte big-endian TTL. + /// EDNS OPT pseudo-records are excluded (their "TTL" is flags, not a real TTL). + pub ttl_offsets: Vec, + /// How many of the offsets belong to the answer section (the first `answer_count` + /// entries). Used to extract min-TTL from answers only. + pub answer_count: usize, +} + +/// Extract the first question's (domain, query type) from raw DNS wire bytes. +/// +/// Reads only the 12-byte header + first question section. Returns the lowercased +/// domain name and query type without allocating a full `DnsPacket`. +pub fn extract_question(wire: &[u8]) -> Result<(String, QueryType)> { + if wire.len() < 12 { + return Err("wire too short for DNS header".into()); + } + let qdcount = u16::from_be_bytes([wire[4], wire[5]]); + if qdcount == 0 { + return Err("no questions in wire".into()); + } + + let mut pos = 12; + let mut domain = String::with_capacity(64); + read_wire_qname(wire, &mut pos, &mut domain)?; + + if pos + 4 > wire.len() { + return Err("wire truncated in question section".into()); + } + let qtype = u16::from_be_bytes([wire[pos], wire[pos + 1]]); + // skip QTYPE(2) + QCLASS(2) + + Ok((domain, QueryType::from_num(qtype))) +} + +/// Scan a DNS response's wire bytes and return metadata about TTL field locations. +/// +/// Walks the header, skips the question section, then for each resource record in +/// answer, authority, and additional sections, records the byte offset of the TTL +/// field. EDNS OPT records (type 41 with root name) are excluded. +pub fn scan_ttl_offsets(wire: &[u8]) -> Result { + if wire.len() < 12 { + return Err("wire too short for DNS header".into()); + } + + let qdcount = u16::from_be_bytes([wire[4], wire[5]]) as usize; + let ancount = u16::from_be_bytes([wire[6], wire[7]]) as usize; + let nscount = u16::from_be_bytes([wire[8], wire[9]]) as usize; + let arcount = u16::from_be_bytes([wire[10], wire[11]]) as usize; + + let mut pos = 12; + + // Skip question section + for _ in 0..qdcount { + skip_wire_name(wire, &mut pos)?; + if pos + 4 > wire.len() { + return Err("wire truncated in question section".into()); + } + pos += 4; // QTYPE(2) + QCLASS(2) + } + + let mut ttl_offsets = Vec::new(); + + // Process answer + authority + additional sections + let section_counts = [ancount, nscount, arcount]; + let mut answer_offset_count = 0; + + for (section_idx, &count) in section_counts.iter().enumerate() { + for _ in 0..count { + // Check if this is an OPT record: root name (0x00) + type 41 + let is_opt = pos < wire.len() + && wire[pos] == 0x00 + && pos + 3 <= wire.len() + && u16::from_be_bytes([wire[pos + 1], wire[pos + 2]]) == 41; + + // Skip name + skip_wire_name(wire, &mut pos)?; + + if pos + 10 > wire.len() { + return Err("wire truncated in resource record".into()); + } + + // TYPE(2) + CLASS(2) = 4 bytes before TTL + let ttl_offset = pos + 4; + + if !is_opt { + ttl_offsets.push(ttl_offset); + if section_idx == 0 { + answer_offset_count += 1; + } + } + + // Skip TYPE(2) + CLASS(2) + TTL(4) + RDLENGTH(2) = 10 bytes + let rdlength = u16::from_be_bytes([wire[pos + 8], wire[pos + 9]]) as usize; + pos += 10 + rdlength; + + if pos > wire.len() { + return Err("wire truncated in resource record RDATA".into()); + } + } + } + + Ok(WireMeta { + ttl_offsets, + answer_count: answer_offset_count, + }) +} + +/// Extract the minimum TTL from the answer section offsets of a wire response. +pub fn min_ttl_from_wire(wire: &[u8], meta: &WireMeta) -> Option { + meta.ttl_offsets + .iter() + .take(meta.answer_count) + .filter_map(|&off| { + if off + 4 <= wire.len() { + Some(u32::from_be_bytes([ + wire[off], + wire[off + 1], + wire[off + 2], + wire[off + 3], + ])) + } else { + None + } + }) + .min() +} + +/// Patch the transaction ID (bytes 0..2) in a DNS wire message. +pub fn patch_id(wire: &mut [u8], new_id: u16) { + let bytes = new_id.to_be_bytes(); + wire[0] = bytes[0]; + wire[1] = bytes[1]; +} + +/// Patch all TTL fields at the given offsets to `new_ttl`. +pub fn patch_ttls(wire: &mut [u8], offsets: &[usize], new_ttl: u32) { + let bytes = new_ttl.to_be_bytes(); + for &off in offsets { + wire[off] = bytes[0]; + wire[off + 1] = bytes[1]; + wire[off + 2] = bytes[2]; + wire[off + 3] = bytes[3]; + } +} + +/// Read a DNS name from wire bytes at `pos`, handling compression pointers. +/// Advances `pos` past the name as it appears at the current position +/// (compression pointer targets do NOT advance `pos`). +fn read_wire_qname(wire: &[u8], pos: &mut usize, out: &mut String) -> Result<()> { + let mut jumped = false; + let mut read_pos = *pos; + let mut jumps = 0; + let max_jumps = 20; + + loop { + if read_pos >= wire.len() { + return Err("wire truncated reading name".into()); + } + let len = wire[read_pos] as usize; + + // Compression pointer: top 2 bits set + if len & 0xC0 == 0xC0 { + if read_pos + 1 >= wire.len() { + return Err("wire truncated in compression pointer".into()); + } + if !jumped { + *pos = read_pos + 2; // advance past the pointer + } + let offset = ((len & 0x3F) << 8) | wire[read_pos + 1] as usize; + read_pos = offset; + jumped = true; + jumps += 1; + if jumps > max_jumps { + return Err("too many compression jumps".into()); + } + continue; + } + + if len == 0 { + if !jumped { + *pos = read_pos + 1; + } + break; + } + + if read_pos + 1 + len > wire.len() { + return Err("wire truncated in name label".into()); + } + + if !out.is_empty() { + out.push('.'); + } + for &b in &wire[read_pos + 1..read_pos + 1 + len] { + out.push(b.to_ascii_lowercase() as char); + } + read_pos += 1 + len; + } + + Ok(()) +} + +/// Skip a DNS name in wire bytes, advancing `pos` past it. +fn skip_wire_name(wire: &[u8], pos: &mut usize) -> Result<()> { + loop { + if *pos >= wire.len() { + return Err("wire truncated skipping name".into()); + } + let len = wire[*pos] as usize; + + if len & 0xC0 == 0xC0 { + *pos += 2; // compression pointer is 2 bytes + return Ok(()); + } + if len == 0 { + *pos += 1; + return Ok(()); + } + *pos += 1 + len; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::buffer::BytePacketBuffer; + use crate::cache::{DnsCache, DnssecStatus}; + use crate::header::ResultCode; + use crate::packet::{DnsPacket, EdnsOpt}; + use crate::question::DnsQuestion; + use crate::record::DnsRecord; + + // ── Helpers ────────────────────────────────────────────────────── + + /// Serialize a DnsPacket to wire bytes. + fn to_wire(pkt: &DnsPacket) -> Vec { + let mut buf = BytePacketBuffer::new(); + pkt.write(&mut buf).unwrap(); + buf.filled().to_vec() + } + + /// Build a minimal response with given answers. + fn response(id: u16, domain: &str, answers: Vec) -> DnsPacket { + let mut pkt = DnsPacket::new(); + pkt.header.id = id; + pkt.header.response = true; + pkt.header.recursion_desired = true; + pkt.header.recursion_available = true; + pkt.header.rescode = ResultCode::NOERROR; + pkt.questions + .push(DnsQuestion::new(domain.to_string(), QueryType::A)); + pkt.answers = answers; + pkt + } + + fn a_record(domain: &str, ip: &str, ttl: u32) -> DnsRecord { + DnsRecord::A { + domain: domain.into(), + addr: ip.parse().unwrap(), + ttl, + } + } + + fn aaaa_record(domain: &str, ip: &str, ttl: u32) -> DnsRecord { + DnsRecord::AAAA { + domain: domain.into(), + addr: ip.parse().unwrap(), + ttl, + } + } + + fn cname_record(domain: &str, host: &str, ttl: u32) -> DnsRecord { + DnsRecord::CNAME { + domain: domain.into(), + host: host.into(), + ttl, + } + } + + fn ns_record(domain: &str, host: &str, ttl: u32) -> DnsRecord { + DnsRecord::NS { + domain: domain.into(), + host: host.into(), + ttl, + } + } + + fn mx_record(domain: &str, host: &str, priority: u16, ttl: u32) -> DnsRecord { + DnsRecord::MX { + domain: domain.into(), + priority, + host: host.into(), + ttl, + } + } + + // ── A. TTL offset extraction ──────────────────────────────────── + + #[test] + fn scan_single_a_record() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 1); + assert_eq!(meta.answer_count, 1); + + let off = meta.ttl_offsets[0]; + let ttl = u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]); + assert_eq!(ttl, 300); + } + + #[test] + fn scan_multiple_a_records() { + let pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 600), + a_record("example.com", "9.10.11.12", 120), + ], + ); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 3); + assert_eq!(meta.answer_count, 3); + + let ttls: Vec = meta + .ttl_offsets + .iter() + .map(|&off| u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]])) + .collect(); + assert_eq!(ttls, vec![300, 600, 120]); + } + + #[test] + fn scan_mixed_sections() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.authorities + .push(ns_record("example.com", "ns1.example.com", 3600)); + pkt.authorities + .push(ns_record("example.com", "ns2.example.com", 3600)); + pkt.resources + .push(a_record("ns1.example.com", "10.0.0.1", 1800)); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 4); // 1 answer + 2 authority + 1 additional + assert_eq!(meta.answer_count, 1); + } + + #[test] + fn scan_cname_chain() { + let pkt = response( + 0x1234, + "www.example.com", + vec![ + cname_record("www.example.com", "example.com", 300), + a_record("example.com", "1.2.3.4", 600), + ], + ); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 2); + assert_eq!(meta.answer_count, 2); + + let ttls: Vec = meta + .ttl_offsets + .iter() + .map(|&off| u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]])) + .collect(); + assert_eq!(ttls, vec![300, 600]); + } + + #[test] + fn scan_compressed_names() { + // Build a packet with name compression (the serializer uses compression + // for repeated domain names). Two A records for the same domain will + // have the second name compressed as a pointer. + let pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 600), + ], + ); + let wire = to_wire(&pkt); + + // Verify compression is actually present (second name should be a pointer) + // The first answer's name is at some offset, and the second should use 0xC0xx + let meta = scan_ttl_offsets(&wire).unwrap(); + assert_eq!(meta.ttl_offsets.len(), 2); + + let ttls: Vec = meta + .ttl_offsets + .iter() + .map(|&off| u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]])) + .collect(); + assert_eq!(ttls, vec![300, 600]); + } + + #[test] + fn scan_edns_opt_excluded() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: false, + options: vec![], + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + // Only the A record's TTL, not the OPT pseudo-record's "TTL" + assert_eq!(meta.ttl_offsets.len(), 1); + assert_eq!(meta.answer_count, 1); + } + + #[test] + fn scan_rrsig_only_wire_ttl() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.answers.push(DnsRecord::RRSIG { + domain: "example.com".into(), + type_covered: 1, // A + algorithm: 13, + labels: 2, + original_ttl: 9999, // must NOT appear in offsets + expiration: 1700000000, + inception: 1690000000, + key_tag: 12345, + signer_name: "example.com".into(), + signature: vec![0x01, 0x02, 0x03, 0x04], + ttl: 300, + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + // 2 TTL offsets: A record + RRSIG wire TTL + assert_eq!(meta.ttl_offsets.len(), 2); + assert_eq!(meta.answer_count, 2); + + // Both wire TTLs should be 300, not 9999 + for &off in &meta.ttl_offsets { + let ttl = + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]); + assert_eq!(ttl, 300); + } + + // Verify that 9999 (original_ttl) exists somewhere in the wire but is NOT in offsets + let original_ttl_bytes = 9999u32.to_be_bytes(); + let found_at = wire + .windows(4) + .position(|w| w == original_ttl_bytes) + .expect("original_ttl should be in wire"); + assert!( + !meta.ttl_offsets.contains(&found_at), + "original_ttl offset must not be in ttl_offsets" + ); + } + + #[test] + fn scan_nsec_variable_rdata() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.authorities.push(DnsRecord::NSEC { + domain: "example.com".into(), + next_domain: "z.example.com".into(), + type_bitmap: vec![0x00, 0x06, 0x40, 0x01, 0x00, 0x00, 0x00, 0x03], + ttl: 1800, + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 2); // A + NSEC + assert_eq!(meta.answer_count, 1); + + let nsec_ttl_off = meta.ttl_offsets[1]; + let ttl = u32::from_be_bytes([ + wire[nsec_ttl_off], + wire[nsec_ttl_off + 1], + wire[nsec_ttl_off + 2], + wire[nsec_ttl_off + 3], + ]); + assert_eq!(ttl, 1800); + } + + #[test] + fn scan_empty_response() { + let pkt = response(0x1234, "nxdomain.example.com", vec![]); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert!(meta.ttl_offsets.is_empty()); + assert_eq!(meta.answer_count, 0); + } + + #[test] + fn scan_unknown_record_type() { + // Manually build a response with an unknown type (99) using raw wire bytes + let mut pkt = response(0x1234, "example.com", vec![]); + pkt.answers.push(DnsRecord::UNKNOWN { + domain: "example.com".into(), + qtype: 99, + data: vec![0xDE, 0xAD, 0xBE, 0xEF], + ttl: 500, + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 1); + let off = meta.ttl_offsets[0]; + let ttl = u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]); + assert_eq!(ttl, 500); + } + + #[test] + fn scan_truncated_wire_returns_error() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let wire = to_wire(&pkt); + // Truncate mid-record + let truncated = &wire[..wire.len() - 2]; + assert!(scan_ttl_offsets(truncated).is_err()); + } + + #[test] + fn scan_too_short_for_header() { + assert!(scan_ttl_offsets(&[0u8; 5]).is_err()); + } + + #[test] + fn scan_query_packet_no_offsets() { + let pkt = DnsPacket::query(0x1234, "example.com", QueryType::A); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + assert!(meta.ttl_offsets.is_empty()); + } + + // ── B. TTL patching ───────────────────────────────────────────── + + #[test] + fn patch_ttl_single() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, 120); + + let off = meta.ttl_offsets[0]; + assert_eq!( + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]), + 120 + ); + } + + #[test] + fn patch_ttl_multiple() { + let pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 600), + a_record("example.com", "9.10.11.12", 900), + ], + ); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, 42); + + for &off in &meta.ttl_offsets { + assert_eq!( + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]), + 42 + ); + } + } + + #[test] + fn patch_ttl_preserves_other_bytes() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let original = to_wire(&pkt); + let mut patched = original.clone(); + let meta = scan_ttl_offsets(&patched).unwrap(); + + patch_ttls(&mut patched, &meta.ttl_offsets, 120); + + // Every byte outside TTL offsets should be identical + for (i, (&orig, &patc)) in original.iter().zip(patched.iter()).enumerate() { + let in_ttl = meta + .ttl_offsets + .iter() + .any(|&off| i >= off && i < off + 4); + if !in_ttl { + assert_eq!( + orig, patc, + "byte {} changed (outside TTL): orig={:#04x}, patched={:#04x}", + i, orig, patc + ); + } + } + } + + #[test] + fn patch_ttl_zero() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, 0); + + let off = meta.ttl_offsets[0]; + assert_eq!(&wire[off..off + 4], &[0, 0, 0, 0]); + } + + #[test] + fn patch_ttl_max_u32() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, u32::MAX); + + let off = meta.ttl_offsets[0]; + assert_eq!(&wire[off..off + 4], &[0xFF, 0xFF, 0xFF, 0xFF]); + } + + #[test] + fn patch_ttl_edns_untouched() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: true, + options: vec![], + }); + let original = to_wire(&pkt); + let mut patched = original.clone(); + let meta = scan_ttl_offsets(&patched).unwrap(); + + patch_ttls(&mut patched, &meta.ttl_offsets, 42); + + // Only the A record's TTL bytes should differ; everything else + // (including the OPT "TTL" containing the DO bit) must be unchanged. + for (i, (&orig, &patc)) in original.iter().zip(patched.iter()).enumerate() { + let in_ttl = meta + .ttl_offsets + .iter() + .any(|&off| i >= off && i < off + 4); + if !in_ttl { + assert_eq!( + orig, patc, + "byte {} changed (outside TTL): orig={:#04x}, patched={:#04x}", + i, orig, patc + ); + } + } + } + + // ── C. ID patching ────────────────────────────────────────────── + + #[test] + fn patch_id_basic() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut wire = to_wire(&pkt); + + patch_id(&mut wire, 0xABCD); + assert_eq!(&wire[0..2], &[0xAB, 0xCD]); + } + + #[test] + fn patch_id_preserves_flags() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let original = to_wire(&pkt); + let mut patched = original.clone(); + + patch_id(&mut patched, 0x9999); + + // Bytes 2..12 (flags + counts) unchanged + assert_eq!(&original[2..12], &patched[2..12]); + } + + #[test] + fn patch_id_zero() { + let pkt = response(0xFFFF, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let mut wire = to_wire(&pkt); + + patch_id(&mut wire, 0x0000); + assert_eq!(&wire[0..2], &[0x00, 0x00]); + } + + // ── D. extract_question ───────────────────────────────────────── + + #[test] + fn extract_question_basic() { + let pkt = DnsPacket::query(0x1234, "Example.COM", QueryType::A); + let wire = to_wire(&pkt); + let (domain, qtype) = extract_question(&wire).unwrap(); + + assert_eq!(domain, "example.com"); // lowercased + assert_eq!(qtype, QueryType::A); + } + + #[test] + fn extract_question_aaaa() { + let pkt = DnsPacket::query(0x1234, "rust-lang.org", QueryType::AAAA); + let wire = to_wire(&pkt); + let (domain, qtype) = extract_question(&wire).unwrap(); + + assert_eq!(domain, "rust-lang.org"); + assert_eq!(qtype, QueryType::AAAA); + } + + #[test] + fn extract_question_too_short() { + assert!(extract_question(&[0u8; 5]).is_err()); + } + + #[test] + fn extract_question_no_questions() { + let mut wire = to_wire(&DnsPacket::query(0x1234, "example.com", QueryType::A)); + // Zero out QDCOUNT (bytes 4-5) + wire[4] = 0; + wire[5] = 0; + assert!(extract_question(&wire).is_err()); + } + + // ── E. min_ttl_from_wire ──────────────────────────────────────── + + #[test] + fn min_ttl_answers_only() { + let mut pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 60), + ], + ); + pkt.authorities + .push(ns_record("example.com", "ns1.example.com", 10)); // lower but in authority, not answer + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(min_ttl_from_wire(&wire, &meta), Some(60)); // from answers only + } + + #[test] + fn min_ttl_empty_answers() { + let pkt = response(0x1234, "example.com", vec![]); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + assert_eq!(min_ttl_from_wire(&wire, &meta), None); + } + + // ── F. Round-trip fidelity ────────────────────────────────────── + // + // These verify that wire bytes → scan → patch → parse produces the + // same semantic content as the original packet. They test the full + // integration path that the wire-level cache will use. + + #[test] + fn round_trip_simple_a() { + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + let mut patched = wire.clone(); + patch_id(&mut patched, 0xABCD); + patch_ttls(&mut patched, &meta.ttl_offsets, 120); + + // Parse the patched wire + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.header.id, 0xABCD); + assert_eq!(parsed.answers.len(), 1); + match &parsed.answers[0] { + DnsRecord::A { domain, addr, ttl } => { + assert_eq!(domain, "example.com"); + assert_eq!(*addr, "1.2.3.4".parse::().unwrap()); + assert_eq!(*ttl, 120); + } + other => panic!("expected A record, got {:?}", other), + } + } + + #[test] + fn round_trip_edns_survives() { + let mut pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: true, + options: vec![], + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 42); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + let edns = parsed.edns.as_ref().expect("EDNS should survive"); + assert_eq!(edns.udp_payload_size, 1232); + assert!(edns.do_bit); + } + + #[test] + fn round_trip_dnssec_full() { + let mut pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + DnsRecord::RRSIG { + domain: "example.com".into(), + type_covered: 1, + algorithm: 13, + labels: 2, + original_ttl: 300, + expiration: 1700000000, + inception: 1690000000, + key_tag: 12345, + signer_name: "example.com".into(), + signature: vec![1, 2, 3, 4, 5, 6, 7, 8], + ttl: 300, + }, + ], + ); + pkt.authorities.push(DnsRecord::NSEC { + domain: "example.com".into(), + next_domain: "z.example.com".into(), + type_bitmap: vec![0x00, 0x06, 0x40, 0x01, 0x00, 0x00, 0x00, 0x03], + ttl: 300, + }); + pkt.resources.push(DnsRecord::DNSKEY { + domain: "example.com".into(), + flags: 257, + protocol: 3, + algorithm: 13, + public_key: vec![10, 20, 30, 40], + ttl: 3600, + }); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: true, + options: vec![], + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + // 4 TTL offsets: A + RRSIG (answers) + NSEC (authority) + DNSKEY (additional) + // OPT excluded + assert_eq!(meta.ttl_offsets.len(), 4); + assert_eq!(meta.answer_count, 2); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 42); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.answers.len(), 2); + assert_eq!(parsed.authorities.len(), 1); + assert_eq!(parsed.resources.len(), 1); + assert!(parsed.edns.is_some()); + + // All TTLs should be 42 now + for ans in &parsed.answers { + assert_eq!(ans.ttl(), 42); + } + for auth in &parsed.authorities { + assert_eq!(auth.ttl(), 42); + } + for res in &parsed.resources { + assert_eq!(res.ttl(), 42); + } + + // RRSIG original_ttl must be preserved (it's inside RDATA, not a wire TTL) + match &parsed.answers[1] { + DnsRecord::RRSIG { original_ttl, .. } => assert_eq!(*original_ttl, 300), + other => panic!("expected RRSIG, got {:?}", other), + } + } + + #[test] + fn round_trip_nxdomain_soa() { + let mut pkt = DnsPacket::new(); + pkt.header.id = 0x5678; + pkt.header.response = true; + pkt.header.rescode = ResultCode::NXDOMAIN; + pkt.questions + .push(DnsQuestion::new("missing.example.com".into(), QueryType::A)); + // SOA in authority (we don't have a SOA variant, so use NS as proxy for offset testing) + pkt.authorities + .push(ns_record("example.com", "ns1.example.com", 900)); + + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 1); + assert_eq!(meta.answer_count, 0); // no answers, only authority + + let mut patched = wire.clone(); + patch_id(&mut patched, 0x9999); + patch_ttls(&mut patched, &meta.ttl_offsets, 60); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.header.id, 0x9999); + assert_eq!(parsed.header.rescode, ResultCode::NXDOMAIN); + assert_eq!(parsed.authorities[0].ttl(), 60); + } + + #[test] + fn round_trip_mx_record() { + let pkt = response( + 0x1234, + "example.com", + vec![mx_record("example.com", "mail.example.com", 10, 3600)], + ); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 100); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + match &parsed.answers[0] { + DnsRecord::MX { + domain, + priority, + host, + ttl, + } => { + assert_eq!(domain, "example.com"); + assert_eq!(*priority, 10); + assert_eq!(host, "mail.example.com"); + assert_eq!(*ttl, 100); + } + other => panic!("expected MX, got {:?}", other), + } + } + + #[test] + fn round_trip_many_records() { + let answers: Vec = (0..20) + .map(|i| a_record("example.com", &format!("10.0.0.{}", i), 300 + i * 10)) + .collect(); + let pkt = response(0x1234, "example.com", answers); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 20); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 1); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.answers.len(), 20); + for ans in &parsed.answers { + assert_eq!(ans.ttl(), 1); + } + } + + // ── G. Edge cases ─────────────────────────────────────────────── + + #[test] + fn scan_rejects_empty_wire() { + assert!(scan_ttl_offsets(&[]).is_err()); + } + + #[test] + fn extract_question_rejects_empty_wire() { + assert!(extract_question(&[]).is_err()); + } + + // ── H. Cache behavior tests ───────────────────────────────────── + // + // These test existing DnsCache behavior that must be preserved after + // the wire-level migration. They use the current parsed-packet API + // and serve as a regression suite. + + #[test] + fn cache_insert_lookup_hit() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + + let (result, status) = cache + .lookup_with_status("example.com", QueryType::A) + .expect("should hit"); + assert_eq!(result.answers.len(), 1); + assert_eq!(status, DnssecStatus::Indeterminate); + } + + #[test] + fn cache_lookup_adjusts_ttl() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + + let (result, _) = cache.lookup_with_status("example.com", QueryType::A).unwrap(); + // TTL should be <= 300 (at most original, reduced by elapsed time) + assert!(result.answers[0].ttl() <= 300); + assert!(result.answers[0].ttl() > 0); + } + + #[test] + fn cache_miss_wrong_domain() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + + assert!(cache + .lookup_with_status("other.com", QueryType::A) + .is_none()); + } + + #[test] + fn cache_miss_wrong_qtype() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + + assert!(cache + .lookup_with_status("example.com", QueryType::AAAA) + .is_none()); + } + + #[test] + fn cache_overwrite_no_double_count() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt1 = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt2 = response(0x5678, "example.com", vec![a_record("example.com", "5.6.7.8", 600)]); + + cache.insert("example.com", QueryType::A, &pkt1); + assert_eq!(cache.len(), 1); + + cache.insert("example.com", QueryType::A, &pkt2); + assert_eq!(cache.len(), 1); // no double count + + let (result, _) = cache.lookup_with_status("example.com", QueryType::A).unwrap(); + match &result.answers[0] { + DnsRecord::A { addr, .. } => { + assert_eq!(*addr, "5.6.7.8".parse::().unwrap()) + } + _ => panic!("expected A record"), + } + } + + #[test] + fn cache_ttl_clamped_min() { + let mut cache = DnsCache::new(100, 60, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 5)]); + cache.insert("example.com", QueryType::A, &pkt); + + let (remaining, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); + assert_eq!(total, 60); // clamped up from 5 + assert!(remaining <= 60); + } + + #[test] + fn cache_ttl_clamped_max() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = + response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 999999)]); + cache.insert("example.com", QueryType::A, &pkt); + + let (_, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); + assert_eq!(total, 3600); // clamped down from 999999 + } + + #[test] + fn cache_len_empty_clear() { + let mut cache = DnsCache::new(100, 1, 3600); + assert!(cache.is_empty()); + assert_eq!(cache.len(), 0); + + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + assert!(!cache.is_empty()); + assert_eq!(cache.len(), 1); + + cache.clear(); + assert!(cache.is_empty()); + assert_eq!(cache.len(), 0); + assert!(cache.lookup("example.com", QueryType::A).is_none()); + } + + #[test] + fn cache_remove_domain() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt_a = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt_aaaa = response( + 0x5678, + "example.com", + vec![aaaa_record("example.com", "::1", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt_a); + cache.insert("example.com", QueryType::AAAA, &pkt_aaaa); + assert_eq!(cache.len(), 2); + + cache.remove("example.com"); + assert_eq!(cache.len(), 0); + assert!(cache.lookup("example.com", QueryType::A).is_none()); + assert!(cache.lookup("example.com", QueryType::AAAA).is_none()); + } + + #[test] + fn cache_list_entries() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt_a = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt_b = response(0x5678, "test.org", vec![a_record("test.org", "5.6.7.8", 600)]); + cache.insert("example.com", QueryType::A, &pkt_a); + cache.insert("test.org", QueryType::A, &pkt_b); + + let list = cache.list(); + assert_eq!(list.len(), 2); + let domains: Vec<&str> = list.iter().map(|e| e.domain.as_str()).collect(); + assert!(domains.contains(&"example.com")); + assert!(domains.contains(&"test.org")); + } + + #[test] + fn cache_heap_bytes_grows() { + let mut cache = DnsCache::new(100, 1, 3600); + let empty = cache.heap_bytes(); + + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + assert!(cache.heap_bytes() > empty); + } + + #[test] + fn cache_needs_warm_behavior() { + let mut cache = DnsCache::new(100, 1, 3600); + + // Missing → needs warm + assert!(cache.needs_warm("example.com")); + + // Both A and AAAA cached → does not need warm + let pkt_a = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + let pkt_aaaa = response( + 0x5678, + "example.com", + vec![aaaa_record("example.com", "::1", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt_a); + cache.insert("example.com", QueryType::AAAA, &pkt_aaaa); + assert!(!cache.needs_warm("example.com")); + + // Only A cached → needs warm (AAAA missing) + cache.remove("example.com"); + cache.insert("example.com", QueryType::A, &pkt_a); + assert!(cache.needs_warm("example.com")); + } + + #[test] + fn cache_ttl_remaining_api() { + let mut cache = DnsCache::new(100, 60, 3600); + assert!(cache.ttl_remaining("missing.com", QueryType::A).is_none()); + + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert("example.com", QueryType::A, &pkt); + let (remaining, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); + assert_eq!(total, 300); + assert!(remaining > 0); + assert!(remaining <= 300); + } + + #[test] + fn cache_dnssec_status_preserved() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response(0x1234, "example.com", vec![a_record("example.com", "1.2.3.4", 300)]); + cache.insert_with_status("example.com", QueryType::A, &pkt, DnssecStatus::Secure); + + let (_, status) = cache + .lookup_with_status("example.com", QueryType::A) + .unwrap(); + assert_eq!(status, DnssecStatus::Secure); + } + + // ── I. Memory footprint baseline ────────────────────────────── + // + // Measures the current parsed-packet cache memory vs what wire-level + // storage would cost for the same entries. This is a baseline — after + // migration, re-run to verify improvement. + + #[test] + fn memory_footprint_baseline() { + let mut cache = DnsCache::new(1000, 1, 3600); + + // Simulate a realistic cache: 50 domains, mix of record types + let domains: Vec = (0..50).map(|i| format!("domain{}.example.com", i)).collect(); + + let mut total_wire_bytes = 0usize; + let mut total_wire_meta_bytes = 0usize; + + for (i, domain) in domains.iter().enumerate() { + // A record + let pkt_a = response( + i as u16, + domain, + vec![ + a_record(domain, &format!("10.0.{}.1", i % 256), 300), + a_record(domain, &format!("10.0.{}.2", i % 256), 300), + ], + ); + cache.insert(domain, QueryType::A, &pkt_a); + + let wire_a = to_wire(&pkt_a); + let meta_a = scan_ttl_offsets(&wire_a).unwrap(); + total_wire_bytes += wire_a.len(); + total_wire_meta_bytes += meta_a.ttl_offsets.len() * std::mem::size_of::(); + + // AAAA record for half of them + if i % 2 == 0 { + let pkt_aaaa = response( + (i + 1000) as u16, + domain, + vec![aaaa_record(domain, &format!("2001:db8::{:x}", i), 600)], + ); + cache.insert(domain, QueryType::AAAA, &pkt_aaaa); + + let wire_aaaa = to_wire(&pkt_aaaa); + let meta_aaaa = scan_ttl_offsets(&wire_aaaa).unwrap(); + total_wire_bytes += wire_aaaa.len(); + total_wire_meta_bytes += + meta_aaaa.ttl_offsets.len() * std::mem::size_of::(); + } + } + + // Compare only the variable per-entry data (what actually differs + // between parsed and wire storage). HashMap overhead, domain keys, + // Instant, Duration, DnssecStatus are identical in both approaches. + let mut parsed_data_bytes = 0usize; + // Re-insert and measure just packet.heap_bytes() per entry + { + let mut cache2 = DnsCache::new(1000, 1, 3600); + for (i, domain) in domains.iter().enumerate() { + let pkt_a = response( + i as u16, + domain, + vec![ + a_record(domain, &format!("10.0.{}.1", i % 256), 300), + a_record(domain, &format!("10.0.{}.2", i % 256), 300), + ], + ); + parsed_data_bytes += pkt_a.heap_bytes(); + cache2.insert(domain, QueryType::A, &pkt_a); + + if i % 2 == 0 { + let pkt_aaaa = response( + (i + 1000) as u16, + domain, + vec![aaaa_record(domain, &format!("2001:db8::{:x}", i), 600)], + ); + parsed_data_bytes += pkt_aaaa.heap_bytes(); + cache2.insert(domain, QueryType::AAAA, &pkt_aaaa); + } + } + } + + let wire_total = total_wire_bytes + total_wire_meta_bytes; + let entry_count = cache.len(); + + // Also measure the struct size difference per entry + let parsed_struct = std::mem::size_of::(); + let wire_struct = std::mem::size_of::>() + std::mem::size_of::>() + std::mem::size_of::(); // wire + offsets + answer_count + + println!(); + println!("=== Cache Memory Footprint Baseline ({} entries) ===", entry_count); + println!(); + println!("Variable data (heap, per-entry payload):"); + println!(" Parsed (packet.heap_bytes): {} bytes ({:.1}/entry)", parsed_data_bytes, parsed_data_bytes as f64 / entry_count as f64); + println!(" Wire (bytes + TTL offsets): {} bytes ({:.1}/entry)", wire_total, wire_total as f64 / entry_count as f64); + println!(" Ratio: {:.1}x smaller with wire", parsed_data_bytes as f64 / wire_total as f64); + println!(); + println!("Struct overhead (stack, per entry):"); + println!(" DnsPacket: {} bytes", parsed_struct); + println!(" Wire (Vec+Vec+usize): {} bytes", wire_struct); + println!(); + println!("Total per-entry (struct + avg heap):"); + let parsed_total_per = parsed_struct as f64 + parsed_data_bytes as f64 / entry_count as f64; + let wire_total_per = wire_struct as f64 + wire_total as f64 / entry_count as f64; + println!(" Parsed: {:.0} bytes", parsed_total_per); + println!(" Wire: {:.0} bytes", wire_total_per); + println!(" Ratio: {:.1}x smaller with wire", parsed_total_per / wire_total_per); + println!(); + + // Assertions + assert!( + wire_total < parsed_data_bytes, + "wire data ({wire_total}) should be smaller than parsed data ({parsed_data_bytes})" + ); + } + + #[test] + fn cache_max_entries_cap() { + let mut cache = DnsCache::new(2, 1, 3600); + for i in 0..3 { + let domain = format!("test{}.com", i); + let pkt = response( + i as u16, + &domain, + vec![a_record(&domain, &format!("1.2.3.{}", i), 3600)], + ); + cache.insert(&domain, QueryType::A, &pkt); + } + // Should not exceed max (third insert is silently dropped or evicts) + assert!(cache.len() <= 2); + } +} diff --git a/tests/integration.sh b/tests/integration.sh index 92da878..c70ec59 100755 --- a/tests/integration.sh +++ b/tests/integration.sh @@ -53,7 +53,17 @@ CONF echo "Starting Numa on :$PORT ($SUITE_NAME)..." RUST_LOG=info "$BINARY" "$CONFIG" > "$LOG" 2>&1 & NUMA_PID=$! - sleep 4 + sleep 2 + + # Wait for blocklist to load (if blocking is enabled in this suite) + if echo "$SUITE_CONFIG" | grep -q 'enabled = true'; then + for i in $(seq 1 20); do + LOADED=$(curl -sf http://127.0.0.1:$API_PORT/blocking/stats 2>/dev/null \ + | grep -o '"domains_loaded":[0-9]*' | cut -d: -f2) + if [ "${LOADED:-0}" -gt 0 ]; then break; fi + sleep 1 + done + fi if ! kill -0 "$NUMA_PID" 2>/dev/null; then echo "Failed to start Numa:"