From 17a1a6ddba351d8b5ec529ef5ef242e57bcb56ec Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 12 Apr 2026 06:42:59 +0300 Subject: [PATCH] refactor: remove forward_with_failover duplication, fix warm-branch hedge bug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove forward_with_failover (parsed): warm_domain now uses _raw + insert_wire - forward_udp delegates to forward_udp_raw (single UDP socket implementation) - forward_query uses unified _raw path for all protocols - Fix send_query_hedged warm branch: bare select! dropped secondary on primary error instead of waiting for it — now drains both futures like the cold branch - Remove pointless raw_len = len rename --- src/forward.rs | 85 +++++++++--------------------------------------- src/main.rs | 52 +++++++++++++++++------------ src/recursive.rs | 27 +++++++++++++-- 3 files changed, 71 insertions(+), 93 deletions(-) diff --git a/src/forward.rs b/src/forward.rs index 6afb7e5..ebbe777 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -157,58 +157,6 @@ impl UpstreamPool { } } -pub async fn forward_with_failover( - query: &DnsPacket, - pool: &UpstreamPool, - srtt: &RwLock, - timeout_duration: Duration, -) -> Result { - // Build candidate list: primary (sorted by SRTT for UDP) then fallback - let mut candidates: Vec<(usize, u64)> = pool - .primary - .iter() - .enumerate() - .map(|(i, u)| { - let rtt = match u { - Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()), - _ => 0, // DoH: keep config order (stable sort preserves it) - }; - (i, rtt) - }) - .collect(); - candidates.sort_by_key(|&(_, rtt)| rtt); - - let all_upstreams: Vec<&Upstream> = candidates - .iter() - .map(|&(i, _)| &pool.primary[i]) - .chain(pool.fallback.iter()) - .collect(); - - let mut last_err: Option> = None; - - for upstream in &all_upstreams { - let start = Instant::now(); - match forward_query(query, upstream, timeout_duration).await { - Ok(resp) => { - if let Upstream::Udp(addr) = upstream { - let rtt_ms = start.elapsed().as_millis() as u64; - srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false); - } - return Ok(resp); - } - Err(e) => { - if let Upstream::Udp(addr) = upstream { - srtt.write().unwrap().record_failure(addr.ip()); - } - log::debug!("upstream {} failed: {}", upstream, e); - last_err = Some(e); - } - } - } - - Err(last_err.unwrap_or_else(|| "no upstream configured".into())) -} - pub async fn forward_query( query: &DnsPacket, upstream: &Upstream, @@ -226,24 +174,14 @@ pub(crate) async fn forward_udp( upstream: SocketAddr, timeout_duration: Duration, ) -> Result { - let socket = UdpSocket::bind("0.0.0.0:0").await?; - let mut send_buffer = BytePacketBuffer::new(); query.write(&mut send_buffer)?; - socket.send_to(send_buffer.filled(), upstream).await?; - - let mut recv_buffer = BytePacketBuffer::new(); - let (size, _) = timeout(timeout_duration, socket.recv_from(&mut recv_buffer.buf)).await??; - - if size == recv_buffer.buf.len() { - log::debug!( - "upstream response truncated ({} bytes, buffer {})", - size, - recv_buffer.buf.len() - ); + let data = forward_udp_raw(send_buffer.filled(), upstream, timeout_duration).await?; + if data.len() >= 4096 { + log::debug!("upstream response may be truncated ({} bytes)", data.len()); } - + let mut recv_buffer = BytePacketBuffer::from_bytes(&data); DnsPacket::from_buffer(&mut recv_buffer) } @@ -721,10 +659,19 @@ mod tests { ); let srtt = RwLock::new(SrttCache::new(true)); - let result = forward_with_failover(&query, &pool, &srtt, Duration::from_millis(500)) - .await - .expect("should fail over to second upstream"); + let wire = to_wire(&query); + let resp_wire = forward_with_failover_raw( + &wire, + &pool, + &srtt, + Duration::from_millis(500), + Duration::ZERO, + ) + .await + .expect("should fail over to second upstream"); + let mut buf = BytePacketBuffer::from_bytes(&resp_wire); + let result = DnsPacket::from_buffer(&mut buf).unwrap(); assert_eq!(result.header.id, 0xABCD); assert_eq!(result.answers.len(), 1); } diff --git a/src/main.rs b/src/main.rs index 0211a59..68e4794 100644 --- a/src/main.rs +++ b/src/main.rs @@ -607,11 +607,9 @@ async fn main() -> numa::Result<()> { } Err(e) => return Err(e.into()), }; - let raw_len = len; - let ctx = Arc::clone(&ctx); tokio::spawn(async move { - if let Err(e) = handle_query(buffer, raw_len, src_addr, &ctx).await { + if let Err(e) = handle_query(buffer, len, src_addr, &ctx).await { error!("{} | HANDLER ERROR | {}", src_addr, e); } }); @@ -762,27 +760,39 @@ async fn warm_domain(ctx: &ServerCtx, domain: &str) { use numa::question::QueryType; for qtype in [QueryType::A, QueryType::AAAA] { - let query = numa::packet::DnsPacket::query(0, domain, qtype); - let result = if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { - numa::recursive::resolve_recursive( - domain, - qtype, - &ctx.cache, - &query, - &ctx.root_hints, - &ctx.srtt, + if ctx.upstream_mode == numa::config::UpstreamMode::Recursive { + let query = numa::packet::DnsPacket::query(0, domain, qtype); + match numa::recursive::resolve_recursive( + domain, qtype, &ctx.cache, &query, &ctx.root_hints, &ctx.srtt, ) .await - } else { - let pool = ctx.upstream_pool.lock().unwrap().clone(); - numa::forward::forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await - }; - match result { - Ok(resp) => { - ctx.cache.write().unwrap().insert(domain, qtype, &resp); - log::debug!("cache warm: {} {:?}", domain, qtype); + { + Ok(resp) => { + ctx.cache.write().unwrap().insert(domain, qtype, &resp); + log::debug!("cache warm: {} {:?}", domain, qtype); + } + Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), + } + } else { + let query = numa::packet::DnsPacket::query(0, domain, qtype); + let mut buf = numa::buffer::BytePacketBuffer::new(); + if query.write(&mut buf).is_err() { + continue; + } + let pool = ctx.upstream_pool.lock().unwrap().clone(); + match numa::forward::forward_with_failover_raw( + buf.filled(), &pool, &ctx.srtt, ctx.timeout, ctx.hedge_delay, + ) + .await + { + Ok(wire) => { + ctx.cache.write().unwrap().insert_wire( + domain, qtype, &wire, numa::cache::DnssecStatus::Indeterminate, + ); + log::debug!("cache warm: {} {:?}", domain, qtype); + } + Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), } - Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e), } } } diff --git a/src/recursive.rs b/src/recursive.rs index 190a57a..70f35c0 100644 --- a/src/recursive.rs +++ b/src/recursive.rs @@ -690,9 +690,30 @@ async fn send_query_hedged( let fut_b = send_query(qname, qtype, secondary, srtt); tokio::pin!(fut_b); - tokio::select! { - r = fut_a => r, - r = fut_b => r, + // First Ok wins; if one errors, wait for the other. + let mut a_err: Option = None; + let mut b_err: Option = None; + loop { + tokio::select! { + r = &mut fut_a, if a_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if b_err.is_some() { return Err(e); } + a_err = Some(e); + } + } + } + r = &mut fut_b, if b_err.is_none() => { + match r { + Ok(resp) => return Ok(resp), + Err(e) => { + if let Some(ae) = a_err.take() { return Err(ae); } + b_err = Some(e); + } + } + } + } } } }