fix(forward): track SRTT for DoT upstreams, not just UDP

The SRTT ordering + failure penalty path was UDP-only, so a DoT primary
in a forwarding-rule pool was never deprioritized on failure and all
DoT entries tied at INITIAL_SRTT_MS in the sort key. With [[forwarding]]
now accepting arrays of upstreams, DoT pools are a first-class case and
need the same healthiest-first behavior the default pool gets for UDP.

- Add Upstream::tracked_ip() → Some(ip) for Udp/Dot, None for Doh
  (DoH has no stable IP — reqwest pools connections by hostname).
- Rewire the three SRTT call sites in forward_with_failover_raw.
- Hoist srtt.read() out of the candidate-scoring loop — one lock per
  query instead of N (matters now that pools commonly have N>1).
- Drop unused #[derive(Debug)] on UpstreamPool and ForwardingRule.
- Regression tests: udp_failure_records_in_srtt + dot_failure_records_in_srtt.
This commit is contained in:
Razvan Dimescu
2026-04-17 03:39:21 +03:00
parent ab6cda0c91
commit 5f77af55e9
2 changed files with 87 additions and 18 deletions

View File

@@ -25,6 +25,18 @@ pub enum Upstream {
}, },
} }
impl Upstream {
/// IP address to key SRTT tracking on, if the upstream has a stable one.
/// `Doh` routes through a URL + connection pool, so there's no single IP
/// to track; SRTT is skipped for it.
pub fn tracked_ip(&self) -> Option<IpAddr> {
match self {
Upstream::Udp(addr) | Upstream::Dot { addr, .. } => Some(addr.ip()),
Upstream::Doh { .. } => None,
}
}
}
impl PartialEq for Upstream { impl PartialEq for Upstream {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
match (self, other) { match (self, other) {
@@ -118,7 +130,7 @@ fn build_dot_connector() -> Result<tokio_rustls::TlsConnector> {
))) )))
} }
#[derive(Clone, Debug)] #[derive(Clone)]
pub struct UpstreamPool { pub struct UpstreamPool {
primary: Vec<Upstream>, primary: Vec<Upstream>,
fallback: Vec<Upstream>, fallback: Vec<Upstream>,
@@ -345,18 +357,17 @@ pub async fn forward_with_failover_raw(
timeout_duration: Duration, timeout_duration: Duration,
hedge_delay: Duration, hedge_delay: Duration,
) -> Result<Vec<u8>> { ) -> Result<Vec<u8>> {
let mut candidates: Vec<(usize, u64)> = pool let mut candidates: Vec<(usize, u64)> = {
.primary let srtt_read = srtt.read().unwrap();
.iter() pool.primary
.enumerate() .iter()
.map(|(i, u)| { .enumerate()
let rtt = match u { .map(|(i, u)| {
Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()), let rtt = u.tracked_ip().map(|ip| srtt_read.get(ip)).unwrap_or(0);
_ => 0, (i, rtt)
}; })
(i, rtt) .collect()
}) };
.collect();
candidates.sort_by_key(|&(_, rtt)| rtt); candidates.sort_by_key(|&(_, rtt)| rtt);
let all_upstreams: Vec<&Upstream> = candidates let all_upstreams: Vec<&Upstream> = candidates
@@ -380,15 +391,15 @@ pub async fn forward_with_failover_raw(
}; };
match result { match result {
Ok(resp) => { Ok(resp) => {
if let Upstream::Udp(addr) = upstream { if let Some(ip) = upstream.tracked_ip() {
let rtt_ms = start.elapsed().as_millis() as u64; let rtt_ms = start.elapsed().as_millis() as u64;
srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false); srtt.write().unwrap().record_rtt(ip, rtt_ms, false);
} }
return Ok(resp); return Ok(resp);
} }
Err(e) => { Err(e) => {
if let Upstream::Udp(addr) = upstream { if let Some(ip) = upstream.tracked_ip() {
srtt.write().unwrap().record_failure(addr.ip()); srtt.write().unwrap().record_failure(ip);
} }
log::debug!("upstream {} failed: {}", upstream, e); log::debug!("upstream {} failed: {}", upstream, e);
last_err = Some(e); last_err = Some(e);
@@ -707,4 +718,62 @@ mod tests {
assert!(!pool.maybe_update_primary("not-an-ip", 53)); assert!(!pool.maybe_update_primary("not-an-ip", 53));
assert_eq!(pool.preferred().unwrap().to_string(), "1.2.3.4:53"); assert_eq!(pool.preferred().unwrap().to_string(), "1.2.3.4:53");
} }
fn tcp_closed_port() -> SocketAddr {
// Bind a TCP listener, grab the port, drop → kernel returns RST on connect.
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
addr
}
#[tokio::test]
async fn udp_failure_records_in_srtt() {
let blackhole = crate::testutil::blackhole_upstream();
let pool = UpstreamPool::new(vec![Upstream::Udp(blackhole)], vec![]);
let srtt = RwLock::new(SrttCache::new(true));
let _ = forward_with_failover_raw(
&[0u8; 12],
&pool,
&srtt,
Duration::from_millis(100),
Duration::ZERO,
)
.await;
assert!(srtt.read().unwrap().is_known(blackhole.ip()));
}
#[tokio::test]
async fn dot_failure_records_in_srtt() {
let dead1 = tcp_closed_port();
let dead2 = tcp_closed_port();
let connector = build_dot_connector().unwrap();
let pool = UpstreamPool::new(
vec![
Upstream::Dot {
addr: dead1,
tls_name: Some("dns.quad9.net".to_string()),
connector: connector.clone(),
},
Upstream::Dot {
addr: dead2,
tls_name: Some("dns.quad9.net".to_string()),
connector,
},
],
vec![],
);
let srtt = RwLock::new(SrttCache::new(true));
let _ = forward_with_failover_raw(
&[0u8; 12],
&pool,
&srtt,
Duration::from_millis(500),
Duration::ZERO,
)
.await;
let cache = srtt.read().unwrap();
assert!(cache.is_known(dead1.ip()));
assert!(cache.is_known(dead2.ip()));
}
} }

View File

@@ -22,7 +22,7 @@ fn is_loopback_or_stub(addr: &str) -> bool {
} }
/// A conditional forwarding rule: domains matching `suffix` are forwarded to `upstream`. /// A conditional forwarding rule: domains matching `suffix` are forwarded to `upstream`.
#[derive(Debug, Clone)] #[derive(Clone)]
pub struct ForwardingRule { pub struct ForwardingRule {
pub suffix: String, pub suffix: String,
dot_suffix: String, // pre-computed ".suffix" for zero-alloc matching dot_suffix: String, // pre-computed ".suffix" for zero-alloc matching