fix(forward): track SRTT for DoT upstreams, not just UDP
The SRTT ordering + failure penalty path was UDP-only, so a DoT primary in a forwarding-rule pool was never deprioritized on failure and all DoT entries tied at INITIAL_SRTT_MS in the sort key. With [[forwarding]] now accepting arrays of upstreams, DoT pools are a first-class case and need the same healthiest-first behavior the default pool gets for UDP. - Add Upstream::tracked_ip() → Some(ip) for Udp/Dot, None for Doh (DoH has no stable IP — reqwest pools connections by hostname). - Rewire the three SRTT call sites in forward_with_failover_raw. - Hoist srtt.read() out of the candidate-scoring loop — one lock per query instead of N (matters now that pools commonly have N>1). - Drop unused #[derive(Debug)] on UpstreamPool and ForwardingRule. - Regression tests: udp_failure_records_in_srtt + dot_failure_records_in_srtt.
This commit is contained in:
103
src/forward.rs
103
src/forward.rs
@@ -25,6 +25,18 @@ pub enum Upstream {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Upstream {
|
||||||
|
/// IP address to key SRTT tracking on, if the upstream has a stable one.
|
||||||
|
/// `Doh` routes through a URL + connection pool, so there's no single IP
|
||||||
|
/// to track; SRTT is skipped for it.
|
||||||
|
pub fn tracked_ip(&self) -> Option<IpAddr> {
|
||||||
|
match self {
|
||||||
|
Upstream::Udp(addr) | Upstream::Dot { addr, .. } => Some(addr.ip()),
|
||||||
|
Upstream::Doh { .. } => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl PartialEq for Upstream {
|
impl PartialEq for Upstream {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
match (self, other) {
|
match (self, other) {
|
||||||
@@ -118,7 +130,7 @@ fn build_dot_connector() -> Result<tokio_rustls::TlsConnector> {
|
|||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone)]
|
||||||
pub struct UpstreamPool {
|
pub struct UpstreamPool {
|
||||||
primary: Vec<Upstream>,
|
primary: Vec<Upstream>,
|
||||||
fallback: Vec<Upstream>,
|
fallback: Vec<Upstream>,
|
||||||
@@ -345,18 +357,17 @@ pub async fn forward_with_failover_raw(
|
|||||||
timeout_duration: Duration,
|
timeout_duration: Duration,
|
||||||
hedge_delay: Duration,
|
hedge_delay: Duration,
|
||||||
) -> Result<Vec<u8>> {
|
) -> Result<Vec<u8>> {
|
||||||
let mut candidates: Vec<(usize, u64)> = pool
|
let mut candidates: Vec<(usize, u64)> = {
|
||||||
.primary
|
let srtt_read = srtt.read().unwrap();
|
||||||
.iter()
|
pool.primary
|
||||||
.enumerate()
|
.iter()
|
||||||
.map(|(i, u)| {
|
.enumerate()
|
||||||
let rtt = match u {
|
.map(|(i, u)| {
|
||||||
Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()),
|
let rtt = u.tracked_ip().map(|ip| srtt_read.get(ip)).unwrap_or(0);
|
||||||
_ => 0,
|
(i, rtt)
|
||||||
};
|
})
|
||||||
(i, rtt)
|
.collect()
|
||||||
})
|
};
|
||||||
.collect();
|
|
||||||
candidates.sort_by_key(|&(_, rtt)| rtt);
|
candidates.sort_by_key(|&(_, rtt)| rtt);
|
||||||
|
|
||||||
let all_upstreams: Vec<&Upstream> = candidates
|
let all_upstreams: Vec<&Upstream> = candidates
|
||||||
@@ -380,15 +391,15 @@ pub async fn forward_with_failover_raw(
|
|||||||
};
|
};
|
||||||
match result {
|
match result {
|
||||||
Ok(resp) => {
|
Ok(resp) => {
|
||||||
if let Upstream::Udp(addr) = upstream {
|
if let Some(ip) = upstream.tracked_ip() {
|
||||||
let rtt_ms = start.elapsed().as_millis() as u64;
|
let rtt_ms = start.elapsed().as_millis() as u64;
|
||||||
srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false);
|
srtt.write().unwrap().record_rtt(ip, rtt_ms, false);
|
||||||
}
|
}
|
||||||
return Ok(resp);
|
return Ok(resp);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
if let Upstream::Udp(addr) = upstream {
|
if let Some(ip) = upstream.tracked_ip() {
|
||||||
srtt.write().unwrap().record_failure(addr.ip());
|
srtt.write().unwrap().record_failure(ip);
|
||||||
}
|
}
|
||||||
log::debug!("upstream {} failed: {}", upstream, e);
|
log::debug!("upstream {} failed: {}", upstream, e);
|
||||||
last_err = Some(e);
|
last_err = Some(e);
|
||||||
@@ -707,4 +718,62 @@ mod tests {
|
|||||||
assert!(!pool.maybe_update_primary("not-an-ip", 53));
|
assert!(!pool.maybe_update_primary("not-an-ip", 53));
|
||||||
assert_eq!(pool.preferred().unwrap().to_string(), "1.2.3.4:53");
|
assert_eq!(pool.preferred().unwrap().to_string(), "1.2.3.4:53");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn tcp_closed_port() -> SocketAddr {
|
||||||
|
// Bind a TCP listener, grab the port, drop → kernel returns RST on connect.
|
||||||
|
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
drop(listener);
|
||||||
|
addr
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn udp_failure_records_in_srtt() {
|
||||||
|
let blackhole = crate::testutil::blackhole_upstream();
|
||||||
|
let pool = UpstreamPool::new(vec![Upstream::Udp(blackhole)], vec![]);
|
||||||
|
let srtt = RwLock::new(SrttCache::new(true));
|
||||||
|
let _ = forward_with_failover_raw(
|
||||||
|
&[0u8; 12],
|
||||||
|
&pool,
|
||||||
|
&srtt,
|
||||||
|
Duration::from_millis(100),
|
||||||
|
Duration::ZERO,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
assert!(srtt.read().unwrap().is_known(blackhole.ip()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn dot_failure_records_in_srtt() {
|
||||||
|
let dead1 = tcp_closed_port();
|
||||||
|
let dead2 = tcp_closed_port();
|
||||||
|
let connector = build_dot_connector().unwrap();
|
||||||
|
let pool = UpstreamPool::new(
|
||||||
|
vec![
|
||||||
|
Upstream::Dot {
|
||||||
|
addr: dead1,
|
||||||
|
tls_name: Some("dns.quad9.net".to_string()),
|
||||||
|
connector: connector.clone(),
|
||||||
|
},
|
||||||
|
Upstream::Dot {
|
||||||
|
addr: dead2,
|
||||||
|
tls_name: Some("dns.quad9.net".to_string()),
|
||||||
|
connector,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
vec![],
|
||||||
|
);
|
||||||
|
let srtt = RwLock::new(SrttCache::new(true));
|
||||||
|
let _ = forward_with_failover_raw(
|
||||||
|
&[0u8; 12],
|
||||||
|
&pool,
|
||||||
|
&srtt,
|
||||||
|
Duration::from_millis(500),
|
||||||
|
Duration::ZERO,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let cache = srtt.read().unwrap();
|
||||||
|
assert!(cache.is_known(dead1.ip()));
|
||||||
|
assert!(cache.is_known(dead2.ip()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ fn is_loopback_or_stub(addr: &str) -> bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A conditional forwarding rule: domains matching `suffix` are forwarded to `upstream`.
|
/// A conditional forwarding rule: domains matching `suffix` are forwarded to `upstream`.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ForwardingRule {
|
pub struct ForwardingRule {
|
||||||
pub suffix: String,
|
pub suffix: String,
|
||||||
dot_suffix: String, // pre-computed ".suffix" for zero-alloc matching
|
dot_suffix: String, // pre-computed ".suffix" for zero-alloc matching
|
||||||
|
|||||||
Reference in New Issue
Block a user