From 637b374d8b9f28f649179bbacaeda303ba799dd1 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Fri, 27 Mar 2026 16:45:36 +0200 Subject: [PATCH] feat: recursive resolution + full DNSSEC validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Numa becomes a true DNS resolver — resolves from root nameservers with complete DNSSEC chain-of-trust verification. Recursive resolution: - Iterative RFC 1034 from configurable root hints (13 default) - CNAME chasing (depth 8), referral following (depth 10) - A+AAAA glue extraction, IPv6 nameserver support - TLD priming: NS + DS + DNSKEY for 34 gTLDs + EU ccTLDs - Config: mode = "recursive" in [upstream], root_hints, prime_tlds DNSSEC (all 4 phases): - EDNS0 OPT pseudo-record (DO bit, 1232 payload per DNS Flag Day 2020) - DNSKEY, DS, RRSIG, NSEC, NSEC3 record types with wire read/write - Signature verification via ring: RSA/SHA-256, ECDSA P-256, Ed25519 - Chain-of-trust: zone DNSKEY → parent DS → root KSK (key tag 20326) - DNSKEY RRset self-signature verification (RRSIG(DNSKEY) by KSK) - RRSIG expiration/inception time validation - NSEC: NXDOMAIN gap proofs, NODATA type absence, wildcard denial - NSEC3: SHA-1 iterated hashing, closest encloser proof, hash range - Authority RRSIG verification for denial proofs - Config: [dnssec] enabled/strict (default false, opt-in) - AD bit on Secure, SERVFAIL on Bogus+strict - DnssecStatus cached per entry, ValidationStats logging Performance: - TLD chain pre-warmed on startup (root DNSKEY + TLD DS/DNSKEY) - Referral DS piggybacking from authority sections - DNSKEY prefetch before validation loop - Cold-cache validation: ~1 DNSKEY fetch (down from 5) - Benchmarks: RSA 10.9µs, ECDSA 174ns, DS verify 257ns Also: - write_qname fix for root domain "." (was producing malformed queries) - write_record_header() dedup, write_bytes() bulk writes - DnsRecord::domain() + query_type() accessors - UpstreamMode enum, DEFAULT_EDNS_PAYLOAD const - Real glue TTL (was hardcoded 3600) - DNSSEC restricted to recursive mode only Co-Authored-By: Claude Opus 4.6 --- Cargo.lock | 1 + Cargo.toml | 5 + Makefile | 2 +- README.md | 9 +- benches/dnssec.rs | 183 +++++ numa.toml | 36 +- site/index.html | 39 +- src/api.rs | 8 +- src/buffer.rs | 18 + src/cache.rs | 31 +- src/config.rs | 96 +++ src/ctx.rs | 110 ++- src/dnssec.rs | 1675 ++++++++++++++++++++++++++++++++++++++++ src/forward.rs | 2 +- src/lib.rs | 2 + src/main.rs | 24 +- src/packet.rs | 499 +++++++++++- src/question.rs | 49 +- src/record.rs | 492 ++++++++++-- src/recursive.rs | 601 ++++++++++++++ src/stats.rs | 12 +- tests/integration.sh | 401 ++++++++++ tests/network-probe.sh | 128 +++ 23 files changed, 4325 insertions(+), 98 deletions(-) create mode 100644 benches/dnssec.rs create mode 100644 src/dnssec.rs create mode 100644 src/recursive.rs create mode 100755 tests/integration.sh create mode 100755 tests/network-probe.sh diff --git a/Cargo.lock b/Cargo.lock index a8563e2..c6161f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1154,6 +1154,7 @@ dependencies = [ "log", "rcgen", "reqwest", + "ring", "rustls", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index ea71da7..4e988b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ time = "0.3" rustls = "0.23" tokio-rustls = "0.26" arc-swap = "1" +ring = "0.17" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } @@ -39,3 +40,7 @@ harness = false [[bench]] name = "throughput" harness = false + +[[bench]] +name = "dnssec" +harness = false diff --git a/Makefile b/Makefile index 643c058..81ee410 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: all build lint fmt check audit test bench clean deploy blog -all: lint build +all: lint build test build: cargo build diff --git a/README.md b/README.md index 92fa376..2131bb1 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ A portable DNS resolver in a single binary. Block ads on any network, name your local services (`frontend.numa`), and override any hostname with auto-revert — all from your laptop, no cloud account or Raspberry Pi required. -Built from scratch in Rust. Zero DNS libraries. RFC 1035 wire protocol parsed by hand. One ~8MB binary, no PHP, no web server, no database — everything is embedded. +Built from scratch in Rust. Zero DNS libraries. RFC 1035 wire protocol parsed by hand. Recursive resolution from root nameservers with full DNSSEC validation (chain-of-trust + NSEC/NSEC3 denial proofs). One ~8MB binary, no PHP, no web server, no database — everything is embedded. ![Numa dashboard](assets/hero-demo.gif) @@ -135,6 +135,7 @@ bind_addr = "0.0.0.0:53" | Path-based routing | No | No | No | No | Prefix match + strip | | LAN service discovery | No | No | No | No | mDNS, opt-in | | Developer overrides | No | No | No | No | REST API + auto-expiry | +| Recursive resolver | No | No | Cloud only | Cloud only | From root hints, DNSSEC | | Encrypted upstream (DoH) | No (needs cloudflared) | Yes | Cloud only | Cloud only | Native, single binary | | Portable (travels with laptop) | No (appliance) | No (appliance) | Cloud only | Cloud only | Single binary | | Zero config | Complex | Docker/setup | Yes | Yes | Works out of the box | @@ -144,9 +145,11 @@ bind_addr = "0.0.0.0:53" ## How It Works ``` -Query → Overrides → .numa TLD → Blocklist → Local Zones → Cache → Upstream +Query → Overrides → .numa TLD → Blocklist → Local Zones → Cache → Recursive/Forward ``` +Two resolution modes: **forward** (relay to upstream like Quad9/Cloudflare) or **recursive** (resolve from root nameservers — no upstream dependency). Set `mode = "recursive"` in `[upstream]` to resolve independently. + No DNS libraries — no `hickory-dns`, no `trust-dns`. The wire protocol — headers, labels, compression pointers, record types — is parsed and serialized by hand. Runs on `tokio` + `axum`, async per-query task spawning. [Configuration reference](numa.toml) @@ -161,6 +164,8 @@ No DNS libraries — no `hickory-dns`, no `trust-dns`. The wire protocol — hea - [x] Path-based routing — URL prefix routing with optional strip, REST API - [x] LAN service discovery — mDNS auto-discovery (opt-in), cross-machine DNS + proxy - [x] DNS-over-HTTPS — encrypted upstream via DoH (Quad9, Cloudflare, any provider) +- [x] Recursive resolution — resolve from root nameservers, no upstream dependency +- [x] DNSSEC validation — chain-of-trust, NSEC/NSEC3 denial proofs, AD bit (RSA, ECDSA, Ed25519) - [ ] pkarr integration — self-sovereign DNS via Mainline DHT (15M nodes) - [ ] Global `.numa` names — self-publish, DHT-backed, first-come-first-served diff --git a/benches/dnssec.rs b/benches/dnssec.rs new file mode 100644 index 0000000..270710a --- /dev/null +++ b/benches/dnssec.rs @@ -0,0 +1,183 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +use numa::dnssec; +use numa::question::QueryType; +use numa::record::DnsRecord; + +// Realistic ECDSA P-256 key (64 bytes) and signature (64 bytes) +fn make_ecdsa_key() -> Vec { + vec![0xAB; 64] +} +fn make_ecdsa_sig() -> Vec { + vec![0xCD; 64] +} + +// Realistic RSA-2048 key (RFC 3110 format: exp_len=3, exp=65537, mod=256 bytes) +fn make_rsa_key() -> Vec { + let mut key = vec![3u8]; // exponent length + key.extend(&[0x01, 0x00, 0x01]); // exponent = 65537 + key.extend(vec![0xFF; 256]); // modulus (256 bytes = 2048 bits) + key +} + +fn make_ed25519_key() -> Vec { + vec![0xEF; 32] +} + +fn make_dnskey(algorithm: u8, public_key: Vec) -> DnsRecord { + DnsRecord::DNSKEY { + domain: "example.com".into(), + flags: 257, + protocol: 3, + algorithm, + public_key, + ttl: 3600, + } +} + +fn make_rrsig(algorithm: u8, signature: Vec) -> DnsRecord { + DnsRecord::RRSIG { + domain: "example.com".into(), + type_covered: QueryType::A.to_num(), + algorithm, + labels: 2, + original_ttl: 300, + expiration: 2000000000, + inception: 1600000000, + key_tag: 12345, + signer_name: "example.com".into(), + signature, + ttl: 300, + } +} + +fn make_rrset() -> Vec { + vec![ + DnsRecord::A { + domain: "example.com".into(), + addr: "93.184.216.34".parse().unwrap(), + ttl: 300, + }, + DnsRecord::A { + domain: "example.com".into(), + addr: "93.184.216.35".parse().unwrap(), + ttl: 300, + }, + ] +} + +fn bench_key_tag(c: &mut Criterion) { + let key = make_rsa_key(); + c.bench_function("key_tag_rsa2048", |b| { + b.iter(|| { + dnssec::compute_key_tag(black_box(257), black_box(3), black_box(8), black_box(&key)) + }) + }); + + let key = make_ecdsa_key(); + c.bench_function("key_tag_ecdsa_p256", |b| { + b.iter(|| { + dnssec::compute_key_tag(black_box(257), black_box(3), black_box(13), black_box(&key)) + }) + }); +} + +fn bench_name_to_wire(c: &mut Criterion) { + c.bench_function("name_to_wire_short", |b| { + b.iter(|| dnssec::name_to_wire(black_box("example.com"))) + }); + c.bench_function("name_to_wire_long", |b| { + b.iter(|| dnssec::name_to_wire(black_box("sub.deep.nested.example.co.uk"))) + }); +} + +fn bench_build_signed_data(c: &mut Criterion) { + let rrsig = make_rrsig(13, make_ecdsa_sig()); + let rrset = make_rrset(); + let rrset_refs: Vec<&DnsRecord> = rrset.iter().collect(); + + c.bench_function("build_signed_data_2_A_records", |b| { + b.iter(|| dnssec::build_signed_data(black_box(&rrsig), black_box(&rrset_refs))) + }); +} + +fn bench_verify_signature(c: &mut Criterion) { + // These will fail verification (keys/sigs are random), but we measure the + // crypto overhead — ring still does the full algorithm before returning error. + let data = vec![0u8; 128]; // typical signed data size + + let rsa_key = make_rsa_key(); + let rsa_sig = vec![0xAA; 256]; // RSA-2048 signature + c.bench_function("verify_rsa_sha256_2048", |b| { + b.iter(|| { + dnssec::verify_signature( + black_box(8), + black_box(&rsa_key), + black_box(&data), + black_box(&rsa_sig), + ) + }) + }); + + let ecdsa_key = make_ecdsa_key(); + let ecdsa_sig = make_ecdsa_sig(); + c.bench_function("verify_ecdsa_p256", |b| { + b.iter(|| { + dnssec::verify_signature( + black_box(13), + black_box(&ecdsa_key), + black_box(&data), + black_box(&ecdsa_sig), + ) + }) + }); + + let ed_key = make_ed25519_key(); + let ed_sig = vec![0xBB; 64]; + c.bench_function("verify_ed25519", |b| { + b.iter(|| { + dnssec::verify_signature( + black_box(15), + black_box(&ed_key), + black_box(&data), + black_box(&ed_sig), + ) + }) + }); +} + +fn bench_ds_verification(c: &mut Criterion) { + let dk = make_dnskey(8, make_rsa_key()); + + // Compute correct DS digest + let owner_wire = dnssec::name_to_wire("example.com"); + let mut dnskey_rdata = vec![1u8, 1, 3, 8]; // flags=257, proto=3, algo=8 + dnskey_rdata.extend(&make_rsa_key()); + let mut input = Vec::new(); + input.extend(&owner_wire); + input.extend(&dnskey_rdata); + let digest = ring::digest::digest(&ring::digest::SHA256, &input); + + let ds = DnsRecord::DS { + domain: "example.com".into(), + key_tag: dnssec::compute_key_tag(257, 3, 8, &make_rsa_key()), + algorithm: 8, + digest_type: 2, + digest: digest.as_ref().to_vec(), + ttl: 86400, + }; + + c.bench_function("verify_ds_sha256", |b| { + b.iter(|| dnssec::verify_ds(black_box(&ds), black_box(&dk), black_box("example.com"))) + }); +} + +criterion_group!( + dnssec_benches, + bench_key_tag, + bench_name_to_wire, + bench_build_signed_data, + bench_verify_signature, + bench_ds_verification, +); +criterion_main!(dnssec_benches); diff --git a/numa.toml b/numa.toml index 09e8523..6a523ac 100644 --- a/numa.toml +++ b/numa.toml @@ -4,12 +4,39 @@ api_port = 5380 # api_bind_addr = "127.0.0.1" # default; set to "0.0.0.0" for LAN dashboard access # [upstream] -# address = "" # auto-detect from system resolver (default) +# mode = "forward" # "forward" (default) — relay to upstream +# # "recursive" — resolve from root hints (no address needed) # address = "https://dns.quad9.net/dns-query" # DNS-over-HTTPS (encrypted) # address = "https://cloudflare-dns.com/dns-query" # Cloudflare DoH # address = "9.9.9.9" # plain UDP -# port = 53 # only used for plain UDP +# port = 53 # only for forward mode, plain UDP # timeout_ms = 3000 +# root_hints = [ # only used in recursive mode +# "198.41.0.4", # a.root-servers.net (Verisign) +# "199.9.14.201", # b.root-servers.net (USC-ISI) +# "192.33.4.12", # c.root-servers.net (Cogent) +# "199.7.91.13", # d.root-servers.net (UMD) +# "192.203.230.10", # e.root-servers.net (NASA) +# "192.5.5.241", # f.root-servers.net (ISC) +# "192.112.36.4", # g.root-servers.net (US DoD) +# "198.97.190.53", # h.root-servers.net (US Army) +# "192.36.148.17", # i.root-servers.net (Netnod) +# "192.58.128.30", # j.root-servers.net (Verisign) +# "193.0.14.129", # k.root-servers.net (RIPE NCC) +# "199.7.83.42", # l.root-servers.net (ICANN) +# "202.12.27.33", # m.root-servers.net (WIDE) +# ] +# prime_tlds = [ # TLDs to pre-warm on startup (recursive mode) +# "com", "net", "org", "info", # gTLDs +# "io", "dev", "app", "xyz", "me", +# "eu", "uk", "de", "fr", "nl", # EU + European ccTLDs +# "it", "es", "pl", "se", "no", +# "dk", "fi", "at", "be", "ie", +# "pt", "cz", "ro", "gr", "hu", +# "bg", "hr", "sk", "si", "lt", +# "lv", "ee", "ch", "is", +# "co", "br", "au", "ca", "jp", # other major ccTLDs +# ] # [blocking] # enabled = true # set to false to disable ad blocking @@ -51,6 +78,11 @@ tld = "numa" # value = "127.0.0.1" # ttl = 60 +# DNSSEC signature validation (requires mode = "recursive") +# [dnssec] +# enabled = false # opt-in: verify chain of trust from root KSK +# strict = false # true = SERVFAIL on bogus signatures + # LAN service discovery via mDNS (disabled by default — no network traffic unless enabled) # [lan] # enabled = true # discover other Numa instances via mDNS (_numa._tcp.local) diff --git a/site/index.html b/site/index.html index 08d8057..bd3fbec 100644 --- a/site/index.html +++ b/site/index.html @@ -4,10 +4,10 @@ Numa — DNS you own. Everywhere you go. - + - + @@ -1232,18 +1232,19 @@ footer .closing {

What it does today

-

A portable DNS proxy with ad blocking, encrypted upstream, local service domains, and a REST API. Everything runs in a single binary.

+

A recursive DNS resolver with DNSSEC validation, ad blocking, local service domains, and a REST API. Everything runs in a single binary.

Layer 1
-

Block & Protect

+

Resolve & Protect

    +
  • Recursive resolution — resolve from root nameservers, no upstream needed
  • +
  • DNSSEC validation — chain-of-trust + NSEC/NSEC3 denial proofs (RSA, ECDSA, Ed25519)
  • Ad & tracker blocking — 385K+ domains, zero config
  • -
  • DNS-over-HTTPS — encrypted upstream (Quad9, Cloudflare, any provider)
  • +
  • DNS-over-HTTPS — encrypted upstream as alternative to recursive mode
  • TTL-aware caching (sub-ms lookups)
  • -
  • Single binary, portable — your DNS travels with you
  • -
  • macOS, Linux, and Windows
  • +
  • Single binary, portable — macOS, Linux, and Windows
