feat: cache warming for configured domains #78
85
src/cache.rs
85
src/cache.rs
@@ -82,6 +82,29 @@ impl DnsCache {
|
||||
Some((packet, entry.dnssec_status))
|
||||
}
|
||||
|
||||
pub fn ttl_remaining(&self, domain: &str, qtype: QueryType) -> Option<(u32, u32)> {
|
||||
let type_map = self.entries.get(domain)?;
|
||||
let entry = type_map.get(&qtype)?;
|
||||
let elapsed = entry.inserted_at.elapsed();
|
||||
if elapsed >= entry.ttl {
|
||||
return None;
|
||||
}
|
||||
let total = entry.ttl.as_secs() as u32;
|
||||
let remaining = (entry.ttl - elapsed).as_secs() as u32;
|
||||
Some((remaining, total))
|
||||
}
|
||||
|
||||
pub fn needs_warm(&self, domain: &str) -> bool {
|
||||
for qtype in [QueryType::A, QueryType::AAAA] {
|
||||
match self.ttl_remaining(domain, qtype) {
|
||||
None => return true,
|
||||
Some((remaining, total)) if remaining < total / 4 => return true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) {
|
||||
self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate);
|
||||
}
|
||||
@@ -233,4 +256,66 @@ mod tests {
|
||||
cache.insert("example.com", QueryType::A, &pkt);
|
||||
assert!(cache.heap_bytes() > empty);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ttl_remaining_returns_values_for_fresh_entry() {
|
||||
let mut cache = DnsCache::new(100, 60, 3600);
|
||||
let mut pkt = DnsPacket::new();
|
||||
pkt.answers.push(DnsRecord::A {
|
||||
domain: "example.com".into(),
|
||||
addr: "1.2.3.4".parse().unwrap(),
|
||||
ttl: 300,
|
||||
});
|
||||
cache.insert("example.com", QueryType::A, &pkt);
|
||||
let (remaining, total) = cache.ttl_remaining("example.com", QueryType::A).unwrap();
|
||||
assert_eq!(total, 300);
|
||||
assert!(remaining <= 300);
|
||||
assert!(remaining > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ttl_remaining_none_for_missing() {
|
||||
let cache = DnsCache::new(100, 1, 3600);
|
||||
assert!(cache.ttl_remaining("missing.com", QueryType::A).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn needs_warm_true_when_missing() {
|
||||
let cache = DnsCache::new(100, 1, 3600);
|
||||
assert!(cache.needs_warm("missing.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn needs_warm_false_when_fresh() {
|
||||
let mut cache = DnsCache::new(100, 1, 3600);
|
||||
let mut pkt_a = DnsPacket::new();
|
||||
pkt_a.answers.push(DnsRecord::A {
|
||||
domain: "example.com".into(),
|
||||
addr: "1.2.3.4".parse().unwrap(),
|
||||
ttl: 300,
|
||||
});
|
||||
let mut pkt_aaaa = DnsPacket::new();
|
||||
pkt_aaaa.answers.push(DnsRecord::AAAA {
|
||||
domain: "example.com".into(),
|
||||
addr: "::1".parse().unwrap(),
|
||||
ttl: 300,
|
||||
});
|
||||
cache.insert("example.com", QueryType::A, &pkt_a);
|
||||
cache.insert("example.com", QueryType::AAAA, &pkt_aaaa);
|
||||
assert!(!cache.needs_warm("example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn needs_warm_true_when_only_a_cached() {
|
||||
let mut cache = DnsCache::new(100, 1, 3600);
|
||||
let mut pkt = DnsPacket::new();
|
||||
pkt.answers.push(DnsRecord::A {
|
||||
domain: "example.com".into(),
|
||||
addr: "1.2.3.4".parse().unwrap(),
|
||||
ttl: 300,
|
||||
});
|
||||
cache.insert("example.com", QueryType::A, &pkt);
|
||||
// AAAA missing → needs warm
|
||||
assert!(cache.needs_warm("example.com"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +247,8 @@ pub struct CacheConfig {
|
||||
pub min_ttl: u32,
|
||||
#[serde(default = "default_max_ttl")]
|
||||
pub max_ttl: u32,
|
||||
#[serde(default)]
|
||||
pub warm: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for CacheConfig {
|
||||
@@ -255,6 +257,7 @@ impl Default for CacheConfig {
|
||||
max_entries: default_max_entries(),
|
||||
min_ttl: default_min_ttl(),
|
||||
max_ttl: default_max_ttl(),
|
||||
warm: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
62
src/main.rs
62
src/main.rs
@@ -402,6 +402,9 @@ async fn main() -> numa::Result<()> {
|
||||
g,
|
||||
&format!("max {} entries", config.cache.max_entries),
|
||||
);
|
||||
if !config.cache.warm.is_empty() {
|
||||
row("Warm", g, &format!("{} domains", config.cache.warm.len()));
|
||||
}
|
||||
row(
|
||||
"Blocking",
|
||||
g,
|
||||
@@ -484,6 +487,15 @@ async fn main() -> numa::Result<()> {
|
||||
});
|
||||
}
|
||||
|
||||
// Spawn cache warming for user-configured domains
|
||||
if !config.cache.warm.is_empty() {
|
||||
let warm_ctx = Arc::clone(&ctx);
|
||||
let warm_domains = config.cache.warm.clone();
|
||||
tokio::spawn(async move {
|
||||
cache_warm_loop(warm_ctx, warm_domains).await;
|
||||
});
|
||||
}
|
||||
|
||||
// Spawn HTTP API server
|
||||
let api_ctx = Arc::clone(&ctx);
|
||||
let api_addr: SocketAddr = format!("{}:{}", config.server.api_bind_addr, api_port).parse()?;
|
||||
@@ -720,3 +732,53 @@ async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) {
|
||||
downloaded.len()
|
||||
);
|
||||
}
|
||||
|
||||
async fn warm_domain(ctx: &ServerCtx, domain: &str) {
|
||||
use numa::question::QueryType;
|
||||
|
||||
for qtype in [QueryType::A, QueryType::AAAA] {
|
||||
let query = numa::packet::DnsPacket::query(0, domain, qtype);
|
||||
let result = if ctx.upstream_mode == numa::config::UpstreamMode::Recursive {
|
||||
numa::recursive::resolve_recursive(
|
||||
domain,
|
||||
qtype,
|
||||
&ctx.cache,
|
||||
&query,
|
||||
&ctx.root_hints,
|
||||
&ctx.srtt,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
let pool = ctx.upstream_pool.lock().unwrap().clone();
|
||||
numa::forward::forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await
|
||||
};
|
||||
match result {
|
||||
Ok(resp) => {
|
||||
ctx.cache.write().unwrap().insert(domain, qtype, &resp);
|
||||
log::debug!("cache warm: {} {:?}", domain, qtype);
|
||||
}
|
||||
Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn cache_warm_loop(ctx: Arc<ServerCtx>, domains: Vec<String>) {
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
|
||||
for domain in &domains {
|
||||
warm_domain(&ctx, domain).await;
|
||||
}
|
||||
info!("cache warm: {} domains resolved at startup", domains.len());
|
||||
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(30));
|
||||
interval.tick().await;
|
||||
loop {
|
||||
interval.tick().await;
|
||||
for domain in &domains {
|
||||
let refresh = ctx.cache.read().unwrap().needs_warm(domain);
|
||||
if refresh {
|
||||
warm_domain(&ctx, domain).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user