feat: cache warming for configured domains #78

Merged
razvandimescu merged 1 commits from feat/cache-warming into main 2026-04-11 06:14:05 +08:00
3 changed files with 150 additions and 0 deletions
Showing only changes of commit b342ee8692 - Show all commits

View File

@@ -82,6 +82,29 @@ impl DnsCache {
Some((packet, entry.dnssec_status))
}
pub fn ttl_remaining(&self, domain: &str, qtype: QueryType) -> Option<(u32, u32)> {
let type_map = self.entries.get(domain)?;
let entry = type_map.get(&qtype)?;
let elapsed = entry.inserted_at.elapsed();
if elapsed >= entry.ttl {
return None;
}
let total = entry.ttl.as_secs() as u32;
let remaining = (entry.ttl - elapsed).as_secs() as u32;
Some((remaining, total))
}
pub fn needs_warm(&self, domain: &str) -> bool {
for qtype in [QueryType::A, QueryType::AAAA] {
match self.ttl_remaining(domain, qtype) {
None => return true,
Some((remaining, total)) if remaining < total / 4 => return true,
_ => {}
}
}
false
}
pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) {
self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate);
}
@@ -233,4 +256,66 @@ mod tests {
cache.insert("example.com", QueryType::A, &pkt);
assert!(cache.heap_bytes() > empty);
}
#[test]
fn ttl_remaining_returns_values_for_fresh_entry() {
let mut cache = DnsCache::new(100, 60, 3600);
let mut pkt = DnsPacket::new();
pkt.answers.push(DnsRecord::A {
domain: "example.com".into(),
addr: "1.2.3.4".parse().unwrap(),
ttl: 300,
});
cache.insert("example.com", QueryType::A, &pkt);
let (remaining, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap();
assert_eq!(total, 300);
assert!(remaining <= 300);
assert!(remaining > 0);
}
#[test]
fn ttl_remaining_none_for_missing() {
let cache = DnsCache::new(100, 1, 3600);
assert!(cache.ttl_remaining("missing.com", QueryType::A).is_none());
}
#[test]
fn needs_warm_true_when_missing() {
let cache = DnsCache::new(100, 1, 3600);
assert!(cache.needs_warm("missing.com"));
}
#[test]
fn needs_warm_false_when_fresh() {
let mut cache = DnsCache::new(100, 1, 3600);
let mut pkt_a = DnsPacket::new();
pkt_a.answers.push(DnsRecord::A {
domain: "example.com".into(),
addr: "1.2.3.4".parse().unwrap(),
ttl: 300,
});
let mut pkt_aaaa = DnsPacket::new();
pkt_aaaa.answers.push(DnsRecord::AAAA {
domain: "example.com".into(),
addr: "::1".parse().unwrap(),
ttl: 300,
});
cache.insert("example.com", QueryType::A, &pkt_a);
cache.insert("example.com", QueryType::AAAA, &pkt_aaaa);
assert!(!cache.needs_warm("example.com"));
}
#[test]
fn needs_warm_true_when_only_a_cached() {
let mut cache = DnsCache::new(100, 1, 3600);
let mut pkt = DnsPacket::new();
pkt.answers.push(DnsRecord::A {
domain: "example.com".into(),
addr: "1.2.3.4".parse().unwrap(),
ttl: 300,
});
cache.insert("example.com", QueryType::A, &pkt);
// AAAA missing → needs warm
assert!(cache.needs_warm("example.com"));
}
}

View File

@@ -247,6 +247,8 @@ pub struct CacheConfig {
pub min_ttl: u32,
#[serde(default = "default_max_ttl")]
pub max_ttl: u32,
#[serde(default)]
pub warm: Vec<String>,
}
impl Default for CacheConfig {
@@ -255,6 +257,7 @@ impl Default for CacheConfig {
max_entries: default_max_entries(),
min_ttl: default_min_ttl(),
max_ttl: default_max_ttl(),
warm: Vec::new(),
}
}
}

View File

@@ -402,6 +402,9 @@ async fn main() -> numa::Result<()> {
g,
&format!("max {} entries", config.cache.max_entries),
);
if !config.cache.warm.is_empty() {
row("Warm", g, &format!("{} domains", config.cache.warm.len()));
}
row(
"Blocking",
g,
@@ -484,6 +487,15 @@ async fn main() -> numa::Result<()> {
});
}
// Spawn cache warming for user-configured domains
if !config.cache.warm.is_empty() {
let warm_ctx = Arc::clone(&ctx);
let warm_domains = config.cache.warm.clone();
tokio::spawn(async move {
cache_warm_loop(warm_ctx, warm_domains).await;
});
}
// Spawn HTTP API server
let api_ctx = Arc::clone(&ctx);
let api_addr: SocketAddr = format!("{}:{}", config.server.api_bind_addr, api_port).parse()?;
@@ -720,3 +732,53 @@ async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) {
downloaded.len()
);
}
async fn warm_domain(ctx: &ServerCtx, domain: &str) {
use numa::question::QueryType;
for qtype in [QueryType::A, QueryType::AAAA] {
let query = numa::packet::DnsPacket::query(0, domain, qtype);
let result = if ctx.upstream_mode == numa::config::UpstreamMode::Recursive {
numa::recursive::resolve_recursive(
domain,
qtype,
&ctx.cache,
&query,
&ctx.root_hints,
&ctx.srtt,
)
.await
} else {
let pool = ctx.upstream_pool.lock().unwrap().clone();
numa::forward::forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await
};
match result {
Ok(resp) => {
ctx.cache.write().unwrap().insert(domain, qtype, &resp);
log::debug!("cache warm: {} {:?}", domain, qtype);
}
Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e),
}
}
}
async fn cache_warm_loop(ctx: Arc<ServerCtx>, domains: Vec<String>) {
tokio::time::sleep(Duration::from_secs(2)).await;
for domain in &domains {
warm_domain(&ctx, domain).await;
}
info!("cache warm: {} domains resolved at startup", domains.len());
let mut interval = tokio::time::interval(Duration::from_secs(30));
interval.tick().await;
loop {
interval.tick().await;
for domain in &domains {
let refresh = ctx.cache.read().unwrap().needs_warm(domain);
if refresh {
warm_domain(&ctx, domain).await;
}
}
}
}