diff --git a/src/forward.rs b/src/forward.rs index a55723c..09157cb 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -106,36 +106,30 @@ pub async fn forward_with_failover( srtt: &RwLock, timeout_duration: Duration, ) -> Result { - // Build candidate list: primary (sorted by SRTT) then fallback (config order) - let mut candidates: Vec<&Upstream> = - Vec::with_capacity(pool.primary.len() + pool.fallback.len()); + // Build candidate list: primary (sorted by SRTT for UDP) then fallback + let mut candidates: Vec<(usize, u64)> = pool + .primary + .iter() + .enumerate() + .map(|(i, u)| { + let rtt = match u { + Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()), + _ => 0, // DoH: keep config order (stable sort preserves it) + }; + (i, rtt) + }) + .collect(); + candidates.sort_by_key(|&(_, rtt)| rtt); - // Sort primary UDP upstreams by SRTT; DoH keeps config order - let mut primary_udp_indices: Vec<(usize, u64)> = Vec::new(); - for (i, u) in pool.primary.iter().enumerate() { - if let Upstream::Udp(addr) = u { - let rtt = srtt.read().unwrap().get(addr.ip()); - primary_udp_indices.push((i, rtt)); - } - } - primary_udp_indices.sort_by_key(|&(_, rtt)| rtt); - - // Interleave: sorted UDP first, then DoH in config order - for &(i, _) in &primary_udp_indices { - candidates.push(&pool.primary[i]); - } - for u in &pool.primary { - if matches!(u, Upstream::Doh { .. }) { - candidates.push(u); - } - } - for u in &pool.fallback { - candidates.push(u); - } + let all_upstreams: Vec<&Upstream> = candidates + .iter() + .map(|&(i, _)| &pool.primary[i]) + .chain(pool.fallback.iter()) + .collect(); let mut last_err: Option> = None; - for upstream in &candidates { + for upstream in &all_upstreams { let start = Instant::now(); match forward_query(query, upstream, timeout_duration).await { Ok(resp) => { diff --git a/src/main.rs b/src/main.rs index 807ec61..e92e63d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -129,19 +129,18 @@ async fn main() -> numa::Result<()> { let root_hints = numa::recursive::parse_root_hints(&config.upstream.root_hints); + let recursive_pool = || { + let dummy = UpstreamPool::new(vec![Upstream::Udp("0.0.0.0:0".parse().unwrap())], vec![]); + (dummy, "recursive (root hints)".to_string()) + }; + let (resolved_mode, upstream_auto, pool, upstream_label) = match config.upstream.mode { numa::config::UpstreamMode::Auto => { info!("auto mode: probing recursive resolution..."); if numa::recursive::probe_recursive(&root_hints).await { info!("recursive probe succeeded — self-sovereign mode"); - let dummy = - UpstreamPool::new(vec![Upstream::Udp("0.0.0.0:0".parse().unwrap())], vec![]); - ( - numa::config::UpstreamMode::Recursive, - false, - dummy, - "recursive (root hints)".to_string(), - ) + let (pool, label) = recursive_pool(); + (numa::config::UpstreamMode::Recursive, false, pool, label) } else { log::warn!("recursive probe failed — falling back to Quad9 DoH"); let client = reqwest::Client::builder() @@ -155,14 +154,8 @@ async fn main() -> numa::Result<()> { } } numa::config::UpstreamMode::Recursive => { - let dummy = - UpstreamPool::new(vec![Upstream::Udp("0.0.0.0:0".parse().unwrap())], vec![]); - ( - numa::config::UpstreamMode::Recursive, - false, - dummy, - "recursive (root hints)".to_string(), - ) + let (pool, label) = recursive_pool(); + (numa::config::UpstreamMode::Recursive, false, pool, label) } numa::config::UpstreamMode::Forward => { let addrs = if config.upstream.address.is_empty() {