diff --git a/src/cache.rs b/src/cache.rs index fb5889b..18fdc19 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -6,6 +6,22 @@ use crate::packet::DnsPacket; use crate::question::QueryType; use crate::wire::WireMeta; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Freshness { + /// Within TTL, no action needed. + Fresh, + /// Within TTL but <10% remaining — trigger background prefetch. + NearExpiry, + /// Past TTL but within stale window — serve with TTL=1, trigger background refresh. + Stale, +} + +impl Freshness { + pub fn needs_refresh(self) -> bool { + matches!(self, Freshness::NearExpiry | Freshness::Stale) + } +} + #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] pub enum DnssecStatus { Secure, @@ -64,17 +80,21 @@ impl DnsCache { domain: &str, qtype: QueryType, new_id: u16, - ) -> Option<(Vec, DnssecStatus, bool)> { + ) -> Option<(Vec, DnssecStatus, Freshness)> { let type_map = self.entries.get(domain)?; let entry = type_map.get(&qtype)?; let elapsed = entry.inserted_at.elapsed(); - let (remaining, stale) = if elapsed < entry.ttl { + let (remaining, freshness) = if elapsed < entry.ttl { let secs = (entry.ttl - elapsed).as_secs() as u32; - let near_expiry = elapsed * 10 >= entry.ttl * 9; // <10% TTL remaining - (secs.max(1), near_expiry) + let f = if elapsed * 10 >= entry.ttl * 9 { + Freshness::NearExpiry + } else { + Freshness::Fresh + }; + (secs.max(1), f) } else if elapsed < entry.ttl + STALE_WINDOW { - (1, true) + (1, Freshness::Stale) } else { return None; }; @@ -83,7 +103,7 @@ impl DnsCache { crate::wire::patch_id(&mut wire, new_id); crate::wire::patch_ttls(&mut wire, &entry.meta.ttl_offsets, remaining); - Some((wire, entry.dnssec_status, stale)) + Some((wire, entry.dnssec_status, freshness)) } pub fn insert_wire( @@ -141,11 +161,11 @@ impl DnsCache { &self, domain: &str, qtype: QueryType, - ) -> Option<(DnsPacket, DnssecStatus, bool)> { - let (wire, status, stale) = self.lookup_wire(domain, qtype, 0)?; + ) -> Option<(DnsPacket, DnssecStatus, Freshness)> { + let (wire, status, freshness) = self.lookup_wire(domain, qtype, 0)?; let mut buf = BytePacketBuffer::from_bytes(&wire); let pkt = DnsPacket::from_buffer(&mut buf).ok()?; - Some((pkt, status, stale)) + Some((pkt, status, freshness)) } pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { diff --git a/src/ctx.rs b/src/ctx.rs index 8632a28..e97a7ea 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -168,14 +168,14 @@ pub async fn resolve_query( (resp, QueryPath::Blocked, DnssecStatus::Indeterminate) } else { let cached = ctx.cache.read().unwrap().lookup_with_status(&qname, qtype); - if let Some((cached, cached_dnssec, stale)) = cached { - if stale { + if let Some((cached, cached_dnssec, freshness)) = cached { + if freshness.needs_refresh() { let key = (qname.clone(), qtype); let already = !ctx.refreshing.lock().unwrap().insert(key.clone()); if !already { let ctx = Arc::clone(ctx); tokio::spawn(async move { - warm_stale(&ctx, &key.0, key.1).await; + refresh_entry(&ctx, &key.0, key.1).await; ctx.refreshing.lock().unwrap().remove(&key); }); } @@ -388,8 +388,9 @@ fn cache_and_parse( DnsPacket::from_buffer(&mut buf) } -/// Background refresh for a stale cache entry (RFC 8767 revalidation). -async fn warm_stale(ctx: &ServerCtx, qname: &str, qtype: QueryType) { +/// Re-resolve a single (domain, qtype) and update the cache. +/// Used for both stale-entry refresh and proactive cache warming. +pub async fn refresh_entry(ctx: &ServerCtx, qname: &str, qtype: QueryType) { let query = DnsPacket::query(0, qname, qtype); if ctx.upstream_mode == UpstreamMode::Recursive { if let Ok(resp) = crate::recursive::resolve_recursive( @@ -445,7 +446,6 @@ pub async fn handle_query( src_addr: SocketAddr, ctx: &Arc, ) -> crate::Result<()> { - let raw_wire = buffer.buf[..raw_len].to_vec(); let query = match DnsPacket::from_buffer(&mut buffer) { Ok(packet) => packet, Err(e) => { @@ -453,7 +453,7 @@ pub async fn handle_query( return Ok(()); } }; - match resolve_query(query, &raw_wire, src_addr, ctx).await { + match resolve_query(query, &buffer.buf[..raw_len], src_addr, ctx).await { Ok(resp_buffer) => { ctx.socket.send_to(resp_buffer.filled(), src_addr).await?; } diff --git a/src/dot.rs b/src/dot.rs index 0216dbf..d4eeb95 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -180,7 +180,6 @@ async fn handle_dot_connection( break; }; - let raw_wire = buffer.buf[..msg_len].to_vec(); let query = match DnsPacket::from_buffer(&mut buffer) { Ok(q) => q, Err(e) => { @@ -202,7 +201,7 @@ async fn handle_dot_connection( } }; - match resolve_query(query.clone(), &raw_wire, remote_addr, ctx).await { + match resolve_query(query.clone(), &buffer.buf[..msg_len], remote_addr, ctx).await { Ok(resp_buffer) => { if write_framed(&mut stream, resp_buffer.filled()) .await diff --git a/src/main.rs b/src/main.rs index 9aa3f17..1ec7791 100644 --- a/src/main.rs +++ b/src/main.rs @@ -758,55 +758,11 @@ async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) { } async fn warm_domain(ctx: &ServerCtx, domain: &str) { - use numa::question::QueryType; - - for qtype in [QueryType::A, QueryType::AAAA] { - if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { - let query = numa::packet::DnsPacket::query(0, domain, qtype); - match numa::recursive::resolve_recursive( - domain, - qtype, - &ctx.cache, - &query, - &ctx.root_hints, - &ctx.srtt, - ) - .await - { - Ok(resp) => { - ctx.cache.write().unwrap().insert(domain, qtype, &resp); - log::debug!("cache warm: {} {:?}", domain, qtype); - } - Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), - } - } else { - let query = numa::packet::DnsPacket::query(0, domain, qtype); - let mut buf = numa::buffer::BytePacketBuffer::new(); - if query.write(&mut buf).is_err() { - continue; - } - let pool = ctx.upstream_pool.lock().unwrap().clone(); - match numa::forward::forward_with_failover_raw( - buf.filled(), - &pool, - &ctx.srtt, - ctx.timeout, - ctx.hedge_delay, - ) - .await - { - Ok(wire) => { - ctx.cache.write().unwrap().insert_wire( - domain, - qtype, - &wire, - numa::cache::DnssecStatus::Indeterminate, - ); - log::debug!("cache warm: {} {:?}", domain, qtype); - } - Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), - } - } + for qtype in [ + numa::question::QueryType::A, + numa::question::QueryType::AAAA, + ] { + numa::ctx::refresh_entry(ctx, domain, qtype).await; } } diff --git a/src/wire.rs b/src/wire.rs index aa419f2..3ee2ab3 100644 --- a/src/wire.rs +++ b/src/wire.rs @@ -1374,29 +1374,28 @@ mod tests { #[test] fn lookup_wire_signals_stale_when_expired() { + use crate::cache::Freshness; let mut cache = DnsCache::new(100, 1, 1); // max_ttl=1s so entry expires fast let pkt = response( 0x1234, "example.com", - vec![a_record("example.com", "1.2.3.4", 1)], // 1s TTL, clamped to min=1 + vec![a_record("example.com", "1.2.3.4", 1)], ); cache.insert("example.com", QueryType::A, &pkt); - // Fresh: not stale - let (_, _, stale) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); - assert!(!stale); + let (_, _, f) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert_eq!(f, Freshness::Fresh); - // Wait for expiry std::thread::sleep(std::time::Duration::from_millis(1100)); - // Expired but within stale window: stale=true - let (_, _, stale) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); - assert!(stale); + let (_, _, f) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert_eq!(f, Freshness::Stale); } #[test] fn lookup_wire_signals_prefetch_near_expiry() { - let mut cache = DnsCache::new(100, 10, 10); // min_ttl=10, max_ttl=10 → entry gets 10s TTL + use crate::cache::Freshness; + let mut cache = DnsCache::new(100, 10, 10); let pkt = response( 0x1234, "example.com", @@ -1404,18 +1403,14 @@ mod tests { ); cache.insert("example.com", QueryType::A, &pkt); - // Fresh (>10% remaining): not stale - let (_, _, stale) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); - assert!(!stale); + let (_, _, f) = cache.lookup_wire("example.com", QueryType::A, 0).unwrap(); + assert_eq!(f, Freshness::Fresh); - // Wait until <10% remaining (>9s elapsed of 10s TTL) std::thread::sleep(std::time::Duration::from_millis(9100)); - // Still valid but near expiry: stale=true (triggers prefetch) let result = cache.lookup_wire("example.com", QueryType::A, 0); - if let Some((_, _, stale)) = result { - assert!(stale, "entry at <10% TTL should signal stale for prefetch"); + if let Some((_, _, f)) = result { + assert_eq!(f, Freshness::NearExpiry); } - // (entry may have fully expired on slow CI, so we don't assert Some) } }