fix: forwarding rules override special-use NXDOMAIN #95

Merged
razvandimescu merged 7 commits from fix/forwarding-precedes-special-use into main 2026-04-13 14:37:19 +08:00
5 changed files with 128 additions and 173 deletions
Showing only changes of commit b40004fe5e - Show all commits

View File

@@ -1020,53 +1020,10 @@ mod tests {
use super::*; use super::*;
use axum::body::Body; use axum::body::Body;
use http::Request; use http::Request;
use std::sync::{Mutex, RwLock};
use tower::ServiceExt; use tower::ServiceExt;
async fn test_ctx() -> Arc<ServerCtx> { async fn test_ctx() -> Arc<ServerCtx> {
let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap(); Arc::new(crate::testutil::test_ctx().await)
Arc::new(ServerCtx {
socket,
zone_map: std::collections::HashMap::new(),
cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)),
refreshing: Mutex::new(std::collections::HashSet::new()),
stats: Mutex::new(crate::stats::ServerStats::new()),
overrides: RwLock::new(crate::override_store::OverrideStore::new()),
blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()),
query_log: Mutex::new(crate::query_log::QueryLog::new(100)),
services: Mutex::new(crate::service_store::ServiceStore::new()),
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
forwarding_rules: Vec::new(),
upstream_pool: Mutex::new(crate::forward::UpstreamPool::new(
vec![crate::forward::Upstream::Udp(
"127.0.0.1:53".parse().unwrap(),
)],
vec![],
)),
upstream_auto: false,
upstream_port: 53,
lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST),
timeout: std::time::Duration::from_secs(3),
hedge_delay: std::time::Duration::ZERO,
proxy_tld: "numa".to_string(),
proxy_tld_suffix: ".numa".to_string(),
lan_enabled: false,
config_path: "/tmp/test-numa.toml".to_string(),
config_found: false,
config_dir: std::path::PathBuf::from("/tmp"),
data_dir: std::path::PathBuf::from("/tmp"),
tls_config: None,
upstream_mode: crate::config::UpstreamMode::Forward,
root_hints: Vec::new(),
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
inflight: Mutex::new(std::collections::HashMap::new()),
dnssec_enabled: false,
dnssec_strict: false,
health_meta: crate::health::HealthMeta::test_fixture(),
ca_pem: None,
mobile_enabled: false,
mobile_port: 8765,
})
} }
#[tokio::test] #[tokio::test]

View File

