refactor: remove forward_with_failover duplication, fix warm-branch hedge bug

- Remove forward_with_failover (parsed): warm_domain now uses _raw + insert_wire
- forward_udp delegates to forward_udp_raw (single UDP socket implementation)
- forward_query uses unified _raw path for all protocols
- Fix send_query_hedged warm branch: bare select! dropped secondary on primary
  error instead of waiting for it — now drains both futures like the cold branch
- Remove pointless raw_len = len rename
This commit is contained in:
Razvan Dimescu
2026-04-12 06:42:59 +03:00
parent 72b540a44a
commit 17a1a6ddba
3 changed files with 71 additions and 93 deletions

View File

@@ -157,58 +157,6 @@ impl UpstreamPool {
} }
} }
pub async fn forward_with_failover(
query: &DnsPacket,
pool: &UpstreamPool,
srtt: &RwLock<SrttCache>,
timeout_duration: Duration,
) -> Result<DnsPacket> {
// Build candidate list: primary (sorted by SRTT for UDP) then fallback
let mut candidates: Vec<(usize, u64)> = pool
.primary
.iter()
.enumerate()
.map(|(i, u)| {
let rtt = match u {
Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()),
_ => 0, // DoH: keep config order (stable sort preserves it)
};
(i, rtt)
})
.collect();
candidates.sort_by_key(|&(_, rtt)| rtt);
let all_upstreams: Vec<&Upstream> = candidates
.iter()
.map(|&(i, _)| &pool.primary[i])
.chain(pool.fallback.iter())
.collect();
let mut last_err: Option<Box<dyn std::error::Error + Send + Sync>> = None;
for upstream in &all_upstreams {
let start = Instant::now();
match forward_query(query, upstream, timeout_duration).await {
Ok(resp) => {
if let Upstream::Udp(addr) = upstream {
let rtt_ms = start.elapsed().as_millis() as u64;
srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false);
}
return Ok(resp);
}
Err(e) => {
if let Upstream::Udp(addr) = upstream {
srtt.write().unwrap().record_failure(addr.ip());
}
log::debug!("upstream {} failed: {}", upstream, e);
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| "no upstream configured".into()))
}
pub async fn forward_query( pub async fn forward_query(
query: &DnsPacket, query: &DnsPacket,
upstream: &Upstream, upstream: &Upstream,
@@ -226,24 +174,14 @@ pub(crate) async fn forward_udp(
upstream: SocketAddr, upstream: SocketAddr,
timeout_duration: Duration, timeout_duration: Duration,
) -> Result<DnsPacket> { ) -> Result<DnsPacket> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
let mut send_buffer = BytePacketBuffer::new(); let mut send_buffer = BytePacketBuffer::new();
query.write(&mut send_buffer)?; query.write(&mut send_buffer)?;
socket.send_to(send_buffer.filled(), upstream).await?; let data = forward_udp_raw(send_buffer.filled(), upstream, timeout_duration).await?;
if data.len() >= 4096 {
let mut recv_buffer = BytePacketBuffer::new(); log::debug!("upstream response may be truncated ({} bytes)", data.len());
let (size, _) = timeout(timeout_duration, socket.recv_from(&mut recv_buffer.buf)).await??;
if size == recv_buffer.buf.len() {
log::debug!(
"upstream response truncated ({} bytes, buffer {})",
size,
recv_buffer.buf.len()
);
} }
let mut recv_buffer = BytePacketBuffer::from_bytes(&data);
DnsPacket::from_buffer(&mut recv_buffer) DnsPacket::from_buffer(&mut recv_buffer)
} }
@@ -721,10 +659,19 @@ mod tests {
); );
let srtt = RwLock::new(SrttCache::new(true)); let srtt = RwLock::new(SrttCache::new(true));
let result = forward_with_failover(&query, &pool, &srtt, Duration::from_millis(500)) let wire = to_wire(&query);
let resp_wire = forward_with_failover_raw(
&wire,
&pool,
&srtt,
Duration::from_millis(500),
Duration::ZERO,
)
.await .await
.expect("should fail over to second upstream"); .expect("should fail over to second upstream");
let mut buf = BytePacketBuffer::from_bytes(&resp_wire);
let result = DnsPacket::from_buffer(&mut buf).unwrap();
assert_eq!(result.header.id, 0xABCD); assert_eq!(result.header.id, 0xABCD);
assert_eq!(result.answers.len(), 1); assert_eq!(result.answers.len(), 1);
} }

