refactor: remove forward_with_failover duplication, fix warm-branch hedge bug
- Remove forward_with_failover (parsed): warm_domain now uses _raw + insert_wire - forward_udp delegates to forward_udp_raw (single UDP socket implementation) - forward_query uses unified _raw path for all protocols - Fix send_query_hedged warm branch: bare select! dropped secondary on primary error instead of waiting for it — now drains both futures like the cold branch - Remove pointless raw_len = len rename
This commit is contained in:
@@ -157,58 +157,6 @@ impl UpstreamPool {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn forward_with_failover(
|
||||
query: &DnsPacket,
|
||||
pool: &UpstreamPool,
|
||||
srtt: &RwLock<SrttCache>,
|
||||
timeout_duration: Duration,
|
||||
) -> Result<DnsPacket> {
|
||||
// Build candidate list: primary (sorted by SRTT for UDP) then fallback
|
||||
let mut candidates: Vec<(usize, u64)> = pool
|
||||
.primary
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, u)| {
|
||||
let rtt = match u {
|
||||
Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()),
|
||||
_ => 0, // DoH: keep config order (stable sort preserves it)
|
||||
};
|
||||
(i, rtt)
|
||||
})
|
||||
.collect();
|
||||
candidates.sort_by_key(|&(_, rtt)| rtt);
|
||||
|
||||
let all_upstreams: Vec<&Upstream> = candidates
|
||||
.iter()
|
||||
.map(|&(i, _)| &pool.primary[i])
|
||||
.chain(pool.fallback.iter())
|
||||
.collect();
|
||||
|
||||
let mut last_err: Option<Box<dyn std::error::Error + Send + Sync>> = None;
|
||||
|
||||
for upstream in &all_upstreams {
|
||||
let start = Instant::now();
|
||||
match forward_query(query, upstream, timeout_duration).await {
|
||||
Ok(resp) => {
|
||||
if let Upstream::Udp(addr) = upstream {
|
||||
let rtt_ms = start.elapsed().as_millis() as u64;
|
||||
srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false);
|
||||
}
|
||||
return Ok(resp);
|
||||
}
|
||||
Err(e) => {
|
||||
if let Upstream::Udp(addr) = upstream {
|
||||
srtt.write().unwrap().record_failure(addr.ip());
|
||||
}
|
||||
log::debug!("upstream {} failed: {}", upstream, e);
|
||||
last_err = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or_else(|| "no upstream configured".into()))
|
||||
}
|
||||
|
||||
pub async fn forward_query(
|
||||
query: &DnsPacket,
|
||||
upstream: &Upstream,
|
||||
@@ -226,24 +174,14 @@ pub(crate) async fn forward_udp(
|
||||
upstream: SocketAddr,
|
||||
timeout_duration: Duration,
|
||||
) -> Result<DnsPacket> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
|
||||
let mut send_buffer = BytePacketBuffer::new();
|
||||
query.write(&mut send_buffer)?;
|
||||
|
||||
socket.send_to(send_buffer.filled(), upstream).await?;
|
||||
|
||||
let mut recv_buffer = BytePacketBuffer::new();
|
||||
let (size, _) = timeout(timeout_duration, socket.recv_from(&mut recv_buffer.buf)).await??;
|
||||
|
||||
if size == recv_buffer.buf.len() {
|
||||
log::debug!(
|
||||
"upstream response truncated ({} bytes, buffer {})",
|
||||
size,
|
||||
recv_buffer.buf.len()
|
||||
);
|
||||
let data = forward_udp_raw(send_buffer.filled(), upstream, timeout_duration).await?;
|
||||
if data.len() >= 4096 {
|
||||
log::debug!("upstream response may be truncated ({} bytes)", data.len());
|
||||
}
|
||||
|
||||
let mut recv_buffer = BytePacketBuffer::from_bytes(&data);
|
||||
DnsPacket::from_buffer(&mut recv_buffer)
|
||||
}
|
||||
|
||||
@@ -721,10 +659,19 @@ mod tests {
|
||||
);
|
||||
|
||||
let srtt = RwLock::new(SrttCache::new(true));
|
||||
let result = forward_with_failover(&query, &pool, &srtt, Duration::from_millis(500))
|
||||
.await
|
||||
.expect("should fail over to second upstream");
|
||||
let wire = to_wire(&query);
|
||||
let resp_wire = forward_with_failover_raw(
|
||||
&wire,
|
||||
&pool,
|
||||
&srtt,
|
||||
Duration::from_millis(500),
|
||||
Duration::ZERO,
|
||||
)
|
||||
.await
|
||||
.expect("should fail over to second upstream");
|
||||
|
||||
let mut buf = BytePacketBuffer::from_bytes(&resp_wire);
|
||||
let result = DnsPacket::from_buffer(&mut buf).unwrap();
|
||||
assert_eq!(result.header.id, 0xABCD);
|
||||
assert_eq!(result.answers.len(), 1);
|
||||
}
|
||||
|
||||
52
src/main.rs
52
src/main.rs
@@ -607,11 +607,9 @@ async fn main() -> numa::Result<()> {
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
let raw_len = len;
|
||||
|
||||
let ctx = Arc::clone(&ctx);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_query(buffer, raw_len, src_addr, &ctx).await {
|
||||
if let Err(e) = handle_query(buffer, len, src_addr, &ctx).await {
|
||||
error!("{} | HANDLER ERROR | {}", src_addr, e);
|
||||
}
|
||||
});
|
||||
@@ -762,27 +760,39 @@ async fn warm_domain(ctx: &ServerCtx, domain: &str) {
|
||||
use numa::question::QueryType;
|
||||
|
||||
for qtype in [QueryType::A, QueryType::AAAA] {
|
||||
let query = numa::packet::DnsPacket::query(0, domain, qtype);
|
||||
let result = if ctx.upstream_mode == numa::config::UpstreamMode::Recursive {
|
||||
numa::recursive::resolve_recursive(
|
||||
domain,
|
||||
qtype,
|
||||
&ctx.cache,
|
||||
&query,
|
||||
&ctx.root_hints,
|
||||
&ctx.srtt,
|
||||
if ctx.upstream_mode == numa::config::UpstreamMode::Recursive {
|
||||
let query = numa::packet::DnsPacket::query(0, domain, qtype);
|
||||
match numa::recursive::resolve_recursive(
|
||||
domain, qtype, &ctx.cache, &query, &ctx.root_hints, &ctx.srtt,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
let pool = ctx.upstream_pool.lock().unwrap().clone();
|
||||
numa::forward::forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await
|
||||
};
|
||||
match result {
|
||||
Ok(resp) => {
|
||||
ctx.cache.write().unwrap().insert(domain, qtype, &resp);
|
||||
log::debug!("cache warm: {} {:?}", domain, qtype);
|
||||
{
|
||||
Ok(resp) => {
|
||||
ctx.cache.write().unwrap().insert(domain, qtype, &resp);
|
||||
log::debug!("cache warm: {} {:?}", domain, qtype);
|
||||
}
|
||||
Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e),
|
||||
}
|
||||
} else {
|
||||
let query = numa::packet::DnsPacket::query(0, domain, qtype);
|
||||
let mut buf = numa::buffer::BytePacketBuffer::new();
|
||||
if query.write(&mut buf).is_err() {
|
||||
continue;
|
||||
}
|
||||
let pool = ctx.upstream_pool.lock().unwrap().clone();
|
||||
match numa::forward::forward_with_failover_raw(
|
||||
buf.filled(), &pool, &ctx.srtt, ctx.timeout, ctx.hedge_delay,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(wire) => {
|
||||
ctx.cache.write().unwrap().insert_wire(
|
||||
domain, qtype, &wire, numa::cache::DnssecStatus::Indeterminate,
|
||||
);
|
||||
log::debug!("cache warm: {} {:?}", domain, qtype);
|
||||
}
|
||||
Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e),
|
||||
}
|
||||
Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -690,9 +690,30 @@ async fn send_query_hedged(
|
||||
let fut_b = send_query(qname, qtype, secondary, srtt);
|
||||
tokio::pin!(fut_b);
|
||||
|
||||
tokio::select! {
|
||||
r = fut_a => r,
|
||||
r = fut_b => r,
|
||||
// First Ok wins; if one errors, wait for the other.
|
||||
let mut a_err: Option<crate::Error> = None;
|
||||
let mut b_err: Option<crate::Error> = None;
|
||||
loop {
|
||||
tokio::select! {
|
||||
r = &mut fut_a, if a_err.is_none() => {
|
||||
match r {
|
||||
Ok(resp) => return Ok(resp),
|
||||
Err(e) => {
|
||||
if b_err.is_some() { return Err(e); }
|
||||
a_err = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
r = &mut fut_b, if b_err.is_none() => {
|
||||
match r {
|
||||
Ok(resp) => return Ok(resp),
|
||||
Err(e) => {
|
||||
if let Some(ae) = a_err.take() { return Err(ae); }
|
||||
b_err = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user