refactor: remove forward_with_failover duplication, fix warm-branch hedge bug
- Remove forward_with_failover (parsed): warm_domain now uses _raw + insert_wire - forward_udp delegates to forward_udp_raw (single UDP socket implementation) - forward_query uses unified _raw path for all protocols - Fix send_query_hedged warm branch: bare select! dropped secondary on primary error instead of waiting for it — now drains both futures like the cold branch - Remove pointless raw_len = len rename
This commit is contained in:
@@ -157,58 +157,6 @@ impl UpstreamPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn forward_with_failover(
|
|
||||||
query: &DnsPacket,
|
|
||||||
pool: &UpstreamPool,
|
|
||||||
srtt: &RwLock<SrttCache>,
|
|
||||||
timeout_duration: Duration,
|
|
||||||
) -> Result<DnsPacket> {
|
|
||||||
// Build candidate list: primary (sorted by SRTT for UDP) then fallback
|
|
||||||
let mut candidates: Vec<(usize, u64)> = pool
|
|
||||||
.primary
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(i, u)| {
|
|
||||||
let rtt = match u {
|
|
||||||
Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()),
|
|
||||||
_ => 0, // DoH: keep config order (stable sort preserves it)
|
|
||||||
};
|
|
||||||
(i, rtt)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
candidates.sort_by_key(|&(_, rtt)| rtt);
|
|
||||||
|
|
||||||
let all_upstreams: Vec<&Upstream> = candidates
|
|
||||||
.iter()
|
|
||||||
.map(|&(i, _)| &pool.primary[i])
|
|
||||||
.chain(pool.fallback.iter())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let mut last_err: Option<Box<dyn std::error::Error + Send + Sync>> = None;
|
|
||||||
|
|
||||||
for upstream in &all_upstreams {
|
|
||||||
let start = Instant::now();
|
|
||||||
match forward_query(query, upstream, timeout_duration).await {
|
|
||||||
Ok(resp) => {
|
|
||||||
if let Upstream::Udp(addr) = upstream {
|
|
||||||
let rtt_ms = start.elapsed().as_millis() as u64;
|
|
||||||
srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false);
|
|
||||||
}
|
|
||||||
return Ok(resp);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
if let Upstream::Udp(addr) = upstream {
|
|
||||||
srtt.write().unwrap().record_failure(addr.ip());
|
|
||||||
}
|
|
||||||
log::debug!("upstream {} failed: {}", upstream, e);
|
|
||||||
last_err = Some(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Err(last_err.unwrap_or_else(|| "no upstream configured".into()))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn forward_query(
|
pub async fn forward_query(
|
||||||
query: &DnsPacket,
|
query: &DnsPacket,
|
||||||
upstream: &Upstream,
|
upstream: &Upstream,
|
||||||
@@ -226,24 +174,14 @@ pub(crate) async fn forward_udp(
|
|||||||
upstream: SocketAddr,
|
upstream: SocketAddr,
|
||||||
timeout_duration: Duration,
|
timeout_duration: Duration,
|
||||||
) -> Result<DnsPacket> {
|
) -> Result<DnsPacket> {
|
||||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
|
||||||
|
|
||||||
let mut send_buffer = BytePacketBuffer::new();
|
let mut send_buffer = BytePacketBuffer::new();
|
||||||
query.write(&mut send_buffer)?;
|
query.write(&mut send_buffer)?;
|
||||||
|
|
||||||
socket.send_to(send_buffer.filled(), upstream).await?;
|
let data = forward_udp_raw(send_buffer.filled(), upstream, timeout_duration).await?;
|
||||||
|
if data.len() >= 4096 {
|
||||||
let mut recv_buffer = BytePacketBuffer::new();
|
log::debug!("upstream response may be truncated ({} bytes)", data.len());
|
||||||
let (size, _) = timeout(timeout_duration, socket.recv_from(&mut recv_buffer.buf)).await??;
|
|
||||||
|
|
||||||
if size == recv_buffer.buf.len() {
|
|
||||||
log::debug!(
|
|
||||||
"upstream response truncated ({} bytes, buffer {})",
|
|
||||||
size,
|
|
||||||
recv_buffer.buf.len()
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
let mut recv_buffer = BytePacketBuffer::from_bytes(&data);
|
||||||
DnsPacket::from_buffer(&mut recv_buffer)
|
DnsPacket::from_buffer(&mut recv_buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -721,10 +659,19 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let srtt = RwLock::new(SrttCache::new(true));
|
let srtt = RwLock::new(SrttCache::new(true));
|
||||||
let result = forward_with_failover(&query, &pool, &srtt, Duration::from_millis(500))
|
let wire = to_wire(&query);
|
||||||
.await
|
let resp_wire = forward_with_failover_raw(
|
||||||
.expect("should fail over to second upstream");
|
&wire,
|
||||||
|
&pool,
|
||||||
|
&srtt,
|
||||||
|
Duration::from_millis(500),
|
||||||
|
Duration::ZERO,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("should fail over to second upstream");
|
||||||
|
|
||||||
|
let mut buf = BytePacketBuffer::from_bytes(&resp_wire);
|
||||||
|
let result = DnsPacket::from_buffer(&mut buf).unwrap();
|
||||||
assert_eq!(result.header.id, 0xABCD);
|
assert_eq!(result.header.id, 0xABCD);
|
||||||
assert_eq!(result.answers.len(), 1);
|
assert_eq!(result.answers.len(), 1);
|
||||||
}
|
}
|
||||||
|
|||||||
52
src/main.rs
52
src/main.rs
@@ -607,11 +607,9 @@ async fn main() -> numa::Result<()> {
|
|||||||
}
|
}
|
||||||
Err(e) => return Err(e.into()),
|
Err(e) => return Err(e.into()),
|
||||||
};
|
};
|
||||||
let raw_len = len;
|
|
||||||
|
|
||||||
let ctx = Arc::clone(&ctx);
|
let ctx = Arc::clone(&ctx);
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = handle_query(buffer, raw_len, src_addr, &ctx).await {
|
if let Err(e) = handle_query(buffer, len, src_addr, &ctx).await {
|
||||||
error!("{} | HANDLER ERROR | {}", src_addr, e);
|
error!("{} | HANDLER ERROR | {}", src_addr, e);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -762,27 +760,39 @@ async fn warm_domain(ctx: &ServerCtx, domain: &str) {
|
|||||||
use numa::question::QueryType;
|
use numa::question::QueryType;
|
||||||
|
|
||||||
for qtype in [QueryType::A, QueryType::AAAA] {
|
for qtype in [QueryType::A, QueryType::AAAA] {
|
||||||
let query = numa::packet::DnsPacket::query(0, domain, qtype);
|
if ctx.upstream_mode == numa::config::UpstreamMode::Recursive {
|
||||||
let result = if ctx.upstream_mode == numa::config::UpstreamMode::Recursive {
|
let query = numa::packet::DnsPacket::query(0, domain, qtype);
|
||||||
numa::recursive::resolve_recursive(
|
match numa::recursive::resolve_recursive(
|
||||||
domain,
|
domain, qtype, &ctx.cache, &query, &ctx.root_hints, &ctx.srtt,
|
||||||
qtype,
|
|
||||||
&ctx.cache,
|
|
||||||
&query,
|
|
||||||
&ctx.root_hints,
|
|
||||||
&ctx.srtt,
|
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
} else {
|
{
|
||||||
let pool = ctx.upstream_pool.lock().unwrap().clone();
|
Ok(resp) => {
|
||||||
numa::forward::forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await
|
ctx.cache.write().unwrap().insert(domain, qtype, &resp);
|
||||||
};
|
log::debug!("cache warm: {} {:?}", domain, qtype);
|
||||||
match result {
|
}
|
||||||
Ok(resp) => {
|
Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e),
|
||||||
ctx.cache.write().unwrap().insert(domain, qtype, &resp);
|
}
|
||||||
log::debug!("cache warm: {} {:?}", domain, qtype);
|
} else {
|
||||||
|
let query = numa::packet::DnsPacket::query(0, domain, qtype);
|
||||||
|
let mut buf = numa::buffer::BytePacketBuffer::new();
|
||||||
|
if query.write(&mut buf).is_err() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let pool = ctx.upstream_pool.lock().unwrap().clone();
|
||||||
|
match numa::forward::forward_with_failover_raw(
|
||||||
|
buf.filled(), &pool, &ctx.srtt, ctx.timeout, ctx.hedge_delay,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(wire) => {
|
||||||
|
ctx.cache.write().unwrap().insert_wire(
|
||||||
|
domain, qtype, &wire, numa::cache::DnssecStatus::Indeterminate,
|
||||||
|
);
|
||||||
|
log::debug!("cache warm: {} {:?}", domain, qtype);
|
||||||
|
}
|
||||||
|
Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e),
|
||||||
}
|
}
|
||||||
Err(e) => log::warn!("cache warm: {} {:?} failed: {}", domain, qtype, e),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -690,9 +690,30 @@ async fn send_query_hedged(
|
|||||||
let fut_b = send_query(qname, qtype, secondary, srtt);
|
let fut_b = send_query(qname, qtype, secondary, srtt);
|
||||||
tokio::pin!(fut_b);
|
tokio::pin!(fut_b);
|
||||||
|
|
||||||
tokio::select! {
|
// First Ok wins; if one errors, wait for the other.
|
||||||
r = fut_a => r,
|
let mut a_err: Option<crate::Error> = None;
|
||||||
r = fut_b => r,
|
let mut b_err: Option<crate::Error> = None;
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
r = &mut fut_a, if a_err.is_none() => {
|
||||||
|
match r {
|
||||||
|
Ok(resp) => return Ok(resp),
|
||||||
|
Err(e) => {
|
||||||
|
if b_err.is_some() { return Err(e); }
|
||||||
|
a_err = Some(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r = &mut fut_b, if b_err.is_none() => {
|
||||||
|
match r {
|
||||||
|
Ok(resp) => return Ok(resp),
|
||||||
|
Err(e) => {
|
||||||
|
if let Some(ae) = a_err.take() { return Err(ae); }
|
||||||
|
b_err = Some(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user