View File

@@ -607,11 +607,9 @@ async fn main() -> numa::Result<()> {
} }
Err(e) => return Err(e.into()), Err(e) => return Err(e.into()),
}; };
let raw_len = len;
let ctx = Arc::clone(&ctx); let ctx = Arc::clone(&ctx);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = handle_query(buffer, raw_len, src_addr, &ctx).await { if let Err(e) = handle_query(buffer, len, src_addr, &ctx).await {
error!("{} | HANDLER ERROR | {}", src_addr, e); error!("{} | HANDLER ERROR | {}", src_addr, e);
} }
}); });
@@ -762,28 +760,40 @@ async fn warm_domain(ctx: &ServerCtx, domain: &str) {
use numa::question::QueryType; use numa::question::QueryType;
for qtype in [QueryType::A, QueryType::AAAA] { for qtype in [QueryType::A, QueryType::AAAA] {
if ctx.upstream_mode == numa::config::UpstreamMode::Recursive {
let query = numa::packet::DnsPacket::query(0, domain, qtype); let query = numa::packet::DnsPacket::query(0, domain, qtype);
let result = if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { match numa::recursive::resolve_recursive(
numa::recursive::resolve_recursive( domain, qtype, &ctx.cache, &query, &ctx.root_hints, &ctx.srtt,
domain,
qtype,
&ctx.cache,
&query,
&ctx.root_hints,
&ctx.srtt,
) )
.await .await
} else { {
let pool = ctx.upstream_pool.lock().unwrap().clone();
numa::forward::forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await
};
match result {
Ok(resp) => { Ok(resp) => {
ctx.cache.write().unwrap().insert(domain, qtype, &resp); ctx.cache.write().unwrap().insert(domain, qtype, &resp);
log::debug!("cache warm: {} {:?}", domain, qtype); log::debug!("cache warm: {} {:?}", domain, qtype);
} }
Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e),
} }
} else {
let query = numa::packet::DnsPacket::query(0, domain, qtype);
let mut buf = numa::buffer::BytePacketBuffer::new();
if query.write(&mut buf).is_err() {
continue;
}
let pool = ctx.upstream_pool.lock().unwrap().clone();
match numa::forward::forward_with_failover_raw(
buf.filled(), &pool, &ctx.srtt, ctx.timeout, ctx.hedge_delay,
)
.await
{
Ok(wire) => {
ctx.cache.write().unwrap().insert_wire(
domain, qtype, &wire, numa::cache::DnssecStatus::Indeterminate,
);
log::debug!("cache warm: {} {:?}", domain, qtype);
}
Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e),
}
}
} }
} }

View File

@@ -690,9 +690,30 @@ async fn send_query_hedged(
let fut_b = send_query(qname, qtype, secondary, srtt); let fut_b = send_query(qname, qtype, secondary, srtt);
tokio::pin!(fut_b); tokio::pin!(fut_b);
// First Ok wins; if one errors, wait for the other.
let mut a_err: Option<crate::Error> = None;
let mut b_err: Option<crate::Error> = None;
loop {
tokio::select! { tokio::select! {
r = fut_a => r, r = &mut fut_a, if a_err.is_none() => {
r = fut_b => r, match r {
Ok(resp) => return Ok(resp),
Err(e) => {
if b_err.is_some() { return Err(e); }
a_err = Some(e);
}
}
}
r = &mut fut_b, if b_err.is_none() => {
match r {
Ok(resp) => return Ok(resp),
Err(e) => {
if let Some(ae) = a_err.take() { return Err(ae); }
b_err = Some(e);
}
}
}
}
} }
} }
} }