feat: wire-level forwarding, cache, request hedging, and DoH keepalive

Wire-level forwarding path skips DnsPacket parse/serialize on the hot
path. Cache stores raw wire bytes with pre-scanned TTL offsets — patches
ID + TTLs in-place on lookup instead of cloning parsed packets.

Request hedging (Dean & Barroso "Tail at Scale") fires a second
parallel request after a configurable delay (default 10ms) when
the primary upstream stalls. DoH keepalive loop prevents idle
HTTP/2 + TLS connection teardown.

Recursive resolver now hedges across multiple NS addresses and
caches NS delegation records to skip TLD re-queries.

Integration test harness polls /blocking/stats instead of fixed
sleep, eliminating the blocklist-download race condition.
This commit is contained in:
Razvan Dimescu
2026-04-12 04:20:18 +03:00
parent 4f46550283
commit 7efac85836
18 changed files with 4091 additions and 110 deletions

View File

@@ -1029,6 +1029,7 @@ mod tests {
upstream_port: 53,
lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST),
timeout: std::time::Duration::from_secs(3),
hedge_delay: std::time::Duration::ZERO,
proxy_tld: "numa".to_string(),
proxy_tld_suffix: ".numa".to_string(),
lan_enabled: false,

View File

@@ -1,9 +1,10 @@
use std::collections::HashMap;
use std::time::{Duration, Instant};
use crate::buffer::BytePacketBuffer;
use crate::packet::DnsPacket;
use crate::question::QueryType;
use crate::record::DnsRecord;
use crate::wire::WireMeta;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum DnssecStatus {
@@ -26,14 +27,16 @@ impl DnssecStatus {
}
struct CacheEntry {
packet: DnsPacket,
wire: Vec<u8>,
meta: WireMeta,
inserted_at: Instant,
ttl: Duration,
dnssec_status: DnssecStatus,
}
/// DNS cache using a two-level map (domain -> query_type -> entry) so that
/// lookups can borrow `&str` instead of allocating a `String` key.
const STALE_WINDOW: Duration = Duration::from_secs(3600);
/// DNS cache with serve-stale (RFC 8767). Stores raw wire bytes.
pub struct DnsCache {
entries: HashMap<String, HashMap<QueryType, CacheEntry>>,
entry_count: usize,
@@ -53,6 +56,80 @@ impl DnsCache {
}
}
/// Look up cached wire bytes, patching ID and TTLs in the returned copy.
/// Implements serve-stale (RFC 8767): expired entries within STALE_WINDOW
/// are returned with TTL=1 and `stale=true` so callers can revalidate.
pub fn lookup_wire(
&self,
domain: &str,
qtype: QueryType,
new_id: u16,
) -> Option<(Vec<u8>, DnssecStatus, bool)> {
let type_map = self.entries.get(domain)?;
let entry = type_map.get(&qtype)?;
let elapsed = entry.inserted_at.elapsed();
let (remaining, stale) = if elapsed < entry.ttl {
let secs = (entry.ttl - elapsed).as_secs() as u32;
(secs.max(1), false)
} else if elapsed < entry.ttl + STALE_WINDOW {
(1, true)
} else {
return None;
};
let mut wire = entry.wire.clone();
crate::wire::patch_id(&mut wire, new_id);
crate::wire::patch_ttls(&mut wire, &entry.meta.ttl_offsets, remaining);
Some((wire, entry.dnssec_status, stale))
}
pub fn insert_wire(
&mut self,
domain: &str,
qtype: QueryType,
wire: &[u8],
dnssec_status: DnssecStatus,
) {
let meta = match crate::wire::scan_ttl_offsets(wire) {
Ok(m) => m,
Err(_) => return, // malformed wire, skip
};
if self.entry_count >= self.max_entries {
self.evict_expired();
if self.entry_count >= self.max_entries {
return;
}
}
let min_ttl = crate::wire::min_ttl_from_wire(wire, &meta)
.unwrap_or(self.min_ttl)
.clamp(self.min_ttl, self.max_ttl);
let type_map = if let Some(existing) = self.entries.get_mut(domain) {
existing
} else {
self.entries.entry(domain.to_string()).or_default()
};
if !type_map.contains_key(&qtype) {
self.entry_count += 1;
}
type_map.insert(
qtype,
CacheEntry {
wire: wire.to_vec(),
meta,
inserted_at: Instant::now(),
ttl: Duration::from_secs(min_ttl as u64),
dnssec_status,
},
);
}
/// Read-only lookup — expired entries are left in place (cleaned up on insert).
pub fn lookup(&self, domain: &str, qtype: QueryType) -> Option<DnsPacket> {
self.lookup_with_status(domain, qtype).map(|(pkt, _)| pkt)
@@ -63,23 +140,28 @@ impl DnsCache {
domain: &str,
qtype: QueryType,
) -> Option<(DnsPacket, DnssecStatus)> {
let type_map = self.entries.get(domain)?;
let entry = type_map.get(&qtype)?;
let (wire, status, _stale) = self.lookup_wire(domain, qtype, 0)?;
let mut buf = BytePacketBuffer::from_bytes(&wire);
let pkt = DnsPacket::from_buffer(&mut buf).ok()?;
Some((pkt, status))
}
let elapsed = entry.inserted_at.elapsed();
if elapsed >= entry.ttl {
return None;
pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) {
self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate);
}
pub fn insert_with_status(
&mut self,
domain: &str,
qtype: QueryType,
packet: &DnsPacket,
dnssec_status: DnssecStatus,
) {
let mut buf = BytePacketBuffer::new();
if packet.write(&mut buf).is_err() {
return;
}
let remaining_secs = (entry.ttl - elapsed).as_secs() as u32;
let remaining = remaining_secs.max(1);
let mut packet = entry.packet.clone();
adjust_ttls(&mut packet.answers, remaining);
adjust_ttls(&mut packet.authorities, remaining);
adjust_ttls(&mut packet.resources, remaining);
Some((packet, entry.dnssec_status))
self.insert_wire(domain, qtype, buf.filled(), dnssec_status);
}
pub fn ttl_remaining(&self, domain: &str, qtype: QueryType) -> Option<(u32, u32)> {
@@ -105,49 +187,6 @@ impl DnsCache {
false
}
pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) {
self.insert_with_status(domain, qtype, packet, DnssecStatus::Indeterminate);
}
pub fn insert_with_status(
&mut self,
domain: &str,
qtype: QueryType,
packet: &DnsPacket,
dnssec_status: DnssecStatus,
) {
if self.entry_count >= self.max_entries {
self.evict_expired();
if self.entry_count >= self.max_entries {
return;
}
}
let min_ttl = extract_min_ttl(&packet.answers)
.unwrap_or(self.min_ttl)
.clamp(self.min_ttl, self.max_ttl);
let type_map = if let Some(existing) = self.entries.get_mut(domain) {
existing
} else {
self.entries.entry(domain.to_string()).or_default()
};
if !type_map.contains_key(&qtype) {
self.entry_count += 1;
}
type_map.insert(
qtype,
CacheEntry {
packet: packet.clone(),
inserted_at: Instant::now(),
ttl: Duration::from_secs(min_ttl as u64),
dnssec_status,
},
);
}
pub fn len(&self) -> usize {
self.entry_count
}
@@ -179,7 +218,8 @@ impl DnsCache {
+ 1;
total += type_map.capacity() * inner_slot;
for entry in type_map.values() {
total += entry.packet.heap_bytes();
total += entry.wire.capacity()
+ entry.meta.ttl_offsets.capacity() * std::mem::size_of::<usize>();
}
}
total
@@ -228,20 +268,11 @@ pub struct CacheInfo {
pub ttl_remaining: u32,
}
fn extract_min_ttl(records: &[DnsRecord]) -> Option<u32> {
records.iter().map(|r| r.ttl()).min()
}
fn adjust_ttls(records: &mut [DnsRecord], new_ttl: u32) {
for record in records.iter_mut() {
record.set_ttl(new_ttl);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packet::DnsPacket;
use crate::record::DnsRecord;
#[test]
fn heap_bytes_grows_with_entries() {

View File

@@ -138,6 +138,8 @@ pub struct UpstreamConfig {
pub fallback: Vec<String>,
#[serde(default = "default_timeout_ms")]
pub timeout_ms: u64,
#[serde(default = "default_hedge_ms")]
pub hedge_ms: u64,
#[serde(default = "default_root_hints")]
pub root_hints: Vec<String>,
#[serde(default = "default_prime_tlds")]
@@ -154,6 +156,7 @@ impl Default for UpstreamConfig {
port: default_upstream_port(),
fallback: Vec::new(),
timeout_ms: default_timeout_ms(),
hedge_ms: default_hedge_ms(),
root_hints: default_root_hints(),
prime_tlds: default_prime_tlds(),
srtt: default_srtt(),
@@ -271,6 +274,9 @@ fn default_upstream_port() -> u16 {
fn default_timeout_ms() -> u64 {
5000
}
fn default_hedge_ms() -> u64 {
10
}
#[derive(Deserialize)]
pub struct CacheConfig {

View File

@@ -16,7 +16,9 @@ use crate::blocklist::BlocklistStore;
use crate::buffer::BytePacketBuffer;
use crate::cache::{DnsCache, DnssecStatus};
use crate::config::{UpstreamMode, ZoneMap};
use crate::forward::{forward_query, forward_with_failover, Upstream, UpstreamPool};
use crate::forward::{
forward_query_raw, forward_with_failover_raw, Upstream, UpstreamPool,
};
use crate::header::ResultCode;
use crate::health::HealthMeta;
use crate::lan::PeerStore;
@@ -47,6 +49,7 @@ pub struct ServerCtx {
pub upstream_port: u16,
pub lan_ip: Mutex<std::net::Ipv4Addr>,
pub timeout: Duration,
pub hedge_delay: Duration,
pub proxy_tld: String,
pub proxy_tld_suffix: String, // pre-computed ".{tld}" to avoid per-query allocation
pub lan_enabled: bool,
@@ -81,6 +84,7 @@ pub struct ServerCtx {
/// (and logging parse errors) before calling this function.
pub async fn resolve_query(
query: DnsPacket,
raw_wire: &[u8],
src_addr: SocketAddr,
ctx: &ServerCtx,
) -> crate::Result<BytePacketBuffer> {
@@ -177,9 +181,8 @@ pub async fn resolve_query(
// Conditional forwarding takes priority over recursive mode
// (e.g. Tailscale .ts.net, VPC private zones)
let upstream = Upstream::Udp(fwd_addr);
match forward_query(&query, &upstream, ctx.timeout).await {
match forward_and_cache(raw_wire, &upstream, ctx, &qname, qtype).await {
Ok(resp) => {
ctx.cache.write().unwrap().insert(&qname, qtype, &resp);
(resp, QueryPath::Forwarded, DnssecStatus::Indeterminate)
}
Err(e) => {
@@ -221,10 +224,19 @@ pub async fn resolve_query(
(resp, path, DnssecStatus::Indeterminate)
} else {
let pool = ctx.upstream_pool.lock().unwrap().clone();
match forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await {
Ok(resp) => {
ctx.cache.write().unwrap().insert(&qname, qtype, &resp);
(resp, QueryPath::Forwarded, DnssecStatus::Indeterminate)
match forward_with_failover_raw(raw_wire, &pool, &ctx.srtt, ctx.timeout, ctx.hedge_delay).await {
Ok(resp_wire) => {
ctx.cache.write().unwrap().insert_wire(
&qname, qtype, &resp_wire, DnssecStatus::Indeterminate,
);
let mut buf = BytePacketBuffer::from_bytes(&resp_wire);
match DnsPacket::from_buffer(&mut buf) {
Ok(resp) => (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate),
Err(e) => {
error!("{} | {:?} {} | PARSE ERROR | {}", src_addr, qtype, qname, e);
(DnsPacket::response_from(&query, ResultCode::SERVFAIL), QueryPath::UpstreamError, DnssecStatus::Indeterminate)
}
}
}
Err(e) => {
error!(
@@ -347,12 +359,29 @@ pub async fn resolve_query(
Ok(resp_buffer)
}
/// Handle a DNS query received over UDP. Thin wrapper around resolve_query.
async fn forward_and_cache(
wire: &[u8],
upstream: &Upstream,
ctx: &ServerCtx,
qname: &str,
qtype: QueryType,
) -> crate::Result<DnsPacket> {
let resp_wire = forward_query_raw(wire, upstream, ctx.timeout).await?;
ctx.cache
.write()
.unwrap()
.insert_wire(qname, qtype, &resp_wire, DnssecStatus::Indeterminate);
let mut buf = BytePacketBuffer::from_bytes(&resp_wire);
DnsPacket::from_buffer(&mut buf)
}
pub async fn handle_query(
mut buffer: BytePacketBuffer,
raw_len: usize,
src_addr: SocketAddr,
ctx: &ServerCtx,
) -> crate::Result<()> {
let raw_wire = buffer.buf[..raw_len].to_vec();
let query = match DnsPacket::from_buffer(&mut buffer) {
Ok(packet) => packet,
Err(e) => {
@@ -360,7 +389,7 @@ pub async fn handle_query(
return Ok(());
}
};
match resolve_query(query, src_addr, ctx).await {
match resolve_query(query, &raw_wire, src_addr, ctx).await {
Ok(resp_buffer) => {
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
}

View File

@@ -82,7 +82,7 @@ async fn resolve_doh(dns_bytes: &[u8], src: SocketAddr, ctx: &ServerCtx) -> Resp
let query_rd = query.header.recursion_desired;
let questions = query.questions.clone();
match resolve_query(query, src, ctx).await {
match resolve_query(query, dns_bytes, src, ctx).await {
Ok(resp_buffer) => {
let min_ttl = extract_min_ttl(resp_buffer.filled());
dns_response(resp_buffer.filled(), min_ttl)
@@ -102,11 +102,10 @@ async fn resolve_doh(dns_bytes: &[u8], src: SocketAddr, ctx: &ServerCtx) -> Resp
}
fn extract_min_ttl(wire: &[u8]) -> u32 {
let mut buf = BytePacketBuffer::from_bytes(wire);
match DnsPacket::from_buffer(&mut buf) {
Ok(pkt) => pkt.answers.iter().map(|r| r.ttl()).min().unwrap_or(0),
Err(_) => 0,
}
crate::wire::scan_ttl_offsets(wire)
.ok()
.and_then(|meta| crate::wire::min_ttl_from_wire(wire, &meta))
.unwrap_or(0)
}
fn dns_response(wire: &[u8], min_ttl: u32) -> Response {

View File

@@ -177,8 +177,7 @@ where
break;
};
// Parse query up-front so we can echo its question section in SERVFAIL
// responses when resolve_query fails.
let raw_wire = buffer.buf[..msg_len].to_vec();
let query = match DnsPacket::from_buffer(&mut buffer) {
Ok(q) => q,
Err(e) => {
@@ -200,7 +199,7 @@ where
}
};
match resolve_query(query.clone(), remote_addr, ctx).await {
match resolve_query(query.clone(), &raw_wire, remote_addr, ctx).await {
Ok(resp_buffer) => {
if write_framed(&mut stream, resp_buffer.filled())
.await
@@ -370,6 +369,7 @@ mod tests {
upstream_port: 53,
lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST),
timeout: Duration::from_millis(200),
hedge_delay: Duration::ZERO,
proxy_tld: "numa".to_string(),
proxy_tld_suffix: ".numa".to_string(),
lan_enabled: false,

View File

@@ -65,6 +65,13 @@ pub fn parse_upstream(s: &str, default_port: u16) -> Result<Upstream> {
if s.starts_with("https://") {
let client = reqwest::Client::builder()
.use_rustls_tls()
.http2_initial_stream_window_size(65_535)
.http2_initial_connection_window_size(65_535)
.http2_keep_alive_interval(Duration::from_secs(15))
.http2_keep_alive_while_idle(true)
.http2_keep_alive_timeout(Duration::from_secs(10))
.pool_idle_timeout(Duration::from_secs(300))
.pool_max_idle_per_host(1)
.build()
.unwrap_or_default();
return Ok(Upstream::Doh {
@@ -325,13 +332,170 @@ async fn forward_doh(
let mut send_buffer = BytePacketBuffer::new();
query.write(&mut send_buffer)?;
let resp_bytes = forward_doh_raw(send_buffer.filled(), url, client, timeout_duration).await?;
let mut recv_buffer = BytePacketBuffer::from_bytes(&resp_bytes);
DnsPacket::from_buffer(&mut recv_buffer)
}
pub async fn forward_query_raw(
wire: &[u8],
upstream: &Upstream,
timeout_duration: Duration,
) -> Result<Vec<u8>> {
match upstream {
Upstream::Udp(addr) => forward_udp_raw(wire, *addr, timeout_duration).await,
Upstream::Doh { url, client } => forward_doh_raw(wire, url, client, timeout_duration).await,
}
}
pub async fn forward_with_hedging_raw(
wire: &[u8],
primary: &Upstream,
secondary: &Upstream,
hedge_delay: Duration,
timeout_duration: Duration,
) -> Result<Vec<u8>> {
use tokio::time::sleep;
let primary_fut = forward_query_raw(wire, primary, timeout_duration);
tokio::pin!(primary_fut);
let delay = sleep(hedge_delay);
tokio::pin!(delay);
// Phase 1: wait for either primary to return, or the hedge delay.
tokio::select! {
result = &mut primary_fut => return result,
_ = &mut delay => {}
}
// Phase 2: hedge delay expired — fire secondary while still polling primary.
let secondary_fut = forward_query_raw(wire, secondary, timeout_duration);
tokio::pin!(secondary_fut);
// First successful response wins. If one errors, wait for the other.
let mut primary_err: Option<crate::Error> = None;
let mut secondary_err: Option<crate::Error> = None;
loop {
tokio::select! {
r = &mut primary_fut, if primary_err.is_none() => {
match r {
Ok(resp) => return Ok(resp),
Err(e) => {
if let Some(se) = secondary_err.take() {
return Err(se);
}
primary_err = Some(e);
}
}
}
r = &mut secondary_fut, if secondary_err.is_none() => {
match r {
Ok(resp) => return Ok(resp),
Err(e) => {
if let Some(pe) = primary_err.take() {
return Err(pe);
}
secondary_err = Some(e);
}
}
}
}
match (primary_err, secondary_err) {
(Some(pe), Some(_)) => return Err(pe),
(pe, se) => { primary_err = pe; secondary_err = se; }
}
}
}
pub async fn forward_with_failover_raw(
wire: &[u8],
pool: &UpstreamPool,
srtt: &RwLock<SrttCache>,
timeout_duration: Duration,
hedge_delay: Duration,
) -> Result<Vec<u8>> {
let mut candidates: Vec<(usize, u64)> = pool
.primary
.iter()
.enumerate()
.map(|(i, u)| {
let rtt = match u {
Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()),
_ => 0,
};
(i, rtt)
})
.collect();
candidates.sort_by_key(|&(_, rtt)| rtt);
let all_upstreams: Vec<&Upstream> = candidates
.iter()
.map(|&(i, _)| &pool.primary[i])
.chain(pool.fallback.iter())
.collect();
let mut last_err: Option<Box<dyn std::error::Error + Send + Sync>> = None;
for upstream in &all_upstreams {
let start = Instant::now();
let result = if !hedge_delay.is_zero() && matches!(upstream, Upstream::Doh { .. }) {
// Hedge against the same upstream: parallel h2 streams on same
// connection. Independent stream scheduling rescues dispatch spikes.
forward_with_hedging_raw(wire, upstream, upstream, hedge_delay, timeout_duration).await
} else {
forward_query_raw(wire, upstream, timeout_duration).await
};
match result {
Ok(resp) => {
if let Upstream::Udp(addr) = upstream {
let rtt_ms = start.elapsed().as_millis() as u64;
srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false);
}
return Ok(resp);
}
Err(e) => {
if let Upstream::Udp(addr) = upstream {
srtt.write().unwrap().record_failure(addr.ip());
}
log::debug!("upstream {} failed: {}", upstream, e);
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| "no upstream configured".into()))
}
async fn forward_udp_raw(
wire: &[u8],
upstream: SocketAddr,
timeout_duration: Duration,
) -> Result<Vec<u8>> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.send_to(wire, upstream).await?;
let mut recv_buf = vec![0u8; 4096];
let (size, _) = timeout(timeout_duration, socket.recv_from(&mut recv_buf)).await??;
recv_buf.truncate(size);
Ok(recv_buf)
}
async fn forward_doh_raw(
wire: &[u8],
url: &str,
client: &reqwest::Client,
timeout_duration: Duration,
) -> Result<Vec<u8>> {
let resp = timeout(
timeout_duration,
client
.post(url)
.header("content-type", "application/dns-message")
.header("accept", "application/dns-message")
.body(send_buffer.filled().to_vec())
.body(wire.to_vec())
.send(),
)
.await??
@@ -339,9 +503,25 @@ async fn forward_doh(
let bytes = resp.bytes().await?;
log::debug!("DoH response: {} bytes", bytes.len());
Ok(bytes.to_vec())
}
let mut recv_buffer = BytePacketBuffer::from_bytes(&bytes);
DnsPacket::from_buffer(&mut recv_buffer)
/// Send a lightweight keepalive query to a DoH upstream to prevent
/// the HTTP/2 + TLS connection from going idle and being torn down.
pub async fn keepalive_doh(upstream: &Upstream) {
if let Upstream::Doh { url, client } = upstream {
// Query for . NS — minimal, always succeeds, response is small
let wire: &[u8] = &[
0x00, 0x00, // ID
0x01, 0x00, // flags: RD=1
0x00, 0x01, // QDCOUNT=1
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // AN=0, NS=0, AR=0
0x00, // root name (.)
0x00, 0x02, // type NS
0x00, 0x01, // class IN
];
let _ = forward_doh_raw(wire, url, client, Duration::from_secs(5)).await;
}
}
#[cfg(test)]

View File

@@ -26,6 +26,7 @@ pub mod srtt;
pub mod stats;
pub mod system_dns;
pub mod tls;
pub mod wire;
pub type Error = Box<dyn std::error::Error + Send + Sync>;
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -297,6 +297,7 @@ async fn main() -> numa::Result<()> {
upstream_port: config.upstream.port,
lan_ip: Mutex::new(numa::lan::detect_lan_ip().unwrap_or(std::net::Ipv4Addr::LOCALHOST)),
timeout: Duration::from_millis(config.upstream.timeout_ms),
hedge_delay: Duration::from_millis(config.upstream.hedge_ms),
proxy_tld_suffix: if config.proxy.tld.is_empty() {
String::new()
} else {
@@ -511,6 +512,14 @@ async fn main() -> numa::Result<()> {
});
}
// Spawn DoH connection keepalive — prevents idle TLS teardown
{
let keepalive_ctx = Arc::clone(&ctx);
tokio::spawn(async move {
doh_keepalive_loop(keepalive_ctx).await;
});
}
// Spawn HTTP API server
let api_ctx = Arc::clone(&ctx);
let api_addr: SocketAddr = format!("{}:{}", config.server.api_bind_addr, api_port).parse()?;
@@ -590,7 +599,7 @@ async fn main() -> numa::Result<()> {
#[allow(clippy::infinite_loop)]
loop {
let mut buffer = BytePacketBuffer::new();
let (_, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await {
let (len, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await {
Ok(r) => r,
Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => {
// Windows delivers ICMP port-unreachable as ConnectionReset on UDP sockets
@@ -598,10 +607,11 @@ async fn main() -> numa::Result<()> {
}
Err(e) => return Err(e.into()),
};
let raw_len = len;
let ctx = Arc::clone(&ctx);
tokio::spawn(async move {
if let Err(e) = handle_query(buffer, src_addr, &ctx).await {
if let Err(e) = handle_query(buffer, raw_len, src_addr, &ctx).await {
error!("{} | HANDLER ERROR | {}", src_addr, e);
}
});
@@ -777,6 +787,18 @@ async fn warm_domain(ctx: &ServerCtx, domain: &str) {
}
}
async fn doh_keepalive_loop(ctx: Arc<ServerCtx>) {
let mut interval = tokio::time::interval(Duration::from_secs(25));
interval.tick().await; // skip first immediate tick
loop {
interval.tick().await;
let pool = ctx.upstream_pool.lock().unwrap().clone();
if let Some(upstream) = pool.preferred() {
numa::forward::keepalive_doh(upstream).await;
}
}
}
async fn cache_warm_loop(ctx: Arc<ServerCtx>, domains: Vec<String>) {
tokio::time::sleep(Duration::from_secs(2)).await;

View File

@@ -202,23 +202,22 @@ pub(crate) fn resolve_iterative<'a>(
let mut ns_idx = 0;
for _ in 0..MAX_REFERRAL_DEPTH {
let ns_addr = match ns_addrs.get(ns_idx) {
Some(addr) => *addr,
None => return Err("no nameserver available".into()),
};
if ns_idx >= ns_addrs.len() {
return Err("no nameserver available".into());
}
let (q_name, q_type) = minimize_query(qname, qtype, &current_zone);
debug!(
"recursive: querying {} for {:?} {} (zone: {}, depth {})",
ns_addr, q_type, q_name, current_zone, referral_depth
"recursive: querying {} (+ hedge) for {:?} {} (zone: {}, depth {})",
ns_addrs[ns_idx], q_type, q_name, current_zone, referral_depth
);
let response = match send_query(q_name, q_type, ns_addr, srtt).await {
let response = match send_query_hedged(q_name, q_type, &ns_addrs[ns_idx..], srtt).await {
Ok(r) => r,
Err(e) => {
debug!("recursive: NS {} failed: {}", ns_addr, e);
ns_idx += 1;
debug!("recursive: NS query failed: {}", e);
ns_idx += 2; // both tried, skip past them
continue;
}
};
@@ -228,6 +227,9 @@ pub(crate) fn resolve_iterative<'a>(
{
if let Some(zone) = referral_zone(&response) {
current_zone = zone;
let mut cache_w = cache.write().unwrap();
cache_ns_delegation(&mut cache_w, &current_zone, &response);
drop(cache_w);
}
let mut all_ns = extract_ns_from_records(&response.answers);
if all_ns.is_empty() {
@@ -296,6 +298,7 @@ pub(crate) fn resolve_iterative<'a>(
{
let mut cache_w = cache.write().unwrap();
cache_ns_delegation(&mut cache_w, &current_zone, &response);
cache_ds_from_authority(&mut cache_w, &response);
}
let mut new_ns_addrs = resolve_ns_addrs_from_glue(&response, &ns_names, cache);
@@ -560,6 +563,23 @@ fn cache_ds_from_authority(cache: &mut DnsCache, response: &DnsPacket) {
}
}
/// Cache NS delegation records from a referral response so that
/// `find_closest_ns` can skip re-querying TLD servers on subsequent lookups.
fn cache_ns_delegation(cache: &mut DnsCache, zone: &str, response: &DnsPacket) {
let ns_records: Vec<_> = response
.authorities
.iter()
.filter(|r| matches!(r, DnsRecord::NS { .. }))
.cloned()
.collect();
if ns_records.is_empty() {
return;
}
let mut pkt = make_glue_packet();
pkt.answers = ns_records;
cache.insert(zone, QueryType::NS, &pkt);
}
fn make_glue_packet() -> DnsPacket {
let mut pkt = DnsPacket::new();
pkt.header.response = true;
@@ -587,6 +607,91 @@ async fn tcp_with_srtt(
}
}
/// Smart NS query: fire to two servers simultaneously when SRTT is unknown
/// (cold queries), or to the best server with SRTT-based hedge when known.
async fn send_query_hedged(
qname: &str,
qtype: QueryType,
servers: &[SocketAddr],
srtt: &RwLock<SrttCache>,
) -> crate::Result<DnsPacket> {
if servers.is_empty() {
return Err("no nameserver available".into());
}
if servers.len() == 1 {
return send_query(qname, qtype, servers[0], srtt).await;
}
let primary = servers[0];
let secondary = servers[1];
let primary_known = srtt.read().unwrap().is_known(primary.ip());
if !primary_known {
// Cold: fire both simultaneously, first response wins
debug!(
"recursive: parallel query to {} and {} for {:?} {}",
primary, secondary, qtype, qname
);
let fut_a = send_query(qname, qtype, primary, srtt);
let fut_b = send_query(qname, qtype, secondary, srtt);
tokio::pin!(fut_a);
tokio::pin!(fut_b);
// First Ok wins. If one errors, wait for the other.
let mut a_done = false;
let mut b_done = false;
let mut a_err: Option<crate::Error> = None;
let mut b_err: Option<crate::Error> = None;
loop {
tokio::select! {
r = &mut fut_a, if !a_done => {
match r {
Ok(resp) => return Ok(resp),
Err(e) => { a_done = true; a_err = Some(e); }
}
}
r = &mut fut_b, if !b_done => {
match r {
Ok(resp) => return Ok(resp),
Err(e) => { b_done = true; b_err = Some(e); }
}
}
}
match (a_err.take(), b_err.take()) {
(Some(e), Some(_)) => return Err(e),
(a, b) => { a_err = a; b_err = b; }
}
}
} else {
// Warm: send to best, hedge after SRTT × 3 if slow
let hedge_ms = srtt.read().unwrap().get(primary.ip()) * 3;
let hedge_delay = Duration::from_millis(hedge_ms.max(50));
let fut_a = send_query(qname, qtype, primary, srtt);
tokio::pin!(fut_a);
let delay = tokio::time::sleep(hedge_delay);
tokio::pin!(delay);
tokio::select! {
r = &mut fut_a => return r,
_ = &mut delay => {}
}
debug!(
"recursive: hedging {} -> {} after {}ms for {:?} {}",
primary, secondary, hedge_ms, qtype, qname
);
let fut_b = send_query(qname, qtype, secondary, srtt);
tokio::pin!(fut_b);
tokio::select! {
r = fut_a => r,
r = fut_b => r,
}
}
}
async fn send_query(
qname: &str,
qtype: QueryType,

View File

@@ -45,6 +45,11 @@ impl SrttCache {
}
}
/// Whether we have observed RTT data for this IP.
pub fn is_known(&self, ip: IpAddr) -> bool {
self.entries.contains_key(&ip)
}
/// Apply time-based decay: each DECAY_AFTER_SECS period halves distance to INITIAL.
fn decayed_srtt(entry: &SrttEntry) -> u64 {
Self::decay_for_age(entry.srtt_ms, entry.updated_at.elapsed().as_secs())

1347
src/wire.rs Normal file

File diff suppressed because it is too large Load Diff