From 8abcd91f95d805eb36782b2837851579d8ec2c4f Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sat, 11 Apr 2026 00:26:58 +0300 Subject: [PATCH] feat: multi-forwarder with SRTT-based failover (#77) * feat: multi-forwarder with SRTT-based failover address accepts string or array, with optional per-server port override. New fallback pool tried only when all primaries fail. Sequential failover with SRTT ranking ensures fastest upstream is tried first. Closes #34 (items 1, 2, 3) Co-Authored-By: Claude Opus 4.6 (1M context) * refactor: simplify failover candidate list and deduplicate recursive pool Co-Authored-By: Claude Opus 4.6 * refactor: extract maybe_update_primary for testable upstream re-detection Co-Authored-By: Claude Opus 4.6 * style: rustfmt Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 (1M context) --- src/api.rs | 18 ++-- src/config.rs | 66 ++++++++++++-- src/ctx.rs | 12 +-- src/dot.rs | 5 +- src/forward.rs | 241 ++++++++++++++++++++++++++++++++++++++++++++++++- src/main.rs | 96 ++++++++------------ 6 files changed, 357 insertions(+), 81 deletions(-) diff --git a/src/api.rs b/src/api.rs index 2e66931..a0bae58 100644 --- a/src/api.rs +++ b/src/api.rs @@ -411,9 +411,12 @@ async fn diagnose( } // Check upstream (async, no locks held) - let upstream = ctx.upstream.lock().unwrap().clone(); - let (upstream_matched, upstream_detail) = - forward_query_for_diagnose(&domain_lower, &upstream, ctx.timeout).await; + let upstream = ctx.upstream_pool.lock().unwrap().preferred().cloned(); + let (upstream_matched, upstream_detail) = if let Some(ref u) = upstream { + forward_query_for_diagnose(&domain_lower, u, ctx.timeout).await + } else { + (false, "no upstream configured".to_string()) + }; steps.push(DiagnoseStep { source: "upstream".to_string(), matched: upstream_matched, @@ -520,7 +523,7 @@ async fn stats(State(ctx): State>) -> Json { let upstream = if ctx.upstream_mode == crate::config::UpstreamMode::Recursive { "recursive (root hints)".to_string() } else { - ctx.upstream.lock().unwrap().to_string() + ctx.upstream_pool.lock().unwrap().label() }; Json(StatsResponse { @@ -1016,8 +1019,11 @@ mod tests { services: Mutex::new(crate::service_store::ServiceStore::new()), lan_peers: Mutex::new(crate::lan::PeerStore::new(90)), forwarding_rules: Vec::new(), - upstream: Mutex::new(crate::forward::Upstream::Udp( - "127.0.0.1:53".parse().unwrap(), + upstream_pool: Mutex::new(crate::forward::UpstreamPool::new( + vec![crate::forward::Upstream::Udp( + "127.0.0.1:53".parse().unwrap(), + )], + vec![], )), upstream_auto: false, upstream_port: 53, diff --git a/src/config.rs b/src/config.rs index 9373d33..fa794d7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -97,10 +97,12 @@ impl UpstreamMode { pub struct UpstreamConfig { #[serde(default)] pub mode: UpstreamMode, - #[serde(default = "default_upstream_addr")] - pub address: String, + #[serde(default, deserialize_with = "string_or_vec")] + pub address: Vec, #[serde(default = "default_upstream_port")] pub port: u16, + #[serde(default)] + pub fallback: Vec, #[serde(default = "default_timeout_ms")] pub timeout_ms: u64, #[serde(default = "default_root_hints")] @@ -115,8 +117,9 @@ impl Default for UpstreamConfig { fn default() -> Self { UpstreamConfig { mode: UpstreamMode::default(), - address: default_upstream_addr(), + address: Vec::new(), port: default_upstream_port(), + fallback: Vec::new(), timeout_ms: default_timeout_ms(), root_hints: default_root_hints(), prime_tlds: default_prime_tlds(), @@ -125,6 +128,33 @@ impl Default for UpstreamConfig { } } +fn string_or_vec<'de, D>(deserializer: D) -> std::result::Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + struct Visitor; + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = Vec; + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("string or array of strings") + } + fn visit_str(self, v: &str) -> std::result::Result { + Ok(vec![v.to_string()]) + } + fn visit_seq>( + self, + mut seq: A, + ) -> std::result::Result { + let mut v = Vec::new(); + while let Some(s) = seq.next_element::()? { + v.push(s); + } + Ok(v) + } + } + deserializer.deserialize_any(Visitor) +} + fn default_true() -> bool { true } @@ -202,9 +232,6 @@ fn default_root_hints() -> Vec { ] } -fn default_upstream_addr() -> String { - String::new() // empty = auto-detect from system resolver -} fn default_upstream_port() -> u16 { 53 } @@ -525,6 +552,33 @@ mod tests { assert!(config.services[0].routes[0].strip); assert!(!config.services[0].routes[1].strip); // default false } + + #[test] + fn address_string_parses_to_vec() { + let config: Config = toml::from_str("[upstream]\naddress = \"1.2.3.4\"").unwrap(); + assert_eq!(config.upstream.address, vec!["1.2.3.4"]); + } + + #[test] + fn address_array_parses() { + let config: Config = + toml::from_str("[upstream]\naddress = [\"1.2.3.4\", \"5.6.7.8:5353\"]").unwrap(); + assert_eq!(config.upstream.address, vec!["1.2.3.4", "5.6.7.8:5353"]); + } + + #[test] + fn fallback_parses() { + let config: Config = + toml::from_str("[upstream]\nfallback = [\"8.8.8.8\", \"1.1.1.1\"]").unwrap(); + assert_eq!(config.upstream.fallback, vec!["8.8.8.8", "1.1.1.1"]); + } + + #[test] + fn empty_address_gives_empty_vec() { + let config: Config = toml::from_str("").unwrap(); + assert!(config.upstream.address.is_empty()); + assert!(config.upstream.fallback.is_empty()); + } } pub struct ConfigLoad { diff --git a/src/ctx.rs b/src/ctx.rs index 6b774eb..b4e0777 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -16,7 +16,7 @@ use crate::blocklist::BlocklistStore; use crate::buffer::BytePacketBuffer; use crate::cache::{DnsCache, DnssecStatus}; use crate::config::{UpstreamMode, ZoneMap}; -use crate::forward::{forward_query, Upstream}; +use crate::forward::{forward_query, forward_with_failover, Upstream, UpstreamPool}; use crate::header::ResultCode; use crate::health::HealthMeta; use crate::lan::PeerStore; @@ -42,7 +42,7 @@ pub struct ServerCtx { pub services: Mutex, pub lan_peers: Mutex, pub forwarding_rules: Vec, - pub upstream: Mutex, + pub upstream_pool: Mutex, pub upstream_auto: bool, pub upstream_port: u16, pub lan_ip: Mutex, @@ -220,12 +220,8 @@ pub async fn resolve_query( } (resp, path, DnssecStatus::Indeterminate) } else { - let upstream = - match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) { - Some(addr) => Upstream::Udp(addr), - None => ctx.upstream.lock().unwrap().clone(), - }; - match forward_query(&query, &upstream, ctx.timeout).await { + let pool = ctx.upstream_pool.lock().unwrap().clone(); + match forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await { Ok(resp) => { ctx.cache.write().unwrap().insert(&qname, qtype, &resp); (resp, QueryPath::Forwarded, DnssecStatus::Indeterminate) diff --git a/src/dot.rs b/src/dot.rs index 3ed47ba..0d48fa2 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -362,7 +362,10 @@ mod tests { services: Mutex::new(crate::service_store::ServiceStore::new()), lan_peers: Mutex::new(crate::lan::PeerStore::new(90)), forwarding_rules: Vec::new(), - upstream: Mutex::new(crate::forward::Upstream::Udp(upstream_addr)), + upstream_pool: Mutex::new(crate::forward::UpstreamPool::new( + vec![crate::forward::Upstream::Udp(upstream_addr)], + vec![], + )), upstream_auto: false, upstream_port: 53, lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), diff --git a/src/forward.rs b/src/forward.rs index ea2b03e..78efcb9 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -1,12 +1,14 @@ use std::fmt; -use std::net::SocketAddr; -use std::time::Duration; +use std::net::{IpAddr, SocketAddr}; +use std::sync::RwLock; +use std::time::{Duration, Instant}; use tokio::net::UdpSocket; use tokio::time::timeout; use crate::buffer::BytePacketBuffer; use crate::packet::DnsPacket; +use crate::srtt::SrttCache; use crate::Result; #[derive(Clone)] @@ -37,6 +39,133 @@ impl fmt::Display for Upstream { } } +pub fn parse_upstream_addr(s: &str, default_port: u16) -> std::result::Result { + // Try full socket addr first: "1.2.3.4:5353" or "[::1]:5353" + if let Ok(addr) = s.parse::() { + return Ok(addr); + } + // Bare IP: "1.2.3.4" or "::1" + if let Ok(ip) = s.parse::() { + return Ok(SocketAddr::new(ip, default_port)); + } + Err(format!("invalid upstream address: {}", s)) +} + +pub fn parse_upstream(s: &str, default_port: u16) -> Result { + if s.starts_with("https://") { + let client = reqwest::Client::builder() + .use_rustls_tls() + .build() + .unwrap_or_default(); + return Ok(Upstream::Doh { + url: s.to_string(), + client, + }); + } + let addr = parse_upstream_addr(s, default_port)?; + Ok(Upstream::Udp(addr)) +} + +#[derive(Clone)] +pub struct UpstreamPool { + primary: Vec, + fallback: Vec, +} + +impl UpstreamPool { + pub fn new(primary: Vec, fallback: Vec) -> Self { + Self { primary, fallback } + } + + pub fn preferred(&self) -> Option<&Upstream> { + self.primary.first().or(self.fallback.first()) + } + + pub fn set_primary(&mut self, primary: Vec) { + self.primary = primary; + } + + /// Update the primary upstream if `new_addr` (parsed with `port`) differs + /// from the current preferred upstream. Returns `true` if the pool changed. + pub fn maybe_update_primary(&mut self, new_addr: &str, port: u16) -> bool { + let Ok(new_sock) = format!("{}:{}", new_addr, port).parse::() else { + return false; + }; + let new_upstream = Upstream::Udp(new_sock); + if self.preferred() == Some(&new_upstream) { + return false; + } + self.primary = vec![new_upstream]; + true + } + + pub fn label(&self) -> String { + match self.preferred() { + Some(u) => { + let total = self.primary.len() + self.fallback.len(); + if total > 1 { + format!("{} (+{} more)", u, total - 1) + } else { + u.to_string() + } + } + None => "none".to_string(), + } + } +} + +pub async fn forward_with_failover( + query: &DnsPacket, + pool: &UpstreamPool, + srtt: &RwLock, + timeout_duration: Duration, +) -> Result { + // Build candidate list: primary (sorted by SRTT for UDP) then fallback + let mut candidates: Vec<(usize, u64)> = pool + .primary + .iter() + .enumerate() + .map(|(i, u)| { + let rtt = match u { + Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()), + _ => 0, // DoH: keep config order (stable sort preserves it) + }; + (i, rtt) + }) + .collect(); + candidates.sort_by_key(|&(_, rtt)| rtt); + + let all_upstreams: Vec<&Upstream> = candidates + .iter() + .map(|&(i, _)| &pool.primary[i]) + .chain(pool.fallback.iter()) + .collect(); + + let mut last_err: Option> = None; + + for upstream in &all_upstreams { + let start = Instant::now(); + match forward_query(query, upstream, timeout_duration).await { + Ok(resp) => { + if let Upstream::Udp(addr) = upstream { + let rtt_ms = start.elapsed().as_millis() as u64; + srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false); + } + return Ok(resp); + } + Err(e) => { + if let Upstream::Udp(addr) = upstream { + srtt.write().unwrap().record_failure(addr.ip()); + } + log::debug!("upstream {} failed: {}", upstream, e); + last_err = Some(e); + } + } + } + + Err(last_err.unwrap_or_else(|| "no upstream configured".into())) +} + pub async fn forward_query( query: &DnsPacket, upstream: &Upstream, @@ -271,4 +400,112 @@ mod tests { let result = forward_query(&make_query(), &upstream, Duration::from_millis(100)).await; assert!(result.is_err()); } + + #[test] + fn parse_addr_ip_only() { + let addr = parse_upstream_addr("1.2.3.4", 53).unwrap(); + assert_eq!(addr, "1.2.3.4:53".parse::().unwrap()); + } + + #[test] + fn parse_addr_ip_port() { + let addr = parse_upstream_addr("1.2.3.4:5353", 53).unwrap(); + assert_eq!(addr, "1.2.3.4:5353".parse::().unwrap()); + } + + #[test] + fn parse_addr_ipv6_bracketed() { + let addr = parse_upstream_addr("[::1]:5553", 53).unwrap(); + assert_eq!(addr, "[::1]:5553".parse::().unwrap()); + } + + #[test] + fn parse_addr_ipv6_bare() { + let addr = parse_upstream_addr("::1", 53).unwrap(); + assert_eq!(addr, "[::1]:53".parse::().unwrap()); + } + + #[test] + fn pool_label_single() { + let pool = UpstreamPool::new(vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())], vec![]); + assert_eq!(pool.label(), "1.2.3.4:53"); + } + + #[test] + fn pool_label_multi() { + let pool = UpstreamPool::new( + vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())], + vec![Upstream::Udp("8.8.8.8:53".parse().unwrap())], + ); + assert_eq!(pool.label(), "1.2.3.4:53 (+1 more)"); + } + + #[tokio::test] + async fn failover_tries_next_on_failure() { + // First upstream is unreachable, second responds + let query = make_query(); + let response_bytes = to_wire(&make_response(&query)); + + let app = axum::Router::new().route( + "/dns-query", + axum::routing::post(move || { + let body = response_bytes.clone(); + async move { + ( + [(axum::http::header::CONTENT_TYPE, "application/dns-message")], + body, + ) + } + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let good_addr = listener.local_addr().unwrap(); + tokio::spawn(axum::serve(listener, app).into_future()); + + // Unreachable UDP upstream + working DoH upstream + let pool = UpstreamPool::new( + vec![ + Upstream::Udp("127.0.0.1:1".parse().unwrap()), // will fail + Upstream::Doh { + url: format!("http://{}/dns-query", good_addr), + client: reqwest::Client::new(), + }, + ], + vec![], + ); + + let srtt = RwLock::new(SrttCache::new(true)); + let result = forward_with_failover(&query, &pool, &srtt, Duration::from_millis(500)) + .await + .expect("should fail over to second upstream"); + + assert_eq!(result.header.id, 0xABCD); + assert_eq!(result.answers.len(), 1); + } + + #[test] + fn maybe_update_primary_swaps_when_different() { + let mut pool = UpstreamPool::new( + vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())], + vec![Upstream::Udp("8.8.8.8:53".parse().unwrap())], + ); + assert!(pool.maybe_update_primary("5.6.7.8", 53)); + assert_eq!(pool.preferred().unwrap().to_string(), "5.6.7.8:53"); + } + + #[test] + fn maybe_update_primary_noop_when_same() { + let mut pool = + UpstreamPool::new(vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())], vec![]); + assert!(!pool.maybe_update_primary("1.2.3.4", 53)); + } + + #[test] + fn maybe_update_primary_rejects_invalid_addr() { + let mut pool = + UpstreamPool::new(vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())], vec![]); + assert!(!pool.maybe_update_primary("not-an-ip", 53)); + assert_eq!(pool.preferred().unwrap().to_string(), "1.2.3.4:53"); + } } diff --git a/src/main.rs b/src/main.rs index 62acb69..9e2d2f8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,7 +11,7 @@ use numa::buffer::BytePacketBuffer; use numa::cache::DnsCache; use numa::config::{build_zone_map, load_config, ConfigLoad}; use numa::ctx::{handle_query, ServerCtx}; -use numa::forward::Upstream; +use numa::forward::{parse_upstream, Upstream, UpstreamPool}; use numa::override_store::OverrideStore; use numa::query_log::QueryLog; use numa::service_store::ServiceStore; @@ -129,18 +129,18 @@ async fn main() -> numa::Result<()> { let root_hints = numa::recursive::parse_root_hints(&config.upstream.root_hints); - let (resolved_mode, upstream_auto, upstream, upstream_label) = match config.upstream.mode { + let recursive_pool = || { + let dummy = UpstreamPool::new(vec![Upstream::Udp("0.0.0.0:0".parse().unwrap())], vec![]); + (dummy, "recursive (root hints)".to_string()) + }; + + let (resolved_mode, upstream_auto, pool, upstream_label) = match config.upstream.mode { numa::config::UpstreamMode::Auto => { info!("auto mode: probing recursive resolution..."); if numa::recursive::probe_recursive(&root_hints).await { info!("recursive probe succeeded — self-sovereign mode"); - let dummy = Upstream::Udp("0.0.0.0:0".parse().unwrap()); - ( - numa::config::UpstreamMode::Recursive, - false, - dummy, - "recursive (root hints)".to_string(), - ) + let (pool, label) = recursive_pool(); + (numa::config::UpstreamMode::Recursive, false, pool, label) } else { log::warn!("recursive probe failed — falling back to Quad9 DoH"); let client = reqwest::Client::builder() @@ -149,55 +149,45 @@ async fn main() -> numa::Result<()> { .unwrap_or_default(); let url = DOH_FALLBACK.to_string(); let label = url.clone(); - ( - numa::config::UpstreamMode::Forward, - false, - Upstream::Doh { url, client }, - label, - ) + let pool = UpstreamPool::new(vec![Upstream::Doh { url, client }], vec![]); + (numa::config::UpstreamMode::Forward, false, pool, label) } } numa::config::UpstreamMode::Recursive => { - let dummy = Upstream::Udp("0.0.0.0:0".parse().unwrap()); - ( - numa::config::UpstreamMode::Recursive, - false, - dummy, - "recursive (root hints)".to_string(), - ) + let (pool, label) = recursive_pool(); + (numa::config::UpstreamMode::Recursive, false, pool, label) } numa::config::UpstreamMode::Forward => { - let upstream_addr = if config.upstream.address.is_empty() { - system_dns + let addrs = if config.upstream.address.is_empty() { + let detected = system_dns .default_upstream .or_else(numa::system_dns::detect_dhcp_dns) .unwrap_or_else(|| { info!("could not detect system DNS, falling back to Quad9 DoH"); DOH_FALLBACK.to_string() - }) + }); + vec![detected] } else { config.upstream.address.clone() }; - let upstream: Upstream = if upstream_addr.starts_with("https://") { - let client = reqwest::Client::builder() - .use_rustls_tls() - .build() - .unwrap_or_default(); - Upstream::Doh { - url: upstream_addr, - client, - } - } else { - let addr: SocketAddr = - format!("{}:{}", upstream_addr, config.upstream.port).parse()?; - Upstream::Udp(addr) - }; - let label = upstream.to_string(); + let primary: Vec = addrs + .iter() + .map(|s| parse_upstream(s, config.upstream.port)) + .collect::>>()?; + let fallback: Vec = config + .upstream + .fallback + .iter() + .map(|s| parse_upstream(s, config.upstream.port)) + .collect::>>()?; + + let pool = UpstreamPool::new(primary, fallback); + let label = pool.label(); ( numa::config::UpstreamMode::Forward, config.upstream.address.is_empty(), - upstream, + pool, label, ) } @@ -294,7 +284,7 @@ async fn main() -> numa::Result<()> { services: Mutex::new(service_store), lan_peers: Mutex::new(numa::lan::PeerStore::new(config.lan.peer_timeout_secs)), forwarding_rules, - upstream: Mutex::new(upstream), + upstream_pool: Mutex::new(pool), upstream_auto, upstream_port: config.upstream.port, lan_ip: Mutex::new(numa::lan::detect_lan_ip().unwrap_or(std::net::Ipv4Addr::LOCALHOST)), @@ -613,27 +603,17 @@ async fn network_watch_loop(ctx: Arc) { } } - // Re-detect upstream every 30s or on LAN IP change (UDP only — - // DoH upstreams are explicitly configured via URL, not auto-detected) - if ctx.upstream_auto - && matches!(*ctx.upstream.lock().unwrap(), Upstream::Udp(_)) - && (changed || tick.is_multiple_of(6)) - { + // Re-detect upstream every 30s or on LAN IP change (auto-detect only) + if ctx.upstream_auto && (changed || tick.is_multiple_of(6)) { let dns_info = numa::system_dns::discover_system_dns(); let new_addr = dns_info .default_upstream .or_else(numa::system_dns::detect_dhcp_dns) .unwrap_or_else(|| QUAD9_IP.to_string()); - if let Ok(new_sock) = - format!("{}:{}", new_addr, ctx.upstream_port).parse::() - { - let new_upstream = Upstream::Udp(new_sock); - let mut upstream = ctx.upstream.lock().unwrap(); - if *upstream != new_upstream { - info!("upstream changed: {} → {}", upstream, new_upstream); - *upstream = new_upstream; - changed = true; - } + let mut pool = ctx.upstream_pool.lock().unwrap(); + if pool.maybe_update_primary(&new_addr, ctx.upstream_port) { + info!("upstream changed → {}", pool.label()); + changed = true; } }