From 5f77af55e9595110b4d3da39084b5d5578af77d8 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Fri, 17 Apr 2026 03:39:21 +0300 Subject: [PATCH] fix(forward): track SRTT for DoT upstreams, not just UDP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The SRTT ordering + failure penalty path was UDP-only, so a DoT primary in a forwarding-rule pool was never deprioritized on failure and all DoT entries tied at INITIAL_SRTT_MS in the sort key. With [[forwarding]] now accepting arrays of upstreams, DoT pools are a first-class case and need the same healthiest-first behavior the default pool gets for UDP. - Add Upstream::tracked_ip() → Some(ip) for Udp/Dot, None for Doh (DoH has no stable IP — reqwest pools connections by hostname). - Rewire the three SRTT call sites in forward_with_failover_raw. - Hoist srtt.read() out of the candidate-scoring loop — one lock per query instead of N (matters now that pools commonly have N>1). - Drop unused #[derive(Debug)] on UpstreamPool and ForwardingRule. - Regression tests: udp_failure_records_in_srtt + dot_failure_records_in_srtt. --- src/forward.rs | 103 ++++++++++++++++++++++++++++++++++++++-------- src/system_dns.rs | 2 +- 2 files changed, 87 insertions(+), 18 deletions(-) diff --git a/src/forward.rs b/src/forward.rs index 8bb548e..9bfa426 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -25,6 +25,18 @@ pub enum Upstream { }, } +impl Upstream { + /// IP address to key SRTT tracking on, if the upstream has a stable one. + /// `Doh` routes through a URL + connection pool, so there's no single IP + /// to track; SRTT is skipped for it. + pub fn tracked_ip(&self) -> Option { + match self { + Upstream::Udp(addr) | Upstream::Dot { addr, .. } => Some(addr.ip()), + Upstream::Doh { .. } => None, + } + } +} + impl PartialEq for Upstream { fn eq(&self, other: &Self) -> bool { match (self, other) { @@ -118,7 +130,7 @@ fn build_dot_connector() -> Result { ))) } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct UpstreamPool { primary: Vec, fallback: Vec, @@ -345,18 +357,17 @@ pub async fn forward_with_failover_raw( timeout_duration: Duration, hedge_delay: Duration, ) -> Result> { - let mut candidates: Vec<(usize, u64)> = pool - .primary - .iter() - .enumerate() - .map(|(i, u)| { - let rtt = match u { - Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()), - _ => 0, - }; - (i, rtt) - }) - .collect(); + let mut candidates: Vec<(usize, u64)> = { + let srtt_read = srtt.read().unwrap(); + pool.primary + .iter() + .enumerate() + .map(|(i, u)| { + let rtt = u.tracked_ip().map(|ip| srtt_read.get(ip)).unwrap_or(0); + (i, rtt) + }) + .collect() + }; candidates.sort_by_key(|&(_, rtt)| rtt); let all_upstreams: Vec<&Upstream> = candidates @@ -380,15 +391,15 @@ pub async fn forward_with_failover_raw( }; match result { Ok(resp) => { - if let Upstream::Udp(addr) = upstream { + if let Some(ip) = upstream.tracked_ip() { let rtt_ms = start.elapsed().as_millis() as u64; - srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false); + srtt.write().unwrap().record_rtt(ip, rtt_ms, false); } return Ok(resp); } Err(e) => { - if let Upstream::Udp(addr) = upstream { - srtt.write().unwrap().record_failure(addr.ip()); + if let Some(ip) = upstream.tracked_ip() { + srtt.write().unwrap().record_failure(ip); } log::debug!("upstream {} failed: {}", upstream, e); last_err = Some(e); @@ -707,4 +718,62 @@ mod tests { assert!(!pool.maybe_update_primary("not-an-ip", 53)); assert_eq!(pool.preferred().unwrap().to_string(), "1.2.3.4:53"); } + + fn tcp_closed_port() -> SocketAddr { + // Bind a TCP listener, grab the port, drop → kernel returns RST on connect. + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + drop(listener); + addr + } + + #[tokio::test] + async fn udp_failure_records_in_srtt() { + let blackhole = crate::testutil::blackhole_upstream(); + let pool = UpstreamPool::new(vec![Upstream::Udp(blackhole)], vec![]); + let srtt = RwLock::new(SrttCache::new(true)); + let _ = forward_with_failover_raw( + &[0u8; 12], + &pool, + &srtt, + Duration::from_millis(100), + Duration::ZERO, + ) + .await; + assert!(srtt.read().unwrap().is_known(blackhole.ip())); + } + + #[tokio::test] + async fn dot_failure_records_in_srtt() { + let dead1 = tcp_closed_port(); + let dead2 = tcp_closed_port(); + let connector = build_dot_connector().unwrap(); + let pool = UpstreamPool::new( + vec![ + Upstream::Dot { + addr: dead1, + tls_name: Some("dns.quad9.net".to_string()), + connector: connector.clone(), + }, + Upstream::Dot { + addr: dead2, + tls_name: Some("dns.quad9.net".to_string()), + connector, + }, + ], + vec![], + ); + let srtt = RwLock::new(SrttCache::new(true)); + let _ = forward_with_failover_raw( + &[0u8; 12], + &pool, + &srtt, + Duration::from_millis(500), + Duration::ZERO, + ) + .await; + let cache = srtt.read().unwrap(); + assert!(cache.is_known(dead1.ip())); + assert!(cache.is_known(dead2.ip())); + } } diff --git a/src/system_dns.rs b/src/system_dns.rs index 7f6304b..b70b9d9 100644 --- a/src/system_dns.rs +++ b/src/system_dns.rs @@ -22,7 +22,7 @@ fn is_loopback_or_stub(addr: &str) -> bool { } /// A conditional forwarding rule: domains matching `suffix` are forwarded to `upstream`. -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ForwardingRule { pub suffix: String, dot_suffix: String, // pre-computed ".suffix" for zero-alloc matching