feat: multi-forwarder with SRTT-based failover #77

Merged
razvandimescu merged 4 commits from feat/multi-forwarder-failover into main 2026-04-11 05:26:59 +08:00
6 changed files with 357 additions and 81 deletions

View File

@@ -411,9 +411,12 @@ async fn diagnose(
}
// Check upstream (async, no locks held)
let upstream = ctx.upstream.lock().unwrap().clone();
let (upstream_matched, upstream_detail) =
forward_query_for_diagnose(&domain_lower, &upstream, ctx.timeout).await;
let upstream = ctx.upstream_pool.lock().unwrap().preferred().cloned();
let (upstream_matched, upstream_detail) = if let Some(ref u) = upstream {
forward_query_for_diagnose(&domain_lower, u, ctx.timeout).await
} else {
(false, "no upstream configured".to_string())
};
steps.push(DiagnoseStep {
source: "upstream".to_string(),
matched: upstream_matched,
@@ -520,7 +523,7 @@ async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
let upstream = if ctx.upstream_mode == crate::config::UpstreamMode::Recursive {
"recursive (root hints)".to_string()
} else {
ctx.upstream.lock().unwrap().to_string()
ctx.upstream_pool.lock().unwrap().label()
};
Json(StatsResponse {
@@ -1016,8 +1019,11 @@ mod tests {
services: Mutex::new(crate::service_store::ServiceStore::new()),
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
forwarding_rules: Vec::new(),
upstream: Mutex::new(crate::forward::Upstream::Udp(
"127.0.0.1:53".parse().unwrap(),
upstream_pool: Mutex::new(crate::forward::UpstreamPool::new(
vec![crate::forward::Upstream::Udp(
"127.0.0.1:53".parse().unwrap(),
)],
vec![],
)),
upstream_auto: false,
upstream_port: 53,

View File

@@ -97,10 +97,12 @@ impl UpstreamMode {
pub struct UpstreamConfig {
#[serde(default)]
pub mode: UpstreamMode,
#[serde(default = "default_upstream_addr")]
pub address: String,
#[serde(default, deserialize_with = "string_or_vec")]
pub address: Vec<String>,
#[serde(default = "default_upstream_port")]
pub port: u16,
#[serde(default)]
pub fallback: Vec<String>,
#[serde(default = "default_timeout_ms")]
pub timeout_ms: u64,
#[serde(default = "default_root_hints")]
@@ -115,8 +117,9 @@ impl Default for UpstreamConfig {
fn default() -> Self {
UpstreamConfig {
mode: UpstreamMode::default(),
address: default_upstream_addr(),
address: Vec::new(),
port: default_upstream_port(),
fallback: Vec::new(),
timeout_ms: default_timeout_ms(),
root_hints: default_root_hints(),
prime_tlds: default_prime_tlds(),
@@ -125,6 +128,33 @@ impl Default for UpstreamConfig {
}
}
fn string_or_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = Vec<String>;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str("string or array of strings")
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> std::result::Result<Self::Value, E> {
Ok(vec![v.to_string()])
}
fn visit_seq<A: serde::de::SeqAccess<'de>>(
self,
mut seq: A,
) -> std::result::Result<Self::Value, A::Error> {
let mut v = Vec::new();
while let Some(s) = seq.next_element::<String>()? {
v.push(s);
}
Ok(v)
}
}
deserializer.deserialize_any(Visitor)
}
fn default_true() -> bool {
true
}
@@ -202,9 +232,6 @@ fn default_root_hints() -> Vec<String> {
]
}
fn default_upstream_addr() -> String {
String::new() // empty = auto-detect from system resolver
}
fn default_upstream_port() -> u16 {
53
}
@@ -525,6 +552,33 @@ mod tests {
assert!(config.services[0].routes[0].strip);
assert!(!config.services[0].routes[1].strip); // default false
}
#[test]
fn address_string_parses_to_vec() {
let config: Config = toml::from_str("[upstream]\naddress = \"1.2.3.4\"").unwrap();
assert_eq!(config.upstream.address, vec!["1.2.3.4"]);
}
#[test]
fn address_array_parses() {
let config: Config =
toml::from_str("[upstream]\naddress = [\"1.2.3.4\", \"5.6.7.8:5353\"]").unwrap();
assert_eq!(config.upstream.address, vec!["1.2.3.4", "5.6.7.8:5353"]);
}
#[test]
fn fallback_parses() {
let config: Config =
toml::from_str("[upstream]\nfallback = [\"8.8.8.8\", \"1.1.1.1\"]").unwrap();
assert_eq!(config.upstream.fallback, vec!["8.8.8.8", "1.1.1.1"]);
}
#[test]
fn empty_address_gives_empty_vec() {
let config: Config = toml::from_str("").unwrap();
assert!(config.upstream.address.is_empty());
assert!(config.upstream.fallback.is_empty());
}
}
pub struct ConfigLoad {

View File

@@ -16,7 +16,7 @@ use crate::blocklist::BlocklistStore;
use crate::buffer::BytePacketBuffer;
use crate::cache::{DnsCache, DnssecStatus};
use crate::config::{UpstreamMode, ZoneMap};
use crate::forward::{forward_query, Upstream};
use crate::forward::{forward_query, forward_with_failover, Upstream, UpstreamPool};
use crate::header::ResultCode;
use crate::health::HealthMeta;
use crate::lan::PeerStore;
@@ -42,7 +42,7 @@ pub struct ServerCtx {
pub services: Mutex<ServiceStore>,
pub lan_peers: Mutex<PeerStore>,
pub forwarding_rules: Vec<ForwardingRule>,
pub upstream: Mutex<Upstream>,
pub upstream_pool: Mutex<UpstreamPool>,
pub upstream_auto: bool,
pub upstream_port: u16,
pub lan_ip: Mutex<std::net::Ipv4Addr>,
@@ -220,12 +220,8 @@ pub async fn resolve_query(
}
(resp, path, DnssecStatus::Indeterminate)
} else {
let upstream =
match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) {
Some(addr) => Upstream::Udp(addr),
None => ctx.upstream.lock().unwrap().clone(),
};
match forward_query(&query, &upstream, ctx.timeout).await {
let pool = ctx.upstream_pool.lock().unwrap().clone();
match forward_with_failover(&query, &pool, &ctx.srtt, ctx.timeout).await {
Ok(resp) => {
ctx.cache.write().unwrap().insert(&qname, qtype, &resp);
(resp, QueryPath::Forwarded, DnssecStatus::Indeterminate)

View File

@@ -362,7 +362,10 @@ mod tests {
services: Mutex::new(crate::service_store::ServiceStore::new()),
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
forwarding_rules: Vec::new(),
upstream: Mutex::new(crate::forward::Upstream::Udp(upstream_addr)),
upstream_pool: Mutex::new(crate::forward::UpstreamPool::new(
vec![crate::forward::Upstream::Udp(upstream_addr)],
vec![],
)),
upstream_auto: false,
upstream_port: 53,
lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST),

View File

@@ -1,12 +1,14 @@
use std::fmt;
use std::net::SocketAddr;
use std::time::Duration;
use std::net::{IpAddr, SocketAddr};
use std::sync::RwLock;
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tokio::time::timeout;
use crate::buffer::BytePacketBuffer;
use crate::packet::DnsPacket;
use crate::srtt::SrttCache;
use crate::Result;
#[derive(Clone)]
@@ -37,6 +39,133 @@ impl fmt::Display for Upstream {
}
}
pub fn parse_upstream_addr(s: &str, default_port: u16) -> std::result::Result<SocketAddr, String> {
// Try full socket addr first: "1.2.3.4:5353" or "[::1]:5353"
if let Ok(addr) = s.parse::<SocketAddr>() {
return Ok(addr);
}
// Bare IP: "1.2.3.4" or "::1"
if let Ok(ip) = s.parse::<IpAddr>() {
return Ok(SocketAddr::new(ip, default_port));
}
Err(format!("invalid upstream address: {}", s))
}
pub fn parse_upstream(s: &str, default_port: u16) -> Result<Upstream> {
if s.starts_with("https://") {
let client = reqwest::Client::builder()
.use_rustls_tls()
.build()
.unwrap_or_default();
return Ok(Upstream::Doh {
url: s.to_string(),
client,
});
}
let addr = parse_upstream_addr(s, default_port)?;
Ok(Upstream::Udp(addr))
}
#[derive(Clone)]
pub struct UpstreamPool {
primary: Vec<Upstream>,
fallback: Vec<Upstream>,
}
impl UpstreamPool {
pub fn new(primary: Vec<Upstream>, fallback: Vec<Upstream>) -> Self {
Self { primary, fallback }
}
pub fn preferred(&self) -> Option<&Upstream> {
self.primary.first().or(self.fallback.first())
}
pub fn set_primary(&mut self, primary: Vec<Upstream>) {
self.primary = primary;
}
/// Update the primary upstream if `new_addr` (parsed with `port`) differs
/// from the current preferred upstream. Returns `true` if the pool changed.
pub fn maybe_update_primary(&mut self, new_addr: &str, port: u16) -> bool {
let Ok(new_sock) = format!("{}:{}", new_addr, port).parse::<SocketAddr>() else {
return false;
};
let new_upstream = Upstream::Udp(new_sock);
if self.preferred() == Some(&new_upstream) {
return false;
}
self.primary = vec![new_upstream];
true
}
pub fn label(&self) -> String {
match self.preferred() {
Some(u) => {
let total = self.primary.len() + self.fallback.len();
if total > 1 {
format!("{} (+{} more)", u, total - 1)
} else {
u.to_string()
}
}
None => "none".to_string(),
}
}
}
pub async fn forward_with_failover(
query: &DnsPacket,
pool: &UpstreamPool,
srtt: &RwLock<SrttCache>,
timeout_duration: Duration,
) -> Result<DnsPacket> {
// Build candidate list: primary (sorted by SRTT for UDP) then fallback
let mut candidates: Vec<(usize, u64)> = pool
.primary
.iter()
.enumerate()
.map(|(i, u)| {
let rtt = match u {
Upstream::Udp(addr) => srtt.read().unwrap().get(addr.ip()),
_ => 0, // DoH: keep config order (stable sort preserves it)
};
(i, rtt)
})
.collect();
candidates.sort_by_key(|&(_, rtt)| rtt);
let all_upstreams: Vec<&Upstream> = candidates
.iter()
.map(|&(i, _)| &pool.primary[i])
.chain(pool.fallback.iter())
.collect();
let mut last_err: Option<Box<dyn std::error::Error + Send + Sync>> = None;
for upstream in &all_upstreams {
let start = Instant::now();
match forward_query(query, upstream, timeout_duration).await {
Ok(resp) => {
if let Upstream::Udp(addr) = upstream {
let rtt_ms = start.elapsed().as_millis() as u64;
srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false);
}
return Ok(resp);
}
Err(e) => {
if let Upstream::Udp(addr) = upstream {
srtt.write().unwrap().record_failure(addr.ip());
}
log::debug!("upstream {} failed: {}", upstream, e);
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| "no upstream configured".into()))
}
pub async fn forward_query(
query: &DnsPacket,
upstream: &Upstream,
@@ -271,4 +400,112 @@ mod tests {
let result = forward_query(&make_query(), &upstream, Duration::from_millis(100)).await;
assert!(result.is_err());
}
#[test]
fn parse_addr_ip_only() {
let addr = parse_upstream_addr("1.2.3.4", 53).unwrap();
assert_eq!(addr, "1.2.3.4:53".parse::<SocketAddr>().unwrap());
}
#[test]
fn parse_addr_ip_port() {
let addr = parse_upstream_addr("1.2.3.4:5353", 53).unwrap();
assert_eq!(addr, "1.2.3.4:5353".parse::<SocketAddr>().unwrap());
}
#[test]
fn parse_addr_ipv6_bracketed() {
let addr = parse_upstream_addr("[::1]:5553", 53).unwrap();
assert_eq!(addr, "[::1]:5553".parse::<SocketAddr>().unwrap());
}
#[test]
fn parse_addr_ipv6_bare() {
let addr = parse_upstream_addr("::1", 53).unwrap();
assert_eq!(addr, "[::1]:53".parse::<SocketAddr>().unwrap());
}
#[test]
fn pool_label_single() {
let pool = UpstreamPool::new(vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())], vec![]);
assert_eq!(pool.label(), "1.2.3.4:53");
}
#[test]
fn pool_label_multi() {
let pool = UpstreamPool::new(
vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())],
vec![Upstream::Udp("8.8.8.8:53".parse().unwrap())],
);
assert_eq!(pool.label(), "1.2.3.4:53 (+1 more)");
}
#[tokio::test]
async fn failover_tries_next_on_failure() {
// First upstream is unreachable, second responds
let query = make_query();
let response_bytes = to_wire(&make_response(&query));
let app = axum::Router::new().route(
"/dns-query",
axum::routing::post(move || {
let body = response_bytes.clone();
async move {
(
[(axum::http::header::CONTENT_TYPE, "application/dns-message")],
body,
)
}
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let good_addr = listener.local_addr().unwrap();
tokio::spawn(axum::serve(listener, app).into_future());
// Unreachable UDP upstream + working DoH upstream
let pool = UpstreamPool::new(
vec![
Upstream::Udp("127.0.0.1:1".parse().unwrap()), // will fail
Upstream::Doh {
url: format!("http://{}/dns-query", good_addr),
client: reqwest::Client::new(),
},
],
vec![],
);
let srtt = RwLock::new(SrttCache::new(true));
let result = forward_with_failover(&query, &pool, &srtt, Duration::from_millis(500))
.await
.expect("should fail over to second upstream");
assert_eq!(result.header.id, 0xABCD);
assert_eq!(result.answers.len(), 1);
}
#[test]
fn maybe_update_primary_swaps_when_different() {
let mut pool = UpstreamPool::new(
vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())],
vec![Upstream::Udp("8.8.8.8:53".parse().unwrap())],
);
assert!(pool.maybe_update_primary("5.6.7.8", 53));
assert_eq!(pool.preferred().unwrap().to_string(), "5.6.7.8:53");
}
#[test]
fn maybe_update_primary_noop_when_same() {
let mut pool =
UpstreamPool::new(vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())], vec![]);
assert!(!pool.maybe_update_primary("1.2.3.4", 53));
}
#[test]
fn maybe_update_primary_rejects_invalid_addr() {
let mut pool =
UpstreamPool::new(vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())], vec![]);
assert!(!pool.maybe_update_primary("not-an-ip", 53));
assert_eq!(pool.preferred().unwrap().to_string(), "1.2.3.4:53");
}
}

View File

@@ -11,7 +11,7 @@ use numa::buffer::BytePacketBuffer;
use numa::cache::DnsCache;
use numa::config::{build_zone_map, load_config, ConfigLoad};
use numa::ctx::{handle_query, ServerCtx};
use numa::forward::Upstream;
use numa::forward::{parse_upstream, Upstream, UpstreamPool};
use numa::override_store::OverrideStore;
use numa::query_log::QueryLog;
use numa::service_store::ServiceStore;
@@ -129,18 +129,18 @@ async fn main() -> numa::Result<()> {
let root_hints = numa::recursive::parse_root_hints(&config.upstream.root_hints);
let (resolved_mode, upstream_auto, upstream, upstream_label) = match config.upstream.mode {
let recursive_pool = || {
let dummy = UpstreamPool::new(vec![Upstream::Udp("0.0.0.0:0".parse().unwrap())], vec![]);
(dummy, "recursive (root hints)".to_string())
};
let (resolved_mode, upstream_auto, pool, upstream_label) = match config.upstream.mode {
numa::config::UpstreamMode::Auto => {
info!("auto mode: probing recursive resolution...");
if numa::recursive::probe_recursive(&root_hints).await {
info!("recursive probe succeeded — self-sovereign mode");
let dummy = Upstream::Udp("0.0.0.0:0".parse().unwrap());
(
numa::config::UpstreamMode::Recursive,
false,
dummy,
"recursive (root hints)".to_string(),
)
let (pool, label) = recursive_pool();
(numa::config::UpstreamMode::Recursive, false, pool, label)
} else {
log::warn!("recursive probe failed — falling back to Quad9 DoH");
let client = reqwest::Client::builder()
@@ -149,55 +149,45 @@ async fn main() -> numa::Result<()> {
.unwrap_or_default();
let url = DOH_FALLBACK.to_string();
let label = url.clone();
(
numa::config::UpstreamMode::Forward,
false,
Upstream::Doh { url, client },
label,
)
let pool = UpstreamPool::new(vec![Upstream::Doh { url, client }], vec![]);
(numa::config::UpstreamMode::Forward, false, pool, label)
}
}
numa::config::UpstreamMode::Recursive => {
let dummy = Upstream::Udp("0.0.0.0:0".parse().unwrap());
(
numa::config::UpstreamMode::Recursive,
false,
dummy,
"recursive (root hints)".to_string(),
)
let (pool, label) = recursive_pool();
(numa::config::UpstreamMode::Recursive, false, pool, label)
}
numa::config::UpstreamMode::Forward => {
let upstream_addr = if config.upstream.address.is_empty() {
system_dns
let addrs = if config.upstream.address.is_empty() {
let detected = system_dns
.default_upstream
.or_else(numa::system_dns::detect_dhcp_dns)
.unwrap_or_else(|| {
info!("could not detect system DNS, falling back to Quad9 DoH");
DOH_FALLBACK.to_string()
})
});
vec![detected]
} else {
config.upstream.address.clone()
};
let upstream: Upstream = if upstream_addr.starts_with("https://") {
let client = reqwest::Client::builder()
.use_rustls_tls()
.build()
.unwrap_or_default();
Upstream::Doh {
url: upstream_addr,
client,
}
} else {
let addr: SocketAddr =
format!("{}:{}", upstream_addr, config.upstream.port).parse()?;
Upstream::Udp(addr)
};
let label = upstream.to_string();
let primary: Vec<Upstream> = addrs
.iter()
.map(|s| parse_upstream(s, config.upstream.port))
.collect::<numa::Result<Vec<_>>>()?;
let fallback: Vec<Upstream> = config
.upstream
.fallback
.iter()
.map(|s| parse_upstream(s, config.upstream.port))
.collect::<numa::Result<Vec<_>>>()?;
let pool = UpstreamPool::new(primary, fallback);
let label = pool.label();
(
numa::config::UpstreamMode::Forward,
config.upstream.address.is_empty(),
upstream,
pool,
label,
)
}
@@ -294,7 +284,7 @@ async fn main() -> numa::Result<()> {
services: Mutex::new(service_store),
lan_peers: Mutex::new(numa::lan::PeerStore::new(config.lan.peer_timeout_secs)),
forwarding_rules,
upstream: Mutex::new(upstream),
upstream_pool: Mutex::new(pool),
upstream_auto,
upstream_port: config.upstream.port,
lan_ip: Mutex::new(numa::lan::detect_lan_ip().unwrap_or(std::net::Ipv4Addr::LOCALHOST)),
@@ -613,27 +603,17 @@ async fn network_watch_loop(ctx: Arc<numa::ctx::ServerCtx>) {
}
}
// Re-detect upstream every 30s or on LAN IP change (UDP only
// DoH upstreams are explicitly configured via URL, not auto-detected)
if ctx.upstream_auto
&& matches!(*ctx.upstream.lock().unwrap(), Upstream::Udp(_))
&& (changed || tick.is_multiple_of(6))
{
// Re-detect upstream every 30s or on LAN IP change (auto-detect only)
if ctx.upstream_auto && (changed || tick.is_multiple_of(6)) {
let dns_info = numa::system_dns::discover_system_dns();
let new_addr = dns_info
.default_upstream
.or_else(numa::system_dns::detect_dhcp_dns)
.unwrap_or_else(|| QUAD9_IP.to_string());
if let Ok(new_sock) =
format!("{}:{}", new_addr, ctx.upstream_port).parse::<SocketAddr>()
{
let new_upstream = Upstream::Udp(new_sock);
let mut upstream = ctx.upstream.lock().unwrap();
if *upstream != new_upstream {
info!("upstream changed: {} → {}", upstream, new_upstream);
*upstream = new_upstream;
changed = true;
}
let mut pool = ctx.upstream_pool.lock().unwrap();
if pool.maybe_update_primary(&new_addr, ctx.upstream_port) {
info!("upstream changed → {}", pool.label());
changed = true;
}
}