feat: TCP fallback, query minimization, UDP auto-disable

Transport resilience for restrictive networks (ISPs blocking UDP:53):
- DNS-over-TCP fallback: UDP fail/truncation → automatic TCP retry
- UDP auto-disable: after 3 consecutive failures, switch to TCP-first
- IPv6 → TCP directly (UDP socket binds 0.0.0.0, can't reach IPv6)
- Network change resets UDP detection for re-probing
- Root hint rotation in TLD priming

Privacy:
- RFC 7816 query minimization: root servers see TLD only, not full name

Code quality:
- Merged find_starting_ns + find_starting_zone → find_closest_ns
- Extracted resolve_ns_addrs_from_glue shared helper
- Removed overall timeout wrapper (per-hop timeouts sufficient)
- forward_tcp for DNS-over-TCP (RFC 1035 §4.2.2)

Testing:
- Mock TCP-only DNS server for fallback tests (no network needed)
- tcp_fallback_resolves_when_udp_blocked
- tcp_only_iterative_resolution
- tcp_fallback_handles_nxdomain
- udp_auto_disable_resets
- Integration test suite (4 suites, 51 tests)
- Network probe script (tests/network-probe.sh)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-03-27 19:50:19 +02:00
parent 637b374d8b
commit 5b2cc874a1
11 changed files with 1372 additions and 103 deletions

View File

@@ -176,7 +176,7 @@ fn default_upstream_port() -> u16 {
53
}
fn default_timeout_ms() -> u64 {
3000
5000
}
#[derive(Deserialize)]

View File

@@ -154,7 +154,6 @@ pub async fn handle_query(
&qname,
qtype,
&ctx.cache,
ctx.timeout,
&query,
&ctx.root_hints,
)
@@ -162,28 +161,14 @@ pub async fn handle_query(
{
Ok(resp) => (resp, QueryPath::Recursive),
Err(e) => {
// Auto-fallback: retry via forward upstream if configured
let upstream = ctx.upstream.lock().unwrap().clone();
match forward_query(&query, &upstream, ctx.timeout).await {
Ok(resp) => {
debug!(
"{} | {:?} {} | RECURSIVE FALLBACK → FORWARD | {}",
src_addr, qtype, qname, e
);
ctx.cache.write().unwrap().insert(&qname, qtype, &resp);
(resp, QueryPath::Forwarded)
}
Err(e2) => {
error!(
"{} | {:?} {} | RECURSIVE+FORWARD FAILED | recursive: {} | forward: {}",
src_addr, qtype, qname, e, e2
);
(
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
)
}
}
error!(
"{} | {:?} {} | RECURSIVE ERROR | {}",
src_addr, qtype, qname, e
);
(
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
)
}
}
} else {

View File

@@ -74,6 +74,37 @@ pub(crate) async fn forward_udp(
DnsPacket::from_buffer(&mut recv_buffer)
}
/// DNS over TCP (RFC 1035 §4.2.2): 2-byte length prefix, then the DNS message.
pub(crate) async fn forward_tcp(
query: &DnsPacket,
upstream: SocketAddr,
timeout_duration: Duration,
) -> Result<DnsPacket> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
let mut send_buffer = BytePacketBuffer::new();
query.write(&mut send_buffer)?;
let msg = send_buffer.filled();
let mut stream = timeout(timeout_duration, TcpStream::connect(upstream)).await??;
// Write length-prefixed message
stream.write_all(&(msg.len() as u16).to_be_bytes()).await?;
stream.write_all(msg).await?;
// Read length-prefixed response
let mut len_buf = [0u8; 2];
timeout(timeout_duration, stream.read_exact(&mut len_buf)).await??;
let resp_len = u16::from_be_bytes(len_buf) as usize;
let mut data = vec![0u8; resp_len];
timeout(timeout_duration, stream.read_exact(&mut data)).await??;
let mut recv_buffer = BytePacketBuffer::from_bytes(&data);
DnsPacket::from_buffer(&mut recv_buffer)
}
async fn forward_doh(
query: &DnsPacket,
url: &str,

View File

@@ -447,6 +447,7 @@ async fn network_watch_loop(ctx: Arc<numa::ctx::ServerCtx>) {
info!("LAN IP changed: {} → {}", current_ip, new_ip);
*current_ip = new_ip;
changed = true;
numa::recursive::reset_udp_state();
}
}

View File

@@ -4,7 +4,6 @@ use std::sync::RwLock;
use std::time::Duration;
use log::{debug, info};
use tokio::time::timeout;
use crate::cache::DnsCache;
use crate::forward::forward_udp;
@@ -15,9 +14,13 @@ use crate::record::DnsRecord;
const MAX_REFERRAL_DEPTH: u8 = 10;
const MAX_CNAME_DEPTH: u8 = 8;
const NS_QUERY_TIMEOUT: Duration = Duration::from_secs(2);
const NS_QUERY_TIMEOUT: Duration = Duration::from_millis(800);
const TCP_TIMEOUT: Duration = Duration::from_millis(1500);
const UDP_FAIL_THRESHOLD: u8 = 3;
static QUERY_ID: AtomicU16 = AtomicU16::new(1);
static UDP_FAILURES: std::sync::atomic::AtomicU8 = std::sync::atomic::AtomicU8::new(0);
static UDP_DISABLED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
fn next_id() -> u16 {
QUERY_ID.fetch_add(1, Ordering::Relaxed)
@@ -27,17 +30,31 @@ fn dns_addr(ip: impl Into<IpAddr>) -> SocketAddr {
SocketAddr::new(ip.into(), 53)
}
/// Query root servers for common TLDs and cache NS + glue + DNSKEY + DS records.
/// Pre-warms the DNSSEC trust chain so first queries skip chain-walking I/O.
pub fn reset_udp_state() {
UDP_DISABLED.store(false, Ordering::Relaxed);
UDP_FAILURES.store(0, Ordering::Relaxed);
}
pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr], tlds: &[String]) {
let root_addr = match root_hints.first() {
Some(addr) => *addr,
None => return,
};
if tlds.is_empty() {
if root_hints.is_empty() || tlds.is_empty() {
return;
}
let mut root_addr = root_hints[0];
for hint in root_hints {
info!("prime: probing root {}", hint);
match send_query(".", QueryType::NS, *hint).await {
Ok(_) => {
info!("prime: root {} reachable", hint);
root_addr = *hint;
break;
}
Err(e) => {
info!("prime: root {} failed: {}, trying next", hint, e);
}
}
}
// Fetch root DNSKEY (needed for DNSSEC chain-of-trust terminus)
if let Ok(root_dnskey) = send_query(".", QueryType::DNSKEY, root_addr).await {
cache
@@ -98,19 +115,12 @@ pub async fn resolve_recursive(
qname: &str,
qtype: QueryType,
cache: &RwLock<DnsCache>,
overall_timeout: Duration,
original_query: &DnsPacket,
root_hints: &[SocketAddr],
) -> crate::Result<DnsPacket> {
let mut resp = match timeout(
overall_timeout,
resolve_iterative(qname, qtype, cache, root_hints, 0, 0),
)
.await
{
Ok(result) => result?,
Err(_) => return Err(format!("recursive resolution timed out for {}", qname).into()),
};
// No overall timeout — each hop is bounded by NS_QUERY_TIMEOUT (UDP + TCP fallback),
// and MAX_REFERRAL_DEPTH caps the chain length.
let mut resp = resolve_iterative(qname, qtype, cache, root_hints, 0, 0).await?;
resp.header.id = original_query.header.id;
resp.header.recursion_available = true;
@@ -136,7 +146,7 @@ pub(crate) fn resolve_iterative<'a>(
return Ok(cached);
}
let mut ns_addrs = find_starting_ns(qname, cache, root_hints);
let (mut current_zone, mut ns_addrs) = find_closest_ns(qname, cache, root_hints);
let mut ns_idx = 0;
for _ in 0..MAX_REFERRAL_DEPTH {
@@ -145,12 +155,14 @@ pub(crate) fn resolve_iterative<'a>(
None => return Err("no nameserver available".into()),
};
let (q_name, q_type) = minimize_query(qname, qtype, &current_zone);
debug!(
"recursive: querying {} for {:?} {} (depth {})",
ns_addr, qtype, qname, referral_depth
"recursive: querying {} for {:?} {} (zone: {}, depth {})",
ns_addr, q_type, q_name, current_zone, referral_depth
);
let response = match send_query(qname, qtype, ns_addr).await {
let response = match send_query(q_name, q_type, ns_addr).await {
Ok(r) => r,
Err(e) => {
debug!("recursive: NS {} failed: {}", ns_addr, e);
@@ -159,6 +171,27 @@ pub(crate) fn resolve_iterative<'a>(
}
};
// Minimized query response — treat as referral, not final answer
if (q_type != qtype || !q_name.eq_ignore_ascii_case(qname))
&& (!response.authorities.is_empty() || !response.answers.is_empty())
{
if let Some(zone) = referral_zone(&response) {
current_zone = zone;
}
let mut all_ns = extract_ns_from_records(&response.answers);
if all_ns.is_empty() {
all_ns = extract_ns_names(&response);
}
let new_addrs = resolve_ns_addrs_from_glue(&response, &all_ns, cache);
if !new_addrs.is_empty() {
ns_addrs = new_addrs;
ns_idx = 0;
continue;
}
ns_idx += 1;
continue;
}
if !response.answers.is_empty() {
let has_target = response.answers.iter().any(|r| r.query_type() == qtype);
@@ -201,32 +234,24 @@ pub(crate) fn resolve_iterative<'a>(
}
// Referral — extract NS + glue, cache glue, resolve NS addresses
// Update zone for query minimization
if let Some(zone) = referral_zone(&response) {
current_zone = zone;
}
let ns_names = extract_ns_names(&response);
if ns_names.is_empty() {
return Ok(response);
}
// Cache glue + DS from referral (avoids separate fetch during DNSSEC validation)
let mut new_ns_addrs = Vec::new();
{
let mut cache_w = cache.write().unwrap();
cache_glue(&mut cache_w, &response, &ns_names);
cache_ds_from_authority(&mut cache_w, &response);
}
for ns_name in &ns_names {
let glue = glue_addrs_for(&response, ns_name);
if !glue.is_empty() {
new_ns_addrs.extend_from_slice(&glue);
break;
}
}
let mut new_ns_addrs = resolve_ns_addrs_from_glue(&response, &ns_names, cache);
// If no glue, try cache (A then AAAA) then recursive resolve
if new_ns_addrs.is_empty() {
for ns_name in &ns_names {
new_ns_addrs.extend(addrs_from_cache(cache, ns_name));
if new_ns_addrs.is_empty() && referral_depth < MAX_REFERRAL_DEPTH {
if referral_depth < MAX_REFERRAL_DEPTH {
debug!("recursive: resolving glue-less NS {}", ns_name);
// Try A first, then AAAA
for qt in [QueryType::A, QueryType::AAAA] {
@@ -276,11 +301,13 @@ pub(crate) fn resolve_iterative<'a>(
})
}
fn find_starting_ns(
/// Find the closest cached NS zone and its resolved addresses.
/// Returns (zone_name, ns_addresses). Falls back to (".", root_hints).
fn find_closest_ns(
qname: &str,
cache: &RwLock<DnsCache>,
root_hints: &[SocketAddr],
) -> Vec<SocketAddr> {
) -> (String, Vec<SocketAddr>) {
let guard = cache.read().unwrap();
let mut pos = 0;
@@ -294,12 +321,8 @@ fn find_starting_ns(
if let Some(resp) = guard.lookup(host, qt) {
for rec in &resp.answers {
match rec {
DnsRecord::A { addr, .. } => {
addrs.push(dns_addr(*addr));
}
DnsRecord::AAAA { addr, .. } => {
addrs.push(dns_addr(*addr));
}
DnsRecord::A { addr, .. } => addrs.push(dns_addr(*addr)),
DnsRecord::AAAA { addr, .. } => addrs.push(dns_addr(*addr)),
_ => {}
}
}
@@ -309,7 +332,7 @@ fn find_starting_ns(
}
if !addrs.is_empty() {
debug!("recursive: starting from cached NS for zone '{}'", zone);
return addrs;
return (zone.to_string(), addrs);
}
}
@@ -324,7 +347,70 @@ fn find_starting_ns(
"recursive: starting from root hints ({} servers)",
root_hints.len()
);
root_hints.to_vec()
(".".to_string(), root_hints.to_vec())
}
/// Extract NS hostnames from any record section (answers or authorities).
fn extract_ns_from_records(records: &[DnsRecord]) -> Vec<String> {
records
.iter()
.filter_map(|r| match r {
DnsRecord::NS { host, .. } => Some(host.clone()),
_ => None,
})
.collect()
}
/// Resolve NS addresses from glue records, then cache fallback.
fn resolve_ns_addrs_from_glue(
response: &DnsPacket,
ns_names: &[String],
cache: &RwLock<DnsCache>,
) -> Vec<SocketAddr> {
let mut addrs = Vec::new();
{
let mut cache_w = cache.write().unwrap();
cache_glue(&mut cache_w, response, ns_names);
}
for ns_name in ns_names {
let glue = glue_addrs_for(response, ns_name);
if !glue.is_empty() {
addrs.extend_from_slice(&glue);
break;
}
}
if addrs.is_empty() {
for ns_name in ns_names {
addrs.extend(addrs_from_cache(cache, ns_name));
if !addrs.is_empty() {
break;
}
}
}
addrs
}
fn referral_zone(response: &DnsPacket) -> Option<String> {
response.authorities.iter().find_map(|r| match r {
DnsRecord::NS { domain, .. } => Some(domain.clone()),
_ => None,
})
}
/// RFC 7816 query minimization (conservative): only minimize at root.
fn minimize_query<'a>(
qname: &'a str,
qtype: QueryType,
current_zone: &str,
) -> (&'a str, QueryType) {
if current_zone != "." {
return (qname, qtype);
}
// At root: extract TLD (last label)
match qname.rfind('.') {
Some(dot) if dot > 0 => (&qname[dot + 1..], QueryType::NS),
_ => (qname, qtype),
}
}
fn addrs_from_cache(cache: &RwLock<DnsCache>, name: &str) -> Vec<SocketAddr> {
@@ -461,7 +547,40 @@ async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate:
do_bit: true,
..Default::default()
});
forward_udp(&query, server, NS_QUERY_TIMEOUT).await
// Skip IPv6 if the socket can't handle it (bound to 0.0.0.0)
if server.is_ipv6() {
return crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await;
}
// If UDP has been detected as blocked, go TCP-first
if UDP_DISABLED.load(Ordering::Relaxed) {
return crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await;
}
match forward_udp(&query, server, NS_QUERY_TIMEOUT).await {
Ok(resp) if resp.header.truncated_message => {
debug!("send_query: truncated from {}, retrying TCP", server);
crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await
}
Ok(resp) => {
// UDP works — reset failure counter
UDP_FAILURES.store(0, Ordering::Relaxed);
Ok(resp)
}
Err(e) => {
let fails = UDP_FAILURES.fetch_add(1, Ordering::Relaxed) + 1;
if fails >= UDP_FAIL_THRESHOLD && !UDP_DISABLED.load(Ordering::Relaxed) {
UDP_DISABLED.store(true, Ordering::Relaxed);
info!(
"send_query: {} consecutive UDP failures — switching to TCP-first",
fails
);
}
debug!("send_query: UDP failed for {}: {}, trying TCP", server, e);
crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await
}
}
}
fn extract_cname_target(response: &DnsPacket, qname: &str) -> Option<String> {
@@ -589,13 +708,216 @@ mod tests {
}
#[test]
fn find_starting_ns_falls_back_to_hints() {
fn find_closest_ns_falls_back_to_hints() {
let cache = RwLock::new(DnsCache::new(100, 60, 86400));
let hints = vec![
dns_addr(Ipv4Addr::new(198, 41, 0, 4)),
dns_addr(Ipv4Addr::new(199, 9, 14, 201)),
];
let addrs = find_starting_ns("example.com", &cache, &hints);
let (zone, addrs) = find_closest_ns("example.com", &cache, &hints);
assert_eq!(zone, ".");
assert_eq!(addrs, hints);
}
#[test]
fn minimize_query_from_root() {
// At root, only reveal TLD
let (name, qt) = minimize_query("www.example.com", QueryType::A, ".");
assert_eq!(name, "com");
assert_eq!(qt, QueryType::NS);
}
#[test]
fn minimize_query_beyond_root_sends_full() {
// Beyond root, send full query (conservative minimization)
let (name, qt) = minimize_query("www.example.com", QueryType::A, "com");
assert_eq!(name, "www.example.com");
assert_eq!(qt, QueryType::A);
let (name, qt) = minimize_query("www.example.com", QueryType::A, "example.com");
assert_eq!(name, "www.example.com");
assert_eq!(qt, QueryType::A);
}
#[test]
fn minimize_query_single_label() {
// Single label (e.g., "com") from root — send as-is
let (name, qt) = minimize_query("com", QueryType::NS, ".");
assert_eq!(name, "com");
assert_eq!(qt, QueryType::NS);
}
// ---- Mock DNS server (TCP-only) for fallback tests ----
use crate::buffer::BytePacketBuffer;
use crate::header::ResultCode;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
/// Spawn a TCP-only DNS server on localhost. Returns the address.
/// The handler receives each query and returns a response packet.
async fn spawn_tcp_dns_server(
handler: impl Fn(&DnsPacket) -> DnsPacket + Send + Sync + 'static,
) -> SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handler = std::sync::Arc::new(handler);
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(c) => c,
Err(_) => break,
};
let handler = handler.clone();
tokio::spawn(async move {
// Read length-prefixed DNS query
let mut len_buf = [0u8; 2];
if stream.read_exact(&mut len_buf).await.is_err() {
return;
}
let len = u16::from_be_bytes(len_buf) as usize;
let mut data = vec![0u8; len];
if stream.read_exact(&mut data).await.is_err() {
return;
}
let mut buf = BytePacketBuffer::from_bytes(&data);
let query = match DnsPacket::from_buffer(&mut buf) {
Ok(q) => q,
Err(_) => return,
};
let response = handler(&query);
let mut resp_buf = BytePacketBuffer::new();
if response.write(&mut resp_buf).is_err() {
return;
}
let resp_bytes = resp_buf.filled();
let _ = stream
.write_all(&(resp_bytes.len() as u16).to_be_bytes())
.await;
let _ = stream.write_all(resp_bytes).await;
});
}
});
addr
}
/// TCP-only server returns authoritative answer directly.
/// Verifies: UDP fails → TCP fallback → resolves.
#[tokio::test]
async fn tcp_fallback_resolves_when_udp_blocked() {
UDP_DISABLED.store(false, Ordering::Relaxed);
UDP_FAILURES.store(0, Ordering::Relaxed);
let server_addr = spawn_tcp_dns_server(|query| {
let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR);
resp.header.authoritative_answer = true;
if let Some(q) = query.questions.first() {
if q.qtype == QueryType::A || q.qtype == QueryType::NS {
resp.answers.push(DnsRecord::A {
domain: q.name.clone(),
addr: Ipv4Addr::new(10, 0, 0, 1),
ttl: 300,
});
}
}
resp
})
.await;
let result = send_query("test.example.com", QueryType::A, server_addr).await;
let resp = result.expect("should resolve via TCP fallback");
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
assert!(!resp.answers.is_empty());
match &resp.answers[0] {
DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::new(10, 0, 0, 1)),
other => panic!("expected A record, got {:?}", other),
}
}
/// Full iterative resolution through TCP-only mock: root referral → authoritative answer.
/// The mock plays both roles (returns referral for NS queries, answer for A queries).
#[tokio::test]
async fn tcp_only_iterative_resolution() {
UDP_DISABLED.store(true, Ordering::Relaxed); // Skip UDP entirely for speed
let server_addr = spawn_tcp_dns_server(|query| {
let q = match query.questions.first() {
Some(q) => q,
None => return DnsPacket::response_from(query, ResultCode::SERVFAIL),
};
if q.qtype == QueryType::NS || q.name == "com" {
// Return referral — NS points back to ourselves (same IP, port 53 in glue
// won't work, but cache will have our address from root_hints)
let mut resp = DnsPacket::new();
resp.header.id = query.header.id;
resp.header.response = true;
resp.header.rescode = ResultCode::NOERROR;
resp.questions = query.questions.clone();
resp.authorities.push(DnsRecord::NS {
domain: "com".into(),
host: "ns1.com".into(),
ttl: 3600,
});
resp
} else {
// Return authoritative answer
let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR);
resp.header.authoritative_answer = true;
resp.answers.push(DnsRecord::A {
domain: q.name.clone(),
addr: Ipv4Addr::new(10, 0, 0, 42),
ttl: 300,
});
resp
}
})
.await;
let result = send_query("hello.example.com", QueryType::A, server_addr).await;
let resp = result.expect("TCP-only send_query should work");
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
match &resp.answers[0] {
DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::new(10, 0, 0, 42)),
other => panic!("expected A, got {:?}", other),
}
}
#[tokio::test]
async fn tcp_fallback_handles_nxdomain() {
UDP_DISABLED.store(false, Ordering::Relaxed);
UDP_FAILURES.store(0, Ordering::Relaxed);
let server_addr = spawn_tcp_dns_server(|query| {
let mut resp = DnsPacket::response_from(query, ResultCode::NXDOMAIN);
resp.header.authoritative_answer = true;
resp
})
.await;
let cache = RwLock::new(DnsCache::new(100, 60, 86400));
let root_hints = vec![server_addr];
let result =
resolve_iterative("nonexistent.test", QueryType::A, &cache, &root_hints, 0, 0).await;
let resp = result.expect("NXDOMAIN should still return a response");
assert_eq!(resp.header.rescode, ResultCode::NXDOMAIN);
assert!(resp.answers.is_empty());
}
#[tokio::test]
async fn udp_auto_disable_resets() {
UDP_DISABLED.store(true, Ordering::Relaxed);
UDP_FAILURES.store(5, Ordering::Relaxed);
reset_udp_state();
assert!(!UDP_DISABLED.load(Ordering::Relaxed));
assert_eq!(UDP_FAILURES.load(Ordering::Relaxed), 0);
}
}