feat: multi-forwarder with SRTT-based failover

address accepts string or array, with optional per-server port override.
New fallback pool tried only when all primaries fail. Sequential failover
with SRTT ranking ensures fastest upstream is tried first.

Closes #34 (items 1, 2, 3)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-04-10 23:16:56 +03:00
parent 23ff3ce455
commit db80d9f8f9
6 changed files with 320 additions and 63 deletions

View File

@@ -1,12 +1,14 @@
use std::fmt;
use std::net::SocketAddr;
use std::time::Duration;
use std::net::{IpAddr, SocketAddr};
use std::sync::RwLock;
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tokio::time::timeout;
use crate::buffer::BytePacketBuffer;
use crate::packet::DnsPacket;
use crate::srtt::SrttCache;
use crate::Result;
#[derive(Clone)]
@@ -37,6 +39,125 @@ impl fmt::Display for Upstream {
}
}
pub fn parse_upstream_addr(s: &str, default_port: u16) -> std::result::Result<SocketAddr, String> {
// Try full socket addr first: "1.2.3.4:5353" or "[::1]:5353"
if let Ok(addr) = s.parse::<SocketAddr>() {
return Ok(addr);
}
// Bare IP: "1.2.3.4" or "::1"
if let Ok(ip) = s.parse::<IpAddr>() {
return Ok(SocketAddr::new(ip, default_port));
}
Err(format!("invalid upstream address: {}", s))
}
pub fn parse_upstream(s: &str, default_port: u16) -> Result<Upstream> {
if s.starts_with("https://") {
let client = reqwest::Client::builder()
.use_rustls_tls()
.build()
.unwrap_or_default();
return Ok(Upstream::Doh {
url: s.to_string(),
client,
});
}
let addr = parse_upstream_addr(s, default_port)?;
Ok(Upstream::Udp(addr))
}
#[derive(Clone)]
pub struct UpstreamPool {
primary: Vec<Upstream>,
fallback: Vec<Upstream>,
}
impl UpstreamPool {
pub fn new(primary: Vec<Upstream>, fallback: Vec<Upstream>) -> Self {
Self { primary, fallback }
}
pub fn preferred(&self) -> Option<&Upstream> {
self.primary.first().or(self.fallback.first())
}
pub fn set_primary(&mut self, primary: Vec<Upstream>) {
self.primary = primary;
}
pub fn label(&self) -> String {
match self.preferred() {
Some(u) => {
let total = self.primary.len() + self.fallback.len();
if total > 1 {
format!("{} (+{} more)", u, total - 1)
} else {
u.to_string()
}
}
None => "none".to_string(),
}
}
}
pub async fn forward_with_failover(
query: &DnsPacket,
pool: &UpstreamPool,
srtt: &RwLock<SrttCache>,
timeout_duration: Duration,
) -> Result<DnsPacket> {
// Build candidate list: primary (sorted by SRTT) then fallback (config order)
let mut candidates: Vec<&Upstream> =
Vec::with_capacity(pool.primary.len() + pool.fallback.len());
// Sort primary UDP upstreams by SRTT; DoH keeps config order
let mut primary_udp_indices: Vec<(usize, u64)> = Vec::new();
for (i, u) in pool.primary.iter().enumerate() {
if let Upstream::Udp(addr) = u {
let rtt = srtt.read().unwrap().get(addr.ip());
primary_udp_indices.push((i, rtt));
}
}
primary_udp_indices.sort_by_key(|&(_, rtt)| rtt);
// Interleave: sorted UDP first, then DoH in config order
for &(i, _) in &primary_udp_indices {
candidates.push(&pool.primary[i]);
}
for u in &pool.primary {
if matches!(u, Upstream::Doh { .. }) {
candidates.push(u);
}
}
for u in &pool.fallback {
candidates.push(u);
}
let mut last_err: Option<Box<dyn std::error::Error + Send + Sync>> = None;
for upstream in &candidates {
let start = Instant::now();
match forward_query(query, upstream, timeout_duration).await {
Ok(resp) => {
if let Upstream::Udp(addr) = upstream {
let rtt_ms = start.elapsed().as_millis() as u64;
srtt.write().unwrap().record_rtt(addr.ip(), rtt_ms, false);
}
return Ok(resp);
}
Err(e) => {
if let Upstream::Udp(addr) = upstream {
srtt.write().unwrap().record_failure(addr.ip());
}
log::debug!("upstream {} failed: {}", upstream, e);
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| "no upstream configured".into()))
}
pub async fn forward_query(
query: &DnsPacket,
upstream: &Upstream,
@@ -271,4 +392,87 @@ mod tests {
let result = forward_query(&make_query(), &upstream, Duration::from_millis(100)).await;
assert!(result.is_err());
}
#[test]
fn parse_addr_ip_only() {
let addr = parse_upstream_addr("1.2.3.4", 53).unwrap();
assert_eq!(addr, "1.2.3.4:53".parse::<SocketAddr>().unwrap());
}
#[test]
fn parse_addr_ip_port() {
let addr = parse_upstream_addr("1.2.3.4:5353", 53).unwrap();
assert_eq!(addr, "1.2.3.4:5353".parse::<SocketAddr>().unwrap());
}
#[test]
fn parse_addr_ipv6_bracketed() {
let addr = parse_upstream_addr("[::1]:5553", 53).unwrap();
assert_eq!(addr, "[::1]:5553".parse::<SocketAddr>().unwrap());
}
#[test]
fn parse_addr_ipv6_bare() {
let addr = parse_upstream_addr("::1", 53).unwrap();
assert_eq!(addr, "[::1]:53".parse::<SocketAddr>().unwrap());
}
#[test]
fn pool_label_single() {
let pool = UpstreamPool::new(vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())], vec![]);
assert_eq!(pool.label(), "1.2.3.4:53");
}
#[test]
fn pool_label_multi() {
let pool = UpstreamPool::new(
vec![Upstream::Udp("1.2.3.4:53".parse().unwrap())],
vec![Upstream::Udp("8.8.8.8:53".parse().unwrap())],
);
assert_eq!(pool.label(), "1.2.3.4:53 (+1 more)");
}
#[tokio::test]
async fn failover_tries_next_on_failure() {
// First upstream is unreachable, second responds
let query = make_query();
let response_bytes = to_wire(&make_response(&query));
let app = axum::Router::new().route(
"/dns-query",
axum::routing::post(move || {
let body = response_bytes.clone();
async move {
(
[(axum::http::header::CONTENT_TYPE, "application/dns-message")],
body,
)
}
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let good_addr = listener.local_addr().unwrap();
tokio::spawn(axum::serve(listener, app).into_future());
// Unreachable UDP upstream + working DoH upstream
let pool = UpstreamPool::new(
vec![
Upstream::Udp("127.0.0.1:1".parse().unwrap()), // will fail
Upstream::Doh {
url: format!("http://{}/dns-query", good_addr),
client: reqwest::Client::new(),
},
],
vec![],
);
let srtt = RwLock::new(SrttCache::new(true));
let result = forward_with_failover(&query, &pool, &srtt, Duration::from_millis(500))
.await
.expect("should fail over to second upstream");
assert_eq!(result.header.id, 0xABCD);
assert_eq!(result.answers.len(), 1);
}
}