From b40004fe5e41dc7800d25cbec5e49347a6e68674 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Mon, 13 Apr 2026 07:56:47 +0300 Subject: [PATCH] refactor: extract shared test infrastructure into testutil module - test_ctx(): single ServerCtx builder, replaces 3 copies (ctx/api/dot) - mock_upstream(): canned DNS response server for forwarding tests - blackhole_upstream(): unresponsive socket for timeout tests - Removes ~100 lines of duplicated 30-field struct literals --- src/api.rs | 45 +---------------------- src/ctx.rs | 76 +++------------------------------------ src/dot.rs | 82 +++++++++++++----------------------------- src/lib.rs | 3 ++ src/testutil.rs | 95 +++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 128 insertions(+), 173 deletions(-) create mode 100644 src/testutil.rs diff --git a/src/api.rs b/src/api.rs index fcc0bd9..6ec3e48 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1020,53 +1020,10 @@ mod tests { use super::*; use axum::body::Body; use http::Request; - use std::sync::{Mutex, RwLock}; use tower::ServiceExt; async fn test_ctx() -> Arc { - let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap(); - Arc::new(ServerCtx { - socket, - zone_map: std::collections::HashMap::new(), - cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)), - refreshing: Mutex::new(std::collections::HashSet::new()), - stats: Mutex::new(crate::stats::ServerStats::new()), - overrides: RwLock::new(crate::override_store::OverrideStore::new()), - blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()), - query_log: Mutex::new(crate::query_log::QueryLog::new(100)), - services: Mutex::new(crate::service_store::ServiceStore::new()), - lan_peers: Mutex::new(crate::lan::PeerStore::new(90)), - forwarding_rules: Vec::new(), - upstream_pool: Mutex::new(crate::forward::UpstreamPool::new( - vec![crate::forward::Upstream::Udp( - "127.0.0.1:53".parse().unwrap(), - )], - vec![], - )), - upstream_auto: false, - upstream_port: 53, - lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), - timeout: std::time::Duration::from_secs(3), - hedge_delay: std::time::Duration::ZERO, - proxy_tld: "numa".to_string(), - proxy_tld_suffix: ".numa".to_string(), - lan_enabled: false, - config_path: "/tmp/test-numa.toml".to_string(), - config_found: false, - config_dir: std::path::PathBuf::from("/tmp"), - data_dir: std::path::PathBuf::from("/tmp"), - tls_config: None, - upstream_mode: crate::config::UpstreamMode::Forward, - root_hints: Vec::new(), - srtt: RwLock::new(crate::srtt::SrttCache::new(true)), - inflight: Mutex::new(std::collections::HashMap::new()), - dnssec_enabled: false, - dnssec_strict: false, - health_meta: crate::health::HealthMeta::test_fixture(), - ca_pem: None, - mobile_enabled: false, - mobile_port: 8765, - }) + Arc::new(crate::testutil::test_ctx().await) } #[tokio::test] diff --git a/src/ctx.rs b/src/ctx.rs index 3f1370a..475dfe7 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -659,7 +659,6 @@ mod tests { use super::*; use std::collections::HashMap; use std::net::Ipv4Addr; - use std::path::PathBuf; use std::sync::{Arc, Mutex}; use tokio::sync::broadcast; @@ -1044,10 +1043,6 @@ mod tests { // ---- Full-pipeline resolve_query tests ---- - async fn test_ctx() -> Arc { - test_ctx_with_forwarding(Vec::new()).await - } - /// Send a query through the full resolve_query pipeline and return /// the parsed response + query path. async fn resolve_in_test( @@ -1072,87 +1067,26 @@ mod tests { #[tokio::test] async fn special_use_private_ptr_returns_nxdomain() { - let ctx = test_ctx().await; + let ctx = Arc::new(crate::testutil::test_ctx().await); let (resp, path) = resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await; assert_eq!(path, QueryPath::Local); assert_eq!(resp.header.rescode, ResultCode::NXDOMAIN); } - async fn test_ctx_with_forwarding(rules: Vec) -> Arc { - let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - Arc::new(ServerCtx { - socket, - zone_map: HashMap::new(), - cache: RwLock::new(DnsCache::new(100, 60, 86400)), - refreshing: Mutex::new(HashSet::new()), - stats: Mutex::new(ServerStats::new()), - overrides: RwLock::new(OverrideStore::new()), - blocklist: RwLock::new(BlocklistStore::new()), - query_log: Mutex::new(QueryLog::new(100)), - services: Mutex::new(ServiceStore::new()), - lan_peers: Mutex::new(PeerStore::new(90)), - forwarding_rules: rules, - upstream_pool: Mutex::new(UpstreamPool::new( - vec![Upstream::Udp("127.0.0.1:53".parse().unwrap())], - vec![], - )), - upstream_auto: false, - upstream_port: 53, - lan_ip: Mutex::new(Ipv4Addr::LOCALHOST), - timeout: Duration::from_millis(100), - hedge_delay: Duration::ZERO, - proxy_tld: "numa".to_string(), - proxy_tld_suffix: ".numa".to_string(), - lan_enabled: false, - config_path: "/tmp/test-numa.toml".to_string(), - config_found: false, - config_dir: PathBuf::from("/tmp"), - data_dir: PathBuf::from("/tmp"), - tls_config: None, - upstream_mode: UpstreamMode::Forward, - root_hints: Vec::new(), - srtt: RwLock::new(SrttCache::new(true)), - inflight: Mutex::new(HashMap::new()), - dnssec_enabled: false, - dnssec_strict: false, - health_meta: HealthMeta::test_fixture(), - ca_pem: None, - mobile_enabled: false, - mobile_port: 8765, - }) - } - - /// Spawn a UDP socket that replies to the first DNS query with the given - /// response packet (patching the query ID). Returns the socket address. - async fn mock_upstream(response: DnsPacket) -> SocketAddr { - let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - let addr = sock.local_addr().unwrap(); - tokio::spawn(async move { - let mut buf = [0u8; 512]; - let (_, src) = sock.recv_from(&mut buf).await.unwrap(); - let query_id = u16::from_be_bytes([buf[0], buf[1]]); - let mut resp = response; - resp.header.id = query_id; - let mut out = BytePacketBuffer::new(); - resp.write(&mut out).unwrap(); - sock.send_to(out.filled(), src).await.unwrap(); - }); - addr - } - #[tokio::test] async fn forwarding_rule_overrides_special_use_domain() { let mut resp = DnsPacket::new(); resp.header.response = true; resp.header.rescode = ResultCode::NOERROR; - let upstream_addr = mock_upstream(resp).await; + let upstream_addr = crate::testutil::mock_upstream(resp).await; - let rules = vec![ForwardingRule::new( + let mut ctx = crate::testutil::test_ctx().await; + ctx.forwarding_rules = vec![ForwardingRule::new( "168.192.in-addr.arpa".to_string(), upstream_addr, )]; - let ctx = test_ctx_with_forwarding(rules).await; + let ctx = Arc::new(ctx); let (resp, path) = resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await; diff --git a/src/dot.rs b/src/dot.rs index db8257d..b39d7fe 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -279,7 +279,7 @@ where mod tests { use super::*; use std::collections::HashMap; - use std::sync::{Mutex, RwLock}; + use std::sync::Mutex; use rcgen::{CertificateParams, DnType, KeyPair}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName}; @@ -344,63 +344,29 @@ mod tests { async fn spawn_dot_server() -> (SocketAddr, CertificateDer<'static>) { let (server_tls, cert_der) = test_tls_configs(); - let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap(); - // Bind an unresponsive upstream and leak it so it lives for the test duration. - let blackhole = Box::leak(Box::new(std::net::UdpSocket::bind("127.0.0.1:0").unwrap())); - let upstream_addr = blackhole.local_addr().unwrap(); - let ctx = Arc::new(ServerCtx { - socket, - zone_map: { - let mut m = HashMap::new(); - let mut inner = HashMap::new(); - inner.insert( - QueryType::A, - vec![DnsRecord::A { - domain: "dot-test.example".to_string(), - addr: std::net::Ipv4Addr::new(10, 0, 0, 1), - ttl: 300, - }], - ); - m.insert("dot-test.example".to_string(), inner); - m - }, - cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)), - refreshing: Mutex::new(std::collections::HashSet::new()), - stats: Mutex::new(crate::stats::ServerStats::new()), - overrides: RwLock::new(crate::override_store::OverrideStore::new()), - blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()), - query_log: Mutex::new(crate::query_log::QueryLog::new(100)), - services: Mutex::new(crate::service_store::ServiceStore::new()), - lan_peers: Mutex::new(crate::lan::PeerStore::new(90)), - forwarding_rules: Vec::new(), - upstream_pool: Mutex::new(crate::forward::UpstreamPool::new( - vec![crate::forward::Upstream::Udp(upstream_addr)], - vec![], - )), - upstream_auto: false, - upstream_port: 53, - lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), - timeout: Duration::from_millis(200), - hedge_delay: Duration::ZERO, - proxy_tld: "numa".to_string(), - proxy_tld_suffix: ".numa".to_string(), - lan_enabled: false, - config_path: String::new(), - config_found: false, - config_dir: std::path::PathBuf::from("/tmp"), - data_dir: std::path::PathBuf::from("/tmp"), - tls_config: Some(arc_swap::ArcSwap::from(server_tls)), - upstream_mode: crate::config::UpstreamMode::Forward, - root_hints: Vec::new(), - srtt: RwLock::new(crate::srtt::SrttCache::new(true)), - inflight: Mutex::new(HashMap::new()), - dnssec_enabled: false, - dnssec_strict: false, - health_meta: crate::health::HealthMeta::test_fixture(), - ca_pem: None, - mobile_enabled: false, - mobile_port: 8765, - }); + let upstream_addr = crate::testutil::blackhole_upstream(); + + let mut ctx = crate::testutil::test_ctx().await; + ctx.zone_map = { + let mut m = HashMap::new(); + let mut inner = HashMap::new(); + inner.insert( + QueryType::A, + vec![DnsRecord::A { + domain: "dot-test.example".to_string(), + addr: std::net::Ipv4Addr::new(10, 0, 0, 1), + ttl: 300, + }], + ); + m.insert("dot-test.example".to_string(), inner); + m + }; + ctx.upstream_pool = Mutex::new(crate::forward::UpstreamPool::new( + vec![crate::forward::Upstream::Udp(upstream_addr)], + vec![], + )); + ctx.tls_config = Some(arc_swap::ArcSwap::from(server_tls)); + let ctx = Arc::new(ctx); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 92a0b00..8933e2a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,9 @@ pub mod system_dns; pub mod tls; pub mod wire; +#[cfg(test)] +pub(crate) mod testutil; + pub type Error = Box; pub type Result = std::result::Result; diff --git a/src/testutil.rs b/src/testutil.rs new file mode 100644 index 0000000..8687625 --- /dev/null +++ b/src/testutil.rs @@ -0,0 +1,95 @@ +use std::collections::{HashMap, HashSet}; +use std::net::{Ipv4Addr, SocketAddr}; +use std::path::PathBuf; +use std::sync::{Mutex, RwLock}; +use std::time::Duration; + +use tokio::net::UdpSocket; + +use crate::blocklist::BlocklistStore; +use crate::buffer::BytePacketBuffer; +use crate::cache::DnsCache; +use crate::config::UpstreamMode; +use crate::ctx::ServerCtx; +use crate::forward::{Upstream, UpstreamPool}; +use crate::health::HealthMeta; +use crate::lan::PeerStore; +use crate::override_store::OverrideStore; +use crate::packet::DnsPacket; +use crate::query_log::QueryLog; +use crate::service_store::ServiceStore; +use crate::srtt::SrttCache; +use crate::stats::ServerStats; +/// Minimal `ServerCtx` for tests. Override fields after construction +/// (all fields are `pub`), then wrap in `Arc`. +pub async fn test_ctx() -> ServerCtx { + let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + ServerCtx { + socket, + zone_map: HashMap::new(), + cache: RwLock::new(DnsCache::new(100, 60, 86400)), + refreshing: Mutex::new(HashSet::new()), + stats: Mutex::new(ServerStats::new()), + overrides: RwLock::new(OverrideStore::new()), + blocklist: RwLock::new(BlocklistStore::new()), + query_log: Mutex::new(QueryLog::new(100)), + services: Mutex::new(ServiceStore::new()), + lan_peers: Mutex::new(PeerStore::new(90)), + forwarding_rules: Vec::new(), + upstream_pool: Mutex::new(UpstreamPool::new( + vec![Upstream::Udp("127.0.0.1:53".parse().unwrap())], + vec![], + )), + upstream_auto: false, + upstream_port: 53, + lan_ip: Mutex::new(Ipv4Addr::LOCALHOST), + timeout: Duration::from_millis(200), + hedge_delay: Duration::ZERO, + proxy_tld: "numa".to_string(), + proxy_tld_suffix: ".numa".to_string(), + lan_enabled: false, + config_path: "/tmp/test-numa.toml".to_string(), + config_found: false, + config_dir: PathBuf::from("/tmp"), + data_dir: PathBuf::from("/tmp"), + tls_config: None, + upstream_mode: UpstreamMode::Forward, + root_hints: Vec::new(), + srtt: RwLock::new(SrttCache::new(true)), + inflight: Mutex::new(HashMap::new()), + dnssec_enabled: false, + dnssec_strict: false, + health_meta: HealthMeta::test_fixture(), + ca_pem: None, + mobile_enabled: false, + mobile_port: 8765, + } +} + +/// Spawn a UDP socket that replies to the first DNS query with the given +/// response packet (patching the query ID to match). Returns the socket address. +pub async fn mock_upstream(response: DnsPacket) -> SocketAddr { + let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = sock.local_addr().unwrap(); + tokio::spawn(async move { + let mut buf = [0u8; 512]; + let (_, src) = sock.recv_from(&mut buf).await.unwrap(); + let query_id = u16::from_be_bytes([buf[0], buf[1]]); + let mut resp = response; + resp.header.id = query_id; + let mut out = BytePacketBuffer::new(); + resp.write(&mut out).unwrap(); + sock.send_to(out.filled(), src).await.unwrap(); + }); + addr +} + +/// UDP socket that accepts connections but never replies. +/// Useful as an upstream that triggers timeouts. +pub fn blackhole_upstream() -> SocketAddr { + let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let addr = sock.local_addr().unwrap(); + // Leak so it stays bound for the duration of the test process. + Box::leak(Box::new(sock)); + addr +}