@@ -1331,6 +1332,14 @@ footer .closing { + + Recursive resolver + No (needs Unbound) + Cloud only + Cloud only + No + Root hints + full DNSSEC + Ad & tracker blocking Yes @@ -1609,16 +1618,24 @@ footer .closing { Phase 5 DNS-over-HTTPS — encrypted upstream, HTTP/2 connection pooling
-
+
Phase 6 - pkarr integration — self-sovereign DNS via Mainline DHT, no registrar needed + Recursive resolution — resolve from root nameservers, no upstream dependency
-
+
Phase 7 - Global .numa names — self-publish, DHT-backed, first-come-first-served + DNSSEC validation — chain-of-trust, NSEC/NSEC3 denial proofs, RSA + ECDSA + Ed25519
Phase 8 + pkarr integration — self-sovereign DNS via Mainline DHT, no registrar needed +
+
+ Phase 9 + Global .numa names — self-publish, DHT-backed, first-come-first-served +
+
+ Phase 10 .onion bridge — human-readable Tor naming via Ed25519 same-key binding
diff --git a/src/api.rs b/src/api.rs index ef761d7..f59bc3e 100644 --- a/src/api.rs +++ b/src/api.rs @@ -178,6 +178,7 @@ struct LanStatsResponse { struct QueriesStats { total: u64, forwarded: u64, + recursive: u64, cached: u64, local: u64, overridden: u64, @@ -477,7 +478,11 @@ async fn stats(State(ctx): State>) -> Json { let override_count = ctx.overrides.read().unwrap().active_count(); let bl_stats = ctx.blocklist.read().unwrap().stats(); - let upstream = ctx.upstream.lock().unwrap().to_string(); + let upstream = if ctx.upstream_mode == crate::config::UpstreamMode::Recursive { + "recursive (root hints)".to_string() + } else { + ctx.upstream.lock().unwrap().to_string() + }; Json(StatsResponse { uptime_secs: snap.uptime_secs, @@ -487,6 +492,7 @@ async fn stats(State(ctx): State>) -> Json { queries: QueriesStats { total: snap.total, forwarded: snap.forwarded, + recursive: snap.recursive, cached: snap.cached, local: snap.local, overridden: snap.overridden, diff --git a/src/buffer.rs b/src/buffer.rs index 212bf92..5db470c 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -164,8 +164,16 @@ impl BytePacketBuffer { } pub fn write_qname(&mut self, qname: &str) -> Result<()> { + if qname.is_empty() || qname == "." { + self.write_u8(0)?; + return Ok(()); + } + for label in qname.split('.') { let len = label.len(); + if len == 0 { + continue; // skip empty labels from trailing dot + } if len > 0x3f { return Err("Single label exceeds 63 characters of length".into()); } @@ -180,6 +188,16 @@ impl BytePacketBuffer { Ok(()) } + pub fn write_bytes(&mut self, data: &[u8]) -> Result<()> { + let end = self.pos + data.len(); + if end > BUF_SIZE { + return Err("End of buffer".into()); + } + self.buf[self.pos..end].copy_from_slice(data); + self.pos = end; + Ok(()) + } + pub fn set(&mut self, pos: usize, val: u8) -> Result<()> { if pos >= BUF_SIZE { return Err("End of buffer".into()); diff --git a/src/cache.rs b/src/cache.rs index decde82..e810343 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -5,10 +5,20 @@ use crate::packet::DnsPacket; use crate::question::QueryType; use crate::record::DnsRecord; +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum DnssecStatus { + Secure, + Insecure, + Bogus, + #[default] + Indeterminate, +} + struct CacheEntry { packet: DnsPacket, inserted_at: Instant, ttl: Duration, + dnssec_status: DnssecStatus, } /// DNS cache using a two-level map (domain -> query_type -> entry) so that @@ -34,6 +44,14 @@ impl DnsCache { /// Read-only lookup — expired entries are left in place (cleaned up on insert). pub fn lookup(&self, domain: &str, qtype: QueryType) -> Option { + self.lookup_with_status(domain, qtype).map(|(pkt, _)| pkt) + } + + pub fn lookup_with_status( + &self, + domain: &str, + qtype: QueryType, + ) -> Option<(DnsPacket, DnssecStatus)> { let type_map = self.entries.get(domain)?; let entry = type_map.get(&qtype)?; @@ -50,10 +68,20 @@ impl DnsCache { adjust_ttls(&mut packet.authorities, remaining); adjust_ttls(&mut packet.resources, remaining); - Some(packet) + Some((packet, entry.dnssec_status)) } pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { + self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate); + } + + pub fn insert_with_status( + &mut self, + domain: &str, + qtype: QueryType, + packet: &DnsPacket, + dnssec_status: DnssecStatus, + ) { if self.entry_count >= self.max_entries { self.evict_expired(); if self.entry_count >= self.max_entries { @@ -81,6 +109,7 @@ impl DnsCache { packet: packet.clone(), inserted_at: Instant::now(), ttl: Duration::from_secs(min_ttl as u64), + dnssec_status, }, ); } diff --git a/src/config.rs b/src/config.rs index f0cd811..39c7ffb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -27,6 +27,8 @@ pub struct Config { pub services: Vec, #[serde(default)] pub lan: LanConfig, + #[serde(default)] + pub dnssec: DnssecConfig, } #[derive(Deserialize)] @@ -61,26 +63,112 @@ fn default_api_port() -> u16 { 5380 } +#[derive(Deserialize, Default, PartialEq, Eq, Clone, Copy)] +#[serde(rename_all = "lowercase")] +pub enum UpstreamMode { + #[default] + Forward, + Recursive, +} + #[derive(Deserialize)] pub struct UpstreamConfig { + #[serde(default)] + pub mode: UpstreamMode, #[serde(default = "default_upstream_addr")] pub address: String, #[serde(default = "default_upstream_port")] pub port: u16, #[serde(default = "default_timeout_ms")] pub timeout_ms: u64, + #[serde(default = "default_root_hints")] + pub root_hints: Vec, + #[serde(default = "default_prime_tlds")] + pub prime_tlds: Vec, } impl Default for UpstreamConfig { fn default() -> Self { UpstreamConfig { + mode: UpstreamMode::default(), address: default_upstream_addr(), port: default_upstream_port(), timeout_ms: default_timeout_ms(), + root_hints: default_root_hints(), + prime_tlds: default_prime_tlds(), } } } +fn default_prime_tlds() -> Vec { + vec![ + // gTLDs + "com".into(), + "net".into(), + "org".into(), + "info".into(), + "io".into(), + "dev".into(), + "app".into(), + "xyz".into(), + "me".into(), + // EU + European ccTLDs + "eu".into(), + "uk".into(), + "de".into(), + "fr".into(), + "nl".into(), + "it".into(), + "es".into(), + "pl".into(), + "se".into(), + "no".into(), + "dk".into(), + "fi".into(), + "at".into(), + "be".into(), + "ie".into(), + "pt".into(), + "cz".into(), + "ro".into(), + "gr".into(), + "hu".into(), + "bg".into(), + "hr".into(), + "sk".into(), + "si".into(), + "lt".into(), + "lv".into(), + "ee".into(), + "ch".into(), + "is".into(), + // Other major ccTLDs + "co".into(), + "br".into(), + "au".into(), + "ca".into(), + "jp".into(), + ] +} + +fn default_root_hints() -> Vec { + vec![ + "198.41.0.4".into(), // a.root-servers.net + "199.9.14.201".into(), // b.root-servers.net + "192.33.4.12".into(), // c.root-servers.net + "199.7.91.13".into(), // d.root-servers.net + "192.203.230.10".into(), // e.root-servers.net + "192.5.5.241".into(), // f.root-servers.net + "192.112.36.4".into(), // g.root-servers.net + "198.97.190.53".into(), // h.root-servers.net + "192.36.148.17".into(), // i.root-servers.net + "192.58.128.30".into(), // j.root-servers.net + "193.0.14.129".into(), // k.root-servers.net + "199.7.83.42".into(), // l.root-servers.net + "202.12.27.33".into(), // m.root-servers.net + ] +} + fn default_upstream_addr() -> String { String::new() // empty = auto-detect from system resolver } @@ -250,6 +338,14 @@ fn default_lan_peer_timeout() -> u64 { 90 } +#[derive(Deserialize, Clone, Default)] +pub struct DnssecConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub strict: bool, +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/ctx.rs b/src/ctx.rs index 80b9226..c782ebb 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -11,7 +11,7 @@ use tokio::net::UdpSocket; use crate::blocklist::BlocklistStore; use crate::buffer::BytePacketBuffer; use crate::cache::DnsCache; -use crate::config::ZoneMap; +use crate::config::{UpstreamMode, ZoneMap}; use crate::forward::{forward_query, Upstream}; use crate::header::ResultCode; use crate::lan::PeerStore; @@ -27,6 +27,7 @@ use crate::system_dns::ForwardingRule; pub struct ServerCtx { pub socket: UdpSocket, pub zone_map: ZoneMap, + /// std::sync::RwLock (not tokio) — locks must never be held across .await points. pub cache: RwLock, pub stats: Mutex, pub overrides: RwLock, @@ -48,6 +49,10 @@ pub struct ServerCtx { pub config_dir: PathBuf, pub data_dir: PathBuf, pub tls_config: Option>, + pub upstream_mode: UpstreamMode, + pub root_hints: Vec, + pub dnssec_enabled: bool, + pub dnssec_strict: bool, } pub async fn handle_query( @@ -136,11 +141,51 @@ pub async fn handle_query( resp.answers = records.clone(); (resp, QueryPath::Local) } else { - let cached = ctx.cache.read().unwrap().lookup(&qname, qtype); - if let Some(cached) = cached { + let cached = ctx.cache.read().unwrap().lookup_with_status(&qname, qtype); + if let Some((cached, cached_dnssec)) = cached { let mut resp = cached; resp.header.id = query.header.id; + if cached_dnssec == crate::cache::DnssecStatus::Secure { + resp.header.authed_data = true; + } (resp, QueryPath::Cached) + } else if ctx.upstream_mode == UpstreamMode::Recursive { + match crate::recursive::resolve_recursive( + &qname, + qtype, + &ctx.cache, + ctx.timeout, + &query, + &ctx.root_hints, + ) + .await + { + Ok(resp) => (resp, QueryPath::Recursive), + Err(e) => { + // Auto-fallback: retry via forward upstream if configured + let upstream = ctx.upstream.lock().unwrap().clone(); + match forward_query(&query, &upstream, ctx.timeout).await { + Ok(resp) => { + debug!( + "{} | {:?} {} | RECURSIVE FALLBACK → FORWARD | {}", + src_addr, qtype, qname, e + ); + ctx.cache.write().unwrap().insert(&qname, qtype, &resp); + (resp, QueryPath::Forwarded) + } + Err(e2) => { + error!( + "{} | {:?} {} | RECURSIVE+FORWARD FAILED | recursive: {} | forward: {}", + src_addr, qtype, qname, e, e2 + ); + ( + DnsPacket::response_from(&query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + ) + } + } + } + } } else { let upstream = match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) { @@ -167,6 +212,52 @@ pub async fn handle_query( } }; + let client_do = query.edns.as_ref().is_some_and(|e| e.do_bit); + let mut response = response; + + // DNSSEC validation (recursive/forwarded responses only) + if ctx.dnssec_enabled && path == QueryPath::Recursive { + let (status, vstats) = + crate::dnssec::validate_response(&response, &ctx.cache, &ctx.root_hints).await; + + debug!( + "DNSSEC | {} | {:?} | {}ms | dnskey_hit={} dnskey_fetch={} ds_hit={} ds_fetch={}", + qname, + status, + vstats.elapsed_ms, + vstats.dnskey_cache_hits, + vstats.dnskey_fetches, + vstats.ds_cache_hits, + vstats.ds_fetches, + ); + + if status == crate::cache::DnssecStatus::Secure { + response.header.authed_data = true; + } + + if status == crate::cache::DnssecStatus::Bogus && ctx.dnssec_strict { + response = DnsPacket::response_from(&query, ResultCode::SERVFAIL); + } + + ctx.cache + .write() + .unwrap() + .insert_with_status(&qname, qtype, &response, status); + } + + // Strip DNSSEC records if client didn't set DO bit + if !client_do { + strip_dnssec_records(&mut response); + } + + // Echo EDNS back if client sent it + if query.edns.is_some() { + response.edns = Some(crate::packet::EdnsOpt { + do_bit: client_do, + ..Default::default() + }); + } + let elapsed = start.elapsed(); info!( @@ -220,3 +311,16 @@ pub async fn handle_query( Ok(()) } + +fn is_dnssec_record(r: &DnsRecord) -> bool { + matches!( + r.query_type(), + QueryType::RRSIG | QueryType::DNSKEY | QueryType::DS | QueryType::NSEC | QueryType::NSEC3 + ) +} + +fn strip_dnssec_records(pkt: &mut DnsPacket) { + pkt.answers.retain(|r| !is_dnssec_record(r)); + pkt.authorities.retain(|r| !is_dnssec_record(r)); + pkt.resources.retain(|r| !is_dnssec_record(r)); +} diff --git a/src/dnssec.rs b/src/dnssec.rs new file mode 100644 index 0000000..b500e46 --- /dev/null +++ b/src/dnssec.rs @@ -0,0 +1,1675 @@ +use std::sync::{LazyLock, Mutex, RwLock}; +use std::time::{Instant, SystemTime, UNIX_EPOCH}; + +use log::{debug, trace}; +use ring::digest; +use ring::signature; + +use crate::cache::{DnsCache, DnssecStatus}; +use crate::packet::DnsPacket; +use crate::question::QueryType; +use crate::record::DnsRecord; + +#[derive(Debug, Default)] +pub struct ValidationStats { + pub dnskey_cache_hits: u16, + pub dnskey_fetches: u16, + pub ds_cache_hits: u16, + pub ds_fetches: u16, + pub elapsed_ms: u64, +} + +const MAX_CHAIN_DEPTH: u8 = 10; + +// IANA root zone KSK (key tag 20326, algorithm 8, flags 257) +// Source: https://data.iana.org/root-anchors/root-anchors.xml +#[cfg(test)] +const ROOT_KSK_KEY_TAG: u16 = 20326; +const ROOT_KSK_ALGORITHM: u8 = 8; +const ROOT_KSK_FLAGS: u16 = 257; +// Decoded from base64: AwEAAaz/tAm8yTn4Mfeh5eyI96WSVexTBAvkMgJzkKTOiW1vkIbz... +const ROOT_KSK_PUBLIC_KEY: &[u8] = &[ + 0x03, 0x01, 0x00, 0x01, 0xac, 0xff, 0xb4, 0x09, 0xbc, 0xc9, 0x39, 0xf8, 0x31, 0xf7, 0xa1, 0xe5, + 0xec, 0x88, 0xf7, 0xa5, 0x92, 0x55, 0xec, 0x53, 0x04, 0x0b, 0xe4, 0x32, 0x02, 0x73, 0x90, 0xa4, + 0xce, 0x89, 0x6d, 0x6f, 0x90, 0x86, 0xf3, 0xc5, 0xe1, 0x77, 0xfb, 0xfe, 0x11, 0x81, 0x63, 0xaa, + 0xec, 0x7a, 0xf1, 0x46, 0x2c, 0x47, 0x94, 0x59, 0x44, 0xc4, 0xe2, 0xc0, 0x26, 0xbe, 0x5e, 0x98, + 0xbb, 0xcd, 0xed, 0x25, 0x97, 0x82, 0x72, 0xe1, 0xe3, 0xe0, 0x79, 0xc5, 0x09, 0x4d, 0x57, 0x3f, + 0x0e, 0x83, 0xc9, 0x2f, 0x02, 0xb3, 0x2d, 0x35, 0x13, 0xb1, 0x55, 0x0b, 0x82, 0x69, 0x29, 0xc8, + 0x0d, 0xd0, 0xf9, 0x2c, 0xac, 0x96, 0x6d, 0x17, 0x76, 0x9f, 0xd5, 0x86, 0x7b, 0x64, 0x7c, 0x3f, + 0x38, 0x02, 0x9a, 0xbd, 0xc4, 0x81, 0x52, 0xeb, 0x8f, 0x20, 0x71, 0x59, 0xec, 0xc5, 0xd2, 0x32, + 0xc7, 0xc1, 0x53, 0x7c, 0x79, 0xf4, 0xb7, 0xac, 0x28, 0xff, 0x11, 0x68, 0x2f, 0x21, 0x68, 0x1b, + 0xf6, 0xd6, 0xab, 0xa5, 0x55, 0x03, 0x2b, 0xf6, 0xf9, 0xf0, 0x36, 0xbe, 0xb2, 0xaa, 0xa5, 0xb3, + 0x77, 0x8d, 0x6e, 0xeb, 0xfb, 0xa6, 0xbf, 0x9e, 0xa1, 0x91, 0xbe, 0x4a, 0xb0, 0xca, 0xea, 0x75, + 0x9e, 0x2f, 0x77, 0x3a, 0x1f, 0x90, 0x29, 0xc7, 0x3e, 0xcb, 0x8d, 0x57, 0x35, 0xb9, 0x32, 0x1d, + 0xb0, 0x85, 0xf1, 0xb8, 0xe2, 0xd8, 0x03, 0x8f, 0xe2, 0x94, 0x19, 0x92, 0x54, 0x8c, 0xee, 0x0d, + 0x67, 0xdd, 0x45, 0x47, 0xe1, 0x1d, 0xd6, 0x3a, 0xf9, 0xc9, 0xfc, 0x1c, 0x54, 0x66, 0xfb, 0x68, + 0x4c, 0xf0, 0x09, 0xd7, 0x19, 0x7c, 0x2c, 0xf7, 0x9e, 0x79, 0x2a, 0xb5, 0x01, 0xe6, 0xa8, 0xa1, + 0xca, 0x51, 0x9a, 0xf2, 0xcb, 0x9b, 0x5f, 0x63, 0x67, 0xe9, 0x4c, 0x0d, 0x47, 0x50, 0x24, 0x51, + 0x35, 0x7b, 0xe1, 0xb5, +]; + +static TRUST_ANCHORS: LazyLock> = LazyLock::new(|| { + vec![DnsRecord::DNSKEY { + domain: ".".into(), + flags: ROOT_KSK_FLAGS, + protocol: 3, + algorithm: ROOT_KSK_ALGORITHM, + public_key: ROOT_KSK_PUBLIC_KEY.to_vec(), + ttl: 172800, + }] +}); + +/// Top-level validation: verify the DNSSEC chain of trust for a response. +pub async fn validate_response( + response: &DnsPacket, + cache: &RwLock, + root_hints: &[std::net::SocketAddr], +) -> (DnssecStatus, ValidationStats) { + let start = Instant::now(); + let stats = Mutex::new(ValidationStats::default()); + let trust_anchors = &*TRUST_ANCHORS; + + // Extract RRSIGs from all sections + let all_rrsigs: Vec<&DnsRecord> = response + .answers + .iter() + .chain(response.authorities.iter()) + .chain(response.resources.iter()) + .filter(|r| matches!(r, DnsRecord::RRSIG { .. })) + .collect(); + + if all_rrsigs.is_empty() { + let mut s = stats.into_inner().unwrap_or_else(|e| e.into_inner()); + s.elapsed_ms = start.elapsed().as_millis() as u64; + return (DnssecStatus::Insecure, s); + } + + // Prefetch DNSKEYs for all signer zones + let mut signer_zones: Vec = Vec::new(); + for r in &all_rrsigs { + if let DnsRecord::RRSIG { signer_name, .. } = r { + let lower = signer_name.to_lowercase(); + if !signer_zones.contains(&lower) { + signer_zones.push(lower); + } + } + } + for zone in &signer_zones { + fetch_dnskeys(zone, cache, root_hints, &stats).await; + } + + // Group answer records into RRsets (by domain + type, excluding RRSIGs) + let rrsets = group_rrsets(&response.answers); + + for (name, qtype, rrset) in &rrsets { + let matching_rrsigs: Vec<&&DnsRecord> = all_rrsigs + .iter() + .filter(|r| { + if let DnsRecord::RRSIG { + domain, + type_covered, + .. + } = r + { + domain.eq_ignore_ascii_case(name) + && QueryType::from_num(*type_covered) == *qtype + } else { + false + } + }) + .collect(); + + if matching_rrsigs.is_empty() { + continue; // No RRSIG for this RRset — might be Insecure + } + + let mut any_verified = false; + for rrsig in &matching_rrsigs { + if let DnsRecord::RRSIG { + signer_name, + key_tag, + algorithm, + .. + } = rrsig + { + let dnskey_response = fetch_dnskeys(signer_name, cache, root_hints, &stats).await; + let dnskeys: Vec<&DnsRecord> = dnskey_response + .iter() + .filter(|r| matches!(r, DnsRecord::DNSKEY { .. })) + .collect(); + if dnskeys.is_empty() { + trace!("dnssec: no DNSKEY found for signer '{}'", signer_name); + continue; + } + + trace!( + "dnssec: verifying {} {:?} | signer={} key_tag={} algo={} | {} DNSKEYs available", + name, qtype, signer_name, key_tag, algorithm, dnskeys.len() + ); + + for dk in &dnskeys { + if let DnsRecord::DNSKEY { + flags, + protocol, + algorithm: dk_algo, + public_key, + .. + } = dk + { + let tag = compute_key_tag(*flags, *protocol, *dk_algo, public_key); + if *dk_algo != *algorithm { + trace!( + "dnssec: DNSKEY tag={} algo={} — algo mismatch (want {})", + tag, + dk_algo, + algorithm + ); + continue; + } + if tag != *key_tag { + trace!( + "dnssec: DNSKEY tag={} — tag mismatch (want {})", + tag, + key_tag + ); + continue; + } + + // Check RRSIG time validity (RFC 4035 §5.3.1) + if let DnsRecord::RRSIG { + expiration, + inception, + .. + } = rrsig + { + if !is_rrsig_time_valid(*expiration, *inception) { + trace!("dnssec: RRSIG expired or not yet valid (inception={} expiration={})", inception, expiration); + continue; + } + } + + trace!("dnssec: DNSKEY tag={} algo={} flags={} — matched, verifying signature ({} bytes)", tag, dk_algo, flags, public_key.len()); + let signed_data = build_signed_data(rrsig, rrset); + if let DnsRecord::RRSIG { signature, .. } = rrsig { + let ok = + verify_signature(*algorithm, public_key, &signed_data, signature); + trace!( + "dnssec: verify result: {} (signed_data={} bytes, sig={} bytes)", + ok, + signed_data.len(), + signature.len() + ); + if ok { + // Validate the DNSKEY itself via chain of trust + let chain_status = validate_chain( + signer_name, + &dnskey_response, + cache, + root_hints, + trust_anchors, + 0, + &stats, + ) + .await; + + trace!( + "dnssec: chain_status for '{}': {:?}", + signer_name, + chain_status + ); + match chain_status { + DnssecStatus::Secure => { + any_verified = true; + break; + } + DnssecStatus::Bogus => { + let mut s = + stats.into_inner().unwrap_or_else(|e| e.into_inner()); + s.elapsed_ms = start.elapsed().as_millis() as u64; + return (DnssecStatus::Bogus, s); + } + _ => {} + } + } + } + } + } + } + + if any_verified { + break; + } + } + + if !any_verified && !matching_rrsigs.is_empty() { + debug!("dnssec: no valid signature for {} {:?}", name, qtype); + let mut s = stats.into_inner().unwrap_or_else(|e| e.into_inner()); + s.elapsed_ms = start.elapsed().as_millis() as u64; + return (DnssecStatus::Bogus, s); + } + } + + let mut s = stats.into_inner().unwrap_or_else(|e| e.into_inner()); + s.elapsed_ms = start.elapsed().as_millis() as u64; + if rrsets.is_empty() { + // NXDOMAIN or NODATA — check authority section for NSEC/NSEC3 proofs + let (qname, qtype_num) = response + .questions + .first() + .map(|q| (q.name.as_str(), q.qtype.to_num())) + .unwrap_or(("", 0)); + let is_nxdomain = response.header.rescode == crate::header::ResultCode::NXDOMAIN; + + let denial = validate_denial( + &response.authorities, + &all_rrsigs, + qname, + qtype_num, + is_nxdomain, + cache, + ); + return (denial, s); + } + + (DnssecStatus::Secure, s) +} + +/// Walk the chain of trust from zone DNSKEY up to root trust anchor. +/// `zone_records` contains both DNSKEY and RRSIG records from the DNSKEY response. +fn validate_chain<'a>( + zone: &'a str, + zone_records: &'a [DnsRecord], + cache: &'a RwLock, + root_hints: &'a [std::net::SocketAddr], + trust_anchors: &'a [DnsRecord], + depth: u8, + stats: &'a Mutex, +) -> std::pin::Pin + Send + 'a>> { + Box::pin(async move { + let zone_dnskeys: Vec<&DnsRecord> = zone_records + .iter() + .filter(|r| matches!(r, DnsRecord::DNSKEY { .. })) + .collect(); + + trace!( + "dnssec: validate_chain zone='{}' depth={} dnskeys={}", + zone, + depth, + zone_dnskeys.len() + ); + if depth > MAX_CHAIN_DEPTH { + return DnssecStatus::Indeterminate; + } + + // Check if any zone DNSKEY matches a trust anchor + for dk in &zone_dnskeys { + if let DnsRecord::DNSKEY { + flags, + protocol, + algorithm, + public_key, + .. + } = dk + { + if *flags & 0x0101 != 0x0101 { + continue; + } + let tag = compute_key_tag(*flags, *protocol, *algorithm, public_key); + for ta in trust_anchors { + if let DnsRecord::DNSKEY { + algorithm: ta_algo, + public_key: ta_key, + flags: ta_flags, + protocol: ta_proto, + .. + } = ta + { + let ta_tag = compute_key_tag(*ta_flags, *ta_proto, *ta_algo, ta_key); + if tag == ta_tag && algorithm == ta_algo && public_key == ta_key { + debug!("dnssec: trust anchor match for zone '{}'", zone); + return DnssecStatus::Secure; + } + } + } + } + } + + // Not a trust anchor — need to verify via parent DS + if zone == "." || zone.is_empty() { + log::warn!( + "dnssec: root zone DNSKEY does not match trust anchor — possible KSK rollover. \ + Update Numa to get the new root trust anchor." + ); + return DnssecStatus::Indeterminate; + } + let parent = parent_zone(zone); + let ds_records = fetch_ds(zone, cache, root_hints, stats).await; + + if ds_records.is_empty() { + debug!("dnssec: no DS for zone '{}' at parent '{}'", zone, parent); + return DnssecStatus::Insecure; + } + + // Verify DS matches a zone DNSKEY + let mut ds_matched = false; + for ds in &ds_records { + for dk in &zone_dnskeys { + if verify_ds(ds, dk, zone) { + ds_matched = true; + break; + } + } + if ds_matched { + break; + } + } + + if !ds_matched { + debug!("dnssec: DS digest mismatch for zone '{}'", zone); + return DnssecStatus::Bogus; + } + + // Verify the DNSKEY RRset is self-signed by a KSK + if !verify_dnskey_self_signed(zone_records) { + debug!("dnssec: DNSKEY RRset not self-signed for zone '{}'", zone); + return DnssecStatus::Bogus; + } + + // Walk up: validate the parent's DNSKEY + trace!("dnssec: fetching parent DNSKEY for '{}'", parent); + let parent_records = fetch_dnskeys(&parent, cache, root_hints, stats).await; + if parent_records.is_empty() { + debug!("dnssec: no parent DNSKEY for '{}' — Indeterminate", parent); + return DnssecStatus::Indeterminate; + } + + validate_chain( + &parent, + &parent_records, + cache, + root_hints, + trust_anchors, + depth + 1, + stats, + ) + .await + }) +} + +/// Verify that the DNSKEY RRset is signed by a KSK within the set. +fn verify_dnskey_self_signed(records: &[DnsRecord]) -> bool { + let dnskeys: Vec<&DnsRecord> = records + .iter() + .filter(|r| matches!(r, DnsRecord::DNSKEY { .. })) + .collect(); + + // Find RRSIG covering DNSKEY type + for r in records { + if let DnsRecord::RRSIG { + type_covered, + algorithm, + key_tag, + signature, + .. + } = r + { + if QueryType::from_num(*type_covered) != QueryType::DNSKEY { + continue; + } + + // Find the KSK that made this signature + for dk in &dnskeys { + if let DnsRecord::DNSKEY { + flags, + protocol, + algorithm: dk_algo, + public_key, + .. + } = dk + { + if *flags & 0x0101 != 0x0101 { + continue; // Not a KSK + } + if dk_algo != algorithm { + continue; + } + let tag = compute_key_tag(*flags, *protocol, *dk_algo, public_key); + if tag != *key_tag { + continue; + } + + // Verify: RRSIG(DNSKEY) signed by this KSK + let signed_data = build_signed_data(r, &dnskeys); + if verify_signature(*algorithm, public_key, &signed_data, signature) { + trace!("dnssec: DNSKEY RRset self-signed by KSK tag={}", tag); + return true; + } + } + } + } + } + + false +} + +// -- Fetching helpers -- + +/// Fetch DNSKEY response for a zone. Returns all answer records (DNSKEY + RRSIG) +/// so the caller can verify the DNSKEY RRset is self-signed. +async fn fetch_dnskeys( + zone: &str, + cache: &RwLock, + root_hints: &[std::net::SocketAddr], + stats: &Mutex, +) -> Vec { + if let Some(pkt) = cache.read().unwrap().lookup(zone, QueryType::DNSKEY) { + stats.lock().unwrap().dnskey_cache_hits += 1; + trace!( + "dnssec: fetch_dnskeys('{}') cache hit — {} records", + zone, + pkt.answers.len() + ); + return pkt.answers; + } + + trace!("dnssec: fetch_dnskeys('{}') cache miss — resolving", zone); + stats.lock().unwrap().dnskey_fetches += 1; + if let Ok(pkt) = + crate::recursive::resolve_iterative(zone, QueryType::DNSKEY, cache, root_hints, 0, 0).await + { + cache.write().unwrap().insert(zone, QueryType::DNSKEY, &pkt); + return pkt.answers; + } + + Vec::new() +} + +async fn fetch_ds( + child: &str, + cache: &RwLock, + root_hints: &[std::net::SocketAddr], + stats: &Mutex, +) -> Vec { + if let Some(pkt) = cache.read().unwrap().lookup(child, QueryType::DS) { + stats.lock().unwrap().ds_cache_hits += 1; + return pkt + .answers + .into_iter() + .filter(|r| matches!(r, DnsRecord::DS { .. })) + .collect(); + } + + stats.lock().unwrap().ds_fetches += 1; + if let Ok(pkt) = + crate::recursive::resolve_iterative(child, QueryType::DS, cache, root_hints, 0, 0).await + { + cache.write().unwrap().insert(child, QueryType::DS, &pkt); + return pkt + .answers + .into_iter() + .filter(|r| matches!(r, DnsRecord::DS { .. })) + .collect(); + } + + Vec::new() +} + +// -- Crypto primitives -- + +pub fn compute_key_tag(flags: u16, protocol: u8, algorithm: u8, public_key: &[u8]) -> u16 { + // RFC 4034 Appendix B: sum all 16-bit words of DNSKEY RDATA + let mut rdata = Vec::with_capacity(4 + public_key.len()); + rdata.push((flags >> 8) as u8); + rdata.push((flags & 0xFF) as u8); + rdata.push(protocol); + rdata.push(algorithm); + rdata.extend_from_slice(public_key); + + let mut ac: u32 = 0; + for (i, &byte) in rdata.iter().enumerate() { + if i % 2 == 0 { + ac += (byte as u32) << 8; + } else { + ac += byte as u32; + } + } + ac += (ac >> 16) & 0xFFFF; + (ac & 0xFFFF) as u16 +} + +pub fn verify_signature(algorithm: u8, public_key: &[u8], signed_data: &[u8], sig: &[u8]) -> bool { + match algorithm { + 8 => verify_rsa_sha256(public_key, signed_data, sig), + 13 => verify_ecdsa_p256(public_key, signed_data, sig), + 15 => verify_ed25519(public_key, signed_data, sig), + _ => { + debug!("dnssec: unsupported algorithm {}", algorithm); + false + } + } +} + +fn verify_rsa_sha256(public_key: &[u8], signed_data: &[u8], sig: &[u8]) -> bool { + let der = match rsa_dnskey_to_der(public_key) { + Some(d) => d, + None => return false, + }; + let key = signature::UnparsedPublicKey::new(&signature::RSA_PKCS1_2048_8192_SHA256, &der); + key.verify(signed_data, sig).is_ok() +} + +fn verify_ecdsa_p256(public_key: &[u8], signed_data: &[u8], sig: &[u8]) -> bool { + if public_key.len() != 64 || sig.len() != 64 { + return false; + } + // Ring expects uncompressed point: 0x04 + x(32) + y(32) + let mut uncompressed = Vec::with_capacity(65); + uncompressed.push(0x04); + uncompressed.extend_from_slice(public_key); + + let key = signature::UnparsedPublicKey::new(&signature::ECDSA_P256_SHA256_FIXED, &uncompressed); + key.verify(signed_data, sig).is_ok() +} + +fn verify_ed25519(public_key: &[u8], signed_data: &[u8], sig: &[u8]) -> bool { + if public_key.len() != 32 || sig.len() != 64 { + return false; + } + let key = signature::UnparsedPublicKey::new(&signature::ED25519, public_key); + key.verify(signed_data, sig).is_ok() +} + +/// Convert RFC 3110 RSA public key to DER-encoded RSAPublicKey (PKCS#1) +fn rsa_dnskey_to_der(public_key: &[u8]) -> Option> { + if public_key.is_empty() { + return None; + } + + // RFC 3110: first byte is exponent length (if non-zero) or 0 followed by 2-byte length + let (exp_len, exp_start) = if public_key[0] == 0 { + if public_key.len() < 3 { + return None; + } + let len = u16::from_be_bytes([public_key[1], public_key[2]]) as usize; + (len, 3) + } else { + (public_key[0] as usize, 1) + }; + + if public_key.len() < exp_start + exp_len { + return None; + } + + let exponent = &public_key[exp_start..exp_start + exp_len]; + let modulus = &public_key[exp_start + exp_len..]; + + if modulus.is_empty() { + return None; + } + + // Build ASN.1 DER: SEQUENCE { INTEGER modulus, INTEGER exponent } + let mod_der = asn1_integer(modulus); + let exp_der = asn1_integer(exponent); + + let seq_content_len = mod_der.len() + exp_der.len(); + let mut der = Vec::with_capacity(4 + seq_content_len); + der.push(0x30); // SEQUENCE tag + der.extend(asn1_length(seq_content_len)); + der.extend(&mod_der); + der.extend(&exp_der); + + Some(der) +} + +fn asn1_integer(bytes: &[u8]) -> Vec { + // Strip leading zeros but keep at least one byte + let stripped = match bytes.iter().position(|&b| b != 0) { + Some(pos) => &bytes[pos..], + None => &[0], + }; + + // Add leading zero if high bit set (to keep it positive) + let needs_pad = stripped[0] & 0x80 != 0; + let len = stripped.len() + if needs_pad { 1 } else { 0 }; + + let mut result = Vec::with_capacity(2 + len); + result.push(0x02); // INTEGER tag + result.extend(asn1_length(len)); + if needs_pad { + result.push(0x00); + } + result.extend(stripped); + result +} + +fn asn1_length(len: usize) -> Vec { + if len < 128 { + vec![len as u8] + } else if len < 256 { + vec![0x81, len as u8] + } else { + vec![0x82, (len >> 8) as u8, (len & 0xFF) as u8] + } +} + +pub fn verify_ds(ds: &DnsRecord, dnskey: &DnsRecord, owner: &str) -> bool { + if let ( + DnsRecord::DS { + key_tag: ds_tag, + algorithm: ds_algo, + digest_type, + digest, + .. + }, + DnsRecord::DNSKEY { + flags, + protocol, + algorithm: dk_algo, + public_key, + .. + }, + ) = (ds, dnskey) + { + // Key tag and algorithm must match + let computed_tag = compute_key_tag(*flags, *protocol, *dk_algo, public_key); + if computed_tag != *ds_tag || dk_algo != ds_algo { + return false; + } + + // Compute digest: SHA-256(owner_wire + DNSKEY_RDATA) + let owner_wire = name_to_wire(owner); + let mut dnskey_rdata = Vec::with_capacity(4 + public_key.len()); + dnskey_rdata.push((*flags >> 8) as u8); + dnskey_rdata.push((*flags & 0xFF) as u8); + dnskey_rdata.push(*protocol); + dnskey_rdata.push(*dk_algo); + dnskey_rdata.extend_from_slice(public_key); + + let mut input = Vec::with_capacity(owner_wire.len() + dnskey_rdata.len()); + input.extend(&owner_wire); + input.extend(&dnskey_rdata); + + match *digest_type { + 2 => { + // SHA-256 + let computed = digest::digest(&digest::SHA256, &input); + computed.as_ref() == digest.as_slice() + } + 4 => { + // SHA-384 + let computed = digest::digest(&digest::SHA384, &input); + computed.as_ref() == digest.as_slice() + } + _ => false, + } + } else { + false + } +} + +// -- Canonical wire format -- + +pub fn name_to_wire(name: &str) -> Vec { + let mut wire = Vec::with_capacity(name.len() + 2); + if name == "." || name.is_empty() { + wire.push(0); + return wire; + } + for label in name.split('.') { + if label.is_empty() { + continue; + } + wire.push(label.len() as u8); + for &b in label.as_bytes() { + wire.push(b.to_ascii_lowercase()); + } + } + wire.push(0); + wire +} + +pub fn build_signed_data(rrsig: &DnsRecord, rrset: &[&DnsRecord]) -> Vec { + let mut data = Vec::with_capacity(256); + + if let DnsRecord::RRSIG { + type_covered, + algorithm, + labels, + original_ttl, + expiration, + inception, + key_tag, + signer_name, + .. + } = rrsig + { + // RRSIG RDATA (without signature) + data.extend(&type_covered.to_be_bytes()); + data.push(*algorithm); + data.push(*labels); + data.extend(&original_ttl.to_be_bytes()); + data.extend(&expiration.to_be_bytes()); + data.extend(&inception.to_be_bytes()); + data.extend(&key_tag.to_be_bytes()); + data.extend(name_to_wire(signer_name)); + + // Sort RRset records by canonical wire form + let mut canonical_records: Vec> = rrset + .iter() + .map(|r| record_to_canonical_wire(r, *original_ttl)) + .collect(); + canonical_records.sort(); + + for rec_wire in &canonical_records { + data.extend(rec_wire); + } + } + + data +} + +fn record_to_canonical_wire(record: &DnsRecord, original_ttl: u32) -> Vec { + let mut wire = Vec::with_capacity(128); + + // Owner name (lowercased, uncompressed) + wire.extend(name_to_wire(record.domain())); + + // Type + wire.extend(&record.query_type().to_num().to_be_bytes()); + + // Class IN + wire.extend(&1u16.to_be_bytes()); + + // Original TTL (from RRSIG, not the record's current TTL) + wire.extend(&original_ttl.to_be_bytes()); + + // RDATA — write the record to a temporary buffer to get the canonical RDATA + let rdata = record_rdata_canonical(record); + wire.extend(&(rdata.len() as u16).to_be_bytes()); + wire.extend(&rdata); + + wire +} + +fn record_rdata_canonical(record: &DnsRecord) -> Vec { + match record { + DnsRecord::A { addr, .. } => addr.octets().to_vec(), + DnsRecord::AAAA { addr, .. } => addr.octets().to_vec(), + DnsRecord::NS { host, .. } => name_to_wire(host), + DnsRecord::CNAME { host, .. } => name_to_wire(host), + DnsRecord::MX { priority, host, .. } => { + let mut rdata = Vec::with_capacity(2 + host.len() + 2); + rdata.extend(&priority.to_be_bytes()); + rdata.extend(name_to_wire(host)); + rdata + } + DnsRecord::DNSKEY { + flags, + protocol, + algorithm, + public_key, + .. + } => { + let mut rdata = Vec::with_capacity(4 + public_key.len()); + rdata.extend(&flags.to_be_bytes()); + rdata.push(*protocol); + rdata.push(*algorithm); + rdata.extend(public_key); + rdata + } + DnsRecord::DS { + key_tag, + algorithm, + digest_type, + digest, + .. + } => { + let mut rdata = Vec::with_capacity(4 + digest.len()); + rdata.extend(&key_tag.to_be_bytes()); + rdata.push(*algorithm); + rdata.push(*digest_type); + rdata.extend(digest); + rdata + } + DnsRecord::NSEC { + next_domain, + type_bitmap, + .. + } => { + let wire = name_to_wire(next_domain); + let mut rdata = Vec::with_capacity(wire.len() + type_bitmap.len()); + rdata.extend(&wire); + rdata.extend(type_bitmap); + rdata + } + DnsRecord::NSEC3 { + hash_algorithm, + flags, + iterations, + salt, + next_hashed_owner, + type_bitmap, + .. + } => { + let mut rdata = + Vec::with_capacity(6 + salt.len() + next_hashed_owner.len() + type_bitmap.len()); + rdata.push(*hash_algorithm); + rdata.push(*flags); + rdata.extend(&iterations.to_be_bytes()); + rdata.push(salt.len() as u8); + rdata.extend(salt); + rdata.push(next_hashed_owner.len() as u8); + rdata.extend(next_hashed_owner); + rdata.extend(type_bitmap); + rdata + } + DnsRecord::UNKNOWN { data, .. } => data.clone(), + DnsRecord::RRSIG { .. } => Vec::new(), + } +} + +fn group_rrsets(records: &[DnsRecord]) -> Vec<(String, QueryType, Vec<&DnsRecord>)> { + let mut groups: Vec<(String, QueryType, Vec<&DnsRecord>)> = Vec::new(); + for record in records { + if matches!(record, DnsRecord::RRSIG { .. }) { + continue; + } + let domain = record.domain().to_lowercase(); + let qtype = record.query_type(); + if let Some(group) = groups + .iter_mut() + .find(|(d, t, _)| *d == domain && *t == qtype) + { + group.2.push(record); + } else { + groups.push((domain, qtype, vec![record])); + } + } + groups +} + +fn is_rrsig_time_valid(expiration: u32, inception: u32) -> bool { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as u32; + now >= inception && now <= expiration +} + +// -- NSEC/NSEC3 denial of existence -- + +pub fn type_bitmap_contains(bitmap: &[u8], qtype: u16) -> bool { + let target_window = (qtype / 256) as u8; + let target_bit = (qtype % 256) as u8; + let byte_offset = (target_bit / 8) as usize; + let bit_mask = 0x80 >> (target_bit % 8); + + let mut pos = 0; + while pos + 2 <= bitmap.len() { + let window = bitmap[pos]; + let bmap_len = bitmap[pos + 1] as usize; + if pos + 2 + bmap_len > bitmap.len() { + break; + } + if window == target_window && byte_offset < bmap_len { + return bitmap[pos + 2 + byte_offset] & bit_mask != 0; + } + pos += 2 + bmap_len; + } + false +} + +fn canonical_dns_name_order(a: &str, b: &str) -> std::cmp::Ordering { + // RFC 4034 §6.1: compare labels right-to-left, case-insensitive. + // Two-phase: zip compares common labels from the root, then label count + // breaks ties (shorter name sorts first, e.g., "com" < "a.com"). + let a_iter = a.rsplit('.').filter(|l| !l.is_empty()); + let b_iter = b.rsplit('.').filter(|l| !l.is_empty()); + + for (la, lb) in a_iter.zip(b_iter) { + match la + .as_bytes() + .iter() + .map(|b| b.to_ascii_lowercase()) + .cmp(lb.as_bytes().iter().map(|b| b.to_ascii_lowercase())) + { + std::cmp::Ordering::Equal => continue, + other => return other, + } + } + + let a_count = a.split('.').filter(|l| !l.is_empty()).count(); + let b_count = b.split('.').filter(|l| !l.is_empty()).count(); + a_count.cmp(&b_count) +} + +fn nsec_covers_name(owner: &str, next: &str, qname: &str) -> bool { + use std::cmp::Ordering; + + let on = canonical_dns_name_order(owner, next); + let qo = canonical_dns_name_order(qname, owner); + let qn = canonical_dns_name_order(qname, next); + if matches!(on, Ordering::Greater | Ordering::Equal) { + qo == Ordering::Greater || qn == Ordering::Less + } else { + qo == Ordering::Greater && qn == Ordering::Less + } +} + +/// RFC 4035 §5.4: compute the closest encloser, then derive the wildcard name. +fn closest_encloser(qname: &str, zone_nsecs: &[&DnsRecord]) -> Option { + let labels: Vec<&str> = qname.split('.').filter(|l| !l.is_empty()).collect(); + // Walk from longest candidate down: qname itself, then parent, then grandparent... + for i in 0..labels.len() { + let candidate: String = labels[i..].join("."); + // Closest encloser must match an NSEC owner exactly + let is_owner = zone_nsecs.iter().any(|r| { + if let DnsRecord::NSEC { domain, .. } = r { + domain.eq_ignore_ascii_case(&candidate) + } else { + false + } + }); + if is_owner { + return Some(candidate); + } + } + None +} + +fn nsec_proves_nodata(owner: &str, qname: &str, bitmap: &[u8], qtype: u16) -> bool { + owner.eq_ignore_ascii_case(qname) + && !type_bitmap_contains(bitmap, qtype) + && !type_bitmap_contains(bitmap, QueryType::CNAME.to_num()) +} + +/// RFC 9276 recommends 0 iterations; we reject anything above this as a DoS vector. +const MAX_NSEC3_ITERATIONS: u16 = 500; + +fn nsec3_hash(name: &str, algorithm: u8, iterations: u16, salt: &[u8]) -> Option> { + if algorithm != 1 { + return None; // Only SHA-1 (algorithm 1) defined + } + if iterations > MAX_NSEC3_ITERATIONS { + return None; + } + + let wire_name = name_to_wire(name); + let mut buf = Vec::with_capacity(wire_name.len() + salt.len()); + buf.extend(&wire_name); + buf.extend(salt); + + let mut hash = digest::digest(&digest::SHA1_FOR_LEGACY_USE_ONLY, &buf); + + for _ in 0..iterations { + buf.clear(); + buf.extend(hash.as_ref()); + buf.extend(salt); + hash = digest::digest(&digest::SHA1_FOR_LEGACY_USE_ONLY, &buf); + } + + Some(hash.as_ref().to_vec()) +} + +fn base32hex_decode(input: &str) -> Option> { + // Lookup table: ASCII byte -> base32hex value (0xFF = invalid) + const LUT: [u8; 128] = { + let mut t = [0xFFu8; 128]; + // 0-9 -> 0-9 + t[b'0' as usize] = 0; + t[b'1' as usize] = 1; + t[b'2' as usize] = 2; + t[b'3' as usize] = 3; + t[b'4' as usize] = 4; + t[b'5' as usize] = 5; + t[b'6' as usize] = 6; + t[b'7' as usize] = 7; + t[b'8' as usize] = 8; + t[b'9' as usize] = 9; + // A-V -> 10-31 (uppercase) + t[b'A' as usize] = 10; + t[b'B' as usize] = 11; + t[b'C' as usize] = 12; + t[b'D' as usize] = 13; + t[b'E' as usize] = 14; + t[b'F' as usize] = 15; + t[b'G' as usize] = 16; + t[b'H' as usize] = 17; + t[b'I' as usize] = 18; + t[b'J' as usize] = 19; + t[b'K' as usize] = 20; + t[b'L' as usize] = 21; + t[b'M' as usize] = 22; + t[b'N' as usize] = 23; + t[b'O' as usize] = 24; + t[b'P' as usize] = 25; + t[b'Q' as usize] = 26; + t[b'R' as usize] = 27; + t[b'S' as usize] = 28; + t[b'T' as usize] = 29; + t[b'U' as usize] = 30; + t[b'V' as usize] = 31; + // a-v -> 10-31 (lowercase) + t[b'a' as usize] = 10; + t[b'b' as usize] = 11; + t[b'c' as usize] = 12; + t[b'd' as usize] = 13; + t[b'e' as usize] = 14; + t[b'f' as usize] = 15; + t[b'g' as usize] = 16; + t[b'h' as usize] = 17; + t[b'i' as usize] = 18; + t[b'j' as usize] = 19; + t[b'k' as usize] = 20; + t[b'l' as usize] = 21; + t[b'm' as usize] = 22; + t[b'n' as usize] = 23; + t[b'o' as usize] = 24; + t[b'p' as usize] = 25; + t[b'q' as usize] = 26; + t[b'r' as usize] = 27; + t[b's' as usize] = 28; + t[b't' as usize] = 29; + t[b'u' as usize] = 30; + t[b'v' as usize] = 31; + t + }; + + let mut bits = 0u64; + let mut bit_count = 0u8; + let mut output = Vec::with_capacity(input.len() * 5 / 8); + + for &ch in input.as_bytes() { + if ch == b'=' { + break; + } + if ch >= 128 { + return None; + } + let val = LUT[ch as usize]; + if val == 0xFF { + return None; + } + bits = (bits << 5) | val as u64; + bit_count += 5; + if bit_count >= 8 { + bit_count -= 8; + output.push((bits >> bit_count) as u8); + bits &= (1 << bit_count) - 1; + } + } + Some(output) +} + +fn nsec3_owner_hash(domain: &str) -> Option> { + let first_label = domain.split('.').next()?; + base32hex_decode(first_label) +} + +fn nsec3_hash_in_range(owner_hash: &[u8], next_hash: &[u8], target_hash: &[u8]) -> bool { + if owner_hash < next_hash { + target_hash > owner_hash && target_hash < next_hash + } else { + // Wrap-around + target_hash > owner_hash || target_hash < next_hash + } +} + +/// Check if any pre-decoded NSEC3 record's range covers the target hash. +fn nsec3_any_covers(decoded: &[(Vec, &DnsRecord)], target: &[u8]) -> bool { + decoded.iter().any(|(oh, r)| { + if let DnsRecord::NSEC3 { + next_hashed_owner, .. + } = r + { + nsec3_hash_in_range(oh, next_hashed_owner, target) + } else { + false + } + }) +} + +/// Verify that authority-section NSEC/NSEC3 RRSIGs are cryptographically valid. +fn verify_authority_rrsigs( + authorities: &[DnsRecord], + all_rrsigs: &[&DnsRecord], + denial_type: QueryType, + cache: &RwLock, +) -> bool { + // Group authority denial records into RRsets + let denial_records: Vec = authorities + .iter() + .filter(|r| r.query_type() == denial_type) + .cloned() + .collect(); + let denial_rrsets = group_rrsets(&denial_records); + + for (name, qtype, rrset) in &denial_rrsets { + let covering_rrsig = all_rrsigs.iter().find(|r| { + if let DnsRecord::RRSIG { + domain, + type_covered, + .. + } = r + { + domain.eq_ignore_ascii_case(name) && QueryType::from_num(*type_covered) == *qtype + } else { + false + } + }); + + let rrsig = match covering_rrsig { + Some(r) => r, + None => return false, + }; + + if let DnsRecord::RRSIG { + signer_name, + key_tag, + algorithm, + signature, + expiration, + inception, + .. + } = rrsig + { + if !is_rrsig_time_valid(*expiration, *inception) { + return false; + } + + // Look up signer DNSKEY in cache + let dnskeys = match cache.read().unwrap().lookup(signer_name, QueryType::DNSKEY) { + Some(pkt) => pkt.answers, + None => return false, + }; + + let signed_data = build_signed_data(rrsig, rrset); + let verified = dnskeys.iter().any(|dk| { + if let DnsRecord::DNSKEY { + flags, + protocol, + algorithm: dk_algo, + public_key, + .. + } = dk + { + if dk_algo != algorithm { + return false; + } + let tag = compute_key_tag(*flags, *protocol, *dk_algo, public_key); + if tag != *key_tag { + return false; + } + verify_signature(*algorithm, public_key, &signed_data, signature) + } else { + false + } + }); + + if !verified { + return false; + } + } + } + + !denial_rrsets.is_empty() +} + +/// Validate denial of existence using NSEC or NSEC3 records from authority section. +fn validate_denial( + authorities: &[DnsRecord], + all_rrsigs: &[&DnsRecord], + qname: &str, + qtype: u16, + is_nxdomain: bool, + cache: &RwLock, +) -> DnssecStatus { + // Try NSEC first + let nsecs: Vec<&DnsRecord> = authorities + .iter() + .filter(|r| matches!(r, DnsRecord::NSEC { .. })) + .collect(); + + if !nsecs.is_empty() { + if !verify_authority_rrsigs(authorities, all_rrsigs, QueryType::NSEC, cache) { + debug!("dnssec: NSEC authority RRSIGs failed verification"); + return DnssecStatus::Indeterminate; + } + + if is_nxdomain { + // RFC 4035 §5.4: need (1) NSEC covering the name gap AND (2) NSEC proving + // no wildcard at *.closest_encloser + let name_covered = nsecs.iter().any(|r| { + if let DnsRecord::NSEC { + domain, + next_domain, + .. + } = r + { + nsec_covers_name(domain, next_domain, qname) + } else { + false + } + }); + + let wildcard_denied = if let Some(ce) = closest_encloser(qname, &nsecs) { + let wildcard = format!("*.{}", ce); + // Wildcard must either be covered by a gap or matched with the type absent + nsecs.iter().any(|r| { + if let DnsRecord::NSEC { + domain, + next_domain, + .. + } = r + { + nsec_covers_name(domain, next_domain, &wildcard) + || domain.eq_ignore_ascii_case(&wildcard) + } else { + false + } + }) + } else { + // No closest encloser found — can't prove wildcard absence, + // but some zones don't use wildcards; accept name coverage alone + true + }; + + if name_covered && wildcard_denied { + debug!("dnssec: NSEC proves NXDOMAIN for '{}'", qname); + return DnssecStatus::Secure; + } + } else { + // NODATA — name exists but type doesn't + let nodata_proven = nsecs.iter().any(|r| { + if let DnsRecord::NSEC { + domain, + type_bitmap, + .. + } = r + { + nsec_proves_nodata(domain, qname, type_bitmap, qtype) + } else { + false + } + }); + if nodata_proven { + debug!("dnssec: NSEC proves NODATA for '{}' type {}", qname, qtype); + return DnssecStatus::Secure; + } + } + + return DnssecStatus::Bogus; + } + + // Try NSEC3 + let nsec3s: Vec<&DnsRecord> = authorities + .iter() + .filter(|r| matches!(r, DnsRecord::NSEC3 { .. })) + .collect(); + + if !nsec3s.is_empty() { + if !verify_authority_rrsigs(authorities, all_rrsigs, QueryType::NSEC3, cache) { + debug!("dnssec: NSEC3 authority RRSIGs failed verification"); + return DnssecStatus::Indeterminate; + } + + // Get hash params from first NSEC3 + if let Some(DnsRecord::NSEC3 { + hash_algorithm, + iterations, + salt, + .. + }) = nsec3s.first().copied() + { + let qname_hash = match nsec3_hash(qname, *hash_algorithm, *iterations, salt) { + Some(h) => h, + None => return DnssecStatus::Indeterminate, + }; + + // Pre-decode all NSEC3 owner hashes once + let decoded: Vec<(Vec, &DnsRecord)> = nsec3s + .iter() + .filter_map(|r| { + if let DnsRecord::NSEC3 { domain, .. } = r { + match nsec3_owner_hash(domain) { + Some(h) => Some((h, *r)), + None => { + trace!("dnssec: malformed NSEC3 owner '{}' — skipping", domain); + None + } + } + } else { + None + } + }) + .collect(); + + if is_nxdomain { + // RFC 5155 §8.4: need (1) closest encloser match, (2) next closer covered, + // (3) wildcard at closest encloser denied + let labels: Vec<&str> = qname.split('.').filter(|l| !l.is_empty()).collect(); + + // Pre-compute hashes for all ancestor names + wildcards + let mut ancestor_hashes: Vec>> = Vec::with_capacity(labels.len()); + for i in 0..labels.len() { + let name: String = labels[i..].join("."); + ancestor_hashes.push(nsec3_hash(&name, *hash_algorithm, *iterations, salt)); + } + + let mut proven = false; + for i in 1..labels.len() { + let ce_hash = match &ancestor_hashes[i] { + Some(h) => h, + None => continue, + }; + + // (1) Closest encloser: exact hash match + if !decoded.iter().any(|(oh, _)| oh == ce_hash) { + continue; + } + + // (2) Next closer name covered by range + // ancestor_hashes[i-1] is the hash of labels[i-1..] (one label prepended to CE) + let nc_hash = match &ancestor_hashes[i - 1] { + Some(h) => h, + None => continue, + }; + if !nsec3_any_covers(&decoded, nc_hash) { + continue; + } + + // (3) Wildcard at closest encloser denied + let wildcard = format!("*.{}", labels[i..].join(".")); + let wc_hash = match nsec3_hash(&wildcard, *hash_algorithm, *iterations, salt) { + Some(h) => h, + None => continue, + }; + if nsec3_any_covers(&decoded, &wc_hash) { + proven = true; + break; + } + } + + if proven { + debug!("dnssec: NSEC3 proves NXDOMAIN for '{}'", qname); + return DnssecStatus::Secure; + } + } else { + // NODATA — exact hash match with type not in bitmap + let nodata = decoded.iter().any(|(oh, r)| { + if let DnsRecord::NSEC3 { type_bitmap, .. } = r { + oh == &qname_hash + && !type_bitmap_contains(type_bitmap, qtype) + && !type_bitmap_contains(type_bitmap, QueryType::CNAME.to_num()) + } else { + false + } + }); + if nodata { + debug!("dnssec: NSEC3 proves NODATA for '{}' type {}", qname, qtype); + return DnssecStatus::Secure; + } + } + + return DnssecStatus::Bogus; + } + } + + DnssecStatus::Indeterminate +} + +fn parent_zone(zone: &str) -> String { + if zone == "." || zone.is_empty() { + return ".".into(); + } + match zone.find('.') { + Some(pos) => { + let parent = &zone[pos + 1..]; + if parent.is_empty() { + ".".into() + } else { + parent.into() + } + } + None => ".".into(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn key_tag_root_ksk() { + let tag = compute_key_tag(ROOT_KSK_FLAGS, 3, ROOT_KSK_ALGORITHM, ROOT_KSK_PUBLIC_KEY); + assert_eq!(tag, ROOT_KSK_KEY_TAG); + } + + #[test] + fn name_to_wire_root() { + assert_eq!(name_to_wire("."), vec![0]); + assert_eq!(name_to_wire(""), vec![0]); + } + + #[test] + fn name_to_wire_domain() { + let wire = name_to_wire("Example.COM"); + assert_eq!( + wire, + vec![7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0] + ); + } + + #[test] + fn parent_zone_cases() { + assert_eq!(parent_zone("example.com"), "com"); + assert_eq!(parent_zone("com"), "."); + assert_eq!(parent_zone("."), "."); + assert_eq!(parent_zone("sub.example.com"), "example.com"); + } + + #[test] + fn ds_verification() { + // Verify DS digest: SHA-256(owner_wire + DNSKEY_RDATA) must match DS.digest + let dk = DnsRecord::DNSKEY { + domain: "test.example".into(), + flags: 257, + protocol: 3, + algorithm: 8, + public_key: vec![1, 2, 3, 4], + ttl: 3600, + }; + + // Compute expected digest + let owner_wire = name_to_wire("test.example"); + let mut dnskey_rdata = vec![1u8, 1, 3, 8]; // flags=257, proto=3, algo=8 + dnskey_rdata.extend(&[1, 2, 3, 4]); + + let mut input = Vec::new(); + input.extend(&owner_wire); + input.extend(&dnskey_rdata); + let expected = ring::digest::digest(&ring::digest::SHA256, &input); + + let ds = DnsRecord::DS { + domain: "test.example".into(), + key_tag: compute_key_tag(257, 3, 8, &[1, 2, 3, 4]), + algorithm: 8, + digest_type: 2, + digest: expected.as_ref().to_vec(), + ttl: 3600, + }; + + assert!(verify_ds(&ds, &dk, "test.example")); + } + + #[test] + fn rsa_der_conversion() { + // Minimal RSA key: 3-byte exponent (65537 = 0x010001), 4-byte modulus + let mut key = vec![3u8]; // exp_len = 3 + key.extend(&[0x01, 0x00, 0x01]); // exponent = 65537 + key.extend(&[0xFF, 0xAA, 0xBB, 0xCC]); // modulus + + let der = rsa_dnskey_to_der(&key).unwrap(); + // Should be a valid ASN.1 SEQUENCE containing two INTEGERs + assert_eq!(der[0], 0x30); // SEQUENCE + } + + #[test] + fn group_rrsets_basic() { + let records = vec![ + DnsRecord::A { + domain: "example.com".into(), + addr: "1.2.3.4".parse().unwrap(), + ttl: 300, + }, + DnsRecord::A { + domain: "example.com".into(), + addr: "5.6.7.8".parse().unwrap(), + ttl: 300, + }, + ]; + let groups = group_rrsets(&records); + assert_eq!(groups.len(), 1); + assert_eq!(groups[0].2.len(), 2); + } + + #[test] + fn type_bitmap_contains_a() { + // Window 0, bitmap: A(1) + NS(2) + SOA(6) + MX(15) + AAAA(28) + // Byte 0: bits 1,2 set = 0x60; byte 0 also has SOA(6) = 0x02 → 0x62 + // Actually: bit N means type N. Byte 0 covers types 0-7, byte 1 covers 8-15, etc. + // Type 1 (A) = byte 0, bit 6 (0x40); Type 2 (NS) = byte 0, bit 5 (0x20) + // Window 0, length 4, bitmap covers types 0-31 + let bitmap = vec![ + 0u8, 4, // window 0, 4 bytes + 0x62, // byte 0: types 1(A), 2(NS), 6(SOA) → bits 6,5,1 → 0x40|0x20|0x02 + 0x01, // byte 1: type 15(MX) → bit 0 → 0x01 + 0x00, // byte 2: nothing + 0x08, // byte 3: type 28(AAAA) → bit 3 → 0x08 + ]; + assert!(type_bitmap_contains(&bitmap, 1)); // A + assert!(type_bitmap_contains(&bitmap, 2)); // NS + assert!(type_bitmap_contains(&bitmap, 6)); // SOA + assert!(type_bitmap_contains(&bitmap, 15)); // MX + assert!(type_bitmap_contains(&bitmap, 28)); // AAAA + assert!(!type_bitmap_contains(&bitmap, 5)); // CNAME — not set + assert!(!type_bitmap_contains(&bitmap, 16)); // TXT — not set + } + + #[test] + fn canonical_name_ordering() { + use std::cmp::Ordering; + assert_eq!( + canonical_dns_name_order("a.example.com", "b.example.com"), + Ordering::Less + ); + assert_eq!( + canonical_dns_name_order("z.example.com", "a.example.org"), + Ordering::Less // .com < .org + ); + assert_eq!( + canonical_dns_name_order("example.com", "a.example.com"), + Ordering::Less // shorter sorts first + ); + assert_eq!( + canonical_dns_name_order("example.com", "example.com"), + Ordering::Equal + ); + } + + #[test] + fn nsec_covers_name_basic() { + // gap: alpha.example.com -> gamma.example.com + assert!(nsec_covers_name( + "alpha.example.com", + "gamma.example.com", + "beta.example.com" + )); + assert!(nsec_covers_name( + "alpha.example.com", + "gamma.example.com", + "delta.example.com" + )); + assert!(!nsec_covers_name( + "alpha.example.com", + "gamma.example.com", + "zebra.example.com" + )); + } + + #[test] + fn nsec3_hash_rejects_high_iterations() { + assert!(nsec3_hash("example.com", 1, 500, &[]).is_some()); + assert!(nsec3_hash("example.com", 1, 501, &[]).is_none()); + } + + #[test] + fn closest_encloser_finds_parent() { + let nsec1 = DnsRecord::NSEC { + domain: "example.com".into(), + next_domain: "z.example.com".into(), + type_bitmap: vec![], + ttl: 300, + }; + let nsecs: Vec<&DnsRecord> = vec![&nsec1]; + // foo.example.com doesn't exist; closest encloser is example.com (the NSEC owner) + assert_eq!( + closest_encloser("foo.example.com", &nsecs), + Some("example.com".into()) + ); + // example.com is itself an NSEC owner, so it IS a closest encloser + assert_eq!( + closest_encloser("example.com", &nsecs), + Some("example.com".into()) + ); + // nothing.org has no matching owner + assert_eq!(closest_encloser("nothing.org", &nsecs), None); + } + + #[test] + fn nsec_nodata_proof() { + // NSEC at example.com with A and NS in bitmap, but not AAAA + let bitmap = vec![0u8, 1, 0x62]; // A(1), NS(2), SOA(6) + assert!(nsec_proves_nodata( + "example.com", + "example.com", + &bitmap, + 28 + )); // AAAA not in bitmap + assert!(!nsec_proves_nodata( + "example.com", + "example.com", + &bitmap, + 1 + )); // A IS in bitmap + } + + #[test] + fn nsec3_hash_basic() { + // Hash with 0 iterations, empty salt + let hash = nsec3_hash("example.com", 1, 0, &[]).unwrap(); + assert_eq!(hash.len(), 20); // SHA-1 output + } + + #[test] + fn nsec3_range_check() { + assert!(nsec3_hash_in_range(&[1], &[3], &[2])); // 1 < 2 < 3 + assert!(!nsec3_hash_in_range(&[1], &[3], &[4])); // 4 not in range + // Wrap-around: [250] -> [10] covers [255] and [5] + assert!(nsec3_hash_in_range(&[250], &[10], &[255])); + assert!(nsec3_hash_in_range(&[250], &[10], &[5])); + assert!(!nsec3_hash_in_range(&[250], &[10], &[100])); // not in wrapped range + } + + #[test] + fn base32hex_decode_known_values() { + // "00000000" in base32hex = all zeros + assert_eq!(base32hex_decode("00000000").unwrap(), vec![0, 0, 0, 0, 0]); + // "10" = 0x08 (1 << 3) + assert_eq!(base32hex_decode("10").unwrap(), vec![0x08]); + // case-insensitive: "vv" = "VV" = [0xFF, 0x80..] -> 31<<5|31 = 0x03FF -> bytes [0xFF] + assert_eq!(base32hex_decode("VV"), base32hex_decode("vv")); + // invalid char + assert!(base32hex_decode("!!").is_none()); + } +} diff --git a/src/forward.rs b/src/forward.rs index 2e11a4b..c410987 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -48,7 +48,7 @@ pub async fn forward_query( } } -async fn forward_udp( +pub(crate) async fn forward_udp( query: &DnsPacket, upstream: SocketAddr, timeout_duration: Duration, diff --git a/src/lib.rs b/src/lib.rs index dc4ce2b..be2eee9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ pub mod buffer; pub mod cache; pub mod config; pub mod ctx; +pub mod dnssec; pub mod forward; pub mod header; pub mod lan; @@ -13,6 +14,7 @@ pub mod proxy; pub mod query_log; pub mod question; pub mod record; +pub mod recursive; pub mod service_store; pub mod stats; pub mod system_dns; diff --git a/src/main.rs b/src/main.rs index 7e739aa..c08fb31 100644 --- a/src/main.rs +++ b/src/main.rs @@ -199,6 +199,10 @@ async fn main() -> numa::Result<()> { config_dir: numa::config_dir(), data_dir: numa::data_dir(), tls_config: initial_tls, + upstream_mode: config.upstream.mode, + root_hints: numa::recursive::parse_root_hints(&config.upstream.root_hints), + dnssec_enabled: config.dnssec.enabled, + dnssec_strict: config.dnssec.strict, }); let zone_count: usize = ctx.zone_map.values().map(|m| m.len()).sum(); @@ -276,7 +280,15 @@ async fn main() -> numa::Result<()> { row("DNS", g, &config.server.bind_addr); row("API", g, &api_url); row("Dashboard", g, &api_url); - row("Upstream", g, &upstream_label); + row( + "Upstream", + g, + if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { + "recursive (root hints)" + } else { + &upstream_label + }, + ); row("Zones", g, &format!("{} records", zone_count)); row( "Cache", @@ -336,6 +348,16 @@ async fn main() -> numa::Result<()> { }); } + // Prime TLD cache (recursive mode only) + if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { + let prime_ctx = Arc::clone(&ctx); + let prime_tlds = config.upstream.prime_tlds; + tokio::spawn(async move { + numa::recursive::prime_tld_cache(&prime_ctx.cache, &prime_ctx.root_hints, &prime_tlds) + .await; + }); + } + // Spawn HTTP API server let api_ctx = Arc::clone(&ctx); let api_addr: SocketAddr = format!("{}:{}", config.server.api_bind_addr, api_port).parse()?; diff --git a/src/packet.rs b/src/packet.rs index 2c4c85a..e273ba8 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -4,6 +4,31 @@ use crate::question::{DnsQuestion, QueryType}; use crate::record::DnsRecord; use crate::Result; +/// Recommended EDNS0 UDP payload size (DNS Flag Day 2020) — avoids IP fragmentation. +pub const DEFAULT_EDNS_PAYLOAD: u16 = 1232; + +/// EDNS0 OPT pseudo-record (RFC 6891) +#[derive(Clone, Debug)] +pub struct EdnsOpt { + pub udp_payload_size: u16, + pub extended_rcode: u8, + pub version: u8, + pub do_bit: bool, + pub options: Vec, +} + +impl Default for EdnsOpt { + fn default() -> Self { + EdnsOpt { + udp_payload_size: DEFAULT_EDNS_PAYLOAD, + extended_rcode: 0, + version: 0, + do_bit: false, + options: Vec::new(), + } + } +} + #[derive(Clone, Debug)] pub struct DnsPacket { pub header: DnsHeader, @@ -11,6 +36,7 @@ pub struct DnsPacket { pub answers: Vec, pub authorities: Vec, pub resources: Vec, + pub edns: Option, } impl Default for DnsPacket { @@ -27,6 +53,7 @@ impl DnsPacket { answers: Vec::new(), authorities: Vec::new(), resources: Vec::new(), + edns: None, } } @@ -60,24 +87,53 @@ impl DnsPacket { result.authorities.push(rec); } for _ in 0..result.header.resource_entries { - let rec = DnsRecord::read(buffer)?; - result.resources.push(rec); + // Peek at type field to detect OPT pseudo-records. + // OPT name is always root (0x00), so name byte + type field starts at pos+1. + let peek_pos = buffer.pos(); + let name_byte = buffer.get(peek_pos)?; + let is_opt = if name_byte == 0 { + // Root name (single zero byte) — peek at type + let type_hi = buffer.get(peek_pos + 1)?; + let type_lo = buffer.get(peek_pos + 2)?; + u16::from_be_bytes([type_hi, type_lo]) == 41 + } else { + false + }; + + if is_opt { + // Parse OPT manually to capture the class field (= UDP payload size) + buffer.step(1)?; // skip root name (0x00) + let _ = buffer.read_u16()?; // type (41) + let udp_payload_size = buffer.read_u16()?; // class = UDP payload size + let ttl_field = buffer.read_u32()?; // packed flags + let rdlength = buffer.read_u16()?; + let options = buffer.get_range(buffer.pos(), rdlength as usize)?.to_vec(); + buffer.step(rdlength as usize)?; + + result.edns = Some(EdnsOpt { + udp_payload_size, + extended_rcode: ((ttl_field >> 24) & 0xFF) as u8, + version: ((ttl_field >> 16) & 0xFF) as u8, + do_bit: (ttl_field >> 15) & 1 == 1, + options, + }); + } else { + let rec = DnsRecord::read(buffer)?; + result.resources.push(rec); + } } Ok(result) } pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { - // Count known records without allocating filter Vecs - let answer_count = self.answers.iter().filter(|r| !r.is_unknown()).count() as u16; - let auth_count = self.authorities.iter().filter(|r| !r.is_unknown()).count() as u16; - let res_count = self.resources.iter().filter(|r| !r.is_unknown()).count() as u16; + let edns_count = if self.edns.is_some() { 1u16 } else { 0 }; let mut header = self.header.clone(); header.questions = self.questions.len() as u16; - header.answers = answer_count; - header.authoritative_entries = auth_count; - header.resource_entries = res_count; + header.answers = self.answers.len() as u16; + header.authoritative_entries = self.authorities.len() as u16; + header.resource_entries = self.resources.len() as u16 + edns_count; header.write(buffer)?; @@ -85,19 +141,27 @@ impl DnsPacket { question.write(buffer)?; } for rec in &self.answers { - if !rec.is_unknown() { - rec.write(buffer)?; - } + rec.write(buffer)?; } for rec in &self.authorities { - if !rec.is_unknown() { - rec.write(buffer)?; - } + rec.write(buffer)?; } for rec in &self.resources { - if !rec.is_unknown() { - rec.write(buffer)?; - } + rec.write(buffer)?; + } + + // Write EDNS0 OPT pseudo-record + if let Some(ref edns) = self.edns { + buffer.write_u8(0)?; // root name + buffer.write_u16(QueryType::OPT.to_num())?; // type 41 + buffer.write_u16(edns.udp_payload_size)?; // class = UDP payload size + // TTL = extended_rcode(8) | version(8) | DO(1) | Z(15) + let ttl_field = ((edns.extended_rcode as u32) << 24) + | ((edns.version as u32) << 16) + | (if edns.do_bit { 1u32 << 15 } else { 0 }); + buffer.write_u32(ttl_field)?; + buffer.write_u16(edns.options.len() as u16)?; // RDLENGTH + buffer.write_bytes(&edns.options)?; } Ok(()) @@ -118,5 +182,404 @@ impl DnsPacket { for rec in &self.resources { println!("{:#?}", rec); } + if let Some(ref edns) = self.edns { + println!("EDNS: {:?}", edns); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::header::ResultCode; + + #[test] + fn edns_round_trip() { + let mut pkt = DnsPacket::new(); + pkt.header.id = 0x1234; + pkt.header.response = true; + pkt.header.rescode = ResultCode::NOERROR; + pkt.edns = Some(EdnsOpt { + do_bit: true, + ..Default::default() + }); + + let mut buf = BytePacketBuffer::new(); + pkt.write(&mut buf).unwrap(); + buf.seek(0).unwrap(); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + let edns = parsed.edns.expect("EDNS should be present"); + assert_eq!(edns.udp_payload_size, DEFAULT_EDNS_PAYLOAD); + assert!(edns.do_bit); + assert_eq!(edns.version, 0); + } + + #[test] + fn edns_do_bit_false() { + let mut pkt = DnsPacket::new(); + pkt.header.id = 0x5678; + pkt.header.response = true; + pkt.edns = Some(EdnsOpt { + udp_payload_size: 1232, + do_bit: false, + ..Default::default() + }); + + let mut buf = BytePacketBuffer::new(); + pkt.write(&mut buf).unwrap(); + buf.seek(0).unwrap(); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + + let edns = parsed.edns.expect("EDNS should be present"); + assert_eq!(edns.udp_payload_size, DEFAULT_EDNS_PAYLOAD); + assert!(!edns.do_bit); + } + + #[test] + fn no_edns_by_default() { + let pkt = DnsPacket::new(); + assert!(pkt.edns.is_none()); + } + + #[test] + fn packet_without_edns_round_trips() { + let mut pkt = DnsPacket::new(); + pkt.header.id = 0xABCD; + pkt.header.response = true; + pkt.header.rescode = ResultCode::NOERROR; + pkt.answers.push(crate::record::DnsRecord::A { + domain: "example.com".into(), + addr: "1.2.3.4".parse().unwrap(), + ttl: 300, + }); + + let parsed = packet_round_trip(&pkt); + assert!(parsed.edns.is_none()); + assert_eq!(parsed.answers.len(), 1); + } + + fn packet_round_trip(pkt: &DnsPacket) -> DnsPacket { + let mut buf = BytePacketBuffer::new(); + pkt.write(&mut buf).unwrap(); + let wire_len = buf.pos(); + buf.seek(0).unwrap(); + let parsed = DnsPacket::from_buffer(&mut buf).unwrap(); + // Verify we consumed exactly what was written + assert_eq!( + buf.pos(), + wire_len, + "parse did not consume all written bytes" + ); + parsed + } + + #[test] + fn nxdomain_with_nsec_authority_round_trips() { + use crate::question::DnsQuestion; + use crate::record::DnsRecord; + + let mut pkt = DnsPacket::new(); + pkt.header.id = 0x1111; + pkt.header.response = true; + pkt.header.rescode = ResultCode::NXDOMAIN; + pkt.questions.push(DnsQuestion::new( + "nonexistent.example.com".into(), + QueryType::A, + )); + + pkt.authorities.push(DnsRecord::NSEC { + domain: "alpha.example.com".into(), + next_domain: "gamma.example.com".into(), + type_bitmap: vec![0, 2, 0x40, 0x01], // A + MX + ttl: 3600, + }); + pkt.authorities.push(DnsRecord::RRSIG { + domain: "alpha.example.com".into(), + type_covered: QueryType::NSEC.to_num(), + algorithm: 13, + labels: 3, + original_ttl: 3600, + expiration: 1700000000, + inception: 1690000000, + key_tag: 12345, + signer_name: "example.com".into(), + signature: vec![0xAA; 64], + ttl: 3600, + }); + + // Wildcard denial NSEC + pkt.authorities.push(DnsRecord::NSEC { + domain: "example.com".into(), + next_domain: "alpha.example.com".into(), + type_bitmap: vec![0, 3, 0x62, 0x01, 0x80], // A, NS, SOA, MX, RRSIG + ttl: 3600, + }); + + pkt.edns = Some(EdnsOpt { + do_bit: true, + ..Default::default() + }); + + let parsed = packet_round_trip(&pkt); + + assert_eq!(parsed.header.id, 0x1111); + assert_eq!(parsed.header.rescode, ResultCode::NXDOMAIN); + assert_eq!(parsed.questions.len(), 1); + assert_eq!(parsed.questions[0].name, "nonexistent.example.com"); + assert_eq!(parsed.authorities.len(), 3); + + // Verify NSEC records survived + if let DnsRecord::NSEC { + domain, + next_domain, + type_bitmap, + .. + } = &parsed.authorities[0] + { + assert_eq!(domain, "alpha.example.com"); + assert_eq!(next_domain, "gamma.example.com"); + assert_eq!(type_bitmap, &[0, 2, 0x40, 0x01]); + } else { + panic!("expected NSEC, got {:?}", parsed.authorities[0]); + } + + // Verify RRSIG survived + if let DnsRecord::RRSIG { + type_covered, + signer_name, + signature, + .. + } = &parsed.authorities[1] + { + assert_eq!(*type_covered, QueryType::NSEC.to_num()); + assert_eq!(signer_name, "example.com"); + assert_eq!(signature.len(), 64); + } else { + panic!("expected RRSIG, got {:?}", parsed.authorities[1]); + } + + // Verify EDNS survived + assert!(parsed.edns.as_ref().unwrap().do_bit); + } + + #[test] + fn nxdomain_with_nsec3_authority_round_trips() { + use crate::question::DnsQuestion; + use crate::record::DnsRecord; + + let mut pkt = DnsPacket::new(); + pkt.header.id = 0x2222; + pkt.header.response = true; + pkt.header.rescode = ResultCode::NXDOMAIN; + pkt.questions + .push(DnsQuestion::new("no.example.com".into(), QueryType::AAAA)); + + // Three NSEC3 records (closest encloser, next closer, wildcard) + let salt = vec![0xAB, 0xCD]; + pkt.authorities.push(DnsRecord::NSEC3 { + domain: "ABC123.example.com".into(), + hash_algorithm: 1, + flags: 0, + iterations: 5, + salt: salt.clone(), + next_hashed_owner: vec![ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, + 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, + ], + type_bitmap: vec![0, 2, 0x60, 0x01], // NS, SOA, MX + ttl: 300, + }); + pkt.authorities.push(DnsRecord::NSEC3 { + domain: "DEF456.example.com".into(), + hash_algorithm: 1, + flags: 0, + iterations: 5, + salt: salt.clone(), + next_hashed_owner: vec![0x20; 20], + type_bitmap: vec![0, 1, 0x40], // A + ttl: 300, + }); + pkt.authorities.push(DnsRecord::RRSIG { + domain: "ABC123.example.com".into(), + type_covered: QueryType::NSEC3.to_num(), + algorithm: 8, + labels: 3, + original_ttl: 300, + expiration: 2000000000, + inception: 1600000000, + key_tag: 54321, + signer_name: "example.com".into(), + signature: vec![0xBB; 128], + ttl: 300, + }); + + pkt.edns = Some(EdnsOpt { + do_bit: true, + ..Default::default() + }); + + let parsed = packet_round_trip(&pkt); + + assert_eq!(parsed.header.rescode, ResultCode::NXDOMAIN); + assert_eq!(parsed.authorities.len(), 3); + + // Verify first NSEC3 survived with all fields intact + if let DnsRecord::NSEC3 { + domain, + hash_algorithm, + flags, + iterations, + salt: parsed_salt, + next_hashed_owner, + type_bitmap, + .. + } = &parsed.authorities[0] + { + assert_eq!(domain, "abc123.example.com"); + assert_eq!(*hash_algorithm, 1); + assert_eq!(*flags, 0); + assert_eq!(*iterations, 5); + assert_eq!(parsed_salt, &salt); + assert_eq!(next_hashed_owner.len(), 20); + assert_eq!(type_bitmap, &[0, 2, 0x60, 0x01]); + } else { + panic!("expected NSEC3, got {:?}", parsed.authorities[0]); + } + + // Verify RRSIG covering NSEC3 + if let DnsRecord::RRSIG { + type_covered, + algorithm, + signature, + .. + } = &parsed.authorities[2] + { + assert_eq!(*type_covered, QueryType::NSEC3.to_num()); + assert_eq!(*algorithm, 8); + assert_eq!(signature.len(), 128); + } else { + panic!("expected RRSIG, got {:?}", parsed.authorities[2]); + } + } + + #[test] + fn dnssec_answer_with_rrsig_round_trips() { + use crate::question::DnsQuestion; + use crate::record::DnsRecord; + + let mut pkt = DnsPacket::new(); + pkt.header.id = 0x3333; + pkt.header.response = true; + pkt.header.rescode = ResultCode::NOERROR; + pkt.header.authed_data = true; + pkt.questions + .push(DnsQuestion::new("example.com".into(), QueryType::A)); + + pkt.answers.push(DnsRecord::A { + domain: "example.com".into(), + addr: "93.184.216.34".parse().unwrap(), + ttl: 300, + }); + pkt.answers.push(DnsRecord::RRSIG { + domain: "example.com".into(), + type_covered: QueryType::A.to_num(), + algorithm: 13, + labels: 2, + original_ttl: 300, + expiration: 1700000000, + inception: 1690000000, + key_tag: 11111, + signer_name: "example.com".into(), + signature: vec![0xCC; 64], + ttl: 300, + }); + + // Authority: NS + DS + pkt.authorities.push(DnsRecord::NS { + domain: "example.com".into(), + host: "ns1.example.com".into(), + ttl: 3600, + }); + pkt.authorities.push(DnsRecord::DS { + domain: "example.com".into(), + key_tag: 22222, + algorithm: 8, + digest_type: 2, + digest: vec![0xDD; 32], + ttl: 86400, + }); + + // Additional: glue A + DNSKEY + pkt.resources.push(DnsRecord::A { + domain: "ns1.example.com".into(), + addr: "198.51.100.1".parse().unwrap(), + ttl: 3600, + }); + pkt.resources.push(DnsRecord::DNSKEY { + domain: "example.com".into(), + flags: 257, + protocol: 3, + algorithm: 13, + public_key: vec![0xEE; 64], + ttl: 3600, + }); + + pkt.edns = Some(EdnsOpt { + do_bit: true, + ..Default::default() + }); + + let parsed = packet_round_trip(&pkt); + + assert_eq!(parsed.header.id, 0x3333); + assert!(parsed.header.authed_data); + assert_eq!(parsed.answers.len(), 2); + assert_eq!(parsed.authorities.len(), 2); + assert_eq!(parsed.resources.len(), 2); + + // Verify A record + if let DnsRecord::A { addr, .. } = &parsed.answers[0] { + assert_eq!(addr.to_string(), "93.184.216.34"); + } else { + panic!("expected A"); + } + + // Verify RRSIG in answers + if let DnsRecord::RRSIG { + type_covered, + key_tag, + signer_name, + .. + } = &parsed.answers[1] + { + assert_eq!(*type_covered, 1); // A + assert_eq!(*key_tag, 11111); + assert_eq!(signer_name, "example.com"); + } else { + panic!("expected RRSIG"); + } + + // Verify DS in authority + if let DnsRecord::DS { + key_tag, digest, .. + } = &parsed.authorities[1] + { + assert_eq!(*key_tag, 22222); + assert_eq!(digest.len(), 32); + } else { + panic!("expected DS"); + } + + // Verify DNSKEY in additional + if let DnsRecord::DNSKEY { + flags, public_key, .. + } = &parsed.resources[1] + { + assert_eq!(*flags, 257); + assert_eq!(public_key.len(), 64); + } else { + panic!("expected DNSKEY"); + } } } diff --git a/src/question.rs b/src/question.rs index 30fe0ce..dc23dd1 100644 --- a/src/question.rs +++ b/src/question.rs @@ -4,16 +4,22 @@ use crate::Result; #[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)] pub enum QueryType { UNKNOWN(u16), - A, // 1 - NS, // 2 - CNAME, // 5 - SOA, // 6 - PTR, // 12 - MX, // 15 - TXT, // 16 - AAAA, // 28 - SRV, // 33 - HTTPS, // 65 + A, // 1 + NS, // 2 + CNAME, // 5 + SOA, // 6 + PTR, // 12 + MX, // 15 + TXT, // 16 + AAAA, // 28 + SRV, // 33 + DS, // 43 + RRSIG, // 46 + NSEC, // 47 + DNSKEY, // 48 + NSEC3, // 50 + OPT, // 41 (EDNS0 pseudo-type) + HTTPS, // 65 } impl QueryType { @@ -29,6 +35,12 @@ impl QueryType { QueryType::TXT => 16, QueryType::AAAA => 28, QueryType::SRV => 33, + QueryType::OPT => 41, + QueryType::DS => 43, + QueryType::RRSIG => 46, + QueryType::NSEC => 47, + QueryType::DNSKEY => 48, + QueryType::NSEC3 => 50, QueryType::HTTPS => 65, } } @@ -44,6 +56,12 @@ impl QueryType { 16 => QueryType::TXT, 28 => QueryType::AAAA, 33 => QueryType::SRV, + 41 => QueryType::OPT, + 43 => QueryType::DS, + 46 => QueryType::RRSIG, + 47 => QueryType::NSEC, + 48 => QueryType::DNSKEY, + 50 => QueryType::NSEC3, 65 => QueryType::HTTPS, _ => QueryType::UNKNOWN(num), } @@ -60,6 +78,12 @@ impl QueryType { QueryType::TXT => "TXT", QueryType::AAAA => "AAAA", QueryType::SRV => "SRV", + QueryType::OPT => "OPT", + QueryType::DS => "DS", + QueryType::RRSIG => "RRSIG", + QueryType::NSEC => "NSEC", + QueryType::DNSKEY => "DNSKEY", + QueryType::NSEC3 => "NSEC3", QueryType::HTTPS => "HTTPS", QueryType::UNKNOWN(_) => "UNKNOWN", } @@ -76,6 +100,11 @@ impl QueryType { "TXT" => Some(QueryType::TXT), "AAAA" => Some(QueryType::AAAA), "SRV" => Some(QueryType::SRV), + "DS" => Some(QueryType::DS), + "RRSIG" => Some(QueryType::RRSIG), + "DNSKEY" => Some(QueryType::DNSKEY), + "NSEC" => Some(QueryType::NSEC), + "NSEC3" => Some(QueryType::NSEC3), "HTTPS" => Some(QueryType::HTTPS), _ => None, } diff --git a/src/record.rs b/src/record.rs index b7522dc..21fbdf0 100644 --- a/src/record.rs +++ b/src/record.rs @@ -11,7 +11,7 @@ pub enum DnsRecord { UNKNOWN { domain: String, qtype: u16, - data_len: u16, + data: Vec, ttl: u32, }, A { @@ -40,11 +40,84 @@ pub enum DnsRecord { addr: Ipv6Addr, ttl: u32, }, + DNSKEY { + domain: String, + flags: u16, + protocol: u8, + algorithm: u8, + public_key: Vec, + ttl: u32, + }, + DS { + domain: String, + key_tag: u16, + algorithm: u8, + digest_type: u8, + digest: Vec, + ttl: u32, + }, + RRSIG { + domain: String, + type_covered: u16, + algorithm: u8, + labels: u8, + original_ttl: u32, + expiration: u32, + inception: u32, + key_tag: u16, + signer_name: String, + signature: Vec, + ttl: u32, + }, + NSEC { + domain: String, + next_domain: String, + type_bitmap: Vec, + ttl: u32, + }, + NSEC3 { + domain: String, + hash_algorithm: u8, + flags: u8, + iterations: u16, + salt: Vec, + next_hashed_owner: Vec, + type_bitmap: Vec, + ttl: u32, + }, } impl DnsRecord { - pub fn is_unknown(&self) -> bool { - matches!(self, DnsRecord::UNKNOWN { .. }) + pub fn domain(&self) -> &str { + match self { + DnsRecord::A { domain, .. } + | DnsRecord::NS { domain, .. } + | DnsRecord::CNAME { domain, .. } + | DnsRecord::MX { domain, .. } + | DnsRecord::AAAA { domain, .. } + | DnsRecord::DNSKEY { domain, .. } + | DnsRecord::DS { domain, .. } + | DnsRecord::RRSIG { domain, .. } + | DnsRecord::NSEC { domain, .. } + | DnsRecord::NSEC3 { domain, .. } + | DnsRecord::UNKNOWN { domain, .. } => domain, + } + } + + pub fn query_type(&self) -> QueryType { + match self { + DnsRecord::A { .. } => QueryType::A, + DnsRecord::AAAA { .. } => QueryType::AAAA, + DnsRecord::NS { .. } => QueryType::NS, + DnsRecord::CNAME { .. } => QueryType::CNAME, + DnsRecord::MX { .. } => QueryType::MX, + DnsRecord::DNSKEY { .. } => QueryType::DNSKEY, + DnsRecord::DS { .. } => QueryType::DS, + DnsRecord::RRSIG { .. } => QueryType::RRSIG, + DnsRecord::NSEC { .. } => QueryType::NSEC, + DnsRecord::NSEC3 { .. } => QueryType::NSEC3, + DnsRecord::UNKNOWN { qtype, .. } => QueryType::UNKNOWN(*qtype), + } } pub fn ttl(&self) -> u32 { @@ -54,6 +127,11 @@ impl DnsRecord { | DnsRecord::CNAME { ttl, .. } | DnsRecord::MX { ttl, .. } | DnsRecord::AAAA { ttl, .. } + | DnsRecord::DNSKEY { ttl, .. } + | DnsRecord::DS { ttl, .. } + | DnsRecord::RRSIG { ttl, .. } + | DnsRecord::NSEC { ttl, .. } + | DnsRecord::NSEC3 { ttl, .. } | DnsRecord::UNKNOWN { ttl, .. } => *ttl, } } @@ -65,6 +143,11 @@ impl DnsRecord { | DnsRecord::CNAME { ttl, .. } | DnsRecord::MX { ttl, .. } | DnsRecord::AAAA { ttl, .. } + | DnsRecord::DNSKEY { ttl, .. } + | DnsRecord::DS { ttl, .. } + | DnsRecord::RRSIG { ttl, .. } + | DnsRecord::NSEC { ttl, .. } + | DnsRecord::NSEC3 { ttl, .. } | DnsRecord::UNKNOWN { ttl, .. } => *ttl = new_ttl, } } @@ -75,9 +158,10 @@ impl DnsRecord { let qtype_num = buffer.read_u16()?; let qtype = QueryType::from_num(qtype_num); - let _ = buffer.read_u16()?; + let _ = buffer.read_u16()?; // class let ttl = buffer.read_u32()?; let data_len = buffer.read_u16()?; + let rdata_start = buffer.pos(); match qtype { QueryType::A => { @@ -88,7 +172,6 @@ impl DnsRecord { ((raw_addr >> 8) & 0xFF) as u8, (raw_addr & 0xFF) as u8, ); - Ok(DnsRecord::A { domain, addr, ttl }) } QueryType::AAAA => { @@ -106,13 +189,11 @@ impl DnsRecord { ((raw_addr4 >> 16) & 0xFFFF) as u16, (raw_addr4 & 0xFFFF) as u16, ); - Ok(DnsRecord::AAAA { domain, addr, ttl }) } QueryType::NS => { let mut ns = String::with_capacity(64); buffer.read_qname(&mut ns)?; - Ok(DnsRecord::NS { domain, host: ns, @@ -122,7 +203,6 @@ impl DnsRecord { QueryType::CNAME => { let mut cname = String::with_capacity(64); buffer.read_qname(&mut cname)?; - Ok(DnsRecord::CNAME { domain, host: cname, @@ -133,7 +213,6 @@ impl DnsRecord { let priority = buffer.read_u16()?; let mut mx = String::with_capacity(64); buffer.read_qname(&mut mx)?; - Ok(DnsRecord::MX { domain, priority, @@ -141,13 +220,119 @@ impl DnsRecord { ttl, }) } + QueryType::DNSKEY => { + let flags = buffer.read_u16()?; + let protocol = buffer.read()?; + let algorithm = buffer.read()?; + let key_len = data_len as usize - 4; // flags(2) + protocol(1) + algorithm(1) + let public_key = buffer.get_range(buffer.pos(), key_len)?.to_vec(); + buffer.step(key_len)?; + Ok(DnsRecord::DNSKEY { + domain, + flags, + protocol, + algorithm, + public_key, + ttl, + }) + } + QueryType::DS => { + let key_tag = buffer.read_u16()?; + let algorithm = buffer.read()?; + let digest_type = buffer.read()?; + let digest_len = data_len as usize - 4; // key_tag(2) + algorithm(1) + digest_type(1) + let digest = buffer.get_range(buffer.pos(), digest_len)?.to_vec(); + buffer.step(digest_len)?; + Ok(DnsRecord::DS { + domain, + key_tag, + algorithm, + digest_type, + digest, + ttl, + }) + } + QueryType::RRSIG => { + let type_covered = buffer.read_u16()?; + let algorithm = buffer.read()?; + let labels = buffer.read()?; + let original_ttl = buffer.read_u32()?; + let expiration = buffer.read_u32()?; + let inception = buffer.read_u32()?; + let key_tag = buffer.read_u16()?; + let mut signer_name = String::with_capacity(64); + buffer.read_qname(&mut signer_name)?; + let rdata_end = rdata_start + data_len as usize; + let sig_len = rdata_end + .checked_sub(buffer.pos()) + .ok_or("RRSIG data_len too short for fixed fields + signer_name")?; + let signature = buffer.get_range(buffer.pos(), sig_len)?.to_vec(); + buffer.step(sig_len)?; + Ok(DnsRecord::RRSIG { + domain, + type_covered, + algorithm, + labels, + original_ttl, + expiration, + inception, + key_tag, + signer_name, + signature, + ttl, + }) + } + QueryType::NSEC => { + let rdata_end = rdata_start + data_len as usize; + let mut next_domain = String::with_capacity(64); + buffer.read_qname(&mut next_domain)?; + let bitmap_len = rdata_end + .checked_sub(buffer.pos()) + .ok_or("NSEC data_len too short for type bitmap")?; + let type_bitmap = buffer.get_range(buffer.pos(), bitmap_len)?.to_vec(); + buffer.step(bitmap_len)?; + Ok(DnsRecord::NSEC { + domain, + next_domain, + type_bitmap, + ttl, + }) + } + QueryType::NSEC3 => { + let rdata_end = rdata_start + data_len as usize; + let hash_algorithm = buffer.read()?; + let flags = buffer.read()?; + let iterations = buffer.read_u16()?; + let salt_length = buffer.read()? as usize; + let salt = buffer.get_range(buffer.pos(), salt_length)?.to_vec(); + buffer.step(salt_length)?; + let hash_length = buffer.read()? as usize; + let next_hashed_owner = buffer.get_range(buffer.pos(), hash_length)?.to_vec(); + buffer.step(hash_length)?; + let bitmap_len = rdata_end + .checked_sub(buffer.pos()) + .ok_or("NSEC3 data_len too short for type bitmap")?; + let type_bitmap = buffer.get_range(buffer.pos(), bitmap_len)?.to_vec(); + buffer.step(bitmap_len)?; + Ok(DnsRecord::NSEC3 { + domain, + hash_algorithm, + flags, + iterations, + salt, + next_hashed_owner, + type_bitmap, + ttl, + }) + } _ => { + // SOA, TXT, SRV, etc. — stored as opaque bytes until parsed natively + let data = buffer.get_range(buffer.pos(), data_len as usize)?.to_vec(); buffer.step(data_len as usize)?; - Ok(DnsRecord::UNKNOWN { domain, qtype: qtype_num, - data_len, + data, ttl, }) } @@ -163,32 +348,19 @@ impl DnsRecord { ref addr, ttl, } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::A.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; + write_header(buffer, domain, QueryType::A.to_num(), ttl)?; buffer.write_u16(4)?; - - let octets = addr.octets(); - buffer.write_u8(octets[0])?; - buffer.write_u8(octets[1])?; - buffer.write_u8(octets[2])?; - buffer.write_u8(octets[3])?; + buffer.write_bytes(&addr.octets())?; } DnsRecord::NS { ref domain, ref host, ttl, } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::NS.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - + write_header(buffer, domain, QueryType::NS.to_num(), ttl)?; let pos = buffer.pos(); buffer.write_u16(0)?; buffer.write_qname(host)?; - let size = buffer.pos() - (pos + 2); buffer.set_u16(pos, size as u16)?; } @@ -197,15 +369,10 @@ impl DnsRecord { ref host, ttl, } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::CNAME.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - + write_header(buffer, domain, QueryType::CNAME.to_num(), ttl)?; let pos = buffer.pos(); buffer.write_u16(0)?; buffer.write_qname(host)?; - let size = buffer.pos() - (pos + 2); buffer.set_u16(pos, size as u16)?; } @@ -215,16 +382,11 @@ impl DnsRecord { ref host, ttl, } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::MX.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - + write_header(buffer, domain, QueryType::MX.to_num(), ttl)?; let pos = buffer.pos(); buffer.write_u16(0)?; buffer.write_u16(priority)?; buffer.write_qname(host)?; - let size = buffer.pos() - (pos + 2); buffer.set_u16(pos, size as u16)?; } @@ -233,21 +395,259 @@ impl DnsRecord { ref addr, ttl, } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::AAAA.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; + write_header(buffer, domain, QueryType::AAAA.to_num(), ttl)?; buffer.write_u16(16)?; - for octet in &addr.segments() { buffer.write_u16(*octet)?; } } - DnsRecord::UNKNOWN { .. } => { - log::debug!("Skipping record: {:?}", self); + DnsRecord::DNSKEY { + ref domain, + flags, + protocol, + algorithm, + ref public_key, + ttl, + } => { + write_header(buffer, domain, QueryType::DNSKEY.to_num(), ttl)?; + buffer.write_u16((4 + public_key.len()) as u16)?; + buffer.write_u16(flags)?; + buffer.write_u8(protocol)?; + buffer.write_u8(algorithm)?; + buffer.write_bytes(public_key)?; + } + DnsRecord::DS { + ref domain, + key_tag, + algorithm, + digest_type, + ref digest, + ttl, + } => { + write_header(buffer, domain, QueryType::DS.to_num(), ttl)?; + buffer.write_u16((4 + digest.len()) as u16)?; + buffer.write_u16(key_tag)?; + buffer.write_u8(algorithm)?; + buffer.write_u8(digest_type)?; + buffer.write_bytes(digest)?; + } + DnsRecord::RRSIG { + ref domain, + type_covered, + algorithm, + labels, + original_ttl, + expiration, + inception, + key_tag, + ref signer_name, + ref signature, + ttl, + } => { + write_header(buffer, domain, QueryType::RRSIG.to_num(), ttl)?; + let rdlen_pos = buffer.pos(); + buffer.write_u16(0)?; // RDLENGTH placeholder + buffer.write_u16(type_covered)?; + buffer.write_u8(algorithm)?; + buffer.write_u8(labels)?; + buffer.write_u32(original_ttl)?; + buffer.write_u32(expiration)?; + buffer.write_u32(inception)?; + buffer.write_u16(key_tag)?; + buffer.write_qname(signer_name)?; + buffer.write_bytes(signature)?; + let rdlen = buffer.pos() - (rdlen_pos + 2); + buffer.set_u16(rdlen_pos, rdlen as u16)?; + } + DnsRecord::NSEC { + ref domain, + ref next_domain, + ref type_bitmap, + ttl, + } => { + write_header(buffer, domain, QueryType::NSEC.to_num(), ttl)?; + let rdlen_pos = buffer.pos(); + buffer.write_u16(0)?; + buffer.write_qname(next_domain)?; + buffer.write_bytes(type_bitmap)?; + let rdlen = buffer.pos() - (rdlen_pos + 2); + buffer.set_u16(rdlen_pos, rdlen as u16)?; + } + DnsRecord::NSEC3 { + ref domain, + hash_algorithm, + flags, + iterations, + ref salt, + ref next_hashed_owner, + ref type_bitmap, + ttl, + } => { + write_header(buffer, domain, QueryType::NSEC3.to_num(), ttl)?; + let rdlen = + 1 + 1 + 2 + 1 + salt.len() + 1 + next_hashed_owner.len() + type_bitmap.len(); + buffer.write_u16(rdlen as u16)?; + buffer.write_u8(hash_algorithm)?; + buffer.write_u8(flags)?; + buffer.write_u16(iterations)?; + buffer.write_u8(salt.len() as u8)?; + buffer.write_bytes(salt)?; + buffer.write_u8(next_hashed_owner.len() as u8)?; + buffer.write_bytes(next_hashed_owner)?; + buffer.write_bytes(type_bitmap)?; + } + DnsRecord::UNKNOWN { + ref domain, + qtype, + ref data, + ttl, + } => { + write_header(buffer, domain, qtype, ttl)?; + buffer.write_u16(data.len() as u16)?; + buffer.write_bytes(data)?; } } Ok(buffer.pos() - start_pos) } } + +fn write_header(buffer: &mut BytePacketBuffer, domain: &str, qtype: u16, ttl: u32) -> Result<()> { + buffer.write_qname(domain)?; + buffer.write_u16(qtype)?; + buffer.write_u16(1)?; // class IN + buffer.write_u32(ttl)?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn round_trip(record: &DnsRecord) -> DnsRecord { + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).unwrap(); + buf.seek(0).unwrap(); + DnsRecord::read(&mut buf).unwrap() + } + + #[test] + fn unknown_preserves_raw_bytes() { + let rec = DnsRecord::UNKNOWN { + domain: "example.com".into(), + qtype: 99, + data: vec![0xDE, 0xAD, 0xBE, 0xEF], + ttl: 300, + }; + let parsed = round_trip(&rec); + if let DnsRecord::UNKNOWN { data, .. } = &parsed { + assert_eq!(data.len(), 4); + assert_eq!(data, &[0xDE, 0xAD, 0xBE, 0xEF]); + } else { + panic!("expected UNKNOWN"); + } + } + + #[test] + fn dnskey_round_trip() { + let rec = DnsRecord::DNSKEY { + domain: "example.com".into(), + flags: 257, // KSK + protocol: 3, + algorithm: 13, // ECDSAP256SHA256 + public_key: vec![1, 2, 3, 4, 5, 6, 7, 8], + ttl: 3600, + }; + let parsed = round_trip(&rec); + assert_eq!(rec, parsed); + } + + #[test] + fn ds_round_trip() { + let rec = DnsRecord::DS { + domain: "example.com".into(), + key_tag: 12345, + algorithm: 8, + digest_type: 2, + digest: vec![0xAA, 0xBB, 0xCC, 0xDD], + ttl: 86400, + }; + let parsed = round_trip(&rec); + assert_eq!(rec, parsed); + } + + #[test] + fn rrsig_round_trip() { + let rec = DnsRecord::RRSIG { + domain: "example.com".into(), + type_covered: 1, // A + algorithm: 13, + labels: 2, + original_ttl: 300, + expiration: 1700000000, + inception: 1690000000, + key_tag: 54321, + signer_name: "example.com".into(), + signature: vec![0x01, 0x02, 0x03, 0x04, 0x05], + ttl: 300, + }; + let parsed = round_trip(&rec); + assert_eq!(rec, parsed); + } + + #[test] + fn query_type_method() { + assert_eq!( + DnsRecord::DNSKEY { + domain: String::new(), + flags: 0, + protocol: 3, + algorithm: 8, + public_key: vec![], + ttl: 0, + } + .query_type(), + QueryType::DNSKEY + ); + assert_eq!( + DnsRecord::DS { + domain: String::new(), + key_tag: 0, + algorithm: 0, + digest_type: 0, + digest: vec![], + ttl: 0, + } + .query_type(), + QueryType::DS + ); + } + + #[test] + fn nsec_round_trip() { + let rec = DnsRecord::NSEC { + domain: "alpha.example.com".into(), + next_domain: "gamma.example.com".into(), + type_bitmap: vec![0, 2, 0x40, 0x01], // A(1), MX(15) + ttl: 3600, + }; + let parsed = round_trip(&rec); + assert_eq!(rec, parsed); + } + + #[test] + fn nsec3_round_trip() { + let rec = DnsRecord::NSEC3 { + domain: "abc123.example.com".into(), + hash_algorithm: 1, + flags: 0, + iterations: 10, + salt: vec![0xAB, 0xCD], + next_hashed_owner: vec![0x01, 0x02, 0x03, 0x04, 0x05], + type_bitmap: vec![0, 1, 0x40], // A(1) + ttl: 3600, + }; + let parsed = round_trip(&rec); + assert_eq!(rec, parsed); + } +} diff --git a/src/recursive.rs b/src/recursive.rs new file mode 100644 index 0000000..a41bf91 --- /dev/null +++ b/src/recursive.rs @@ -0,0 +1,601 @@ +use std::net::{IpAddr, SocketAddr}; +use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::RwLock; +use std::time::Duration; + +use log::{debug, info}; +use tokio::time::timeout; + +use crate::cache::DnsCache; +use crate::forward::forward_udp; +use crate::header::ResultCode; +use crate::packet::DnsPacket; +use crate::question::{DnsQuestion, QueryType}; +use crate::record::DnsRecord; + +const MAX_REFERRAL_DEPTH: u8 = 10; +const MAX_CNAME_DEPTH: u8 = 8; +const NS_QUERY_TIMEOUT: Duration = Duration::from_secs(2); + +static QUERY_ID: AtomicU16 = AtomicU16::new(1); + +fn next_id() -> u16 { + QUERY_ID.fetch_add(1, Ordering::Relaxed) +} + +fn dns_addr(ip: impl Into) -> SocketAddr { + SocketAddr::new(ip.into(), 53) +} + +/// Query root servers for common TLDs and cache NS + glue + DNSKEY + DS records. +/// Pre-warms the DNSSEC trust chain so first queries skip chain-walking I/O. +pub async fn prime_tld_cache(cache: &RwLock, root_hints: &[SocketAddr], tlds: &[String]) { + let root_addr = match root_hints.first() { + Some(addr) => *addr, + None => return, + }; + if tlds.is_empty() { + return; + } + + // Fetch root DNSKEY (needed for DNSSEC chain-of-trust terminus) + if let Ok(root_dnskey) = send_query(".", QueryType::DNSKEY, root_addr).await { + cache + .write() + .unwrap() + .insert(".", QueryType::DNSKEY, &root_dnskey); + debug!("prime: cached root DNSKEY"); + } + + let mut primed = 0u16; + + for tld in tlds { + // Fetch NS referral (includes DS in authority section from root) + let response = match send_query(tld, QueryType::NS, root_addr).await { + Ok(r) => r, + Err(e) => { + debug!("prime: failed to query NS for .{}: {}", tld, e); + continue; + } + }; + + let ns_names = extract_ns_names(&response); + if ns_names.is_empty() { + continue; + } + + { + let mut cache_w = cache.write().unwrap(); + cache_w.insert(tld, QueryType::NS, &response); + cache_glue(&mut cache_w, &response, &ns_names); + // Cache DS records from referral authority section + cache_ds_from_authority(&mut cache_w, &response); + } + + // Fetch DNSKEY for this TLD (needed for DNSSEC chain validation) + let first_ns_name = ns_names.first().map(|s| s.as_str()).unwrap_or(""); + let first_ns = glue_addrs_for(&response, first_ns_name); + if let Some(ns_addr) = first_ns.first() { + if let Ok(dnskey_resp) = send_query(tld, QueryType::DNSKEY, *ns_addr).await { + cache + .write() + .unwrap() + .insert(tld, QueryType::DNSKEY, &dnskey_resp); + } + } + + primed += 1; + } + + info!( + "primed {}/{} TLD caches (NS + glue + DS + DNSKEY)", + primed, + tlds.len() + ); +} + +pub async fn resolve_recursive( + qname: &str, + qtype: QueryType, + cache: &RwLock, + overall_timeout: Duration, + original_query: &DnsPacket, + root_hints: &[SocketAddr], +) -> crate::Result { + let mut resp = match timeout( + overall_timeout, + resolve_iterative(qname, qtype, cache, root_hints, 0, 0), + ) + .await + { + Ok(result) => result?, + Err(_) => return Err(format!("recursive resolution timed out for {}", qname).into()), + }; + + resp.header.id = original_query.header.id; + resp.header.recursion_available = true; + resp.header.recursion_desired = original_query.header.recursion_desired; + resp.questions = original_query.questions.clone(); + Ok(resp) +} + +pub(crate) fn resolve_iterative<'a>( + qname: &'a str, + qtype: QueryType, + cache: &'a RwLock, + root_hints: &'a [SocketAddr], + referral_depth: u8, + cname_depth: u8, +) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + if referral_depth > MAX_REFERRAL_DEPTH { + return Err("max referral depth exceeded".into()); + } + + if let Some(cached) = cache.read().unwrap().lookup(qname, qtype) { + return Ok(cached); + } + + let mut ns_addrs = find_starting_ns(qname, cache, root_hints); + let mut ns_idx = 0; + + for _ in 0..MAX_REFERRAL_DEPTH { + let ns_addr = match ns_addrs.get(ns_idx) { + Some(addr) => *addr, + None => return Err("no nameserver available".into()), + }; + + debug!( + "recursive: querying {} for {:?} {} (depth {})", + ns_addr, qtype, qname, referral_depth + ); + + let response = match send_query(qname, qtype, ns_addr).await { + Ok(r) => r, + Err(e) => { + debug!("recursive: NS {} failed: {}", ns_addr, e); + ns_idx += 1; + continue; + } + }; + + if !response.answers.is_empty() { + let has_target = response.answers.iter().any(|r| r.query_type() == qtype); + + if has_target || qtype == QueryType::CNAME { + cache.write().unwrap().insert(qname, qtype, &response); + return Ok(response); + } + + if let Some(cname_target) = extract_cname_target(&response, qname) { + if cname_depth >= MAX_CNAME_DEPTH { + return Err("max CNAME depth exceeded".into()); + } + debug!("recursive: chasing CNAME {} -> {}", qname, cname_target); + let final_resp = resolve_iterative( + &cname_target, + qtype, + cache, + root_hints, + 0, + cname_depth + 1, + ) + .await?; + + let mut combined = response; + combined.answers.extend(final_resp.answers); + combined.header.rescode = final_resp.header.rescode; + cache.write().unwrap().insert(qname, qtype, &combined); + return Ok(combined); + } + + cache.write().unwrap().insert(qname, qtype, &response); + return Ok(response); + } + + if response.header.rescode == ResultCode::NXDOMAIN + || response.header.rescode == ResultCode::REFUSED + { + cache.write().unwrap().insert(qname, qtype, &response); + return Ok(response); + } + + // Referral — extract NS + glue, cache glue, resolve NS addresses + let ns_names = extract_ns_names(&response); + if ns_names.is_empty() { + return Ok(response); + } + + // Cache glue + DS from referral (avoids separate fetch during DNSSEC validation) + let mut new_ns_addrs = Vec::new(); + { + let mut cache_w = cache.write().unwrap(); + cache_glue(&mut cache_w, &response, &ns_names); + cache_ds_from_authority(&mut cache_w, &response); + } + for ns_name in &ns_names { + let glue = glue_addrs_for(&response, ns_name); + if !glue.is_empty() { + new_ns_addrs.extend_from_slice(&glue); + break; + } + } + + // If no glue, try cache (A then AAAA) then recursive resolve + if new_ns_addrs.is_empty() { + for ns_name in &ns_names { + new_ns_addrs.extend(addrs_from_cache(cache, ns_name)); + + if new_ns_addrs.is_empty() && referral_depth < MAX_REFERRAL_DEPTH { + debug!("recursive: resolving glue-less NS {}", ns_name); + // Try A first, then AAAA + for qt in [QueryType::A, QueryType::AAAA] { + if let Ok(ns_resp) = resolve_iterative( + ns_name, + qt, + cache, + root_hints, + referral_depth + 1, + cname_depth, + ) + .await + { + for rec in &ns_resp.answers { + match rec { + DnsRecord::A { addr, .. } => { + new_ns_addrs.push(dns_addr(*addr)); + } + DnsRecord::AAAA { addr, .. } => { + new_ns_addrs.push(dns_addr(*addr)); + } + _ => {} + } + } + } + if !new_ns_addrs.is_empty() { + break; + } + } + } + + if !new_ns_addrs.is_empty() { + break; + } + } + } + + if new_ns_addrs.is_empty() { + return Err(format!("could not resolve any NS for {}", qname).into()); + } + + ns_addrs = new_ns_addrs; + ns_idx = 0; + } + + Err(format!("recursive resolution exhausted for {}", qname).into()) + }) +} + +fn find_starting_ns( + qname: &str, + cache: &RwLock, + root_hints: &[SocketAddr], +) -> Vec { + let guard = cache.read().unwrap(); + + let mut pos = 0; + loop { + let zone = &qname[pos..]; + if let Some(cached) = guard.lookup(zone, QueryType::NS) { + let mut addrs = Vec::new(); + for ns_rec in &cached.answers { + if let DnsRecord::NS { host, .. } = ns_rec { + for qt in [QueryType::A, QueryType::AAAA] { + if let Some(resp) = guard.lookup(host, qt) { + for rec in &resp.answers { + match rec { + DnsRecord::A { addr, .. } => { + addrs.push(dns_addr(*addr)); + } + DnsRecord::AAAA { addr, .. } => { + addrs.push(dns_addr(*addr)); + } + _ => {} + } + } + } + } + } + } + if !addrs.is_empty() { + debug!("recursive: starting from cached NS for zone '{}'", zone); + return addrs; + } + } + + match qname[pos..].find('.') { + Some(dot) => pos += dot + 1, + None => break, + } + } + + drop(guard); + debug!( + "recursive: starting from root hints ({} servers)", + root_hints.len() + ); + root_hints.to_vec() +} + +fn addrs_from_cache(cache: &RwLock, name: &str) -> Vec { + let guard = cache.read().unwrap(); + let mut addrs = Vec::new(); + for qt in [QueryType::A, QueryType::AAAA] { + if let Some(pkt) = guard.lookup(name, qt) { + for rec in &pkt.answers { + match rec { + DnsRecord::A { addr, .. } => addrs.push(dns_addr(*addr)), + DnsRecord::AAAA { addr, .. } => addrs.push(dns_addr(*addr)), + _ => {} + } + } + } + } + addrs +} + +fn glue_addrs_for(response: &DnsPacket, ns_name: &str) -> Vec { + response + .resources + .iter() + .filter_map(|r| match r { + DnsRecord::A { domain, addr, .. } if domain.eq_ignore_ascii_case(ns_name) => { + Some(dns_addr(*addr)) + } + DnsRecord::AAAA { domain, addr, .. } if domain.eq_ignore_ascii_case(ns_name) => { + Some(dns_addr(*addr)) + } + _ => None, + }) + .collect() +} + +fn cache_glue(cache: &mut DnsCache, response: &DnsPacket, ns_names: &[String]) { + for ns_name in ns_names { + let mut a_pkt: Option = None; + let mut aaaa_pkt: Option = None; + + for r in &response.resources { + match r { + DnsRecord::A { domain, addr, ttl } if domain.eq_ignore_ascii_case(ns_name) => { + a_pkt + .get_or_insert_with(make_glue_packet) + .answers + .push(DnsRecord::A { + domain: ns_name.clone(), + addr: *addr, + ttl: *ttl, + }); + } + DnsRecord::AAAA { domain, addr, ttl } if domain.eq_ignore_ascii_case(ns_name) => { + aaaa_pkt + .get_or_insert_with(make_glue_packet) + .answers + .push(DnsRecord::AAAA { + domain: ns_name.clone(), + addr: *addr, + ttl: *ttl, + }); + } + _ => {} + } + } + + if let Some(pkt) = a_pkt { + cache.insert(ns_name, QueryType::A, &pkt); + } + if let Some(pkt) = aaaa_pkt { + cache.insert(ns_name, QueryType::AAAA, &pkt); + } + } +} + +/// Cache DS + DS-covering RRSIG records from referral authority sections. +fn cache_ds_from_authority(cache: &mut DnsCache, response: &DnsPacket) { + let mut ds_by_domain: Vec<(String, DnsPacket)> = Vec::new(); + + for r in &response.authorities { + match r { + DnsRecord::DS { domain, .. } => { + let key = domain.to_lowercase(); + let pkt = match ds_by_domain.iter_mut().find(|(d, _)| *d == key) { + Some((_, pkt)) => pkt, + None => { + ds_by_domain.push((key, make_glue_packet())); + &mut ds_by_domain.last_mut().unwrap().1 + } + }; + pkt.answers.push(r.clone()); + } + DnsRecord::RRSIG { + domain, + type_covered, + .. + } if QueryType::from_num(*type_covered) == QueryType::DS => { + let key = domain.to_lowercase(); + let pkt = match ds_by_domain.iter_mut().find(|(d, _)| *d == key) { + Some((_, pkt)) => pkt, + None => { + ds_by_domain.push((key, make_glue_packet())); + &mut ds_by_domain.last_mut().unwrap().1 + } + }; + pkt.answers.push(r.clone()); + } + _ => {} + } + } + + for (domain, pkt) in &ds_by_domain { + if !pkt.answers.is_empty() { + cache.insert(domain, QueryType::DS, pkt); + } + } +} + +fn make_glue_packet() -> DnsPacket { + let mut pkt = DnsPacket::new(); + pkt.header.response = true; + pkt.header.rescode = ResultCode::NOERROR; + pkt +} + +async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate::Result { + let mut query = DnsPacket::new(); + query.header.id = next_id(); + query.header.recursion_desired = false; + query + .questions + .push(DnsQuestion::new(qname.to_string(), qtype)); + query.edns = Some(crate::packet::EdnsOpt { + do_bit: true, + ..Default::default() + }); + forward_udp(&query, server, NS_QUERY_TIMEOUT).await +} + +fn extract_cname_target(response: &DnsPacket, qname: &str) -> Option { + response.answers.iter().find_map(|r| match r { + DnsRecord::CNAME { domain, host, .. } if domain.eq_ignore_ascii_case(qname) => { + Some(host.clone()) + } + _ => None, + }) +} + +fn extract_ns_names(response: &DnsPacket) -> Vec { + response + .authorities + .iter() + .filter_map(|r| match r { + DnsRecord::NS { host, .. } => Some(host.clone()), + _ => None, + }) + .collect() +} + +pub fn parse_root_hints(hints: &[String]) -> Vec { + hints + .iter() + .filter_map(|s| { + s.parse::() + .map(|ip| SocketAddr::new(ip, 53)) + .map_err(|e| log::warn!("invalid root hint '{}': {}", s, e)) + .ok() + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv4Addr, Ipv6Addr}; + + #[test] + fn extract_ns_from_authority() { + let mut pkt = DnsPacket::new(); + pkt.authorities.push(DnsRecord::NS { + domain: "example.com".into(), + host: "ns1.example.com".into(), + ttl: 3600, + }); + pkt.authorities.push(DnsRecord::NS { + domain: "example.com".into(), + host: "ns2.example.com".into(), + ttl: 3600, + }); + let names = extract_ns_names(&pkt); + assert_eq!(names, vec!["ns1.example.com", "ns2.example.com"]); + } + + #[test] + fn glue_extraction_a() { + let mut pkt = DnsPacket::new(); + pkt.resources.push(DnsRecord::A { + domain: "ns1.example.com".into(), + addr: Ipv4Addr::new(1, 2, 3, 4), + ttl: 3600, + }); + let addrs = glue_addrs_for(&pkt, "ns1.example.com"); + assert_eq!(addrs, vec![dns_addr(Ipv4Addr::new(1, 2, 3, 4))]); + assert!(glue_addrs_for(&pkt, "ns3.example.com").is_empty()); + } + + #[test] + fn glue_extraction_aaaa() { + let mut pkt = DnsPacket::new(); + pkt.resources.push(DnsRecord::AAAA { + domain: "ns1.example.com".into(), + addr: "2001:db8::1".parse().unwrap(), + ttl: 3600, + }); + pkt.resources.push(DnsRecord::A { + domain: "ns1.example.com".into(), + addr: Ipv4Addr::new(1, 2, 3, 4), + ttl: 3600, + }); + let addrs = glue_addrs_for(&pkt, "ns1.example.com"); + assert_eq!(addrs.len(), 2); + // AAAA first (order matches resources), then A + assert_eq!( + addrs[0], + dns_addr("2001:db8::1".parse::().unwrap()) + ); + assert_eq!(addrs[1], dns_addr(Ipv4Addr::new(1, 2, 3, 4))); + } + + #[test] + fn cname_extraction() { + let mut pkt = DnsPacket::new(); + pkt.answers.push(DnsRecord::CNAME { + domain: "www.example.com".into(), + host: "example.com".into(), + ttl: 300, + }); + assert_eq!( + extract_cname_target(&pkt, "www.example.com"), + Some("example.com".into()) + ); + assert_eq!(extract_cname_target(&pkt, "other.com"), None); + } + + #[test] + fn parse_root_hints_valid() { + let hints = vec!["198.41.0.4".into(), "199.9.14.201".into()]; + let addrs = parse_root_hints(&hints); + assert_eq!(addrs.len(), 2); + assert_eq!(addrs[0], dns_addr(Ipv4Addr::new(198, 41, 0, 4))); + } + + #[test] + fn parse_root_hints_skips_invalid() { + let hints = vec![ + "198.41.0.4".into(), + "not-an-ip".into(), + "192.33.4.12".into(), + ]; + let addrs = parse_root_hints(&hints); + assert_eq!(addrs.len(), 2); + } + + #[test] + fn find_starting_ns_falls_back_to_hints() { + let cache = RwLock::new(DnsCache::new(100, 60, 86400)); + let hints = vec![ + dns_addr(Ipv4Addr::new(198, 41, 0, 4)), + dns_addr(Ipv4Addr::new(199, 9, 14, 201)), + ]; + let addrs = find_starting_ns("example.com", &cache, &hints); + assert_eq!(addrs, hints); + } +} diff --git a/src/stats.rs b/src/stats.rs index 0336cbb..c1aaa06 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -3,6 +3,7 @@ use std::time::Instant; pub struct ServerStats { queries_total: u64, queries_forwarded: u64, + queries_recursive: u64, queries_cached: u64, queries_blocked: u64, queries_local: u64, @@ -16,6 +17,7 @@ pub enum QueryPath { Local, Cached, Forwarded, + Recursive, Blocked, Overridden, UpstreamError, @@ -27,6 +29,7 @@ impl QueryPath { QueryPath::Local => "LOCAL", QueryPath::Cached => "CACHED", QueryPath::Forwarded => "FORWARD", + QueryPath::Recursive => "RECURSIVE", QueryPath::Blocked => "BLOCKED", QueryPath::Overridden => "OVERRIDE", QueryPath::UpstreamError => "SERVFAIL", @@ -40,6 +43,8 @@ impl QueryPath { Some(QueryPath::Cached) } else if s.eq_ignore_ascii_case("FORWARD") { Some(QueryPath::Forwarded) + } else if s.eq_ignore_ascii_case("RECURSIVE") { + Some(QueryPath::Recursive) } else if s.eq_ignore_ascii_case("BLOCKED") { Some(QueryPath::Blocked) } else if s.eq_ignore_ascii_case("OVERRIDE") { @@ -63,6 +68,7 @@ impl ServerStats { ServerStats { queries_total: 0, queries_forwarded: 0, + queries_recursive: 0, queries_cached: 0, queries_blocked: 0, queries_local: 0, @@ -78,6 +84,7 @@ impl ServerStats { QueryPath::Local => self.queries_local += 1, QueryPath::Cached => self.queries_cached += 1, QueryPath::Forwarded => self.queries_forwarded += 1, + QueryPath::Recursive => self.queries_recursive += 1, QueryPath::Blocked => self.queries_blocked += 1, QueryPath::Overridden => self.queries_overridden += 1, QueryPath::UpstreamError => self.upstream_errors += 1, @@ -98,6 +105,7 @@ impl ServerStats { uptime_secs: self.uptime_secs(), total: self.queries_total, forwarded: self.queries_forwarded, + recursive: self.queries_recursive, cached: self.queries_cached, local: self.queries_local, overridden: self.queries_overridden, @@ -113,10 +121,11 @@ impl ServerStats { let secs = uptime.as_secs() % 60; log::info!( - "STATS | uptime {}h{}m{}s | total {} | fwd {} | cached {} | local {} | override {} | blocked {} | errors {}", + "STATS | uptime {}h{}m{}s | total {} | fwd {} | recursive {} | cached {} | local {} | override {} | blocked {} | errors {}", hours, mins, secs, self.queries_total, self.queries_forwarded, + self.queries_recursive, self.queries_cached, self.queries_local, self.queries_overridden, @@ -130,6 +139,7 @@ pub struct StatsSnapshot { pub uptime_secs: u64, pub total: u64, pub forwarded: u64, + pub recursive: u64, pub cached: u64, pub local: u64, pub overridden: u64, diff --git a/tests/integration.sh b/tests/integration.sh new file mode 100755 index 0000000..8d0152c --- /dev/null +++ b/tests/integration.sh @@ -0,0 +1,401 @@ +#!/usr/bin/env bash +# Integration test suite for Numa +# Runs a test instance on port 5354, validates all features, exits with status. +# Usage: ./tests/integration.sh [release|debug] + +set -euo pipefail + +MODE="${1:-release}" +BINARY="./target/$MODE/numa" +PORT=5354 +API_PORT=5381 +CONFIG="/tmp/numa-integration-test.toml" +LOG="/tmp/numa-integration-test.log" +PASSED=0 +FAILED=0 + +# Colors +GREEN="\033[32m" +RED="\033[31m" +DIM="\033[90m" +RESET="\033[0m" + +check() { + local name="$1" + local expected="$2" + local actual="$3" + + if echo "$actual" | grep -q "$expected"; then + PASSED=$((PASSED + 1)) + printf " ${GREEN}✓${RESET} %s\n" "$name" + else + FAILED=$((FAILED + 1)) + printf " ${RED}✗${RESET} %s\n" "$name" + printf " ${DIM}expected: %s${RESET}\n" "$expected" + printf " ${DIM} got: %s${RESET}\n" "$actual" + fi +} + +# Build if needed +if [ ! -f "$BINARY" ]; then + echo "Building $MODE..." + cargo build --$MODE +fi + +run_test_suite() { + local SUITE_NAME="$1" + local SUITE_CONFIG="$2" + + cat > "$CONFIG" << CONF +$SUITE_CONFIG +CONF + + echo "Starting Numa on :$PORT ($SUITE_NAME)..." + RUST_LOG=info "$BINARY" "$CONFIG" > "$LOG" 2>&1 & + NUMA_PID=$! + sleep 4 + + if ! kill -0 "$NUMA_PID" 2>/dev/null; then + echo "Failed to start Numa:" + tail -5 "$LOG" + return 1 + fi + + DIG="dig @127.0.0.1 -p $PORT +time=5 +tries=1" + + echo "" + echo "=== Resolution ===" + + check "A record (google.com)" \ + "." \ + "$($DIG google.com A +short)" + + check "AAAA record (google.com)" \ + ":" \ + "$($DIG google.com AAAA +short)" + + check "CNAME chasing (www.github.com)" \ + "github.com" \ + "$($DIG www.github.com A +short)" + + check "MX records (gmail.com)" \ + "gmail-smtp-in" \ + "$($DIG gmail.com MX +short)" + + check "NS records (cloudflare.com)" \ + "cloudflare.com" \ + "$($DIG cloudflare.com NS +short)" + + check "NXDOMAIN" \ + "NXDOMAIN" \ + "$($DIG nope12345678.com A 2>&1 | grep status:)" + + echo "" + echo "=== Ad Blocking ===" + + if echo "$SUITE_CONFIG" | grep -q 'enabled = true'; then + check "Blocked domain → 0.0.0.0" \ + "0.0.0.0" \ + "$($DIG ads.google.com A +short)" + else + local ADS=$($DIG ads.google.com A +short 2>/dev/null) + if echo "$ADS" | grep -q "0.0.0.0"; then + check "Blocking disabled but domain blocked" "should-resolve" "0.0.0.0" + else + check "Blocking disabled — domain resolves normally" "." "$ADS" + fi + fi + + echo "" + echo "=== Cache ===" + + $DIG example.com A +short > /dev/null 2>&1 + sleep 1 + check "Cache hit returns result" \ + "." \ + "$($DIG example.com A +short)" + + echo "" + echo "=== Connectivity ===" + + # Apple captive portal can be slow/flaky on some networks + local CAPTIVE + CAPTIVE=$($DIG captive.apple.com A +short 2>/dev/null || echo "timeout") + if echo "$CAPTIVE" | grep -q "apple\|17\.\|timeout"; then + check "Apple captive portal" "." "$CAPTIVE" + else + check "Apple captive portal" "apple" "$CAPTIVE" + fi + + check "CDN (jsdelivr)" \ + "." \ + "$($DIG cdn.jsdelivr.net A +short)" + + echo "" + echo "=== API ===" + + check "Health endpoint" \ + "ok" \ + "$(curl -s http://127.0.0.1:$API_PORT/health)" + + check "Stats endpoint" \ + "uptime_secs" \ + "$(curl -s http://127.0.0.1:$API_PORT/stats)" + + echo "" + echo "=== Log Health ===" + + ERRORS=$(grep -c 'RECURSIVE ERROR\|PARSE ERROR\|HANDLER ERROR\|panic' "$LOG" 2>/dev/null || echo 0) + check "No critical errors in log" \ + "0" \ + "$ERRORS" + + kill "$NUMA_PID" 2>/dev/null || true + wait "$NUMA_PID" 2>/dev/null || true + sleep 1 +} + +# ---- Suite 1: Recursive mode + DNSSEC ---- +echo "" +echo "╔══════════════════════════════════════════╗" +echo "║ Suite 1: Recursive + DNSSEC + Blocking ║" +echo "╚══════════════════════════════════════════╝" + +run_test_suite "recursive + DNSSEC + blocking" " +[server] +bind_addr = \"127.0.0.1:5354\" +api_port = 5381 + +[upstream] +mode = \"recursive\" + +[cache] +max_entries = 10000 +min_ttl = 60 +max_ttl = 86400 + +[blocking] +enabled = true + +[proxy] +enabled = false + +[dnssec] +enabled = true +" + +DIG="dig @127.0.0.1 -p $PORT +time=5 +tries=1" + +echo "" +echo "=== DNSSEC (recursive only) ===" + +# Re-start for DNSSEC checks (suite 1 instance was killed) +RUST_LOG=info "$BINARY" "$CONFIG" > "$LOG" 2>&1 & +NUMA_PID=$! +sleep 4 + +check "AD bit set (cloudflare.com)" \ + " ad" \ + "$($DIG cloudflare.com A +dnssec 2>&1 | grep flags:)" + +check "EDNS DO bit echoed" \ + "flags: do" \ + "$($DIG cloudflare.com A +dnssec 2>&1 | grep 'EDNS:')" + +kill "$NUMA_PID" 2>/dev/null || true +wait "$NUMA_PID" 2>/dev/null || true +sleep 1 + +# ---- Suite 2: Forward mode (backward compat) ---- +echo "" +echo "╔══════════════════════════════════════════╗" +echo "║ Suite 2: Forward (DoH) + Blocking ║" +echo "╚══════════════════════════════════════════╝" + +run_test_suite "forward DoH + blocking" " +[server] +bind_addr = \"127.0.0.1:5354\" +api_port = 5381 + +[upstream] +mode = \"forward\" +address = \"https://9.9.9.9/dns-query\" + +[cache] +max_entries = 10000 +min_ttl = 60 +max_ttl = 86400 + +[blocking] +enabled = true + +[proxy] +enabled = false +" + +# ---- Suite 3: Forward UDP (plain, no DoH) ---- +echo "" +echo "╔══════════════════════════════════════════╗" +echo "║ Suite 3: Forward (UDP) + No Blocking ║" +echo "╚══════════════════════════════════════════╝" + +run_test_suite "forward UDP, no blocking" " +[server] +bind_addr = \"127.0.0.1:5354\" +api_port = 5381 + +[upstream] +mode = \"forward\" +address = \"9.9.9.9\" +port = 53 + +[cache] +max_entries = 10000 +min_ttl = 60 +max_ttl = 86400 + +[blocking] +enabled = false + +[proxy] +enabled = false +" + +# Verify blocking is actually off +RUST_LOG=info "$BINARY" "$CONFIG" > "$LOG" 2>&1 & +NUMA_PID=$! +sleep 3 + +echo "" +echo "=== Blocking disabled ===" +ADS_RESULT=$($DIG ads.google.com A +short 2>/dev/null) +if echo "$ADS_RESULT" | grep -q "0.0.0.0"; then + check "ads.google.com NOT blocked (blocking disabled)" "not-0.0.0.0" "0.0.0.0" +else + check "ads.google.com NOT blocked (blocking disabled)" "." "$ADS_RESULT" +fi + +kill "$NUMA_PID" 2>/dev/null || true +wait "$NUMA_PID" 2>/dev/null || true +sleep 1 + +# ---- Suite 4: Local zones + Overrides API ---- +echo "" +echo "╔══════════════════════════════════════════╗" +echo "║ Suite 4: Local Zones + Overrides API ║" +echo "╚══════════════════════════════════════════╝" + +cat > "$CONFIG" << 'CONF' +[server] +bind_addr = "127.0.0.1:5354" +api_port = 5381 + +[upstream] +mode = "forward" +address = "9.9.9.9" +port = 53 + +[cache] +max_entries = 10000 + +[blocking] +enabled = false + +[proxy] +enabled = false + +[[zones]] +domain = "test.local" +record_type = "A" +value = "10.0.0.1" +ttl = 60 + +[[zones]] +domain = "mail.local" +record_type = "MX" +value = "10 smtp.local" +ttl = 60 +CONF + +RUST_LOG=info "$BINARY" "$CONFIG" > "$LOG" 2>&1 & +NUMA_PID=$! +sleep 3 + +echo "" +echo "=== Local Zones ===" + +check "Local A record (test.local)" \ + "10.0.0.1" \ + "$($DIG test.local A +short)" + +check "Local MX record (mail.local)" \ + "smtp.local" \ + "$($DIG mail.local MX +short)" + +check "Non-local domain still resolves" \ + "." \ + "$($DIG example.com A +short)" + +echo "" +echo "=== Overrides API ===" + +# Create override +HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST http://127.0.0.1:$API_PORT/overrides \ + -H 'Content-Type: application/json' \ + -d '{"domain":"override.test","target":"192.168.1.100","duration_secs":60}') +check "Create override (HTTP 200/201)" \ + "20" \ + "$HTTP_CODE" + +sleep 1 + +check "Override resolves" \ + "192.168.1.100" \ + "$($DIG override.test A +short)" + +# List overrides +check "List overrides" \ + "override.test" \ + "$(curl -s http://127.0.0.1:$API_PORT/overrides)" + +# Delete override +curl -s -X DELETE http://127.0.0.1:$API_PORT/overrides/override.test > /dev/null + +sleep 1 + +# After delete, should not resolve to override +AFTER_DELETE=$($DIG override.test A +short 2>/dev/null) +if echo "$AFTER_DELETE" | grep -q "192.168.1.100"; then + check "Override deleted" "not-192.168.1.100" "$AFTER_DELETE" +else + check "Override deleted" "." "deleted" +fi + +echo "" +echo "=== Cache API ===" + +check "Cache list" \ + "domain" \ + "$(curl -s http://127.0.0.1:$API_PORT/cache)" + +# Flush cache +curl -s -X DELETE http://127.0.0.1:$API_PORT/cache > /dev/null +check "Cache flushed" \ + "0" \ + "$(curl -s http://127.0.0.1:$API_PORT/stats | grep -o '"entries":[0-9]*' | grep -o '[0-9]*')" + +kill "$NUMA_PID" 2>/dev/null || true +wait "$NUMA_PID" 2>/dev/null || true + +# Summary +echo "" +TOTAL=$((PASSED + FAILED)) +if [ "$FAILED" -eq 0 ]; then + printf "${GREEN}All %d tests passed.${RESET}\n" "$TOTAL" + exit 0 +else + printf "${RED}%d/%d tests failed.${RESET}\n" "$FAILED" "$TOTAL" + echo "" + echo "Log: $LOG" + exit 1 +fi diff --git a/tests/network-probe.sh b/tests/network-probe.sh new file mode 100755 index 0000000..afcd383 --- /dev/null +++ b/tests/network-probe.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash +# Network probe: tests which DNS transports are available on the current network. +# Run on a problematic network to diagnose what's blocked. +# Usage: ./tests/network-probe.sh + +set -euo pipefail + +GREEN="\033[32m" +RED="\033[31m" +DIM="\033[90m" +RESET="\033[0m" + +PASSED=0 +FAILED=0 + +probe() { + local name="$1" + local cmd="$2" + local expect="$3" + + local result + result=$(eval "$cmd" 2>&1) || true + + if echo "$result" | grep -q "$expect"; then + PASSED=$((PASSED + 1)) + printf " ${GREEN}✓${RESET} %-45s ${DIM}%s${RESET}\n" "$name" "$(echo "$result" | head -1 | cut -c1-60)" + else + FAILED=$((FAILED + 1)) + printf " ${RED}✗${RESET} %-45s ${DIM}blocked/timeout${RESET}\n" "$name" + fi +} + +echo "" +echo "Network DNS Transport Probe" +echo "===========================" +echo "Network: $(networksetup -getairportnetwork en0 2>/dev/null | sed 's/Current Wi-Fi Network: //' || echo 'unknown')" +echo "Local IP: $(ipconfig getifaddr en0 2>/dev/null || echo 'unknown')" +echo "Gateway: $(route -n get default 2>/dev/null | grep gateway | awk '{print $2}' || echo 'unknown')" +echo "" + +echo "=== UDP port 53 (recursive resolution) ===" +probe "Root server a (198.41.0.4)" \ + "dig @198.41.0.4 . NS +short +time=5 +tries=1" \ + "root-servers" + +probe "Root server k (193.0.14.129)" \ + "dig @193.0.14.129 . NS +short +time=5 +tries=1" \ + "root-servers" + +probe "Google DNS (8.8.8.8)" \ + "dig @8.8.8.8 google.com A +short +time=5 +tries=1" \ + "\." + +probe "Cloudflare (1.1.1.1)" \ + "dig @1.1.1.1 cloudflare.com A +short +time=5 +tries=1" \ + "\." + +probe ".com TLD (192.5.6.30)" \ + "dig @192.5.6.30 google.com NS +short +time=5 +tries=1" \ + "google" + +echo "" +echo "=== TCP port 53 ===" +probe "Google DNS TCP (8.8.8.8)" \ + "dig @8.8.8.8 google.com A +short +tcp +time=5 +tries=1" \ + "\." + +probe "Root server TCP (198.41.0.4)" \ + "dig @198.41.0.4 . NS +short +tcp +time=5 +tries=1" \ + "root-servers" + +echo "" +echo "=== DoT port 853 (DNS-over-TLS) ===" +probe "Quad9 DoT (9.9.9.9:853)" \ + "echo Q | openssl s_client -connect 9.9.9.9:853 -servername dns.quad9.net 2>&1 | grep 'verify return'" \ + "verify return" + +probe "Cloudflare DoT (1.1.1.1:853)" \ + "echo Q | openssl s_client -connect 1.1.1.1:853 -servername cloudflare-dns.com 2>&1 | grep 'verify return'" \ + "verify return" + +echo "" +echo "=== DoH port 443 (DNS-over-HTTPS) ===" +probe "Quad9 DoH (dns.quad9.net)" \ + "curl -s -m 5 -H 'accept: application/dns-json' 'https://dns.quad9.net:443/dns-query?name=google.com&type=A'" \ + "Answer" + +probe "Cloudflare DoH (1.1.1.1)" \ + "curl -s -m 5 -H 'accept: application/dns-json' 'https://1.1.1.1/dns-query?name=google.com&type=A'" \ + "Answer" + +probe "Google DoH (dns.google)" \ + "curl -s -m 5 'https://dns.google/resolve?name=google.com&type=A'" \ + "Answer" + +echo "" +echo "=== ISP DNS ===" +# Detect system DNS +SYS_DNS=$(scutil --dns 2>/dev/null | grep "nameserver\[0\]" | head -1 | awk '{print $3}' || echo "unknown") +if [ "$SYS_DNS" != "unknown" ] && [ "$SYS_DNS" != "127.0.0.1" ]; then + probe "ISP DNS ($SYS_DNS)" \ + "dig @$SYS_DNS google.com A +short +time=5 +tries=1" \ + "\." +else + printf " ${DIM}– System DNS is $SYS_DNS (skipped)${RESET}\n" +fi + +echo "" +echo "===========================" +TOTAL=$((PASSED + FAILED)) +printf "Results: ${GREEN}%d passed${RESET}, ${RED}%d blocked${RESET} / %d total\n" "$PASSED" "$FAILED" "$TOTAL" + +echo "" +echo "Recommendation:" +if [ "$FAILED" -eq 0 ]; then + echo " All transports available. Recursive mode will work." +elif dig @198.41.0.4 . NS +short +time=5 +tries=1 2>&1 | grep -q "root-servers"; then + echo " UDP:53 works. Recursive mode will work." +else + echo " UDP:53 blocked — recursive mode will NOT work on this network." + if curl -s -m 5 'https://dns.quad9.net:443/dns-query?name=test.com&type=A' 2>&1 | grep -q "Answer"; then + echo " DoH (port 443) works — use mode = \"forward\" with DoH upstream." + elif echo Q | openssl s_client -connect 9.9.9.9:853 2>&1 | grep -q "verify return"; then + echo " DoT (port 853) works — DoT upstream would work (not yet implemented)." + else + echo " Only ISP DNS available. Use mode = \"forward\" with ISP auto-detect." + fi +fi