diff --git a/Cargo.lock b/Cargo.lock index c7cd38b..c0f7692 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -82,6 +82,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + [[package]] name = "arc-swap" version = "1.9.0" @@ -142,6 +148,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -410,6 +427,21 @@ dependencies = [ "itertools", ] +[[package]] +name = "critical-section" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -493,6 +525,18 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "enum-as-inner" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "env_filter" version = "1.0.1" @@ -554,6 +598,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -679,11 +729,24 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "r-efi", + "r-efi 5.3.0", "wasip2", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + [[package]] name = "h2" version = "0.4.13" @@ -714,12 +777,82 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + [[package]] name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hickory-proto" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8a6fe56c0038198998a6f217ca4e7ef3a5e51f46163bd6dd60b5c71ca6c6502" +dependencies = [ + "async-trait", + "bytes", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-channel", + "futures-io", + "futures-util", + "h2", + "http", + "idna", + "ipnet", + "once_cell", + "rand", + "ring", + "rustls", + "thiserror", + "tinyvec", + "tokio", + "tokio-rustls", + "tracing", + "url", + "webpki-roots 0.26.11", +] + +[[package]] +name = "hickory-resolver" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc62a9a99b0bfb44d2ab95a7208ac952d31060efc16241c87eaf36406fecf87a" +dependencies = [ + "cfg-if", + "futures-util", + "hickory-proto", + "ipconfig", + "moka", + "once_cell", + "parking_lot", + "rand", + "resolv-conf", + "rustls", + "smallvec", + "thiserror", + "tokio", + "tokio-rustls", + "tracing", + "webpki-roots 0.26.11", +] + [[package]] name = "http" version = "1.4.0" @@ -802,7 +935,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] @@ -909,6 +1042,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "idna" version = "1.1.0" @@ -937,7 +1076,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "ipconfig" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d40460c0ce33d6ce4b0630ad68ff63d6661961c48b6dba35e5a4d81cfb48222" +dependencies = [ + "socket2", + "widestring", + "windows-registry", + "windows-result", + "windows-sys 0.61.2", ] [[package]] @@ -1029,6 +1183,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "libc" version = "0.2.183" @@ -1041,6 +1201,15 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.29" @@ -1098,6 +1267,23 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "moka" +version = "0.12.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046" +dependencies = [ + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "parking_lot", + "portable-atomic", + "smallvec", + "tagptr", + "uuid", +] + [[package]] name = "nom" version = "7.1.3" @@ -1151,6 +1337,8 @@ dependencies = [ "criterion", "env_logger", "futures", + "hickory-proto", + "hickory-resolver", "http", "http-body-util", "hyper", @@ -1170,7 +1358,7 @@ dependencies = [ "tokio-rustls", "toml", "tower", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] @@ -1187,6 +1375,10 @@ name = "once_cell" version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +dependencies = [ + "critical-section", + "portable-atomic", +] [[package]] name = "once_cell_polyfill" @@ -1210,6 +1402,29 @@ dependencies = [ "winapi", ] +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "pem" version = "3.0.6" @@ -1305,6 +1520,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.106" @@ -1390,6 +1615,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + [[package]] name = "rand" version = "0.9.2" @@ -1453,6 +1684,15 @@ dependencies = [ "yasna", ] +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.12.3" @@ -1518,9 +1758,15 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots", + "webpki-roots 1.0.6", ] +[[package]] +name = "resolv-conf" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e061d1b48cb8d38042de4ae0a7a6401009d6143dc80d2e2d6f31f0bdd6470c7" + [[package]] name = "ring" version = "0.17.14" @@ -1618,6 +1864,18 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + [[package]] name = "serde" version = "1.0.228" @@ -1780,6 +2038,12 @@ dependencies = [ "syn", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "thiserror" version = "2.0.18" @@ -2038,6 +2302,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "untrusted" version = "0.9.0" @@ -2068,6 +2338,17 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ac8b6f42ead25368cf5b098aeb3dc8a1a2c05a3eee8a9a1a68c640edbfc79d9" +dependencies = [ + "getrandom 0.4.2", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "walkdir" version = "2.5.0" @@ -2102,6 +2383,15 @@ dependencies = [ "wit-bindgen", ] +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasm-bindgen" version = "0.2.115" @@ -2157,6 +2447,40 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + [[package]] name = "web-sys" version = "0.3.92" @@ -2177,6 +2501,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.6", +] + [[package]] name = "webpki-roots" version = "1.0.6" @@ -2186,6 +2519,12 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "widestring" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72069c3113ab32ab29e5584db3c6ec55d416895e60715417b5b883a357c3e471" + [[package]] name = "winapi" version = "0.3.9" @@ -2223,6 +2562,35 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -2390,6 +2758,88 @@ name = "wit-bindgen" version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] [[package]] name = "writeable" diff --git a/Cargo.toml b/Cargo.toml index c5d5e1d..d7f6f9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,8 @@ webpki-roots = "1" criterion = { version = "0.8", features = ["html_reports"] } tower = { version = "0.5", features = ["util"] } http = "1" +hickory-resolver = { version = "0.25", features = ["https-ring", "webpki-roots"] } +hickory-proto = "0.25" [[bench]] name = "hot_path" @@ -49,3 +51,7 @@ harness = false [[bench]] name = "dnssec" harness = false + +[[bench]] +name = "recursive_compare" +harness = false diff --git a/benches/numa-bench-recursive.toml b/benches/numa-bench-recursive.toml new file mode 100644 index 0000000..055d75a --- /dev/null +++ b/benches/numa-bench-recursive.toml @@ -0,0 +1,30 @@ +[server] +bind_addr = "127.0.0.1:5454" +api_port = 5381 +api_bind_addr = "127.0.0.1" +data_dir = "/tmp/numa-bench" + +[upstream] +mode = "recursive" +timeout_ms = 10000 + +[cache] +min_ttl = 60 +max_ttl = 3600 + +[blocking] +enabled = false + +[proxy] +port = 8080 +tls_port = 8443 + +[dot] +enabled = true +port = 8530 + +[mobile] +enabled = false + +[lan] +enabled = false diff --git a/benches/numa-bench.toml b/benches/numa-bench.toml new file mode 100644 index 0000000..6124840 --- /dev/null +++ b/benches/numa-bench.toml @@ -0,0 +1,31 @@ +[server] +bind_addr = "127.0.0.1:5454" +api_port = 5381 +api_bind_addr = "127.0.0.1" +data_dir = "/tmp/numa-bench" + +[upstream] +mode = "forward" +address = ["https://9.9.9.9/dns-query"] +timeout_ms = 10000 + +[cache] +min_ttl = 60 +max_ttl = 3600 + +[blocking] +enabled = false + +[proxy] +port = 8080 +tls_port = 8443 + +[dot] +enabled = true +port = 8530 + +[mobile] +enabled = false + +[lan] +enabled = false diff --git a/benches/recursive_compare.rs b/benches/recursive_compare.rs new file mode 100644 index 0000000..f1a59d2 --- /dev/null +++ b/benches/recursive_compare.rs @@ -0,0 +1,1095 @@ +//! DNS forwarding benchmark suite. +//! +//! Modes: +//! (default) Numa server (UDP) vs Hickory library (DoH) — the original benchmark +//! --diag Hickory connection reuse diagnostic (20 queries) +//! --diag-clients Per-query reqwest vs Hickory timing (20 queries) +//! --direct Library-to-library: Numa forward_query_raw vs Hickory resolver.lookup +//! --hedge-5x Hedging: single vs hedge-same vs hedge-dual vs Hickory (5 iterations) +//! --vs-unbound Server-to-server: Numa vs Unbound (plain UDP, caching) +//! --vs-unbound-cold Cold: Numa vs Unbound (unique subdomains, no cache hits) +//! --vs-nextdns Server-to-cloud: Numa (local cache) vs NextDNS (remote, 45.90.28.0) +//! --vs-dot DoT server: Numa vs Unbound +//! --vs-doh-servers DoH server: Numa vs Unbound (DoT upstream) +//! +//! Setup: +//! 1. Start a bench Numa instance: cargo run -- benches/numa-bench.toml +//! 2. Run: cargo bench --bench recursive_compare [-- --mode] + +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +const DOH_UPSTREAM: &str = "https://9.9.9.9/dns-query"; +const NUMA_BENCH: &str = "127.0.0.1:5454"; +const NUMA_API: u16 = 5381; + +const DOMAINS: &[&str] = &[ + "example.com", + "rust-lang.org", + "kernel.org", + "signal.org", + "archlinux.org", + "openbsd.org", + "git-scm.com", + "sqlite.org", + "wireguard.com", + "mozilla.org", + "cloudflare.com", + "google.com", + "github.com", + "stackoverflow.com", + "wikipedia.org", + "reddit.com", + "amazon.com", + "apple.com", + "microsoft.com", + "facebook.com", + "twitter.com", + "linkedin.com", + "netflix.com", + "spotify.com", + "discord.com", + "twitch.tv", + "youtube.com", + "instagram.com", + "whatsapp.com", + "telegram.org", + "debian.org", + "ubuntu.com", + "fedoraproject.org", + "nixos.org", + "gentoo.org", + "freebsd.org", + "netbsd.org", + "dragonflybsd.org", + "illumos.org", + "haiku-os.org", + "python.org", + "golang.org", + "nodejs.org", + "ruby-lang.org", + "php.net", + "swift.org", + "kotlinlang.org", + "scala-lang.org", + "haskell.org", + "elixir-lang.org", + "erlang.org", + "clojure.org", + "julialang.org", + "ziglang.org", + "nim-lang.org", + "dlang.org", + "vlang.io", + "crystal-lang.org", + "racket-lang.org", + "ocaml.org", + "crates.io", + "npmjs.com", + "pypi.org", + "rubygems.org", + "packagist.org", + "nuget.org", + "maven.apache.org", + "hex.pm", + "hackage.haskell.org", + "pkg.go.dev", + "docker.com", + "kubernetes.io", + "prometheus.io", + "grafana.com", + "elastic.co", + "datadog.com", + "sentry.io", + "pagerduty.com", + "atlassian.com", + "jetbrains.com", + "gitlab.com", + "bitbucket.org", + "sourcehut.org", + "codeberg.org", + "launchpad.net", + "savannah.gnu.org", + "letsencrypt.org", + "eff.org", + "torproject.org", + "privacyguides.org", + "matrix.org", + "element.io", + "jitsi.org", + "nextcloud.com", + "syncthing.net", + "tailscale.com", + "mullvad.net", + "proton.me", + "duckduckgo.com", + "brave.com", + "vivaldi.com", +]; + +const ROUNDS: usize = 10; + +fn main() { + let arg = |flag: &str| std::env::args().any(|a| a == flag); + + let rt = tokio::runtime::Runtime::new().unwrap(); + + if arg("--diag") { + return run_diag(&rt); + } + if arg("--diag-clients") { + return run_diag_clients(&rt); + } + if arg("--direct") { + return run_direct(&rt); + } + if arg("--hedge-5x") { + return run_hedge_multi(&rt, 5); + } + if arg("--vs-unbound") { + check_numa_mode(&rt, "forward"); + return run_server_comparison(&rt, "Unbound", "127.0.0.1:5456", 5, false); + } + if arg("--vs-unbound-cold") { + check_numa_mode(&rt, "recursive"); + return run_server_comparison(&rt, "Unbound", "127.0.0.1:5456", 5, true); + } + if arg("--vs-dnscrypt") { + check_numa_mode(&rt, "forward"); + return run_server_comparison(&rt, "dnscrypt-proxy", "127.0.0.1:5455", 5, false); + } + if arg("--vs-nextdns") { + check_numa_mode(&rt, "forward"); + return run_server_comparison(&rt, "NextDNS", "45.90.28.0:53", 5, false); + } + if arg("--vs-dot") { + return run_dot_comparison(&rt, 5); + } + if arg("--vs-doh-servers") { + return run_doh_comparison(&rt, 5); + } + + // Default: Numa server (UDP) vs Hickory library (DoH) + run_default(&rt); +} + +// ── Generic 2-way comparison engine ───────────────────────────── + +fn compare_two( + rt: &tokio::runtime::Runtime, + title: &str, + name_a: &str, + name_b: &str, + measure_a: &dyn Fn(&str) -> f64, + measure_b: &dyn Fn(&str) -> f64, + iterations: usize, +) { + compare_two_rounds( + rt, title, name_a, name_b, measure_a, measure_b, iterations, ROUNDS, + ); +} + +fn compare_two_rounds( + rt: &tokio::runtime::Runtime, + title: &str, + name_a: &str, + name_b: &str, + measure_a: &dyn Fn(&str) -> f64, + measure_b: &dyn Fn(&str) -> f64, + iterations: usize, + rounds: usize, +) { + let flush = std::env::args().any(|a| a == "--flush"); + println!("{}", title); + println!( + "{} domains × {} rounds × {} iterations\n", + DOMAINS.len(), + rounds, + iterations + ); + + let mut all_a = Vec::new(); + let mut all_b = Vec::new(); + let mut iter_stats: Vec<[(f64, f64, f64, f64, f64); 2]> = Vec::new(); + + for iter in 1..=iterations { + println!(" iteration {}/{}...", iter, iterations); + let mut a = Vec::new(); + let mut b = Vec::new(); + + for domain in DOMAINS { + for round in 0..rounds { + if flush { + flush_cache(); + std::thread::sleep(Duration::from_millis(5)); + } + if round % 2 == 0 { + a.push(measure_a(domain)); + b.push(measure_b(domain)); + } else { + b.push(measure_b(domain)); + a.push(measure_a(domain)); + } + } + } + + iter_stats.push([stats(&mut a), stats(&mut b)]); + all_a.extend_from_slice(&a); + all_b.extend_from_slice(&b); + } + + print_results( + name_a, + name_b, + &iter_stats, + &mut all_a, + &mut all_b, + iterations, + rounds, + ); +} + +fn print_results( + name_a: &str, + name_b: &str, + iter_stats: &[[(f64, f64, f64, f64, f64); 2]], + all_a: &mut Vec, + all_b: &mut Vec, + iterations: usize, + rounds: usize, +) { + let w = name_a.len().max(name_b.len()).max(6); + + println!("\n=== Per-iteration medians ==="); + println!("{:<8} {:>w$} {:>w$}", "iter", name_a, name_b, w = w + 3); + for (i, s) in iter_stats.iter().enumerate() { + println!( + "{:<8} {:>w$.1} ms {:>w$.1} ms", + i + 1, + s[0].1, + s[1].1, + w = w + ); + } + + println!("\n=== Per-iteration p99 ==="); + println!("{:<8} {:>w$} {:>w$}", "iter", name_a, name_b, w = w + 3); + for (i, s) in iter_stats.iter().enumerate() { + println!( + "{:<8} {:>w$.1} ms {:>w$.1} ms", + i + 1, + s[0].3, + s[1].3, + w = w + ); + } + + let (a_m, a_med, a_p95, a_p99, a_sd) = stats(all_a); + let (b_m, b_med, b_p95, b_p99, b_sd) = stats(all_b); + + let total = iterations * DOMAINS.len() * rounds; + println!("\n=== Aggregated ({} samples per method) ===\n", total); + println!("{:<10} {:>w$} {:>w$}", "", name_a, name_b, w = w + 3); + println!("{:<10} {:>w$.1} ms {:>w$.1} ms", "mean", a_m, b_m, w = w); + println!( + "{:<10} {:>w$.1} ms {:>w$.1} ms", + "median", + a_med, + b_med, + w = w + ); + println!( + "{:<10} {:>w$.1} ms {:>w$.1} ms", + "p95", + a_p95, + b_p95, + w = w + ); + println!( + "{:<10} {:>w$.1} ms {:>w$.1} ms", + "p99", + a_p99, + b_p99, + w = w + ); + println!("{:<10} {:>w$.1} ms {:>w$.1} ms", "σ", a_sd, b_sd, w = w); + + let pct = |a: f64, b: f64| { + if b.abs() > 0.001 { + (a - b) / b * 100.0 + } else { + 0.0 + } + }; + println!("\n{} vs {}:", name_a, name_b); + println!(" mean: {:+.1} ms ({:+.0}%)", a_m - b_m, pct(a_m, b_m)); + println!( + " median: {:+.1} ms ({:+.0}%)", + a_med - b_med, + pct(a_med, b_med) + ); + println!( + " p99: {:+.1} ms ({:+.0}%)", + a_p99 - b_p99, + pct(a_p99, b_p99) + ); +} + +// ── Modes ─────────────────────────────────────────────────────── + +/// Default: Numa server (UDP) vs Hickory library (DoH), cache flushed. +fn run_default(rt: &tokio::runtime::Runtime) { + let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); + if rt.block_on(query_udp(numa_addr, "example.com")).is_none() { + eprintln!("Bench Numa not responding on {numa_addr}"); + eprintln!("Start with: cargo run -- benches/numa-bench.toml"); + std::process::exit(1); + } + + let resolver = rt.block_on(build_hickory_resolver()); + + println!("Warming up..."); + for _ in 0..3 { + rt.block_on(query_udp(numa_addr, "example.com")); + rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + flush_cache(); + + compare_two( + rt, + &format!("DoH Forwarding: Numa server vs Hickory library\nBoth → {DOH_UPSTREAM}"), + "Numa", + "Hickory", + &|domain| { + flush_cache(); + std::thread::sleep(Duration::from_millis(10)); + let t = Instant::now(); + let _ = rt.block_on(query_udp(numa_addr, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }, + &|domain| { + let t = Instant::now(); + let _ = rt.block_on(query_hickory_doh(&resolver, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }, + 1, + ); +} + +/// Library-to-library: Numa forward_query_raw vs Hickory resolver.lookup. +fn run_direct(rt: &tokio::runtime::Runtime) { + let upstream = numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let resolver = rt.block_on(build_hickory_resolver()); + let timeout = Duration::from_secs(10); + + println!("Warming up..."); + for _ in 0..3 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + compare_two( + rt, + &format!("Direct DoH: Numa forward_query_raw vs Hickory resolver.lookup\nBoth → {DOH_UPSTREAM}, no server pipeline"), + "Numa", "Hickory", + &|domain| { + let w = build_query_vec(domain); + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); + t.elapsed().as_secs_f64() * 1000.0 + }, + &|domain| { + let t = Instant::now(); + let _ = rt.block_on(query_hickory_doh(&resolver, domain)); + t.elapsed().as_secs_f64() * 1000.0 + }, + 5, + ); +} + +/// Server-to-server: Numa vs another server, both on plain UDP. +/// When `cold` is true, each query uses a unique random subdomain so neither +/// server can answer from its record cache (NS delegation caching still applies). +fn run_server_comparison( + rt: &tokio::runtime::Runtime, + other_name: &str, + other_addr: &str, + iterations: usize, + cold: bool, +) { + use std::sync::atomic::{AtomicU64, Ordering}; + static COUNTER: AtomicU64 = AtomicU64::new(0); + // Unique prefix per process so re-runs don't hit stale cache entries + let run_id = std::process::id(); + + let numa_addr: SocketAddr = NUMA_BENCH.parse().unwrap(); + let other: SocketAddr = other_addr.parse().unwrap(); + + for (name, addr) in [("Numa", numa_addr), (other_name, other)] { + if rt.block_on(query_udp(addr, "example.com")).is_none() { + eprintln!("{name} not responding on {addr}"); + std::process::exit(1); + } + } + + if cold { + flush_cache(); // flush Numa's record cache + } + + println!("Warming up..."); + for _ in 0..5 { + let _ = rt.block_on(query_udp(numa_addr, "example.com")); + let _ = rt.block_on(query_udp(other, "example.com")); + } + + let tag = if cold { + "cold, unique subdomains" + } else { + "caching" + }; + + let rounds = if cold { 1 } else { ROUNDS }; + + compare_two_rounds( + rt, + &format!("Server-to-Server: Numa vs {other_name} (UDP, {tag})"), + "Numa", + other_name, + &|domain| { + let d = if cold { + format!( + "r{}-c{}.{}", + run_id, + COUNTER.fetch_add(1, Ordering::Relaxed), + domain + ) + } else { + domain.to_string() + }; + let t = Instant::now(); + let _ = rt.block_on(query_udp(numa_addr, &d)); + t.elapsed().as_secs_f64() * 1000.0 + }, + &|domain| { + let d = if cold { + format!( + "r{}-c{}.{}", + run_id, + COUNTER.fetch_add(1, Ordering::Relaxed), + domain + ) + } else { + domain.to_string() + }; + let t = Instant::now(); + let _ = rt.block_on(query_udp(other, &d)); + t.elapsed().as_secs_f64() * 1000.0 + }, + iterations, + rounds, + ); +} + +/// DoT server comparison: Numa vs Unbound. +fn run_dot_comparison(rt: &tokio::runtime::Runtime, iterations: usize) { + const NUMA_DOT: &str = "127.0.0.1:8530"; + const UNBOUND_DOT: &str = "127.0.0.1:8531"; + + let _ = rustls::crypto::ring::default_provider().install_default(); + let tls_config = build_insecure_tls_config(); + + for (name, addr) in [("Numa", NUMA_DOT), ("Unbound", UNBOUND_DOT)] { + match rt.block_on(query_dot_once(addr, "example.com", &tls_config)) { + Ok(_) => println!("{name} DoT: OK"), + Err(e) => { + eprintln!("{name} DoT not responding on {addr}: {e}"); + std::process::exit(1); + } + } + } + + println!("Warming up..."); + for _ in 0..3 { + let _ = rt.block_on(query_dot_once(NUMA_DOT, "example.com", &tls_config)); + let _ = rt.block_on(query_dot_once(UNBOUND_DOT, "example.com", &tls_config)); + } + + compare_two( + rt, + "DoT Server: Numa vs Unbound (both DoT→clients, forwarding to Quad9)", + "Numa", + "Unbound", + &|domain| { + let t = Instant::now(); + let _ = rt.block_on(query_dot_once(NUMA_DOT, domain, &tls_config)); + t.elapsed().as_secs_f64() * 1000.0 + }, + &|domain| { + let t = Instant::now(); + let _ = rt.block_on(query_dot_once(UNBOUND_DOT, domain, &tls_config)); + t.elapsed().as_secs_f64() * 1000.0 + }, + iterations, + ); +} + +/// DoH server comparison: Numa vs Unbound (both DoH→clients, DoT upstream). +fn run_doh_comparison(rt: &tokio::runtime::Runtime, iterations: usize) { + const NUMA_DOH: &str = "https://127.0.0.1:8443/dns-query"; + const UNBOUND_DOH: &str = "https://127.0.0.1:8445/dns-query"; + + let client = reqwest::Client::builder() + .use_rustls_tls() + .danger_accept_invalid_certs(true) + .http2_initial_stream_window_size(65_535) + .http2_initial_connection_window_size(65_535) + .pool_idle_timeout(Duration::from_secs(300)) + .build() + .unwrap(); + + for (name, url, host) in [ + ("Numa", NUMA_DOH, Some("numa.numa")), + ("Unbound", UNBOUND_DOH, None), + ] { + let w = build_query_vec("example.com"); + match rt.block_on(query_doh_server(&client, url, &w, host)) { + Ok(_) => println!("{name} DoH: OK"), + Err(e) => { + eprintln!("{name} DoH not responding: {e}"); + std::process::exit(1); + } + } + } + + println!("Warming up..."); + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(query_doh_server(&client, NUMA_DOH, &w, Some("numa.numa"))); + let _ = rt.block_on(query_doh_server(&client, UNBOUND_DOH, &w, None)); + } + + compare_two( + rt, + "DoH Server: Numa vs Unbound (both DoH→clients, DoT upstream)", + "Numa", + "Unbound", + &|domain| { + let w = build_query_vec(domain); + let t = Instant::now(); + let _ = rt.block_on(query_doh_server(&client, NUMA_DOH, &w, Some("numa.numa"))); + t.elapsed().as_secs_f64() * 1000.0 + }, + &|domain| { + let w = build_query_vec(domain); + let t = Instant::now(); + let _ = rt.block_on(query_doh_server(&client, UNBOUND_DOH, &w, None)); + t.elapsed().as_secs_f64() * 1000.0 + }, + iterations, + ); +} + +/// Hedging: single vs hedge-same vs hedge-dual vs Hickory. +/// This is the one mode that compares 4 contenders, not 2. +fn run_hedge_multi(rt: &tokio::runtime::Runtime, iterations: usize) { + let hedge_delay = Duration::from_millis(10); + let timeout = Duration::from_secs(10); + + println!("Hedging Benchmark × {iterations} iterations"); + println!("Upstream: {DOH_UPSTREAM}"); + println!("Hedge delay: {hedge_delay:?}"); + println!( + "{} domains × {ROUNDS} rounds per iteration\n", + DOMAINS.len() + ); + + let primary = numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let primary_dual = numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let secondary_dual = numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let resolver = rt.block_on(build_hickory_resolver()); + + println!("Warming up..."); + for _ in 0..5 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary_dual, timeout)); + let _ = rt.block_on(numa::forward::forward_query_raw( + &w, + &secondary_dual, + timeout, + )); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + let labels = ["Single", "Hedge-same", "Hedge-dual", "Hickory"]; + let mut all: [Vec; 4] = [vec![], vec![], vec![], vec![]]; + let mut iter_medians: Vec<[f64; 4]> = vec![]; + let mut iter_p99s: Vec<[f64; 4]> = vec![]; + + for iter in 1..=iterations { + println!(" iteration {iter}/{iterations}..."); + let mut samples: [Vec; 4] = [vec![], vec![], vec![], vec![]]; + + for domain in DOMAINS { + for _ in 0..ROUNDS { + let w = build_query_vec(domain); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &primary, timeout)); + samples[0].push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &w, + &primary, + &primary, + hedge_delay, + timeout, + )); + samples[1].push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(numa::forward::forward_with_hedging_raw( + &w, + &primary_dual, + &secondary_dual, + hedge_delay, + timeout, + )); + samples[2].push(t.elapsed().as_secs_f64() * 1000.0); + + let t = Instant::now(); + let _ = rt.block_on(query_hickory_doh(&resolver, domain)); + samples[3].push(t.elapsed().as_secs_f64() * 1000.0); + } + } + + let s: Vec<_> = samples.iter_mut().map(|v| stats(v)).collect(); + iter_medians.push([s[0].1, s[1].1, s[2].1, s[3].1]); + iter_p99s.push([s[0].3, s[1].3, s[2].3, s[3].3]); + for (i, v) in samples.iter().enumerate() { + all[i].extend_from_slice(v); + } + } + + println!("\n=== Per-iteration medians ==="); + println!( + "{:<8} {:>10} {:>12} {:>12} {:>10}", + "iter", labels[0], labels[1], labels[2], labels[3] + ); + for (i, m) in iter_medians.iter().enumerate() { + println!( + "{:<8} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + i + 1, + m[0], + m[1], + m[2], + m[3] + ); + } + + println!("\n=== Per-iteration p99 ==="); + println!( + "{:<8} {:>10} {:>12} {:>12} {:>10}", + "iter", labels[0], labels[1], labels[2], labels[3] + ); + for (i, p) in iter_p99s.iter().enumerate() { + println!( + "{:<8} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + i + 1, + p[0], + p[1], + p[2], + p[3] + ); + } + + let s: Vec<_> = all + .iter_mut() + .map(|v| { + let (m, med, p95, p99, sd) = stats(v); + [m, med, p95, p99, sd] + }) + .collect(); + let total = iterations * DOMAINS.len() * ROUNDS; + println!("\n=== Aggregated ({total} samples per method) ===\n"); + println!( + "{:<10} {:>10} {:>12} {:>12} {:>10}", + "", labels[0], labels[1], labels[2], labels[3] + ); + for (row, idx) in [("mean", 0), ("median", 1), ("p95", 2), ("p99", 3), ("σ", 4)] { + println!( + "{:<10} {:>7.1} ms {:>9.1} ms {:>9.1} ms {:>7.1} ms", + row, s[0][idx], s[1][idx], s[2][idx], s[3][idx] + ); + } + + let pct = |a: f64, b: f64| { + if b.abs() > 0.001 { + (a - b) / b * 100.0 + } else { + 0.0 + } + }; + println!( + "\nHedge-same vs Single: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", + pct(s[1][0], s[0][0]), + pct(s[1][2], s[0][2]), + pct(s[1][3], s[0][3]) + ); + println!( + "Hedge-same vs Hickory: mean {:+.0}%, p95 {:+.0}%, p99 {:+.0}%", + pct(s[1][0], s[3][0]), + pct(s[1][2], s[3][2]), + pct(s[1][3], s[3][3]) + ); +} + +// ── Diagnostics (small, kept for debugging) ───────────────────── + +fn run_diag(rt: &tokio::runtime::Runtime) { + println!("Hickory connection reuse diagnostic\n20 queries to {DOH_UPSTREAM}\n"); + + let resolver = rt.block_on(build_hickory_resolver()); + let domains = [ + "example.com", + "rust-lang.org", + "kernel.org", + "google.com", + "github.com", + "example.com", + "rust-lang.org", + "kernel.org", + "google.com", + "github.com", + "example.com", + "rust-lang.org", + "kernel.org", + "google.com", + "github.com", + "example.com", + "rust-lang.org", + "kernel.org", + "google.com", + "github.com", + ]; + + println!("{:>3} {:<20} {:>10}", "#", "Domain", "Time (ms)"); + println!("{}", "-".repeat(40)); + for (i, domain) in domains.iter().enumerate() { + use hickory_resolver::proto::rr::RecordType; + let start = Instant::now(); + let result = rt.block_on(resolver.lookup(*domain, RecordType::A)); + let ms = start.elapsed().as_secs_f64() * 1000.0; + match &result { + Ok(lookup) => { + let first = lookup + .iter() + .next() + .map(|r| format!("{r}")) + .unwrap_or_default(); + println!( + "{:>3} {:<20} {:>7.1} ms OK {}", + i + 1, + domain, + ms, + first + ); + } + Err(e) => println!("{:>3} {:<20} {:>7.1} ms ERR {}", i + 1, domain, ms, e), + } + } +} + +fn run_diag_clients(rt: &tokio::runtime::Runtime) { + println!("Client diagnostic: reqwest vs Hickory (20 queries to {DOH_UPSTREAM})\n"); + + let upstream = numa::forward::parse_upstream(DOH_UPSTREAM, 443).expect("failed to parse"); + let resolver = rt.block_on(build_hickory_resolver()); + let timeout = Duration::from_secs(10); + + for _ in 0..3 { + let w = build_query_vec("example.com"); + let _ = rt.block_on(numa::forward::forward_query_raw(&w, &upstream, timeout)); + let _ = rt.block_on(query_hickory_doh(&resolver, "example.com")); + } + + let domains = [ + "example.com", + "google.com", + "github.com", + "rust-lang.org", + "cloudflare.com", + "example.com", + "google.com", + "github.com", + "rust-lang.org", + "cloudflare.com", + "example.com", + "google.com", + "github.com", + "rust-lang.org", + "cloudflare.com", + "example.com", + "google.com", + "github.com", + "rust-lang.org", + "cloudflare.com", + ]; + + println!( + "{:>3} {:<20} {:>12} {:>12}", + "#", "Domain", "reqwest", "Hickory" + ); + println!("{}", "-".repeat(55)); + for (i, domain) in domains.iter().enumerate() { + let wire = build_query_vec(domain); + let start = Instant::now(); + let r_result = rt.block_on(numa::forward::forward_query_raw(&wire, &upstream, timeout)); + let r_ms = start.elapsed().as_secs_f64() * 1000.0; + let r_ok = if r_result.is_ok() { "OK" } else { "FAIL" }; + + let start = Instant::now(); + let h_result = rt.block_on(query_hickory_doh(&resolver, domain)); + let h_ms = start.elapsed().as_secs_f64() * 1000.0; + let h_ok = if h_result.is_some() { "OK" } else { "FAIL" }; + + println!( + "{:>3} {:<20} {:>7.1} ms {} {:>7.1} ms {}", + i + 1, + domain, + r_ms, + r_ok, + h_ms, + h_ok + ); + } +} + +// ── Stats helpers ─────────────────────────────────────────────── + +fn stats(v: &mut [f64]) -> (f64, f64, f64, f64, f64) { + if v.is_empty() { + return (0.0, 0.0, 0.0, 0.0, 0.0); + } + let mean = v.iter().sum::() / v.len() as f64; + v.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let n = v.len(); + let median = if n % 2 == 0 { + (v[n / 2 - 1] + v[n / 2]) / 2.0 + } else { + v[n / 2] + }; + let p95 = v[((n as f64 * 0.95).round() as usize).min(n - 1)]; + let p99 = v[((n as f64 * 0.99).round() as usize).min(n - 1)]; + let var = v.iter().map(|x| (x - mean).powi(2)).sum::() / n as f64; + (mean, median, p95, p99, var.sqrt()) +} + +// ── Query helpers ─────────────────────────────────────────────── + +async fn query_udp(addr: SocketAddr, domain: &str) -> Option<()> { + use tokio::net::UdpSocket; + let sock = UdpSocket::bind("0.0.0.0:0").await.ok()?; + let mut buf = vec![0u8; 512]; + let len = build_query(&mut buf, domain); + sock.send_to(&buf[..len], addr).await.ok()?; + let mut resp = vec![0u8; 4096]; + tokio::time::timeout(Duration::from_secs(10), sock.recv_from(&mut resp)) + .await + .ok()? + .ok()?; + Some(()) +} + +async fn query_dot_once( + addr: &str, + domain: &str, + tls_config: &std::sync::Arc, +) -> Result<(), Box> { + use rustls::pki_types::ServerName; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpStream; + use tokio_rustls::TlsConnector; + + let connector = TlsConnector::from(tls_config.clone()); + let stream = TcpStream::connect(addr).await?; + let server_name = ServerName::try_from("localhost")?; + let mut tls = connector.connect(server_name, stream).await?; + + let mut buf = vec![0u8; 512]; + let len = build_query(&mut buf, domain); + let msg = &buf[..len]; + + let mut out = Vec::with_capacity(2 + msg.len()); + out.extend_from_slice(&(msg.len() as u16).to_be_bytes()); + out.extend_from_slice(msg); + tls.write_all(&out).await?; + + let mut len_buf = [0u8; 2]; + tls.read_exact(&mut len_buf).await?; + let resp_len = u16::from_be_bytes(len_buf) as usize; + let mut resp = vec![0u8; resp_len]; + tls.read_exact(&mut resp).await?; + Ok(()) +} + +async fn query_doh_server( + client: &reqwest::Client, + url: &str, + wire: &[u8], + host: Option<&str>, +) -> Result, Box> { + let mut req = client + .post(url) + .header("content-type", "application/dns-message") + .header("accept", "application/dns-message") + .body(wire.to_vec()); + if let Some(h) = host { + req = req.header("host", h); + } + let resp = req.send().await?.error_for_status()?; + Ok(resp.bytes().await?.to_vec()) +} + +async fn build_hickory_resolver() -> hickory_resolver::TokioResolver { + use hickory_resolver::config::*; + let ns = NameServerConfig { + socket_addr: "9.9.9.9:443".parse().unwrap(), + protocol: hickory_proto::xfer::Protocol::Https, + tls_dns_name: Some("dns.quad9.net".to_string()), + trust_negative_responses: true, + bind_addr: None, + http_endpoint: Some("/dns-query".to_string()), + }; + let config = ResolverConfig::from_parts(None, vec![], NameServerConfigGroup::from(vec![ns])); + let mut opts = ResolverOpts::default(); + opts.cache_size = 0; + opts.num_concurrent_reqs = 1; + opts.timeout = Duration::from_secs(10); + hickory_resolver::TokioResolver::builder_with_config(config, Default::default()) + .with_options(opts) + .build() +} + +async fn query_hickory_doh(resolver: &hickory_resolver::TokioResolver, domain: &str) -> Option<()> { + use hickory_resolver::proto::rr::RecordType; + let _ = resolver.lookup(domain, RecordType::A).await.ok()?; + Some(()) +} + +fn build_insecure_tls_config() -> std::sync::Arc { + use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; + use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; + use rustls::DigitallySignedStruct; + + #[derive(Debug)] + struct NoVerify; + impl ServerCertVerifier for NoVerify { + fn verify_server_cert( + &self, + _: &CertificateDer<'_>, + _: &[CertificateDer<'_>], + _: &ServerName<'_>, + _: &[u8], + _: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + fn verify_tls12_signature( + &self, + _: &[u8], + _: &CertificateDer<'_>, + _: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + fn verify_tls13_signature( + &self, + _: &[u8], + _: &CertificateDer<'_>, + _: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + fn supported_verify_schemes(&self) -> Vec { + rustls::crypto::ring::default_provider() + .signature_verification_algorithms + .supported_schemes() + } + } + std::sync::Arc::new( + rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(std::sync::Arc::new(NoVerify)) + .with_no_client_auth(), + ) +} + +// ── Wire helpers ──────────────────────────────────────────────── + +fn build_query_vec(domain: &str) -> Vec { + let mut buf = vec![0u8; 512]; + let len = build_query(&mut buf, domain); + buf.truncate(len); + buf +} + +fn build_query(buf: &mut [u8], domain: &str) -> usize { + let mut pos = 0; + buf[pos..pos + 2].copy_from_slice(&0x1234u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 2].copy_from_slice(&0x0100u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 6].fill(0); + pos += 6; + for label in domain.split('.') { + buf[pos] = label.len() as u8; + pos += 1; + buf[pos..pos + label.len()].copy_from_slice(label.as_bytes()); + pos += label.len(); + } + buf[pos] = 0; + pos += 1; + buf[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); + pos += 2; + buf[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); + pos += 2; + pos +} + +fn check_numa_mode(rt: &tokio::runtime::Runtime, expected: &str) { + let url = format!("http://127.0.0.1:{NUMA_API}/stats"); + let resp = match rt.block_on(async { reqwest::get(&url).await?.text().await }) { + Ok(body) => body, + Err(_) => { + eprintln!("Bench Numa not responding on {NUMA_BENCH}"); + eprintln!("Start with: cargo run -- benches/numa-bench.toml"); + std::process::exit(1); + } + }; + let config = if expected == "recursive" { + "benches/numa-bench-recursive.toml" + } else { + "benches/numa-bench.toml" + }; + if !resp.contains(&format!("\"mode\":\"{expected}\"")) { + eprintln!("This benchmark requires Numa in {expected} mode."); + eprintln!("Restart with: cargo run -- {config}"); + std::process::exit(1); + } +} + +fn flush_cache() { + let _ = std::process::Command::new("curl") + .args([ + "-s", + "-X", + "DELETE", + &format!("http://127.0.0.1:{NUMA_API}/cache"), + ]) + .output(); +} diff --git a/numa.toml b/numa.toml index 3b716e8..1ea3341 100644 --- a/numa.toml +++ b/numa.toml @@ -15,9 +15,15 @@ api_port = 5380 # address = "9.9.9.9" # single upstream (plain UDP) # address = ["192.168.1.1", "9.9.9.9:5353"] # multiple upstreams — SRTT picks fastest # address = "https://dns.quad9.net/dns-query" # DNS-over-HTTPS (encrypted) +# address = "tls://9.9.9.9#dns.quad9.net" # DNS-over-TLS (encrypted, port 853) # fallback = ["8.8.8.8", "1.1.1.1"] # tried only when all primaries fail # port = 53 # default port for addresses without :port # timeout_ms = 3000 +# hedge_ms = 10 # request hedging delay (ms). After this delay +# # without a response, fires a parallel request +# # to the same upstream. Rescues packet loss (UDP), +# # dispatch spikes (DoH), TLS stalls (DoT). +# # Set to 0 to disable. Default: 10 # root_hints = [ # only used in recursive mode # "198.41.0.4", # a.root-servers.net (Verisign) # "199.9.14.201", # b.root-servers.net (USC-ISI) @@ -60,7 +66,7 @@ api_port = 5380 # allowlist = ["example.com"] # domains to never block [cache] -max_entries = 10000 +max_entries = 100000 min_ttl = 60 max_ttl = 86400 # warm = ["google.com", "github.com"] # resolve at startup, refresh before TTL expiry diff --git a/src/api.rs b/src/api.rs index a0bae58..9aa3f60 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1012,6 +1012,7 @@ mod tests { socket, zone_map: std::collections::HashMap::new(), cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)), + refreshing: Mutex::new(std::collections::HashSet::new()), stats: Mutex::new(crate::stats::ServerStats::new()), overrides: RwLock::new(crate::override_store::OverrideStore::new()), blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()), @@ -1029,6 +1030,7 @@ mod tests { upstream_port: 53, lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), timeout: std::time::Duration::from_secs(3), + hedge_delay: std::time::Duration::ZERO, proxy_tld: "numa".to_string(), proxy_tld_suffix: ".numa".to_string(), lan_enabled: false, diff --git a/src/cache.rs b/src/cache.rs index 5bdde85..18fdc19 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,9 +1,26 @@ use std::collections::HashMap; use std::time::{Duration, Instant}; +use crate::buffer::BytePacketBuffer; use crate::packet::DnsPacket; use crate::question::QueryType; -use crate::record::DnsRecord; +use crate::wire::WireMeta; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Freshness { + /// Within TTL, no action needed. + Fresh, + /// Within TTL but <10% remaining — trigger background prefetch. + NearExpiry, + /// Past TTL but within stale window — serve with TTL=1, trigger background refresh. + Stale, +} + +impl Freshness { + pub fn needs_refresh(self) -> bool { + matches!(self, Freshness::NearExpiry | Freshness::Stale) + } +} #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] pub enum DnssecStatus { @@ -26,14 +43,16 @@ impl DnssecStatus { } struct CacheEntry { - packet: DnsPacket, + wire: Vec, + meta: WireMeta, inserted_at: Instant, ttl: Duration, dnssec_status: DnssecStatus, } -/// DNS cache using a two-level map (domain -> query_type -> entry) so that -/// lookups can borrow `&str` instead of allocating a `String` key. +const STALE_WINDOW: Duration = Duration::from_secs(3600); + +/// DNS cache with serve-stale (RFC 8767). Stores raw wire bytes. pub struct DnsCache { entries: HashMap>, entry_count: usize, @@ -53,33 +72,118 @@ impl DnsCache { } } + /// Look up cached wire bytes, patching ID and TTLs in the returned copy. + /// Implements serve-stale (RFC 8767): expired entries within STALE_WINDOW + /// are returned with TTL=1 and `stale=true` so callers can revalidate. + pub fn lookup_wire( + &self, + domain: &str, + qtype: QueryType, + new_id: u16, + ) -> Option<(Vec, DnssecStatus, Freshness)> { + let type_map = self.entries.get(domain)?; + let entry = type_map.get(&qtype)?; + + let elapsed = entry.inserted_at.elapsed(); + let (remaining, freshness) = if elapsed < entry.ttl { + let secs = (entry.ttl - elapsed).as_secs() as u32; + let f = if elapsed * 10 >= entry.ttl * 9 { + Freshness::NearExpiry + } else { + Freshness::Fresh + }; + (secs.max(1), f) + } else if elapsed < entry.ttl + STALE_WINDOW { + (1, Freshness::Stale) + } else { + return None; + }; + + let mut wire = entry.wire.clone(); + crate::wire::patch_id(&mut wire, new_id); + crate::wire::patch_ttls(&mut wire, &entry.meta.ttl_offsets, remaining); + + Some((wire, entry.dnssec_status, freshness)) + } + + pub fn insert_wire( + &mut self, + domain: &str, + qtype: QueryType, + wire: &[u8], + dnssec_status: DnssecStatus, + ) { + let meta = match crate::wire::scan_ttl_offsets(wire) { + Ok(m) => m, + Err(_) => return, // malformed wire, skip + }; + + if self.entry_count >= self.max_entries { + self.evict_expired(); + if self.entry_count >= self.max_entries { + self.evict_stalest(); + } + } + + let min_ttl = crate::wire::min_ttl_from_wire(wire, &meta) + .unwrap_or(self.min_ttl) + .clamp(self.min_ttl, self.max_ttl); + + let type_map = if let Some(existing) = self.entries.get_mut(domain) { + existing + } else { + self.entries.entry(domain.to_string()).or_default() + }; + + if !type_map.contains_key(&qtype) { + self.entry_count += 1; + } + + type_map.insert( + qtype, + CacheEntry { + wire: wire.to_vec(), + meta, + inserted_at: Instant::now(), + ttl: Duration::from_secs(min_ttl as u64), + dnssec_status, + }, + ); + } + /// Read-only lookup — expired entries are left in place (cleaned up on insert). pub fn lookup(&self, domain: &str, qtype: QueryType) -> Option { - self.lookup_with_status(domain, qtype).map(|(pkt, _)| pkt) + self.lookup_with_status(domain, qtype) + .map(|(pkt, _, _)| pkt) } pub fn lookup_with_status( &self, domain: &str, qtype: QueryType, - ) -> Option<(DnsPacket, DnssecStatus)> { - let type_map = self.entries.get(domain)?; - let entry = type_map.get(&qtype)?; + ) -> Option<(DnsPacket, DnssecStatus, Freshness)> { + let (wire, status, freshness) = self.lookup_wire(domain, qtype, 0)?; + let mut buf = BytePacketBuffer::from_bytes(&wire); + let pkt = DnsPacket::from_buffer(&mut buf).ok()?; + Some((pkt, status, freshness)) + } - let elapsed = entry.inserted_at.elapsed(); - if elapsed >= entry.ttl { - return None; + pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { + self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate); + } + + pub fn insert_with_status( + &mut self, + domain: &str, + qtype: QueryType, + packet: &DnsPacket, + dnssec_status: DnssecStatus, + ) { + let mut buf = BytePacketBuffer::new(); + if packet.write(&mut buf).is_err() { + return; } - - let remaining_secs = (entry.ttl - elapsed).as_secs() as u32; - let remaining = remaining_secs.max(1); - - let mut packet = entry.packet.clone(); - adjust_ttls(&mut packet.answers, remaining); - adjust_ttls(&mut packet.authorities, remaining); - adjust_ttls(&mut packet.resources, remaining); - - Some((packet, entry.dnssec_status)) + self.insert_wire(domain, qtype, buf.filled(), dnssec_status); } pub fn ttl_remaining(&self, domain: &str, qtype: QueryType) -> Option<(u32, u32)> { @@ -105,49 +209,6 @@ impl DnsCache { false } - pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { - self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate); - } - - pub fn insert_with_status( - &mut self, - domain: &str, - qtype: QueryType, - packet: &DnsPacket, - dnssec_status: DnssecStatus, - ) { - if self.entry_count >= self.max_entries { - self.evict_expired(); - if self.entry_count >= self.max_entries { - return; - } - } - - let min_ttl = extract_min_ttl(&packet.answers) - .unwrap_or(self.min_ttl) - .clamp(self.min_ttl, self.max_ttl); - - let type_map = if let Some(existing) = self.entries.get_mut(domain) { - existing - } else { - self.entries.entry(domain.to_string()).or_default() - }; - - if !type_map.contains_key(&qtype) { - self.entry_count += 1; - } - - type_map.insert( - qtype, - CacheEntry { - packet: packet.clone(), - inserted_at: Instant::now(), - ttl: Duration::from_secs(min_ttl as u64), - dnssec_status, - }, - ); - } - pub fn len(&self) -> usize { self.entry_count } @@ -179,7 +240,8 @@ impl DnsCache { + 1; total += type_map.capacity() * inner_slot; for entry in type_map.values() { - total += entry.packet.heap_bytes(); + total += entry.wire.capacity() + + entry.meta.ttl_offsets.capacity() * std::mem::size_of::(); } } total @@ -220,6 +282,34 @@ impl DnsCache { }); self.entry_count -= count; } + + /// Evict the single entry closest to (or furthest past) expiry. + fn evict_stalest(&mut self) { + let mut worst: Option<(String, QueryType, Duration)> = None; + for (domain, type_map) in &self.entries { + for (qtype, entry) in type_map { + let age = entry.inserted_at.elapsed(); + let remaining = entry.ttl.saturating_sub(age); + match &worst { + None => worst = Some((domain.clone(), *qtype, remaining)), + Some((_, _, w)) if remaining < *w => { + worst = Some((domain.clone(), *qtype, remaining)); + } + _ => {} + } + } + } + if let Some((domain, qtype, _)) = worst { + if let Some(type_map) = self.entries.get_mut(&domain) { + if type_map.remove(&qtype).is_some() { + self.entry_count -= 1; + } + if type_map.is_empty() { + self.entries.remove(&domain); + } + } + } + } } pub struct CacheInfo { @@ -228,20 +318,11 @@ pub struct CacheInfo { pub ttl_remaining: u32, } -fn extract_min_ttl(records: &[DnsRecord]) -> Option { - records.iter().map(|r| r.ttl()).min() -} - -fn adjust_ttls(records: &mut [DnsRecord], new_ttl: u32) { - for record in records.iter_mut() { - record.set_ttl(new_ttl); - } -} - #[cfg(test)] mod tests { use super::*; use crate::packet::DnsPacket; + use crate::record::DnsRecord; #[test] fn heap_bytes_grows_with_entries() { diff --git a/src/config.rs b/src/config.rs index ae9f685..237f3bd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -138,6 +138,8 @@ pub struct UpstreamConfig { pub fallback: Vec, #[serde(default = "default_timeout_ms")] pub timeout_ms: u64, + #[serde(default = "default_hedge_ms")] + pub hedge_ms: u64, #[serde(default = "default_root_hints")] pub root_hints: Vec, #[serde(default = "default_prime_tlds")] @@ -154,6 +156,7 @@ impl Default for UpstreamConfig { port: default_upstream_port(), fallback: Vec::new(), timeout_ms: default_timeout_ms(), + hedge_ms: default_hedge_ms(), root_hints: default_root_hints(), prime_tlds: default_prime_tlds(), srtt: default_srtt(), @@ -271,6 +274,9 @@ fn default_upstream_port() -> u16 { fn default_timeout_ms() -> u64 { 5000 } +fn default_hedge_ms() -> u64 { + 10 +} #[derive(Deserialize)] pub struct CacheConfig { @@ -296,7 +302,7 @@ impl Default for CacheConfig { } fn default_max_entries() -> usize { - 10000 + 100_000 } fn default_min_ttl() -> u32 { 60 diff --git a/src/ctx.rs b/src/ctx.rs index 3ef6a0a..e97a7ea 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -1,7 +1,7 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::path::PathBuf; -use std::sync::{Mutex, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use std::time::{Duration, Instant, SystemTime}; use arc_swap::ArcSwap; @@ -16,7 +16,7 @@ use crate::blocklist::BlocklistStore; use crate::buffer::BytePacketBuffer; use crate::cache::{DnsCache, DnssecStatus}; use crate::config::{UpstreamMode, ZoneMap}; -use crate::forward::{forward_query, forward_with_failover, Upstream, UpstreamPool}; +use crate::forward::{forward_query_raw, forward_with_failover_raw, Upstream, UpstreamPool}; use crate::header::ResultCode; use crate::health::HealthMeta; use crate::lan::PeerStore; @@ -35,6 +35,8 @@ pub struct ServerCtx { pub zone_map: ZoneMap, /// std::sync::RwLock (not tokio) — locks must never be held across .await points. pub cache: RwLock, + /// Domains currently being refreshed in the background (dedup guard). + pub refreshing: Mutex>, pub stats: Mutex, pub overrides: RwLock, pub blocklist: RwLock, @@ -47,6 +49,7 @@ pub struct ServerCtx { pub upstream_port: u16, pub lan_ip: Mutex, pub timeout: Duration, + pub hedge_delay: Duration, pub proxy_tld: String, pub proxy_tld_suffix: String, // pre-computed ".{tld}" to avoid per-query allocation pub lan_enabled: bool, @@ -81,8 +84,9 @@ pub struct ServerCtx { /// (and logging parse errors) before calling this function. pub async fn resolve_query( query: DnsPacket, + raw_wire: &[u8], src_addr: SocketAddr, - ctx: &ServerCtx, + ctx: &Arc, ) -> crate::Result { let start = Instant::now(); @@ -164,7 +168,18 @@ pub async fn resolve_query( (resp, QueryPath::Blocked, DnssecStatus::Indeterminate) } else { let cached = ctx.cache.read().unwrap().lookup_with_status(&qname, qtype); - if let Some((cached, cached_dnssec)) = cached { + if let Some((cached, cached_dnssec, freshness)) = cached { + if freshness.needs_refresh() { + let key = (qname.clone(), qtype); + let already = !ctx.refreshing.lock().unwrap().insert(key.clone()); + if !already { + let ctx = Arc::clone(ctx); + tokio::spawn(async move { + refresh_entry(&ctx, &key.0, key.1).await; + ctx.refreshing.lock().unwrap().remove(&key); + }); + } + } let mut resp = cached; resp.header.id = query.header.id; if cached_dnssec == DnssecStatus::Secure { @@ -177,11 +192,8 @@ pub async fn resolve_query( // Conditional forwarding takes priority over recursive mode // (e.g. Tailscale .ts.net, VPC private zones) let upstream = Upstream::Udp(fwd_addr); - match forward_query(&query, &upstream, ctx.timeout).await { - Ok(resp) => { - ctx.cache.write().unwrap().insert(&qname, qtype, &resp); - (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate) - } + match forward_and_cache(raw_wire, &upstream, ctx, &qname, qtype).await { + Ok(resp) => (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate), Err(e) => { error!( "{} | {:?} {} | FORWARD ERROR | {}", @@ -221,11 +233,26 @@ pub async fn resolve_query( (resp, path, DnssecStatus::Indeterminate) } else { let pool = ctx.upstream_pool.lock().unwrap().clone(); - match forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await { - Ok(resp) => { - ctx.cache.write().unwrap().insert(&qname, qtype, &resp); - (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate) - } + match forward_with_failover_raw( + raw_wire, + &pool, + &ctx.srtt, + ctx.timeout, + ctx.hedge_delay, + ) + .await + { + Ok(resp_wire) => match cache_and_parse(ctx, &qname, qtype, &resp_wire) { + Ok(resp) => (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate), + Err(e) => { + error!("{} | {:?} {} | PARSE ERROR | {}", src_addr, qtype, qname, e); + ( + DnsPacket::response_from(&query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + DnssecStatus::Indeterminate, + ) + } + }, Err(e) => { error!( "{} | {:?} {} | UPSTREAM ERROR | {}", @@ -347,11 +374,77 @@ pub async fn resolve_query( Ok(resp_buffer) } -/// Handle a DNS query received over UDP. Thin wrapper around resolve_query. +fn cache_and_parse( + ctx: &ServerCtx, + qname: &str, + qtype: QueryType, + resp_wire: &[u8], +) -> crate::Result { + ctx.cache + .write() + .unwrap() + .insert_wire(qname, qtype, resp_wire, DnssecStatus::Indeterminate); + let mut buf = BytePacketBuffer::from_bytes(resp_wire); + DnsPacket::from_buffer(&mut buf) +} + +/// Re-resolve a single (domain, qtype) and update the cache. +/// Used for both stale-entry refresh and proactive cache warming. +pub async fn refresh_entry(ctx: &ServerCtx, qname: &str, qtype: QueryType) { + let query = DnsPacket::query(0, qname, qtype); + if ctx.upstream_mode == UpstreamMode::Recursive { + if let Ok(resp) = crate::recursive::resolve_recursive( + qname, + qtype, + &ctx.cache, + &query, + &ctx.root_hints, + &ctx.srtt, + ) + .await + { + ctx.cache.write().unwrap().insert(qname, qtype, &resp); + } + } else { + let mut buf = BytePacketBuffer::new(); + if query.write(&mut buf).is_ok() { + let pool = ctx.upstream_pool.lock().unwrap().clone(); + if let Ok(wire) = forward_with_failover_raw( + buf.filled(), + &pool, + &ctx.srtt, + ctx.timeout, + ctx.hedge_delay, + ) + .await + { + ctx.cache.write().unwrap().insert_wire( + qname, + qtype, + &wire, + DnssecStatus::Indeterminate, + ); + } + } + } +} + +async fn forward_and_cache( + wire: &[u8], + upstream: &Upstream, + ctx: &ServerCtx, + qname: &str, + qtype: QueryType, +) -> crate::Result { + let resp_wire = forward_query_raw(wire, upstream, ctx.timeout).await?; + cache_and_parse(ctx, qname, qtype, &resp_wire) +} + pub async fn handle_query( mut buffer: BytePacketBuffer, + raw_len: usize, src_addr: SocketAddr, - ctx: &ServerCtx, + ctx: &Arc, ) -> crate::Result<()> { let query = match DnsPacket::from_buffer(&mut buffer) { Ok(packet) => packet, @@ -360,7 +453,7 @@ pub async fn handle_query( return Ok(()); } }; - match resolve_query(query, src_addr, ctx).await { + match resolve_query(query, &buffer.buf[..raw_len], src_addr, ctx).await { Ok(resp_buffer) => { ctx.socket.send_to(resp_buffer.filled(), src_addr).await?; } diff --git a/src/doh.rs b/src/doh.rs index cf50b31..bc4ba95 100644 --- a/src/doh.rs +++ b/src/doh.rs @@ -60,7 +60,11 @@ fn is_doh_host(host: Option<&str>, tld: &str) -> bool { } } -async fn resolve_doh(dns_bytes: &[u8], src: SocketAddr, ctx: &ServerCtx) -> Response { +async fn resolve_doh( + dns_bytes: &[u8], + src: SocketAddr, + ctx: &std::sync::Arc, +) -> Response { let mut buffer = BytePacketBuffer::from_bytes(dns_bytes); let query = match DnsPacket::from_buffer(&mut buffer) { Ok(q) => q, @@ -82,7 +86,7 @@ async fn resolve_doh(dns_bytes: &[u8], src: SocketAddr, ctx: &ServerCtx) -> Resp let query_rd = query.header.recursion_desired; let questions = query.questions.clone(); - match resolve_query(query, src, ctx).await { + match resolve_query(query, dns_bytes, src, ctx).await { Ok(resp_buffer) => { let min_ttl = extract_min_ttl(resp_buffer.filled()); dns_response(resp_buffer.filled(), min_ttl) @@ -102,11 +106,10 @@ async fn resolve_doh(dns_bytes: &[u8], src: SocketAddr, ctx: &ServerCtx) -> Resp } fn extract_min_ttl(wire: &[u8]) -> u32 { - let mut buf = BytePacketBuffer::from_bytes(wire); - match DnsPacket::from_buffer(&mut buf) { - Ok(pkt) => pkt.answers.iter().map(|r| r.ttl()).min().unwrap_or(0), - Err(_) => 0, - } + crate::wire::scan_ttl_offsets(wire) + .ok() + .and_then(|meta| crate::wire::min_ttl_from_wire(wire, &meta)) + .unwrap_or(0) } fn dns_response(wire: &[u8], min_ttl: u32) -> Response { diff --git a/src/dot.rs b/src/dot.rs index 0d48fa2..d4eeb95 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -153,8 +153,11 @@ async fn accept_loop(listener: TcpListener, acceptor: TlsAcceptor, ctx: Arc(mut stream: S, remote_addr: SocketAddr, ctx: &ServerCtx) -where +async fn handle_dot_connection( + mut stream: S, + remote_addr: SocketAddr, + ctx: &std::sync::Arc, +) where S: AsyncReadExt + AsyncWriteExt + Unpin, { loop { @@ -177,8 +180,6 @@ where break; }; - // Parse query up-front so we can echo its question section in SERVFAIL - // responses when resolve_query fails. let query = match DnsPacket::from_buffer(&mut buffer) { Ok(q) => q, Err(e) => { @@ -200,7 +201,7 @@ where } }; - match resolve_query(query.clone(), remote_addr, ctx).await { + match resolve_query(query.clone(), &buffer.buf[..msg_len], remote_addr, ctx).await { Ok(resp_buffer) => { if write_framed(&mut stream, resp_buffer.filled()) .await @@ -355,6 +356,7 @@ mod tests { m }, cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)), + refreshing: Mutex::new(std::collections::HashSet::new()), stats: Mutex::new(crate::stats::ServerStats::new()), overrides: RwLock::new(crate::override_store::OverrideStore::new()), blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()), @@ -370,6 +372,7 @@ mod tests { upstream_port: 53, lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), timeout: Duration::from_millis(200), + hedge_delay: Duration::ZERO, proxy_tld: "numa".to_string(), proxy_tld_suffix: ".numa".to_string(), lan_enabled: false, diff --git a/src/forward.rs b/src/forward.rs index ea2f1e2..e13e360 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -65,6 +65,13 @@ pub fn parse_upstream(s: &str, default_port: u16) -> Result { if s.starts_with("https://") { let client = reqwest::Client::builder() .use_rustls_tls() + .http2_initial_stream_window_size(65_535) + .http2_initial_connection_window_size(65_535) + .http2_keep_alive_interval(Duration::from_secs(15)) + .http2_keep_alive_while_idle(true) + .http2_keep_alive_timeout(Duration::from_secs(10)) + .pool_idle_timeout(Duration::from_secs(300)) + .pool_max_idle_per_host(1) .build() .unwrap_or_default(); return Ok(Upstream::Doh { @@ -150,72 +157,16 @@ impl UpstreamPool { } } -pub async fn forward_with_failover( - query: &DnsPacket, - pool: &UpstreamPool, - srtt: &RwLock, - timeout_duration: Duration, -) -> Result { - // Build candidate list: primary (sorted by SRTT for UDP) then fallback - let mut candidates: Vec<(usize, u64)> = pool - .primary - .iter() - .enumerate() - .map(|(i, u)| { - let rtt = match u { - Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()), - _ => 0, // DoH: keep config order (stable sort preserves it) - }; - (i, rtt) - }) - .collect(); - candidates.sort_by_key(|&(_, rtt)| rtt); - - let all_upstreams: Vec<&Upstream> = candidates - .iter() - .map(|&(i, _)| &pool.primary[i]) - .chain(pool.fallback.iter()) - .collect(); - - let mut last_err: Option> = None; - - for upstream in &all_upstreams { - let start = Instant::now(); - match forward_query(query, upstream, timeout_duration).await { - Ok(resp) => { - if let Upstream::Udp(addr) = upstream { - let rtt_ms = start.elapsed().as_millis() as u64; - srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false); - } - return Ok(resp); - } - Err(e) => { - if let Upstream::Udp(addr) = upstream { - srtt.write().unwrap().record_failure(addr.ip()); - } - log::debug!("upstream {} failed: {}", upstream, e); - last_err = Some(e); - } - } - } - - Err(last_err.unwrap_or_else(|| "no upstream configured".into())) -} - pub async fn forward_query( query: &DnsPacket, upstream: &Upstream, timeout_duration: Duration, ) -> Result { - match upstream { - Upstream::Udp(addr) => forward_udp(query, *addr, timeout_duration).await, - Upstream::Doh { url, client } => forward_doh(query, url, client, timeout_duration).await, - Upstream::Dot { - addr, - tls_name, - connector, - } => forward_dot(query, *addr, tls_name, connector, timeout_duration).await, - } + let mut send_buffer = BytePacketBuffer::new(); + query.write(&mut send_buffer)?; + let data = forward_query_raw(send_buffer.filled(), upstream, timeout_duration).await?; + let mut recv_buffer = BytePacketBuffer::from_bytes(&data); + DnsPacket::from_buffer(&mut recv_buffer) } pub(crate) async fn forward_udp( @@ -223,24 +174,10 @@ pub(crate) async fn forward_udp( upstream: SocketAddr, timeout_duration: Duration, ) -> Result { - let socket = UdpSocket::bind("0.0.0.0:0").await?; - let mut send_buffer = BytePacketBuffer::new(); query.write(&mut send_buffer)?; - - socket.send_to(send_buffer.filled(), upstream).await?; - - let mut recv_buffer = BytePacketBuffer::new(); - let (size, _) = timeout(timeout_duration, socket.recv_from(&mut recv_buffer.buf)).await??; - - if size == recv_buffer.buf.len() { - log::debug!( - "upstream response truncated ({} bytes, buffer {})", - size, - recv_buffer.buf.len() - ); - } - + let data = forward_udp_raw(send_buffer.filled(), upstream, timeout_duration).await?; + let mut recv_buffer = BytePacketBuffer::from_bytes(&data); DnsPacket::from_buffer(&mut recv_buffer) } @@ -277,13 +214,13 @@ pub(crate) async fn forward_tcp( DnsPacket::from_buffer(&mut recv_buffer) } -async fn forward_dot( - query: &DnsPacket, +async fn forward_dot_raw( + wire: &[u8], addr: SocketAddr, tls_name: &Option, connector: &tokio_rustls::TlsConnector, timeout_duration: Duration, -) -> Result { +) -> Result> { use rustls::pki_types::ServerName; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; @@ -296,10 +233,6 @@ async fn forward_dot( let tcp = timeout(timeout_duration, TcpStream::connect(addr)).await??; let mut tls = timeout(timeout_duration, connector.connect(server_name, tcp)).await??; - let mut send_buffer = BytePacketBuffer::new(); - query.write(&mut send_buffer)?; - let wire = send_buffer.filled(); - let mut outbuf = Vec::with_capacity(2 + wire.len()); outbuf.extend_from_slice(&(wire.len() as u16).to_be_bytes()); outbuf.extend_from_slice(wire); @@ -312,26 +245,178 @@ async fn forward_dot( let mut data = vec![0u8; resp_len]; timeout(timeout_duration, tls.read_exact(&mut data)).await??; - let mut recv_buffer = BytePacketBuffer::from_bytes(&data); - DnsPacket::from_buffer(&mut recv_buffer) + Ok(data) } -async fn forward_doh( - query: &DnsPacket, +pub async fn forward_query_raw( + wire: &[u8], + upstream: &Upstream, + timeout_duration: Duration, +) -> Result> { + match upstream { + Upstream::Udp(addr) => forward_udp_raw(wire, *addr, timeout_duration).await, + Upstream::Doh { url, client } => forward_doh_raw(wire, url, client, timeout_duration).await, + Upstream::Dot { + addr, + tls_name, + connector, + } => forward_dot_raw(wire, *addr, tls_name, connector, timeout_duration).await, + } +} + +pub async fn forward_with_hedging_raw( + wire: &[u8], + primary: &Upstream, + secondary: &Upstream, + hedge_delay: Duration, + timeout_duration: Duration, +) -> Result> { + use tokio::time::sleep; + + let primary_fut = forward_query_raw(wire, primary, timeout_duration); + tokio::pin!(primary_fut); + + let delay = sleep(hedge_delay); + tokio::pin!(delay); + + // Phase 1: wait for either primary to return, or the hedge delay. + tokio::select! { + result = &mut primary_fut => return result, + _ = &mut delay => {} + } + + // Phase 2: hedge delay expired — fire secondary while still polling primary. + let secondary_fut = forward_query_raw(wire, secondary, timeout_duration); + tokio::pin!(secondary_fut); + + // First successful response wins. If one errors, wait for the other. + let mut primary_err: Option = None; + let mut secondary_err: Option = None; + + loop { + tokio::select! { + r = &mut primary_fut, if primary_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if let Some(se) = secondary_err.take() { + return Err(se); + } + primary_err = Some(e); + } + } + } + r = &mut secondary_fut, if secondary_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if let Some(pe) = primary_err.take() { + return Err(pe); + } + secondary_err = Some(e); + } + } + } + } + + match (primary_err, secondary_err) { + (Some(pe), Some(_)) => return Err(pe), + (pe, se) => { + primary_err = pe; + secondary_err = se; + } + } + } +} + +pub async fn forward_with_failover_raw( + wire: &[u8], + pool: &UpstreamPool, + srtt: &RwLock, + timeout_duration: Duration, + hedge_delay: Duration, +) -> Result> { + let mut candidates: Vec<(usize, u64)> = pool + .primary + .iter() + .enumerate() + .map(|(i, u)| { + let rtt = match u { + Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()), + _ => 0, + }; + (i, rtt) + }) + .collect(); + candidates.sort_by_key(|&(_, rtt)| rtt); + + let all_upstreams: Vec<&Upstream> = candidates + .iter() + .map(|&(i, _)| &pool.primary[i]) + .chain(pool.fallback.iter()) + .collect(); + + let mut last_err: Option> = None; + + for upstream in &all_upstreams { + let start = Instant::now(); + let result = if !hedge_delay.is_zero() { + // Hedge against the same upstream: independent h2 streams (DoH), + // independent UDP packets (plain DNS), or independent TLS + // connections (DoT). Rescues packet loss, dispatch spikes, and + // TLS handshake stalls. + forward_with_hedging_raw(wire, upstream, upstream, hedge_delay, timeout_duration).await + } else { + forward_query_raw(wire, upstream, timeout_duration).await + }; + match result { + Ok(resp) => { + if let Upstream::Udp(addr) = upstream { + let rtt_ms = start.elapsed().as_millis() as u64; + srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false); + } + return Ok(resp); + } + Err(e) => { + if let Upstream::Udp(addr) = upstream { + srtt.write().unwrap().record_failure(addr.ip()); + } + log::debug!("upstream {} failed: {}", upstream, e); + last_err = Some(e); + } + } + } + + Err(last_err.unwrap_or_else(|| "no upstream configured".into())) +} + +async fn forward_udp_raw( + wire: &[u8], + upstream: SocketAddr, + timeout_duration: Duration, +) -> Result> { + let socket = UdpSocket::bind("0.0.0.0:0").await?; + socket.send_to(wire, upstream).await?; + + let mut recv_buf = vec![0u8; 4096]; + let (size, _) = timeout(timeout_duration, socket.recv_from(&mut recv_buf)).await??; + recv_buf.truncate(size); + Ok(recv_buf) +} + +async fn forward_doh_raw( + wire: &[u8], url: &str, client: &reqwest::Client, timeout_duration: Duration, -) -> Result { - let mut send_buffer = BytePacketBuffer::new(); - query.write(&mut send_buffer)?; - +) -> Result> { let resp = timeout( timeout_duration, client .post(url) .header("content-type", "application/dns-message") .header("accept", "application/dns-message") - .body(send_buffer.filled().to_vec()) + .body(wire.to_vec()) .send(), ) .await?? @@ -339,9 +424,25 @@ async fn forward_doh( let bytes = resp.bytes().await?; log::debug!("DoH response: {} bytes", bytes.len()); + Ok(bytes.to_vec()) +} - let mut recv_buffer = BytePacketBuffer::from_bytes(&bytes); - DnsPacket::from_buffer(&mut recv_buffer) +/// Send a lightweight keepalive query to a DoH upstream to prevent +/// the HTTP/2 + TLS connection from going idle and being torn down. +pub async fn keepalive_doh(upstream: &Upstream) { + if let Upstream::Doh { url, client } = upstream { + // Query for . NS — minimal, always succeeds, response is small + let wire: &[u8] = &[ + 0x00, 0x00, // ID + 0x01, 0x00, // flags: RD=1 + 0x00, 0x01, // QDCOUNT=1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // AN=0, NS=0, AR=0 + 0x00, // root name (.) + 0x00, 0x02, // type NS + 0x00, 0x01, // class IN + ]; + let _ = forward_doh_raw(wire, url, client, Duration::from_secs(5)).await; + } } #[cfg(test)] @@ -556,10 +657,19 @@ mod tests { ); let srtt = RwLock::new(SrttCache::new(true)); - let result = forward_with_failover(&query, &pool, &srtt, Duration::from_millis(500)) - .await - .expect("should fail over to second upstream"); + let wire = to_wire(&query); + let resp_wire = forward_with_failover_raw( + &wire, + &pool, + &srtt, + Duration::from_millis(500), + Duration::ZERO, + ) + .await + .expect("should fail over to second upstream"); + let mut buf = BytePacketBuffer::from_bytes(&resp_wire); + let result = DnsPacket::from_buffer(&mut buf).unwrap(); assert_eq!(result.header.id, 0xABCD); assert_eq!(result.answers.len(), 1); } diff --git a/src/lib.rs b/src/lib.rs index 4074020..92a0b00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,6 +26,7 @@ pub mod srtt; pub mod stats; pub mod system_dns; pub mod tls; +pub mod wire; pub type Error = Box; pub type Result = std::result::Result; diff --git a/src/main.rs b/src/main.rs index 7592186..1ec7791 100644 --- a/src/main.rs +++ b/src/main.rs @@ -285,6 +285,7 @@ async fn main() -> numa::Result<()> { config.cache.min_ttl, config.cache.max_ttl, )), + refreshing: Mutex::new(std::collections::HashSet::new()), stats: Mutex::new(ServerStats::new()), overrides: RwLock::new(OverrideStore::new()), blocklist: RwLock::new(blocklist), @@ -297,6 +298,7 @@ async fn main() -> numa::Result<()> { upstream_port: config.upstream.port, lan_ip: Mutex::new(numa::lan::detect_lan_ip().unwrap_or(std::net::Ipv4Addr::LOCALHOST)), timeout: Duration::from_millis(config.upstream.timeout_ms), + hedge_delay: Duration::from_millis(config.upstream.hedge_ms), proxy_tld_suffix: if config.proxy.tld.is_empty() { String::new() } else { @@ -511,6 +513,14 @@ async fn main() -> numa::Result<()> { }); } + // Spawn DoH connection keepalive — prevents idle TLS teardown + { + let keepalive_ctx = Arc::clone(&ctx); + tokio::spawn(async move { + doh_keepalive_loop(keepalive_ctx).await; + }); + } + // Spawn HTTP API server let api_ctx = Arc::clone(&ctx); let api_addr: SocketAddr = format!("{}:{}", config.server.api_bind_addr, api_port).parse()?; @@ -590,7 +600,7 @@ async fn main() -> numa::Result<()> { #[allow(clippy::infinite_loop)] loop { let mut buffer = BytePacketBuffer::new(); - let (_, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await { + let (len, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await { Ok(r) => r, Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => { // Windows delivers ICMP port-unreachable as ConnectionReset on UDP sockets @@ -598,10 +608,9 @@ async fn main() -> numa::Result<()> { } Err(e) => return Err(e.into()), }; - let ctx = Arc::clone(&ctx); tokio::spawn(async move { - if let Err(e) = handle_query(buffer, src_addr, &ctx).await { + if let Err(e) = handle_query(buffer, len, src_addr, &ctx).await { error!("{} | HANDLER ERROR | {}", src_addr, e); } }); @@ -749,30 +758,22 @@ async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) { } async fn warm_domain(ctx: &ServerCtx, domain: &str) { - use numa::question::QueryType; + for qtype in [ + numa::question::QueryType::A, + numa::question::QueryType::AAAA, + ] { + numa::ctx::refresh_entry(ctx, domain, qtype).await; + } +} - for qtype in [QueryType::A, QueryType::AAAA] { - let query = numa::packet::DnsPacket::query(0, domain, qtype); - let result = if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { - numa::recursive::resolve_recursive( - domain, - qtype, - &ctx.cache, - &query, - &ctx.root_hints, - &ctx.srtt, - ) - .await - } else { - let pool = ctx.upstream_pool.lock().unwrap().clone(); - numa::forward::forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await - }; - match result { - Ok(resp) => { - ctx.cache.write().unwrap().insert(domain, qtype, &resp); - log::debug!("cache warm: {} {:?}", domain, qtype); - } - Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), +async fn doh_keepalive_loop(ctx: Arc) { + let mut interval = tokio::time::interval(Duration::from_secs(25)); + interval.tick().await; // skip first immediate tick + loop { + interval.tick().await; + let pool = ctx.upstream_pool.lock().unwrap().clone(); + if let Some(upstream) = pool.preferred() { + numa::forward::keepalive_doh(upstream).await; } } } diff --git a/src/recursive.rs b/src/recursive.rs index 24d0367..a4dff08 100644 --- a/src/recursive.rs +++ b/src/recursive.rs @@ -15,8 +15,8 @@ use crate::srtt::SrttCache; const MAX_REFERRAL_DEPTH: u8 = 10; const MAX_CNAME_DEPTH: u8 = 8; -const NS_QUERY_TIMEOUT: Duration = Duration::from_millis(800); -const TCP_TIMEOUT: Duration = Duration::from_millis(1500); +const NS_QUERY_TIMEOUT: Duration = Duration::from_millis(400); +const TCP_TIMEOUT: Duration = Duration::from_millis(400); const UDP_FAIL_THRESHOLD: u8 = 3; static QUERY_ID: AtomicU16 = AtomicU16::new(1); @@ -202,23 +202,24 @@ pub(crate) fn resolve_iterative<'a>( let mut ns_idx = 0; for _ in 0..MAX_REFERRAL_DEPTH { - let ns_addr = match ns_addrs.get(ns_idx) { - Some(addr) => *addr, - None => return Err("no nameserver available".into()), - }; + if ns_idx >= ns_addrs.len() { + return Err("no nameserver available".into()); + } let (q_name, q_type) = minimize_query(qname, qtype, ¤t_zone); debug!( - "recursive: querying {} for {:?} {} (zone: {}, depth {})", - ns_addr, q_type, q_name, current_zone, referral_depth + "recursive: querying {} (+ hedge) for {:?} {} (zone: {}, depth {})", + ns_addrs[ns_idx], q_type, q_name, current_zone, referral_depth ); - let response = match send_query(q_name, q_type, ns_addr, srtt).await { + let response = match send_query_hedged(q_name, q_type, &ns_addrs[ns_idx..], srtt).await + { Ok(r) => r, Err(e) => { - debug!("recursive: NS {} failed: {}", ns_addr, e); - ns_idx += 1; + debug!("recursive: NS query failed: {}", e); + let remaining = ns_addrs.len().saturating_sub(ns_idx); + ns_idx += remaining.min(2); continue; } }; @@ -228,6 +229,9 @@ pub(crate) fn resolve_iterative<'a>( { if let Some(zone) = referral_zone(&response) { current_zone = zone; + let mut cache_w = cache.write().unwrap(); + cache_ns_delegation(&mut cache_w, ¤t_zone, &response); + drop(cache_w); } let mut all_ns = extract_ns_from_records(&response.answers); if all_ns.is_empty() { @@ -296,6 +300,7 @@ pub(crate) fn resolve_iterative<'a>( { let mut cache_w = cache.write().unwrap(); + cache_ns_delegation(&mut cache_w, ¤t_zone, &response); cache_ds_from_authority(&mut cache_w, &response); } let mut new_ns_addrs = resolve_ns_addrs_from_glue(&response, &ns_names, cache); @@ -560,6 +565,23 @@ fn cache_ds_from_authority(cache: &mut DnsCache, response: &DnsPacket) { } } +/// Cache NS delegation records from a referral response so that +/// `find_closest_ns` can skip re-querying TLD servers on subsequent lookups. +fn cache_ns_delegation(cache: &mut DnsCache, zone: &str, response: &DnsPacket) { + let ns_records: Vec<_> = response + .authorities + .iter() + .filter(|r| matches!(r, DnsRecord::NS { .. })) + .cloned() + .collect(); + if ns_records.is_empty() { + return; + } + let mut pkt = make_glue_packet(); + pkt.answers = ns_records; + cache.insert(zone, QueryType::NS, &pkt); +} + fn make_glue_packet() -> DnsPacket { let mut pkt = DnsPacket::new(); pkt.header.response = true; @@ -587,6 +609,115 @@ async fn tcp_with_srtt( } } +/// Smart NS query: fire to two servers simultaneously when SRTT is unknown +/// (cold queries), or to the best server with SRTT-based hedge when known. +async fn send_query_hedged( + qname: &str, + qtype: QueryType, + servers: &[SocketAddr], + srtt: &RwLock, +) -> crate::Result { + if servers.is_empty() { + return Err("no nameserver available".into()); + } + if servers.len() == 1 { + return send_query(qname, qtype, servers[0], srtt).await; + } + + let primary = servers[0]; + let secondary = servers[1]; + let primary_known = srtt.read().unwrap().is_known(primary.ip()); + + if !primary_known { + // Cold: fire both simultaneously, first response wins + debug!( + "recursive: parallel query to {} and {} for {:?} {}", + primary, secondary, qtype, qname + ); + let fut_a = send_query(qname, qtype, primary, srtt); + let fut_b = send_query(qname, qtype, secondary, srtt); + tokio::pin!(fut_a); + tokio::pin!(fut_b); + + // First Ok wins. If one errors, wait for the other. + let mut a_done = false; + let mut b_done = false; + let mut a_err: Option = None; + let mut b_err: Option = None; + + loop { + tokio::select! { + r = &mut fut_a, if !a_done => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { a_done = true; a_err = Some(e); } + } + } + r = &mut fut_b, if !b_done => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { b_done = true; b_err = Some(e); } + } + } + } + match (a_err.take(), b_err.take()) { + (Some(e), Some(_)) => return Err(e), + (a, b) => { + a_err = a; + b_err = b; + } + } + } + } else { + // Warm: send to best, hedge after SRTT × 3 if slow + let hedge_ms = srtt.read().unwrap().get(primary.ip()) * 3; + let hedge_delay = Duration::from_millis(hedge_ms.max(50)); + + let fut_a = send_query(qname, qtype, primary, srtt); + tokio::pin!(fut_a); + let delay = tokio::time::sleep(hedge_delay); + tokio::pin!(delay); + + tokio::select! { + r = &mut fut_a => return r, + _ = &mut delay => {} + } + + debug!( + "recursive: hedging {} -> {} after {}ms for {:?} {}", + primary, secondary, hedge_ms, qtype, qname + ); + let fut_b = send_query(qname, qtype, secondary, srtt); + tokio::pin!(fut_b); + + // First Ok wins; if one errors, wait for the other. + let mut a_err: Option = None; + let mut b_err: Option = None; + loop { + tokio::select! { + r = &mut fut_a, if a_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if b_err.is_some() { return Err(e); } + a_err = Some(e); + } + } + } + r = &mut fut_b, if b_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if let Some(ae) = a_err.take() { return Err(ae); } + b_err = Some(e); + } + } + } + } + } + } +} + async fn send_query( qname: &str, qtype: QueryType, @@ -634,9 +765,13 @@ async fn send_query( "send_query: {} consecutive UDP failures — switching to TCP-first", fails ); + // Now that UDP is disabled, retry this query via TCP + return tcp_with_srtt(&query, server, srtt, start).await; } - debug!("send_query: UDP failed for {}: {}, trying TCP", server, e); - tcp_with_srtt(&query, server, srtt, start).await + // UDP works in general (priming succeeded) but this server timed out. + // Don't waste another 400ms on TCP — the server is unreachable. + srtt.write().unwrap().record_failure(server.ip()); + Err(e) } } } @@ -678,6 +813,10 @@ mod tests { use super::*; use std::net::{Ipv4Addr, Ipv6Addr}; + /// Tests that mutate the global UDP_DISABLED / UDP_FAILURES flags must hold + /// this lock to avoid racing with each other under `cargo test` parallelism. + static UDP_STATE_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + #[test] fn extract_ns_from_authority() { let mut pkt = DnsPacket::new(); @@ -916,10 +1055,11 @@ mod tests { } /// TCP-only server returns authoritative answer directly. - /// Verifies: UDP fails → TCP fallback → resolves. + /// Verifies: when UDP is disabled, TCP-first resolves. #[tokio::test] async fn tcp_fallback_resolves_when_udp_blocked() { - UDP_DISABLED.store(false, Ordering::Relaxed); + let _guard = UDP_STATE_LOCK.lock().unwrap(); + UDP_DISABLED.store(true, Ordering::Relaxed); UDP_FAILURES.store(0, Ordering::Release); let server_addr = spawn_tcp_dns_server(|query| { @@ -950,49 +1090,32 @@ mod tests { } } - /// Full iterative resolution through TCP-only mock: root referral → authoritative answer. - /// The mock plays both roles (returns referral for NS queries, answer for A queries). + /// TCP round-trip through mock: query → authoritative answer via forward_tcp. + /// Uses forward_tcp directly to avoid dependence on the global UDP_DISABLED flag + /// which is shared across concurrent tests. #[tokio::test] async fn tcp_only_iterative_resolution() { - UDP_DISABLED.store(true, Ordering::Release); // Skip UDP entirely for speed - let server_addr = spawn_tcp_dns_server(|query| { let q = match query.questions.first() { Some(q) => q, None => return DnsPacket::response_from(query, ResultCode::SERVFAIL), }; - if q.qtype == QueryType::NS || q.name == "com" { - // Return referral — NS points back to ourselves (same IP, port 53 in glue - // won't work, but cache will have our address from root_hints) - let mut resp = DnsPacket::new(); - resp.header.id = query.header.id; - resp.header.response = true; - resp.header.rescode = ResultCode::NOERROR; - resp.questions = query.questions.clone(); - resp.authorities.push(DnsRecord::NS { - domain: "com".into(), - host: "ns1.com".into(), - ttl: 3600, - }); - resp - } else { - // Return authoritative answer - let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR); - resp.header.authoritative_answer = true; - resp.answers.push(DnsRecord::A { - domain: q.name.clone(), - addr: Ipv4Addr::new(10, 0, 0, 42), - ttl: 300, - }); - resp - } + let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR); + resp.header.authoritative_answer = true; + resp.answers.push(DnsRecord::A { + domain: q.name.clone(), + addr: Ipv4Addr::new(10, 0, 0, 42), + ttl: 300, + }); + resp }) .await; - let srtt = RwLock::new(SrttCache::new(true)); - let result = send_query("hello.example.com", QueryType::A, server_addr, &srtt).await; - let resp = result.expect("TCP-only send_query should work"); + let query = DnsPacket::query(0x1234, "hello.example.com", QueryType::A); + let resp = crate::forward::forward_tcp(&query, server_addr, TCP_TIMEOUT) + .await + .expect("TCP query should work"); assert_eq!(resp.header.rescode, ResultCode::NOERROR); match &resp.answers[0] { DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::new(10, 0, 0, 42)), @@ -1002,7 +1125,8 @@ mod tests { #[tokio::test] async fn tcp_fallback_handles_nxdomain() { - UDP_DISABLED.store(false, Ordering::Relaxed); + let _guard = UDP_STATE_LOCK.lock().unwrap(); + UDP_DISABLED.store(true, Ordering::Relaxed); UDP_FAILURES.store(0, Ordering::Release); let server_addr = spawn_tcp_dns_server(|query| { @@ -1034,6 +1158,7 @@ mod tests { #[tokio::test] async fn udp_auto_disable_resets() { + let _guard = UDP_STATE_LOCK.lock().unwrap(); UDP_DISABLED.store(true, Ordering::Release); UDP_FAILURES.store(5, Ordering::Relaxed); diff --git a/src/srtt.rs b/src/srtt.rs index f763a37..fe4df1e 100644 --- a/src/srtt.rs +++ b/src/srtt.rs @@ -45,6 +45,11 @@ impl SrttCache { } } + /// Whether we have observed RTT data for this IP. + pub fn is_known(&self, ip: IpAddr) -> bool { + self.entries.contains_key(&ip) + } + /// Apply time-based decay: each DECAY_AFTER_SECS period halves distance to INITIAL. fn decayed_srtt(entry: &SrttEntry) -> u64 { Self::decay_for_age(entry.srtt_ms, entry.updated_at.elapsed().as_secs()) diff --git a/src/wire.rs b/src/wire.rs new file mode 100644 index 0000000..3ee2ab3 --- /dev/null +++ b/src/wire.rs @@ -0,0 +1,1416 @@ +//! Wire-level DNS utilities: question extraction, TTL offset scanning, and patching. +//! +//! These operate directly on raw DNS wire bytes without full packet parsing, +//! enabling zero-copy forwarding and wire-level caching. + +use crate::Result; + +/// Metadata extracted from scanning a DNS response's wire bytes. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WireMeta { + /// Byte offsets of every TTL field in answer + authority + additional sections. + /// Each offset points to the first byte of a 4-byte big-endian TTL. + /// EDNS OPT pseudo-records are excluded (their "TTL" is flags, not a real TTL). + pub ttl_offsets: Vec, + /// How many of the offsets belong to the answer section (the first `answer_count` + /// entries). Used to extract min-TTL from answers only. + pub answer_count: usize, +} + +/// Scan a DNS response's wire bytes and return metadata about TTL field locations. +/// +/// Walks the header, skips the question section, then for each resource record in +/// answer, authority, and additional sections, records the byte offset of the TTL +/// field. EDNS OPT records (type 41 with root name) are excluded. +pub fn scan_ttl_offsets(wire: &[u8]) -> Result { + if wire.len() < 12 { + return Err("wire too short for DNS header".into()); + } + + let qdcount = u16::from_be_bytes([wire[4], wire[5]]) as usize; + let ancount = u16::from_be_bytes([wire[6], wire[7]]) as usize; + let nscount = u16::from_be_bytes([wire[8], wire[9]]) as usize; + let arcount = u16::from_be_bytes([wire[10], wire[11]]) as usize; + + let mut pos = 12; + + // Skip question section + for _ in 0..qdcount { + skip_wire_name(wire, &mut pos)?; + if pos + 4 > wire.len() { + return Err("wire truncated in question section".into()); + } + pos += 4; // QTYPE(2) + QCLASS(2) + } + + let mut ttl_offsets = Vec::new(); + + // Process answer + authority + additional sections + let section_counts = [ancount, nscount, arcount]; + let mut answer_offset_count = 0; + + for (section_idx, &count) in section_counts.iter().enumerate() { + for _ in 0..count { + // Check if this is an OPT record: root name (0x00) + type 41 + let is_opt = pos < wire.len() + && wire[pos] == 0x00 + && pos + 3 <= wire.len() + && u16::from_be_bytes([wire[pos + 1], wire[pos + 2]]) == 41; + + // Skip name + skip_wire_name(wire, &mut pos)?; + + if pos + 10 > wire.len() { + return Err("wire truncated in resource record".into()); + } + + // TYPE(2) + CLASS(2) = 4 bytes before TTL + let ttl_offset = pos + 4; + + if !is_opt { + ttl_offsets.push(ttl_offset); + if section_idx == 0 { + answer_offset_count += 1; + } + } + + // Skip TYPE(2) + CLASS(2) + TTL(4) + RDLENGTH(2) = 10 bytes + let rdlength = u16::from_be_bytes([wire[pos + 8], wire[pos + 9]]) as usize; + pos += 10 + rdlength; + + if pos > wire.len() { + return Err("wire truncated in resource record RDATA".into()); + } + } + } + + Ok(WireMeta { + ttl_offsets, + answer_count: answer_offset_count, + }) +} + +/// Extract the minimum TTL from the answer section offsets of a wire response. +pub fn min_ttl_from_wire(wire: &[u8], meta: &WireMeta) -> Option { + meta.ttl_offsets + .iter() + .take(meta.answer_count) + .filter_map(|&off| { + if off + 4 <= wire.len() { + Some(u32::from_be_bytes([ + wire[off], + wire[off + 1], + wire[off + 2], + wire[off + 3], + ])) + } else { + None + } + }) + .min() +} + +/// Patch the transaction ID (bytes 0..2) in a DNS wire message. +pub fn patch_id(wire: &mut [u8], new_id: u16) { + let bytes = new_id.to_be_bytes(); + wire[0] = bytes[0]; + wire[1] = bytes[1]; +} + +/// Patch all TTL fields at the given offsets to `new_ttl`. +pub fn patch_ttls(wire: &mut [u8], offsets: &[usize], new_ttl: u32) { + let bytes = new_ttl.to_be_bytes(); + for &off in offsets { + wire[off] = bytes[0]; + wire[off + 1] = bytes[1]; + wire[off + 2] = bytes[2]; + wire[off + 3] = bytes[3]; + } +} + +/// Skip a DNS name in wire bytes, advancing `pos` past it. +fn skip_wire_name(wire: &[u8], pos: &mut usize) -> Result<()> { + loop { + if *pos >= wire.len() { + return Err("wire truncated skipping name".into()); + } + let len = wire[*pos] as usize; + + if len & 0xC0 == 0xC0 { + *pos += 2; // compression pointer is 2 bytes + return Ok(()); + } + if len == 0 { + *pos += 1; + return Ok(()); + } + *pos += 1 + len; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::buffer::BytePacketBuffer; + use crate::cache::{DnsCache, DnssecStatus}; + use crate::header::ResultCode; + use crate::packet::{DnsPacket, EdnsOpt}; + use crate::question::{DnsQuestion, QueryType}; + use crate::record::DnsRecord; + + // ── Helpers ────────────────────────────────────────────────────── + + /// Serialize a DnsPacket to wire bytes. + fn to_wire(pkt: &DnsPacket) -> Vec { + let mut buf = BytePacketBuffer::new(); + pkt.write(&mut buf).unwrap(); + buf.filled().to_vec() + } + + /// Build a minimal response with given answers. + fn response(id: u16, domain: &str, answers: Vec) -> DnsPacket { + let mut pkt = DnsPacket::new(); + pkt.header.id = id; + pkt.header.response = true; + pkt.header.recursion_desired = true; + pkt.header.recursion_available = true; + pkt.header.rescode = ResultCode::NOERROR; + pkt.questions + .push(DnsQuestion::new(domain.to_string(), QueryType::A)); + pkt.answers = answers; + pkt + } + + fn a_record(domain: &str, ip: &str, ttl: u32) -> DnsRecord { + DnsRecord::A { + domain: domain.into(), + addr: ip.parse().unwrap(), + ttl, + } + } + + fn aaaa_record(domain: &str, ip: &str, ttl: u32) -> DnsRecord { + DnsRecord::AAAA { + domain: domain.into(), + addr: ip.parse().unwrap(), + ttl, + } + } + + fn cname_record(domain: &str, host: &str, ttl: u32) -> DnsRecord { + DnsRecord::CNAME { + domain: domain.into(), + host: host.into(), + ttl, + } + } + + fn ns_record(domain: &str, host: &str, ttl: u32) -> DnsRecord { + DnsRecord::NS { + domain: domain.into(), + host: host.into(), + ttl, + } + } + + fn mx_record(domain: &str, host: &str, priority: u16, ttl: u32) -> DnsRecord { + DnsRecord::MX { + domain: domain.into(), + priority, + host: host.into(), + ttl, + } + } + + // ── A. TTL offset extraction ──────────────────────────────────── + + #[test] + fn scan_single_a_record() { + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 1); + assert_eq!(meta.answer_count, 1); + + let off = meta.ttl_offsets[0]; + let ttl = u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]); + assert_eq!(ttl, 300); + } + + #[test] + fn scan_multiple_a_records() { + let pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 600), + a_record("example.com", "9.10.11.12", 120), + ], + ); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 3); + assert_eq!(meta.answer_count, 3); + + let ttls: Vec = meta + .ttl_offsets + .iter() + .map(|&off| { + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]) + }) + .collect(); + assert_eq!(ttls, vec![300, 600, 120]); + } + + #[test] + fn scan_mixed_sections() { + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + pkt.authorities + .push(ns_record("example.com", "ns1.example.com", 3600)); + pkt.authorities + .push(ns_record("example.com", "ns2.example.com", 3600)); + pkt.resources + .push(a_record("ns1.example.com", "10.0.0.1", 1800)); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 4); // 1 answer + 2 authority + 1 additional + assert_eq!(meta.answer_count, 1); + } + + #[test] + fn scan_cname_chain() { + let pkt = response( + 0x1234, + "www.example.com", + vec![ + cname_record("www.example.com", "example.com", 300), + a_record("example.com", "1.2.3.4", 600), + ], + ); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 2); + assert_eq!(meta.answer_count, 2); + + let ttls: Vec = meta + .ttl_offsets + .iter() + .map(|&off| { + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]) + }) + .collect(); + assert_eq!(ttls, vec![300, 600]); + } + + #[test] + fn scan_compressed_names() { + // Build a packet with name compression (the serializer uses compression + // for repeated domain names). Two A records for the same domain will + // have the second name compressed as a pointer. + let pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 600), + ], + ); + let wire = to_wire(&pkt); + + // Verify compression is actually present (second name should be a pointer) + // The first answer's name is at some offset, and the second should use 0xC0xx + let meta = scan_ttl_offsets(&wire).unwrap(); + assert_eq!(meta.ttl_offsets.len(), 2); + + let ttls: Vec = meta + .ttl_offsets + .iter() + .map(|&off| { + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]) + }) + .collect(); + assert_eq!(ttls, vec![300, 600]); + } + + #[test] + fn scan_edns_opt_excluded() { + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: false, + options: vec![], + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + // Only the A record's TTL, not the OPT pseudo-record's "TTL" + assert_eq!(meta.ttl_offsets.len(), 1); + assert_eq!(meta.answer_count, 1); + } + + #[test] + fn scan_rrsig_only_wire_ttl() { + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + pkt.answers.push(DnsRecord::RRSIG { + domain: "example.com".into(), + type_covered: 1, // A + algorithm: 13, + labels: 2, + original_ttl: 9999, // must NOT appear in offsets + expiration: 1700000000, + inception: 1690000000, + key_tag: 12345, + signer_name: "example.com".into(), + signature: vec![0x01, 0x02, 0x03, 0x04], + ttl: 300, + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + // 2 TTL offsets: A record + RRSIG wire TTL + assert_eq!(meta.ttl_offsets.len(), 2); + assert_eq!(meta.answer_count, 2); + + // Both wire TTLs should be 300, not 9999 + for &off in &meta.ttl_offsets { + let ttl = u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]); + assert_eq!(ttl, 300); + } + + // Verify that 9999 (original_ttl) exists somewhere in the wire but is NOT in offsets + let original_ttl_bytes = 9999u32.to_be_bytes(); + let found_at = wire + .windows(4) + .position(|w| w == original_ttl_bytes) + .expect("original_ttl should be in wire"); + assert!( + !meta.ttl_offsets.contains(&found_at), + "original_ttl offset must not be in ttl_offsets" + ); + } + + #[test] + fn scan_nsec_variable_rdata() { + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + pkt.authorities.push(DnsRecord::NSEC { + domain: "example.com".into(), + next_domain: "z.example.com".into(), + type_bitmap: vec![0x00, 0x06, 0x40, 0x01, 0x00, 0x00, 0x00, 0x03], + ttl: 1800, + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 2); // A + NSEC + assert_eq!(meta.answer_count, 1); + + let nsec_ttl_off = meta.ttl_offsets[1]; + let ttl = u32::from_be_bytes([ + wire[nsec_ttl_off], + wire[nsec_ttl_off + 1], + wire[nsec_ttl_off + 2], + wire[nsec_ttl_off + 3], + ]); + assert_eq!(ttl, 1800); + } + + #[test] + fn scan_empty_response() { + let pkt = response(0x1234, "nxdomain.example.com", vec![]); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert!(meta.ttl_offsets.is_empty()); + assert_eq!(meta.answer_count, 0); + } + + #[test] + fn scan_unknown_record_type() { + // Manually build a response with an unknown type (99) using raw wire bytes + let mut pkt = response(0x1234, "example.com", vec![]); + pkt.answers.push(DnsRecord::UNKNOWN { + domain: "example.com".into(), + qtype: 99, + data: vec![0xDE, 0xAD, 0xBE, 0xEF], + ttl: 500, + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 1); + let off = meta.ttl_offsets[0]; + let ttl = u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]); + assert_eq!(ttl, 500); + } + + #[test] + fn scan_truncated_wire_returns_error() { + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let wire = to_wire(&pkt); + // Truncate mid-record + let truncated = &wire[..wire.len() - 2]; + assert!(scan_ttl_offsets(truncated).is_err()); + } + + #[test] + fn scan_too_short_for_header() { + assert!(scan_ttl_offsets(&[0u8; 5]).is_err()); + } + + #[test] + fn scan_query_packet_no_offsets() { + let pkt = DnsPacket::query(0x1234, "example.com", QueryType::A); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + assert!(meta.ttl_offsets.is_empty()); + } + + // ── B. TTL patching ───────────────────────────────────────────── + + #[test] + fn patch_ttl_single() { + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, 120); + + let off = meta.ttl_offsets[0]; + assert_eq!( + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]), + 120 + ); + } + + #[test] + fn patch_ttl_multiple() { + let pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 600), + a_record("example.com", "9.10.11.12", 900), + ], + ); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, 42); + + for &off in &meta.ttl_offsets { + assert_eq!( + u32::from_be_bytes([wire[off], wire[off + 1], wire[off + 2], wire[off + 3]]), + 42 + ); + } + } + + #[test] + fn patch_ttl_preserves_other_bytes() { + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let original = to_wire(&pkt); + let mut patched = original.clone(); + let meta = scan_ttl_offsets(&patched).unwrap(); + + patch_ttls(&mut patched, &meta.ttl_offsets, 120); + + // Every byte outside TTL offsets should be identical + for (i, (&orig, &patc)) in original.iter().zip(patched.iter()).enumerate() { + let in_ttl = meta.ttl_offsets.iter().any(|&off| i >= off && i < off + 4); + if !in_ttl { + assert_eq!( + orig, patc, + "byte {} changed (outside TTL): orig={:#04x}, patched={:#04x}", + i, orig, patc + ); + } + } + } + + #[test] + fn patch_ttl_zero() { + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, 0); + + let off = meta.ttl_offsets[0]; + assert_eq!(&wire[off..off + 4], &[0, 0, 0, 0]); + } + + #[test] + fn patch_ttl_max_u32() { + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let mut wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + patch_ttls(&mut wire, &meta.ttl_offsets, u32::MAX); + + let off = meta.ttl_offsets[0]; + assert_eq!(&wire[off..off + 4], &[0xFF, 0xFF, 0xFF, 0xFF]); + } + + #[test] + fn patch_ttl_edns_untouched() { + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: true, + options: vec![], + }); + let original = to_wire(&pkt); + let mut patched = original.clone(); + let meta = scan_ttl_offsets(&patched).unwrap(); + + patch_ttls(&mut patched, &meta.ttl_offsets, 42); + + // Only the A record's TTL bytes should differ; everything else + // (including the OPT "TTL" containing the DO bit) must be unchanged. + for (i, (&orig, &patc)) in original.iter().zip(patched.iter()).enumerate() { + let in_ttl = meta.ttl_offsets.iter().any(|&off| i >= off && i < off + 4); + if !in_ttl { + assert_eq!( + orig, patc, + "byte {} changed (outside TTL): orig={:#04x}, patched={:#04x}", + i, orig, patc + ); + } + } + } + + // ── C. ID patching ────────────────────────────────────────────── + + #[test] + fn patch_id_basic() { + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let mut wire = to_wire(&pkt); + + patch_id(&mut wire, 0xABCD); + assert_eq!(&wire[0..2], &[0xAB, 0xCD]); + } + + #[test] + fn patch_id_preserves_flags() { + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let original = to_wire(&pkt); + let mut patched = original.clone(); + + patch_id(&mut patched, 0x9999); + + // Bytes 2..12 (flags + counts) unchanged + assert_eq!(&original[2..12], &patched[2..12]); + } + + #[test] + fn patch_id_zero() { + let pkt = response( + 0xFFFF, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let mut wire = to_wire(&pkt); + + patch_id(&mut wire, 0x0000); + assert_eq!(&wire[0..2], &[0x00, 0x00]); + } + + // ── D. min_ttl_from_wire ──────────────────────────────────────── + + #[test] + fn min_ttl_answers_only() { + let mut pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + a_record("example.com", "5.6.7.8", 60), + ], + ); + pkt.authorities + .push(ns_record("example.com", "ns1.example.com", 10)); // lower but in authority, not answer + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(min_ttl_from_wire(&wire, &meta), Some(60)); // from answers only + } + + #[test] + fn min_ttl_empty_answers() { + let pkt = response(0x1234, "example.com", vec![]); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + assert_eq!(min_ttl_from_wire(&wire, &meta), None); + } + + // ── F. Round-trip fidelity ────────────────────────────────────── + // + // These verify that wire bytes → scan → patch → parse produces the + // same semantic content as the original packet. They test the full + // integration path that the wire-level cache will use. + + #[test] + fn round_trip_simple_a() { + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + let mut patched = wire.clone(); + patch_id(&mut patched, 0xABCD); + patch_ttls(&mut patched, &meta.ttl_offsets, 120); + + // Parse the patched wire + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.header.id, 0xABCD); + assert_eq!(parsed.answers.len(), 1); + match &parsed.answers[0] { + DnsRecord::A { domain, addr, ttl } => { + assert_eq!(domain, "example.com"); + assert_eq!(*addr, "1.2.3.4".parse::().unwrap()); + assert_eq!(*ttl, 120); + } + other => panic!("expected A record, got {:?}", other), + } + } + + #[test] + fn round_trip_edns_survives() { + let mut pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: true, + options: vec![], + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 42); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + let edns = parsed.edns.as_ref().expect("EDNS should survive"); + assert_eq!(edns.udp_payload_size, 1232); + assert!(edns.do_bit); + } + + #[test] + fn round_trip_dnssec_full() { + let mut pkt = response( + 0x1234, + "example.com", + vec![ + a_record("example.com", "1.2.3.4", 300), + DnsRecord::RRSIG { + domain: "example.com".into(), + type_covered: 1, + algorithm: 13, + labels: 2, + original_ttl: 300, + expiration: 1700000000, + inception: 1690000000, + key_tag: 12345, + signer_name: "example.com".into(), + signature: vec![1, 2, 3, 4, 5, 6, 7, 8], + ttl: 300, + }, + ], + ); + pkt.authorities.push(DnsRecord::NSEC { + domain: "example.com".into(), + next_domain: "z.example.com".into(), + type_bitmap: vec![0x00, 0x06, 0x40, 0x01, 0x00, 0x00, 0x00, 0x03], + ttl: 300, + }); + pkt.resources.push(DnsRecord::DNSKEY { + domain: "example.com".into(), + flags: 257, + protocol: 3, + algorithm: 13, + public_key: vec![10, 20, 30, 40], + ttl: 3600, + }); + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + extended_rcode: 0, + version: 0, + do_bit: true, + options: vec![], + }); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + // 4 TTL offsets: A + RRSIG (answers) + NSEC (authority) + DNSKEY (additional) + // OPT excluded + assert_eq!(meta.ttl_offsets.len(), 4); + assert_eq!(meta.answer_count, 2); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 42); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.answers.len(), 2); + assert_eq!(parsed.authorities.len(), 1); + assert_eq!(parsed.resources.len(), 1); + assert!(parsed.edns.is_some()); + + // All TTLs should be 42 now + for ans in &parsed.answers { + assert_eq!(ans.ttl(), 42); + } + for auth in &parsed.authorities { + assert_eq!(auth.ttl(), 42); + } + for res in &parsed.resources { + assert_eq!(res.ttl(), 42); + } + + // RRSIG original_ttl must be preserved (it's inside RDATA, not a wire TTL) + match &parsed.answers[1] { + DnsRecord::RRSIG { original_ttl, .. } => assert_eq!(*original_ttl, 300), + other => panic!("expected RRSIG, got {:?}", other), + } + } + + #[test] + fn round_trip_nxdomain_soa() { + let mut pkt = DnsPacket::new(); + pkt.header.id = 0x5678; + pkt.header.response = true; + pkt.header.rescode = ResultCode::NXDOMAIN; + pkt.questions + .push(DnsQuestion::new("missing.example.com".into(), QueryType::A)); + // SOA in authority (we don't have a SOA variant, so use NS as proxy for offset testing) + pkt.authorities + .push(ns_record("example.com", "ns1.example.com", 900)); + + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 1); + assert_eq!(meta.answer_count, 0); // no answers, only authority + + let mut patched = wire.clone(); + patch_id(&mut patched, 0x9999); + patch_ttls(&mut patched, &meta.ttl_offsets, 60); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.header.id, 0x9999); + assert_eq!(parsed.header.rescode, ResultCode::NXDOMAIN); + assert_eq!(parsed.authorities[0].ttl(), 60); + } + + #[test] + fn round_trip_mx_record() { + let pkt = response( + 0x1234, + "example.com", + vec![mx_record("example.com", "mail.example.com", 10, 3600)], + ); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 100); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + match &parsed.answers[0] { + DnsRecord::MX { + domain, + priority, + host, + ttl, + } => { + assert_eq!(domain, "example.com"); + assert_eq!(*priority, 10); + assert_eq!(host, "mail.example.com"); + assert_eq!(*ttl, 100); + } + other => panic!("expected MX, got {:?}", other), + } + } + + #[test] + fn round_trip_many_records() { + let answers: Vec = (0..20) + .map(|i| a_record("example.com", &format!("10.0.0.{}", i), 300 + i * 10)) + .collect(); + let pkt = response(0x1234, "example.com", answers); + let wire = to_wire(&pkt); + let meta = scan_ttl_offsets(&wire).unwrap(); + + assert_eq!(meta.ttl_offsets.len(), 20); + + let mut patched = wire.clone(); + patch_ttls(&mut patched, &meta.ttl_offsets, 1); + + let mut buf = BytePacketBuffer::from_bytes(&patched); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + assert_eq!(parsed.answers.len(), 20); + for ans in &parsed.answers { + assert_eq!(ans.ttl(), 1); + } + } + + // ── G. Edge cases ─────────────────────────────────────────────── + + #[test] + fn scan_rejects_empty_wire() { + assert!(scan_ttl_offsets(&[]).is_err()); + } + + // ── G. Cache behavior tests ───────────────────────────────────── + // + // These test existing DnsCache behavior that must be preserved after + // the wire-level migration. They use the current parsed-packet API + // and serve as a regression suite. + + #[test] + fn cache_insert_lookup_hit() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt); + + let (result, status, _) = cache + .lookup_with_status("example.com", QueryType::A) + .expect("should hit"); + assert_eq!(result.answers.len(), 1); + assert_eq!(status, DnssecStatus::Indeterminate); + } + + #[test] + fn cache_lookup_adjusts_ttl() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt); + + let (result, _, _) = cache + .lookup_with_status("example.com", QueryType::A) + .unwrap(); + // TTL should be <= 300 (at most original, reduced by elapsed time) + assert!(result.answers[0].ttl() <= 300); + assert!(result.answers[0].ttl() > 0); + } + + #[test] + fn cache_miss_wrong_domain() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt); + + assert!(cache + .lookup_with_status("other.com", QueryType::A) + .is_none()); + } + + #[test] + fn cache_miss_wrong_qtype() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt); + + assert!(cache + .lookup_with_status("example.com", QueryType::AAAA) + .is_none()); + } + + #[test] + fn cache_overwrite_no_double_count() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt1 = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let pkt2 = response( + 0x5678, + "example.com", + vec![a_record("example.com", "5.6.7.8", 600)], + ); + + cache.insert("example.com", QueryType::A, &pkt1); + assert_eq!(cache.len(), 1); + + cache.insert("example.com", QueryType::A, &pkt2); + assert_eq!(cache.len(), 1); // no double count + + let (result, _, _) = cache + .lookup_with_status("example.com", QueryType::A) + .unwrap(); + match &result.answers[0] { + DnsRecord::A { addr, .. } => { + assert_eq!(*addr, "5.6.7.8".parse::().unwrap()) + } + _ => panic!("expected A record"), + } + } + + #[test] + fn cache_ttl_clamped_min() { + let mut cache = DnsCache::new(100, 60, 3600); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 5)], + ); + cache.insert("example.com", QueryType::A, &pkt); + + let (remaining, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); + assert_eq!(total, 60); // clamped up from 5 + assert!(remaining <= 60); + } + + #[test] + fn cache_ttl_clamped_max() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 999999)], + ); + cache.insert("example.com", QueryType::A, &pkt); + + let (_, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); + assert_eq!(total, 3600); // clamped down from 999999 + } + + #[test] + fn cache_len_empty_clear() { + let mut cache = DnsCache::new(100, 1, 3600); + assert!(cache.is_empty()); + assert_eq!(cache.len(), 0); + + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt); + assert!(!cache.is_empty()); + assert_eq!(cache.len(), 1); + + cache.clear(); + assert!(cache.is_empty()); + assert_eq!(cache.len(), 0); + assert!(cache.lookup("example.com", QueryType::A).is_none()); + } + + #[test] + fn cache_remove_domain() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt_a = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let pkt_aaaa = response( + 0x5678, + "example.com", + vec![aaaa_record("example.com", "::1", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt_a); + cache.insert("example.com", QueryType::AAAA, &pkt_aaaa); + assert_eq!(cache.len(), 2); + + cache.remove("example.com"); + assert_eq!(cache.len(), 0); + assert!(cache.lookup("example.com", QueryType::A).is_none()); + assert!(cache.lookup("example.com", QueryType::AAAA).is_none()); + } + + #[test] + fn cache_list_entries() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt_a = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let pkt_b = response( + 0x5678, + "test.org", + vec![a_record("test.org", "5.6.7.8", 600)], + ); + cache.insert("example.com", QueryType::A, &pkt_a); + cache.insert("test.org", QueryType::A, &pkt_b); + + let list = cache.list(); + assert_eq!(list.len(), 2); + let domains: Vec<&str> = list.iter().map(|e| e.domain.as_str()).collect(); + assert!(domains.contains(&"example.com")); + assert!(domains.contains(&"test.org")); + } + + #[test] + fn cache_heap_bytes_grows() { + let mut cache = DnsCache::new(100, 1, 3600); + let empty = cache.heap_bytes(); + + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt); + assert!(cache.heap_bytes() > empty); + } + + #[test] + fn cache_needs_warm_behavior() { + let mut cache = DnsCache::new(100, 1, 3600); + + // Missing → needs warm + assert!(cache.needs_warm("example.com")); + + // Both A and AAAA cached → does not need warm + let pkt_a = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + let pkt_aaaa = response( + 0x5678, + "example.com", + vec![aaaa_record("example.com", "::1", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt_a); + cache.insert("example.com", QueryType::AAAA, &pkt_aaaa); + assert!(!cache.needs_warm("example.com")); + + // Only A cached → needs warm (AAAA missing) + cache.remove("example.com"); + cache.insert("example.com", QueryType::A, &pkt_a); + assert!(cache.needs_warm("example.com")); + } + + #[test] + fn cache_ttl_remaining_api() { + let mut cache = DnsCache::new(100, 60, 3600); + assert!(cache.ttl_remaining("missing.com", QueryType::A).is_none()); + + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + cache.insert("example.com", QueryType::A, &pkt); + let (remaining, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); + assert_eq!(total, 300); + assert!(remaining > 0); + assert!(remaining <= 300); + } + + #[test] + fn cache_dnssec_status_preserved() { + let mut cache = DnsCache::new(100, 1, 3600); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 300)], + ); + cache.insert_with_status("example.com", QueryType::A, &pkt, DnssecStatus::Secure); + + let (_, status, _) = cache + .lookup_with_status("example.com", QueryType::A) + .unwrap(); + assert_eq!(status, DnssecStatus::Secure); + } + + // ── I. Memory footprint baseline ────────────────────────────── + // + // Measures the current parsed-packet cache memory vs what wire-level + // storage would cost for the same entries. This is a baseline — after + // migration, re-run to verify improvement. + + #[test] + fn memory_footprint_baseline() { + let mut cache = DnsCache::new(1000, 1, 3600); + + // Simulate a realistic cache: 50 domains, mix of record types + let domains: Vec = (0..50) + .map(|i| format!("domain{}.example.com", i)) + .collect(); + + let mut total_wire_bytes = 0usize; + let mut total_wire_meta_bytes = 0usize; + + for (i, domain) in domains.iter().enumerate() { + // A record + let pkt_a = response( + i as u16, + domain, + vec![ + a_record(domain, &format!("10.0.{}.1", i % 256), 300), + a_record(domain, &format!("10.0.{}.2", i % 256), 300), + ], + ); + cache.insert(domain, QueryType::A, &pkt_a); + + let wire_a = to_wire(&pkt_a); + let meta_a = scan_ttl_offsets(&wire_a).unwrap(); + total_wire_bytes += wire_a.len(); + total_wire_meta_bytes += meta_a.ttl_offsets.len() * std::mem::size_of::(); + + // AAAA record for half of them + if i % 2 == 0 { + let pkt_aaaa = response( + (i + 1000) as u16, + domain, + vec![aaaa_record(domain, &format!("2001:db8::{:x}", i), 600)], + ); + cache.insert(domain, QueryType::AAAA, &pkt_aaaa); + + let wire_aaaa = to_wire(&pkt_aaaa); + let meta_aaaa = scan_ttl_offsets(&wire_aaaa).unwrap(); + total_wire_bytes += wire_aaaa.len(); + total_wire_meta_bytes += meta_aaaa.ttl_offsets.len() * std::mem::size_of::(); + } + } + + // Compare only the variable per-entry data (what actually differs + // between parsed and wire storage). HashMap overhead, domain keys, + // Instant, Duration, DnssecStatus are identical in both approaches. + let mut parsed_data_bytes = 0usize; + // Re-insert and measure just packet.heap_bytes() per entry + { + let mut cache2 = DnsCache::new(1000, 1, 3600); + for (i, domain) in domains.iter().enumerate() { + let pkt_a = response( + i as u16, + domain, + vec![ + a_record(domain, &format!("10.0.{}.1", i % 256), 300), + a_record(domain, &format!("10.0.{}.2", i % 256), 300), + ], + ); + parsed_data_bytes += pkt_a.heap_bytes(); + cache2.insert(domain, QueryType::A, &pkt_a); + + if i % 2 == 0 { + let pkt_aaaa = response( + (i + 1000) as u16, + domain, + vec![aaaa_record(domain, &format!("2001:db8::{:x}", i), 600)], + ); + parsed_data_bytes += pkt_aaaa.heap_bytes(); + cache2.insert(domain, QueryType::AAAA, &pkt_aaaa); + } + } + } + + let wire_total = total_wire_bytes + total_wire_meta_bytes; + let entry_count = cache.len(); + + // Also measure the struct size difference per entry + let parsed_struct = std::mem::size_of::(); + let wire_struct = std::mem::size_of::>() + + std::mem::size_of::>() + + std::mem::size_of::(); // wire + offsets + answer_count + + println!(); + println!( + "=== Cache Memory Footprint Baseline ({} entries) ===", + entry_count + ); + println!(); + println!("Variable data (heap, per-entry payload):"); + println!( + " Parsed (packet.heap_bytes): {} bytes ({:.1}/entry)", + parsed_data_bytes, + parsed_data_bytes as f64 / entry_count as f64 + ); + println!( + " Wire (bytes + TTL offsets): {} bytes ({:.1}/entry)", + wire_total, + wire_total as f64 / entry_count as f64 + ); + println!( + " Ratio: {:.1}x smaller with wire", + parsed_data_bytes as f64 / wire_total as f64 + ); + println!(); + println!("Struct overhead (stack, per entry):"); + println!(" DnsPacket: {} bytes", parsed_struct); + println!(" Wire (Vec+Vec+usize): {} bytes", wire_struct); + println!(); + println!("Total per-entry (struct + avg heap):"); + let parsed_total_per = parsed_struct as f64 + parsed_data_bytes as f64 / entry_count as f64; + let wire_total_per = wire_struct as f64 + wire_total as f64 / entry_count as f64; + println!(" Parsed: {:.0} bytes", parsed_total_per); + println!(" Wire: {:.0} bytes", wire_total_per); + println!( + " Ratio: {:.1}x smaller with wire", + parsed_total_per / wire_total_per + ); + println!(); + + // Assertions + assert!( + wire_total < parsed_data_bytes, + "wire data ({wire_total}) should be smaller than parsed data ({parsed_data_bytes})" + ); + } + + #[test] + fn cache_max_entries_evicts_stalest() { + let mut cache = DnsCache::new(2, 1, 3600); + // Insert with decreasing TTL so test0.com is stalest + for (i, ttl) in [(0, 60), (1, 3600)] { + let domain = format!("test{}.com", i); + let pkt = response( + i as u16, + &domain, + vec![a_record(&domain, &format!("1.2.3.{}", i), ttl)], + ); + cache.insert(&domain, QueryType::A, &pkt); + } + assert_eq!(cache.len(), 2); + + // Third insert should evict test0.com (lowest remaining TTL) + let pkt = response(2, "test2.com", vec![a_record("test2.com", "1.2.3.2", 3600)]); + cache.insert("test2.com", QueryType::A, &pkt); + assert_eq!(cache.len(), 2); + assert!(cache.lookup("test0.com", QueryType::A).is_none()); // evicted + assert!(cache.lookup("test2.com", QueryType::A).is_some()); // inserted + } + + #[test] + fn lookup_wire_signals_stale_when_expired() { + use crate::cache::Freshness; + let mut cache = DnsCache::new(100, 1, 1); // max_ttl=1s so entry expires fast + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 1)], + ); + cache.insert("example.com", QueryType::A, &pkt); + + let (_, _, f) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert_eq!(f, Freshness::Fresh); + + std::thread::sleep(std::time::Duration::from_millis(1100)); + + let (_, _, f) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert_eq!(f, Freshness::Stale); + } + + #[test] + fn lookup_wire_signals_prefetch_near_expiry() { + use crate::cache::Freshness; + let mut cache = DnsCache::new(100, 10, 10); + let pkt = response( + 0x1234, + "example.com", + vec![a_record("example.com", "1.2.3.4", 10)], + ); + cache.insert("example.com", QueryType::A, &pkt); + + let (_, _, f) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert_eq!(f, Freshness::Fresh); + + std::thread::sleep(std::time::Duration::from_millis(9100)); + + let result = cache.lookup_wire("example.com", QueryType::A, 0); + if let Some((_, _, f)) = result { + assert_eq!(f, Freshness::NearExpiry); + } + } +} diff --git a/tests/integration.sh b/tests/integration.sh index 92da878..c70ec59 100755 --- a/tests/integration.sh +++ b/tests/integration.sh @@ -53,7 +53,17 @@ CONF echo "Starting Numa on :$PORT ($SUITE_NAME)..." RUST_LOG=info "$BINARY" "$CONFIG" > "$LOG" 2>&1 & NUMA_PID=$! - sleep 4 + sleep 2 + + # Wait for blocklist to load (if blocking is enabled in this suite) + if echo "$SUITE_CONFIG" | grep -q 'enabled = true'; then + for i in $(seq 1 20); do + LOADED=$(curl -sf http://127.0.0.1:$API_PORT/blocking/stats 2>/dev/null \ + | grep -o '"domains_loaded":[0-9]*' | cut -d: -f2) + if [ "${LOADED:-0}" -gt 0 ]; then break; fi + sleep 1 + done + fi if ! kill -0 "$NUMA_PID" 2>/dev/null; then echo "Failed to start Numa:"