diff --git a/src/api.rs b/src/api.rs index fcc0bd9..6ec3e48 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1020,53 +1020,10 @@ mod tests { use super::*; use axum::body::Body; use http::Request; - use std::sync::{Mutex, RwLock}; use tower::ServiceExt; async fn test_ctx() -> Arc { - let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap(); - Arc::new(ServerCtx { - socket, - zone_map: std::collections::HashMap::new(), - cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)), - refreshing: Mutex::new(std::collections::HashSet::new()), - stats: Mutex::new(crate::stats::ServerStats::new()), - overrides: RwLock::new(crate::override_store::OverrideStore::new()), - blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()), - query_log: Mutex::new(crate::query_log::QueryLog::new(100)), - services: Mutex::new(crate::service_store::ServiceStore::new()), - lan_peers: Mutex::new(crate::lan::PeerStore::new(90)), - forwarding_rules: Vec::new(), - upstream_pool: Mutex::new(crate::forward::UpstreamPool::new( - vec![crate::forward::Upstream::Udp( - "127.0.0.1:53".parse().unwrap(), - )], - vec![], - )), - upstream_auto: false, - upstream_port: 53, - lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), - timeout: std::time::Duration::from_secs(3), - hedge_delay: std::time::Duration::ZERO, - proxy_tld: "numa".to_string(), - proxy_tld_suffix: ".numa".to_string(), - lan_enabled: false, - config_path: "/tmp/test-numa.toml".to_string(), - config_found: false, - config_dir: std::path::PathBuf::from("/tmp"), - data_dir: std::path::PathBuf::from("/tmp"), - tls_config: None, - upstream_mode: crate::config::UpstreamMode::Forward, - root_hints: Vec::new(), - srtt: RwLock::new(crate::srtt::SrttCache::new(true)), - inflight: Mutex::new(std::collections::HashMap::new()), - dnssec_enabled: false, - dnssec_strict: false, - health_meta: crate::health::HealthMeta::test_fixture(), - ca_pem: None, - mobile_enabled: false, - mobile_port: 8765, - }) + Arc::new(crate::testutil::test_ctx().await) } #[tokio::test] diff --git a/src/ctx.rs b/src/ctx.rs index 65b76d3..2812bed 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -88,7 +88,7 @@ pub async fn resolve_query( src_addr: SocketAddr, ctx: &Arc, transport: Transport, -) -> crate::Result { +) -> crate::Result<(BytePacketBuffer, QueryPath)> { let start = Instant::now(); let (qname, qtype) = match query.questions.first() { @@ -96,7 +96,8 @@ pub async fn resolve_query( None => return Err("empty question section".into()), }; - // Pipeline: overrides -> .tld interception -> blocklist -> local zones -> cache -> upstream + // Pipeline: overrides -> .localhost -> local zones -> special-use (unless forwarded) + // -> .tld proxy -> blocklist -> cache -> forwarding -> recursive/upstream // Each lock is scoped to avoid holding MutexGuard across await points. let (response, path, dnssec) = { let override_record = ctx.overrides.read().unwrap().lookup(&qname); @@ -119,8 +120,10 @@ pub async fn resolve_query( let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); resp.answers = records.clone(); (resp, QueryPath::Local, DnssecStatus::Indeterminate) - } else if is_special_use_domain(&qname) { - // RFC 6761/8880: private PTR, DDR, NAT64 — answer locally + } else if is_special_use_domain(&qname) + && crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules).is_none() + { + // RFC 6761/8880: answer locally unless a forwarding rule covers this zone. let resp = special_use_response(&query, &qname, qtype); (resp, QueryPath::Local, DnssecStatus::Indeterminate) } else if !ctx.proxy_tld_suffix.is_empty() @@ -373,7 +376,7 @@ pub async fn resolve_query( dnssec, }); - Ok(resp_buffer) + Ok((resp_buffer, path)) } fn cache_and_parse( @@ -457,7 +460,7 @@ pub async fn handle_query( } }; match resolve_query(query, &buffer.buf[..raw_len], src_addr, ctx, transport).await { - Ok(resp_buffer) => { + Ok((resp_buffer, _)) => { ctx.socket.send_to(resp_buffer.filled(), src_addr).await?; } Err(e) => { @@ -1036,4 +1039,216 @@ mod tests { "error message must be preserved for logging" ); } + + // ---- Full-pipeline resolve_query tests ---- + + /// Send a query through the full resolve_query pipeline and return + /// the parsed response + query path. + async fn resolve_in_test( + ctx: &Arc, + domain: &str, + qtype: QueryType, + ) -> (DnsPacket, QueryPath) { + let query = DnsPacket::query(0xBEEF, domain, qtype); + let mut buf = BytePacketBuffer::new(); + query.write(&mut buf).unwrap(); + let raw = &buf.buf[..buf.pos]; + let src: SocketAddr = "127.0.0.1:1234".parse().unwrap(); + + let (resp_buf, path) = resolve_query(query, raw, src, ctx, Transport::Udp) + .await + .unwrap(); + + let mut resp_parse_buf = BytePacketBuffer::from_bytes(resp_buf.filled()); + let resp = DnsPacket::from_buffer(&mut resp_parse_buf).unwrap(); + (resp, path) + } + + #[tokio::test] + async fn special_use_private_ptr_returns_nxdomain() { + let ctx = Arc::new(crate::testutil::test_ctx().await); + let (resp, path) = + resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await; + assert_eq!(path, QueryPath::Local); + assert_eq!(resp.header.rescode, ResultCode::NXDOMAIN); + } + + #[tokio::test] + async fn forwarding_rule_overrides_special_use_domain() { + let mut resp = DnsPacket::new(); + resp.header.response = true; + resp.header.rescode = ResultCode::NOERROR; + let upstream_addr = crate::testutil::mock_upstream(resp).await; + + let mut ctx = crate::testutil::test_ctx().await; + ctx.forwarding_rules = vec![ForwardingRule::new( + "168.192.in-addr.arpa".to_string(), + upstream_addr, + )]; + let ctx = Arc::new(ctx); + + let (resp, path) = + resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await; + + assert_eq!( + path, + QueryPath::Forwarded, + "forwarding rule must take precedence over special-use NXDOMAIN" + ); + assert_eq!(resp.header.rescode, ResultCode::NOERROR); + } + + #[tokio::test] + async fn pipeline_override_takes_precedence() { + let ctx = crate::testutil::test_ctx().await; + ctx.overrides + .write() + .unwrap() + .insert("override.test", "1.2.3.4", 60, None) + .unwrap(); + let ctx = Arc::new(ctx); + + let (resp, path) = resolve_in_test(&ctx, "override.test", QueryType::A).await; + assert_eq!(path, QueryPath::Overridden); + assert_eq!(resp.header.rescode, ResultCode::NOERROR); + assert_eq!(resp.answers.len(), 1); + } + + #[tokio::test] + async fn pipeline_localhost_resolves_to_loopback() { + let ctx = Arc::new(crate::testutil::test_ctx().await); + + let (resp, path) = resolve_in_test(&ctx, "localhost", QueryType::A).await; + assert_eq!(path, QueryPath::Local); + assert_eq!(resp.header.rescode, ResultCode::NOERROR); + match &resp.answers[0] { + DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::LOCALHOST), + other => panic!("expected A record, got {:?}", other), + } + } + + #[tokio::test] + async fn pipeline_localhost_subdomain_resolves_to_loopback() { + let ctx = Arc::new(crate::testutil::test_ctx().await); + + let (resp, path) = resolve_in_test(&ctx, "app.localhost", QueryType::A).await; + assert_eq!(path, QueryPath::Local); + assert_eq!(resp.header.rescode, ResultCode::NOERROR); + match &resp.answers[0] { + DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::LOCALHOST), + other => panic!("expected A record, got {:?}", other), + } + } + + #[tokio::test] + async fn pipeline_local_zone_returns_configured_record() { + let mut ctx = crate::testutil::test_ctx().await; + let mut inner = HashMap::new(); + inner.insert( + QueryType::A, + vec![DnsRecord::A { + domain: "myapp.test".to_string(), + addr: Ipv4Addr::new(10, 0, 0, 42), + ttl: 300, + }], + ); + ctx.zone_map.insert("myapp.test".to_string(), inner); + let ctx = Arc::new(ctx); + + let (resp, path) = resolve_in_test(&ctx, "myapp.test", QueryType::A).await; + assert_eq!(path, QueryPath::Local); + assert_eq!(resp.header.rescode, ResultCode::NOERROR); + match &resp.answers[0] { + DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::new(10, 0, 0, 42)), + other => panic!("expected A record, got {:?}", other), + } + } + + #[tokio::test] + async fn pipeline_tld_proxy_resolves_service() { + let ctx = crate::testutil::test_ctx().await; + ctx.services.lock().unwrap().insert("grafana", 3000); + let ctx = Arc::new(ctx); + + let (resp, path) = resolve_in_test(&ctx, "grafana.numa", QueryType::A).await; + assert_eq!(path, QueryPath::Local); + assert_eq!(resp.header.rescode, ResultCode::NOERROR); + match &resp.answers[0] { + DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::LOCALHOST), + other => panic!("expected A record, got {:?}", other), + } + } + + #[tokio::test] + async fn pipeline_blocklist_sinkhole() { + let ctx = crate::testutil::test_ctx().await; + let mut domains = std::collections::HashSet::new(); + domains.insert("ads.tracker.test".to_string()); + ctx.blocklist.write().unwrap().swap_domains(domains, vec![]); + let ctx = Arc::new(ctx); + + let (resp, path) = resolve_in_test(&ctx, "ads.tracker.test", QueryType::A).await; + assert_eq!(path, QueryPath::Blocked); + assert_eq!(resp.header.rescode, ResultCode::NOERROR); + match &resp.answers[0] { + DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::UNSPECIFIED), + other => panic!("expected sinkhole A record, got {:?}", other), + } + } + + #[tokio::test] + async fn pipeline_cache_hit() { + let ctx = Arc::new(crate::testutil::test_ctx().await); + + // Pre-populate cache with a response + let mut pkt = DnsPacket::new(); + pkt.header.response = true; + pkt.header.rescode = ResultCode::NOERROR; + pkt.questions.push(crate::question::DnsQuestion { + name: "cached.test".to_string(), + qtype: QueryType::A, + }); + pkt.answers.push(DnsRecord::A { + domain: "cached.test".to_string(), + addr: Ipv4Addr::new(5, 5, 5, 5), + ttl: 3600, + }); + ctx.cache + .write() + .unwrap() + .insert("cached.test", QueryType::A, &pkt); + + let (resp, path) = resolve_in_test(&ctx, "cached.test", QueryType::A).await; + assert_eq!(path, QueryPath::Cached); + assert_eq!(resp.header.rescode, ResultCode::NOERROR); + } + + #[tokio::test] + async fn pipeline_forwarding_returns_upstream_answer() { + let mut upstream_resp = DnsPacket::new(); + upstream_resp.header.response = true; + upstream_resp.header.rescode = ResultCode::NOERROR; + upstream_resp.answers.push(DnsRecord::A { + domain: "internal.corp".to_string(), + addr: Ipv4Addr::new(10, 1, 2, 3), + ttl: 600, + }); + let upstream_addr = crate::testutil::mock_upstream(upstream_resp).await; + + let mut ctx = crate::testutil::test_ctx().await; + ctx.forwarding_rules = vec![ForwardingRule::new("corp".to_string(), upstream_addr)]; + let ctx = Arc::new(ctx); + + let (resp, path) = resolve_in_test(&ctx, "internal.corp", QueryType::A).await; + assert_eq!(path, QueryPath::Forwarded); + assert_eq!(resp.header.rescode, ResultCode::NOERROR); + assert_eq!(resp.answers.len(), 1); + match &resp.answers[0] { + DnsRecord::A { domain, addr, .. } => { + assert_eq!(domain, "internal.corp"); + assert_eq!(*addr, Ipv4Addr::new(10, 1, 2, 3)); + } + other => panic!("expected A record, got {:?}", other), + } + } } diff --git a/src/doh.rs b/src/doh.rs index f90b919..900edb4 100644 --- a/src/doh.rs +++ b/src/doh.rs @@ -113,7 +113,7 @@ async fn resolve_doh( let questions = query.questions.clone(); match resolve_query(query, dns_bytes, src, ctx, Transport::Doh).await { - Ok(resp_buffer) => { + Ok((resp_buffer, _)) => { let min_ttl = extract_min_ttl(resp_buffer.filled()); dns_response(resp_buffer.filled(), min_ttl) } diff --git a/src/dot.rs b/src/dot.rs index e883e0b..b39d7fe 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -211,7 +211,7 @@ async fn handle_dot_connection( ) .await { - Ok(resp_buffer) => { + Ok((resp_buffer, _)) => { if write_framed(&mut stream, resp_buffer.filled()) .await .is_err() @@ -279,7 +279,7 @@ where mod tests { use super::*; use std::collections::HashMap; - use std::sync::{Mutex, RwLock}; + use std::sync::Mutex; use rcgen::{CertificateParams, DnType, KeyPair}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName}; @@ -344,63 +344,29 @@ mod tests { async fn spawn_dot_server() -> (SocketAddr, CertificateDer<'static>) { let (server_tls, cert_der) = test_tls_configs(); - let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap(); - // Bind an unresponsive upstream and leak it so it lives for the test duration. - let blackhole = Box::leak(Box::new(std::net::UdpSocket::bind("127.0.0.1:0").unwrap())); - let upstream_addr = blackhole.local_addr().unwrap(); - let ctx = Arc::new(ServerCtx { - socket, - zone_map: { - let mut m = HashMap::new(); - let mut inner = HashMap::new(); - inner.insert( - QueryType::A, - vec![DnsRecord::A { - domain: "dot-test.example".to_string(), - addr: std::net::Ipv4Addr::new(10, 0, 0, 1), - ttl: 300, - }], - ); - m.insert("dot-test.example".to_string(), inner); - m - }, - cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)), - refreshing: Mutex::new(std::collections::HashSet::new()), - stats: Mutex::new(crate::stats::ServerStats::new()), - overrides: RwLock::new(crate::override_store::OverrideStore::new()), - blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()), - query_log: Mutex::new(crate::query_log::QueryLog::new(100)), - services: Mutex::new(crate::service_store::ServiceStore::new()), - lan_peers: Mutex::new(crate::lan::PeerStore::new(90)), - forwarding_rules: Vec::new(), - upstream_pool: Mutex::new(crate::forward::UpstreamPool::new( - vec![crate::forward::Upstream::Udp(upstream_addr)], - vec![], - )), - upstream_auto: false, - upstream_port: 53, - lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), - timeout: Duration::from_millis(200), - hedge_delay: Duration::ZERO, - proxy_tld: "numa".to_string(), - proxy_tld_suffix: ".numa".to_string(), - lan_enabled: false, - config_path: String::new(), - config_found: false, - config_dir: std::path::PathBuf::from("/tmp"), - data_dir: std::path::PathBuf::from("/tmp"), - tls_config: Some(arc_swap::ArcSwap::from(server_tls)), - upstream_mode: crate::config::UpstreamMode::Forward, - root_hints: Vec::new(), - srtt: RwLock::new(crate::srtt::SrttCache::new(true)), - inflight: Mutex::new(HashMap::new()), - dnssec_enabled: false, - dnssec_strict: false, - health_meta: crate::health::HealthMeta::test_fixture(), - ca_pem: None, - mobile_enabled: false, - mobile_port: 8765, - }); + let upstream_addr = crate::testutil::blackhole_upstream(); + + let mut ctx = crate::testutil::test_ctx().await; + ctx.zone_map = { + let mut m = HashMap::new(); + let mut inner = HashMap::new(); + inner.insert( + QueryType::A, + vec![DnsRecord::A { + domain: "dot-test.example".to_string(), + addr: std::net::Ipv4Addr::new(10, 0, 0, 1), + ttl: 300, + }], + ); + m.insert("dot-test.example".to_string(), inner); + m + }; + ctx.upstream_pool = Mutex::new(crate::forward::UpstreamPool::new( + vec![crate::forward::Upstream::Udp(upstream_addr)], + vec![], + )); + ctx.tls_config = Some(arc_swap::ArcSwap::from(server_tls)); + let ctx = Arc::new(ctx); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 92a0b00..8933e2a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,9 @@ pub mod system_dns; pub mod tls; pub mod wire; +#[cfg(test)] +pub(crate) mod testutil; + pub type Error = Box; pub type Result = std::result::Result; diff --git a/src/testutil.rs b/src/testutil.rs new file mode 100644 index 0000000..8687625 --- /dev/null +++ b/src/testutil.rs @@ -0,0 +1,95 @@ +use std::collections::{HashMap, HashSet}; +use std::net::{Ipv4Addr, SocketAddr}; +use std::path::PathBuf; +use std::sync::{Mutex, RwLock}; +use std::time::Duration; + +use tokio::net::UdpSocket; + +use crate::blocklist::BlocklistStore; +use crate::buffer::BytePacketBuffer; +use crate::cache::DnsCache; +use crate::config::UpstreamMode; +use crate::ctx::ServerCtx; +use crate::forward::{Upstream, UpstreamPool}; +use crate::health::HealthMeta; +use crate::lan::PeerStore; +use crate::override_store::OverrideStore; +use crate::packet::DnsPacket; +use crate::query_log::QueryLog; +use crate::service_store::ServiceStore; +use crate::srtt::SrttCache; +use crate::stats::ServerStats; +/// Minimal `ServerCtx` for tests. Override fields after construction +/// (all fields are `pub`), then wrap in `Arc`. +pub async fn test_ctx() -> ServerCtx { + let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + ServerCtx { + socket, + zone_map: HashMap::new(), + cache: RwLock::new(DnsCache::new(100, 60, 86400)), + refreshing: Mutex::new(HashSet::new()), + stats: Mutex::new(ServerStats::new()), + overrides: RwLock::new(OverrideStore::new()), + blocklist: RwLock::new(BlocklistStore::new()), + query_log: Mutex::new(QueryLog::new(100)), + services: Mutex::new(ServiceStore::new()), + lan_peers: Mutex::new(PeerStore::new(90)), + forwarding_rules: Vec::new(), + upstream_pool: Mutex::new(UpstreamPool::new( + vec![Upstream::Udp("127.0.0.1:53".parse().unwrap())], + vec![], + )), + upstream_auto: false, + upstream_port: 53, + lan_ip: Mutex::new(Ipv4Addr::LOCALHOST), + timeout: Duration::from_millis(200), + hedge_delay: Duration::ZERO, + proxy_tld: "numa".to_string(), + proxy_tld_suffix: ".numa".to_string(), + lan_enabled: false, + config_path: "/tmp/test-numa.toml".to_string(), + config_found: false, + config_dir: PathBuf::from("/tmp"), + data_dir: PathBuf::from("/tmp"), + tls_config: None, + upstream_mode: UpstreamMode::Forward, + root_hints: Vec::new(), + srtt: RwLock::new(SrttCache::new(true)), + inflight: Mutex::new(HashMap::new()), + dnssec_enabled: false, + dnssec_strict: false, + health_meta: HealthMeta::test_fixture(), + ca_pem: None, + mobile_enabled: false, + mobile_port: 8765, + } +} + +/// Spawn a UDP socket that replies to the first DNS query with the given +/// response packet (patching the query ID to match). Returns the socket address. +pub async fn mock_upstream(response: DnsPacket) -> SocketAddr { + let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = sock.local_addr().unwrap(); + tokio::spawn(async move { + let mut buf = [0u8; 512]; + let (_, src) = sock.recv_from(&mut buf).await.unwrap(); + let query_id = u16::from_be_bytes([buf[0], buf[1]]); + let mut resp = response; + resp.header.id = query_id; + let mut out = BytePacketBuffer::new(); + resp.write(&mut out).unwrap(); + sock.send_to(out.filled(), src).await.unwrap(); + }); + addr +} + +/// UDP socket that accepts connections but never replies. +/// Useful as an upstream that triggers timeouts. +pub fn blackhole_upstream() -> SocketAddr { + let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let addr = sock.local_addr().unwrap(); + // Leak so it stays bound for the duration of the test process. + Box::leak(Box::new(sock)); + addr +}