@@ -659,7 +659,6 @@ mod tests {
use super::*; use super::*;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use std::path::PathBuf;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use tokio::sync::broadcast; use tokio::sync::broadcast;
@@ -1044,10 +1043,6 @@ mod tests {
// ---- Full-pipeline resolve_query tests ---- // ---- Full-pipeline resolve_query tests ----
async fn test_ctx() -> Arc<ServerCtx> {
test_ctx_with_forwarding(Vec::new()).await
}
/// Send a query through the full resolve_query pipeline and return /// Send a query through the full resolve_query pipeline and return
/// the parsed response + query path. /// the parsed response + query path.
async fn resolve_in_test( async fn resolve_in_test(
@@ -1072,87 +1067,26 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn special_use_private_ptr_returns_nxdomain() { async fn special_use_private_ptr_returns_nxdomain() {
let ctx = test_ctx().await; let ctx = Arc::new(crate::testutil::test_ctx().await);
let (resp, path) = let (resp, path) =
resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await; resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await;
assert_eq!(path, QueryPath::Local); assert_eq!(path, QueryPath::Local);
assert_eq!(resp.header.rescode, ResultCode::NXDOMAIN); assert_eq!(resp.header.rescode, ResultCode::NXDOMAIN);
} }
async fn test_ctx_with_forwarding(rules: Vec<ForwardingRule>) -> Arc<ServerCtx> {
let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
Arc::new(ServerCtx {
socket,
zone_map: HashMap::new(),
cache: RwLock::new(DnsCache::new(100, 60, 86400)),
refreshing: Mutex::new(HashSet::new()),
stats: Mutex::new(ServerStats::new()),
overrides: RwLock::new(OverrideStore::new()),
blocklist: RwLock::new(BlocklistStore::new()),
query_log: Mutex::new(QueryLog::new(100)),
services: Mutex::new(ServiceStore::new()),
lan_peers: Mutex::new(PeerStore::new(90)),
forwarding_rules: rules,
upstream_pool: Mutex::new(UpstreamPool::new(
vec![Upstream::Udp("127.0.0.1:53".parse().unwrap())],
vec![],
)),
upstream_auto: false,
upstream_port: 53,
lan_ip: Mutex::new(Ipv4Addr::LOCALHOST),
timeout: Duration::from_millis(100),
hedge_delay: Duration::ZERO,
proxy_tld: "numa".to_string(),
proxy_tld_suffix: ".numa".to_string(),
lan_enabled: false,
config_path: "/tmp/test-numa.toml".to_string(),
config_found: false,
config_dir: PathBuf::from("/tmp"),
data_dir: PathBuf::from("/tmp"),
tls_config: None,
upstream_mode: UpstreamMode::Forward,
root_hints: Vec::new(),
srtt: RwLock::new(SrttCache::new(true)),
inflight: Mutex::new(HashMap::new()),
dnssec_enabled: false,
dnssec_strict: false,
health_meta: HealthMeta::test_fixture(),
ca_pem: None,
mobile_enabled: false,
mobile_port: 8765,
})
}
/// Spawn a UDP socket that replies to the first DNS query with the given
/// response packet (patching the query ID). Returns the socket address.
async fn mock_upstream(response: DnsPacket) -> SocketAddr {
let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let addr = sock.local_addr().unwrap();
tokio::spawn(async move {
let mut buf = [0u8; 512];
let (_, src) = sock.recv_from(&mut buf).await.unwrap();
let query_id = u16::from_be_bytes([buf[0], buf[1]]);
let mut resp = response;
resp.header.id = query_id;
let mut out = BytePacketBuffer::new();
resp.write(&mut out).unwrap();
sock.send_to(out.filled(), src).await.unwrap();
});
addr
}
#[tokio::test] #[tokio::test]
async fn forwarding_rule_overrides_special_use_domain() { async fn forwarding_rule_overrides_special_use_domain() {
let mut resp = DnsPacket::new(); let mut resp = DnsPacket::new();
resp.header.response = true; resp.header.response = true;
resp.header.rescode = ResultCode::NOERROR; resp.header.rescode = ResultCode::NOERROR;
let upstream_addr = mock_upstream(resp).await; let upstream_addr = crate::testutil::mock_upstream(resp).await;
let rules = vec![ForwardingRule::new( let mut ctx = crate::testutil::test_ctx().await;
ctx.forwarding_rules = vec![ForwardingRule::new(
"168.192.in-addr.arpa".to_string(), "168.192.in-addr.arpa".to_string(),
upstream_addr, upstream_addr,
)]; )];
let ctx = test_ctx_with_forwarding(rules).await; let ctx = Arc::new(ctx);
let (resp, path) = let (resp, path) =
resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await; resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await;

View File

@@ -279,7 +279,7 @@ where
mod tests { mod tests {
use super::*; use super::*;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Mutex, RwLock}; use std::sync::Mutex;
use rcgen::{CertificateParams, DnType, KeyPair}; use rcgen::{CertificateParams, DnType, KeyPair};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName};
@@ -344,13 +344,10 @@ mod tests {
async fn spawn_dot_server() -> (SocketAddr, CertificateDer<'static>) { async fn spawn_dot_server() -> (SocketAddr, CertificateDer<'static>) {
let (server_tls, cert_der) = test_tls_configs(); let (server_tls, cert_der) = test_tls_configs();
let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap(); let upstream_addr = crate::testutil::blackhole_upstream();
// Bind an unresponsive upstream and leak it so it lives for the test duration.
let blackhole = Box::leak(Box::new(std::net::UdpSocket::bind("127.0.0.1:0").unwrap())); let mut ctx = crate::testutil::test_ctx().await;
let upstream_addr = blackhole.local_addr().unwrap(); ctx.zone_map = {
let ctx = Arc::new(ServerCtx {
socket,
zone_map: {
let mut m = HashMap::new(); let mut m = HashMap::new();
let mut inner = HashMap::new(); let mut inner = HashMap::new();
inner.insert( inner.insert(
@@ -363,44 +360,13 @@ mod tests {
); );
m.insert("dot-test.example".to_string(), inner); m.insert("dot-test.example".to_string(), inner);
m m
}, };
cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)), ctx.upstream_pool = Mutex::new(crate::forward::UpstreamPool::new(
refreshing: Mutex::new(std::collections::HashSet::new()),
stats: Mutex::new(crate::stats::ServerStats::new()),
overrides: RwLock::new(crate::override_store::OverrideStore::new()),
blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()),
query_log: Mutex::new(crate::query_log::QueryLog::new(100)),
services: Mutex::new(crate::service_store::ServiceStore::new()),
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
forwarding_rules: Vec::new(),
upstream_pool: Mutex::new(crate::forward::UpstreamPool::new(
vec![crate::forward::Upstream::Udp(upstream_addr)], vec![crate::forward::Upstream::Udp(upstream_addr)],
vec![], vec![],
)), ));
upstream_auto: false, ctx.tls_config = Some(arc_swap::ArcSwap::from(server_tls));
upstream_port: 53, let ctx = Arc::new(ctx);
lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST),
timeout: Duration::from_millis(200),
hedge_delay: Duration::ZERO,
proxy_tld: "numa".to_string(),
proxy_tld_suffix: ".numa".to_string(),
lan_enabled: false,
config_path: String::new(),
config_found: false,
config_dir: std::path::PathBuf::from("/tmp"),
data_dir: std::path::PathBuf::from("/tmp"),
tls_config: Some(arc_swap::ArcSwap::from(server_tls)),
upstream_mode: crate::config::UpstreamMode::Forward,
root_hints: Vec::new(),
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
inflight: Mutex::new(HashMap::new()),
dnssec_enabled: false,
dnssec_strict: false,
health_meta: crate::health::HealthMeta::test_fixture(),
ca_pem: None,
mobile_enabled: false,
mobile_port: 8765,
});
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap(); let addr = listener.local_addr().unwrap();

View File

@@ -28,6 +28,9 @@ pub mod system_dns;
pub mod tls; pub mod tls;
pub mod wire; pub mod wire;
#[cfg(test)]
pub(crate) mod testutil;
pub type Error = Box<dyn std::error::Error + Send + Sync>; pub type Error = Box<dyn std::error::Error + Send + Sync>;
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;

95
src/testutil.rs Normal file
View File

@@ -0,0 +1,95 @@
use std::collections::{HashMap, HashSet};
use std::net::{Ipv4Addr, SocketAddr};
use std::path::PathBuf;
use std::sync::{Mutex, RwLock};
use std::time::Duration;
use tokio::net::UdpSocket;
use crate::blocklist::BlocklistStore;
use crate::buffer::BytePacketBuffer;
use crate::cache::DnsCache;
use crate::config::UpstreamMode;
use crate::ctx::ServerCtx;
use crate::forward::{Upstream, UpstreamPool};
use crate::health::HealthMeta;
use crate::lan::PeerStore;
use crate::override_store::OverrideStore;
use crate::packet::DnsPacket;
use crate::query_log::QueryLog;
use crate::service_store::ServiceStore;
use crate::srtt::SrttCache;
use crate::stats::ServerStats;
/// Minimal `ServerCtx` for tests. Override fields after construction
/// (all fields are `pub`), then wrap in `Arc`.
pub async fn test_ctx() -> ServerCtx {
let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
ServerCtx {
socket,
zone_map: HashMap::new(),
cache: RwLock::new(DnsCache::new(100, 60, 86400)),
refreshing: Mutex::new(HashSet::new()),
stats: Mutex::new(ServerStats::new()),
overrides: RwLock::new(OverrideStore::new()),
blocklist: RwLock::new(BlocklistStore::new()),
query_log: Mutex::new(QueryLog::new(100)),
services: Mutex::new(ServiceStore::new()),
lan_peers: Mutex::new(PeerStore::new(90)),
forwarding_rules: Vec::new(),
upstream_pool: Mutex::new(UpstreamPool::new(
vec![Upstream::Udp("127.0.0.1:53".parse().unwrap())],
vec![],
)),
upstream_auto: false,
upstream_port: 53,
lan_ip: Mutex::new(Ipv4Addr::LOCALHOST),
timeout: Duration::from_millis(200),
hedge_delay: Duration::ZERO,
proxy_tld: "numa".to_string(),
proxy_tld_suffix: ".numa".to_string(),
lan_enabled: false,
config_path: "/tmp/test-numa.toml".to_string(),
config_found: false,
config_dir: PathBuf::from("/tmp"),
data_dir: PathBuf::from("/tmp"),
tls_config: None,
upstream_mode: UpstreamMode::Forward,
root_hints: Vec::new(),
srtt: RwLock::new(SrttCache::new(true)),
inflight: Mutex::new(HashMap::new()),
dnssec_enabled: false,
dnssec_strict: false,
health_meta: HealthMeta::test_fixture(),
ca_pem: None,
mobile_enabled: false,
mobile_port: 8765,
}
}
/// Spawn a UDP socket that replies to the first DNS query with the given
/// response packet (patching the query ID to match). Returns the socket address.
pub async fn mock_upstream(response: DnsPacket) -> SocketAddr {
let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let addr = sock.local_addr().unwrap();
tokio::spawn(async move {
let mut buf = [0u8; 512];
let (_, src) = sock.recv_from(&mut buf).await.unwrap();
let query_id = u16::from_be_bytes([buf[0], buf[1]]);
let mut resp = response;
resp.header.id = query_id;
let mut out = BytePacketBuffer::new();
resp.write(&mut out).unwrap();
sock.send_to(out.filled(), src).await.unwrap();
});
addr
}
/// UDP socket that accepts connections but never replies.
/// Useful as an upstream that triggers timeouts.
pub fn blackhole_upstream() -> SocketAddr {
let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
let addr = sock.local_addr().unwrap();
// Leak so it stays bound for the duration of the test process.
Box::leak(Box::new(sock));
addr
}