From 777012958917d454a3323f177d272b660ff4972a Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sat, 11 Apr 2026 01:14:04 +0300 Subject: [PATCH] =?UTF-8?q?feat:=20cache=20warming=20=E2=80=94=20proactive?= =?UTF-8?q?=20DNS=20resolution=20for=20configured=20domains=20(#78)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves A + AAAA at startup for domains listed in [cache] warm, then re-resolves before TTL expiry (at 75% elapsed). Keeps critical domains always hot in cache with zero client-visible latency. Closes #34 (item 4) Co-authored-by: Claude Opus 4.6 --- src/cache.rs | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/config.rs | 3 ++ src/main.rs | 62 +++++++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+) diff --git a/src/cache.rs b/src/cache.rs index d9a2a76..5bdde85 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -82,6 +82,29 @@ impl DnsCache { Some((packet, entry.dnssec_status)) } + pub fn ttl_remaining(&self, domain: &str, qtype: QueryType) -> Option<(u32, u32)> { + let type_map = self.entries.get(domain)?; + let entry = type_map.get(&qtype)?; + let elapsed = entry.inserted_at.elapsed(); + if elapsed >= entry.ttl { + return None; + } + let total = entry.ttl.as_secs() as u32; + let remaining = (entry.ttl - elapsed).as_secs() as u32; + Some((remaining, total)) + } + + pub fn needs_warm(&self, domain: &str) -> bool { + for qtype in [QueryType::A, QueryType::AAAA] { + match self.ttl_remaining(domain, qtype) { + None => return true, + Some((remaining, total)) if remaining < total / 4 => return true, + _ => {} + } + } + false + } + pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate); } @@ -233,4 +256,66 @@ mod tests { cache.insert("example.com", QueryType::A, &pkt); assert!(cache.heap_bytes() > empty); } + + #[test] + fn ttl_remaining_returns_values_for_fresh_entry() { + let mut cache = DnsCache::new(100, 60, 3600); + let mut pkt = DnsPacket::new(); + pkt.answers.push(DnsRecord::A { + domain: "example.com".into(), + addr: "1.2.3.4".parse().unwrap(), + ttl: 300, + }); + cache.insert("example.com", QueryType::A, &pkt); + let (remaining, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap(); + assert_eq!(total, 300); + assert!(remaining <= 300); + assert!(remaining > 0); + } + + #[test] + fn ttl_remaining_none_for_missing() { + let cache = DnsCache::new(100, 1, 3600); + assert!(cache.ttl_remaining("missing.com", QueryType::A).is_none()); + } + + #[test] + fn needs_warm_true_when_missing() { + let cache = DnsCache::new(100, 1, 3600); + assert!(cache.needs_warm("missing.com")); + } + + #[test] + fn needs_warm_false_when_fresh() { + let mut cache = DnsCache::new(100, 1, 3600); + let mut pkt_a = DnsPacket::new(); + pkt_a.answers.push(DnsRecord::A { + domain: "example.com".into(), + addr: "1.2.3.4".parse().unwrap(), + ttl: 300, + }); + let mut pkt_aaaa = DnsPacket::new(); + pkt_aaaa.answers.push(DnsRecord::AAAA { + domain: "example.com".into(), + addr: "::1".parse().unwrap(), + ttl: 300, + }); + cache.insert("example.com", QueryType::A, &pkt_a); + cache.insert("example.com", QueryType::AAAA, &pkt_aaaa); + assert!(!cache.needs_warm("example.com")); + } + + #[test] + fn needs_warm_true_when_only_a_cached() { + let mut cache = DnsCache::new(100, 1, 3600); + let mut pkt = DnsPacket::new(); + pkt.answers.push(DnsRecord::A { + domain: "example.com".into(), + addr: "1.2.3.4".parse().unwrap(), + ttl: 300, + }); + cache.insert("example.com", QueryType::A, &pkt); + // AAAA missing → needs warm + assert!(cache.needs_warm("example.com")); + } } diff --git a/src/config.rs b/src/config.rs index fa794d7..708ed4f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -247,6 +247,8 @@ pub struct CacheConfig { pub min_ttl: u32, #[serde(default = "default_max_ttl")] pub max_ttl: u32, + #[serde(default)] + pub warm: Vec, } impl Default for CacheConfig { @@ -255,6 +257,7 @@ impl Default for CacheConfig { max_entries: default_max_entries(), min_ttl: default_min_ttl(), max_ttl: default_max_ttl(), + warm: Vec::new(), } } } diff --git a/src/main.rs b/src/main.rs index 9e2d2f8..cee680a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -402,6 +402,9 @@ async fn main() -> numa::Result<()> { g, &format!("max {} entries", config.cache.max_entries), ); + if !config.cache.warm.is_empty() { + row("Warm", g, &format!("{} domains", config.cache.warm.len())); + } row( "Blocking", g, @@ -484,6 +487,15 @@ async fn main() -> numa::Result<()> { }); } + // Spawn cache warming for user-configured domains + if !config.cache.warm.is_empty() { + let warm_ctx = Arc::clone(&ctx); + let warm_domains = config.cache.warm.clone(); + tokio::spawn(async move { + cache_warm_loop(warm_ctx, warm_domains).await; + }); + } + // Spawn HTTP API server let api_ctx = Arc::clone(&ctx); let api_addr: SocketAddr = format!("{}:{}", config.server.api_bind_addr, api_port).parse()?; @@ -720,3 +732,53 @@ async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) { downloaded.len() ); } + +async fn warm_domain(ctx: &ServerCtx, domain: &str) { + use numa::question::QueryType; + + for qtype in [QueryType::A, QueryType::AAAA] { + let query = numa::packet::DnsPacket::query(0, domain, qtype); + let result = if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { + numa::recursive::resolve_recursive( + domain, + qtype, + &ctx.cache, + &query, + &ctx.root_hints, + &ctx.srtt, + ) + .await + } else { + let pool = ctx.upstream_pool.lock().unwrap().clone(); + numa::forward::forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await + }; + match result { + Ok(resp) => { + ctx.cache.write().unwrap().insert(domain, qtype, &resp); + log::debug!("cache warm: {} {:?}", domain, qtype); + } + Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), + } + } +} + +async fn cache_warm_loop(ctx: Arc, domains: Vec) { + tokio::time::sleep(Duration::from_secs(2)).await; + + for domain in &domains { + warm_domain(&ctx, domain).await; + } + info!("cache warm: {} domains resolved at startup", domains.len()); + + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.tick().await; + loop { + interval.tick().await; + for domain in &domains { + let refresh = ctx.cache.read().unwrap().needs_warm(domain); + if refresh { + warm_domain(&ctx, domain).await; + } + } + } +}