feat: TCP fallback, query minimization, UDP auto-disable
Transport resilience for restrictive networks (ISPs blocking UDP:53): - DNS-over-TCP fallback: UDP fail/truncation → automatic TCP retry - UDP auto-disable: after 3 consecutive failures, switch to TCP-first - IPv6 → TCP directly (UDP socket binds 0.0.0.0, can't reach IPv6) - Network change resets UDP detection for re-probing - Root hint rotation in TLD priming Privacy: - RFC 7816 query minimization: root servers see TLD only, not full name Code quality: - Merged find_starting_ns + find_starting_zone → find_closest_ns - Extracted resolve_ns_addrs_from_glue shared helper - Removed overall timeout wrapper (per-hop timeouts sufficient) - forward_tcp for DNS-over-TCP (RFC 1035 §4.2.2) Testing: - Mock TCP-only DNS server for fallback tests (no network needed) - tcp_fallback_resolves_when_udp_blocked - tcp_only_iterative_resolution - tcp_fallback_handles_nxdomain - udp_auto_disable_resets - Integration test suite (4 suites, 51 tests) - Network probe script (tests/network-probe.sh) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -176,7 +176,7 @@ fn default_upstream_port() -> u16 {
|
||||
53
|
||||
}
|
||||
fn default_timeout_ms() -> u64 {
|
||||
3000
|
||||
5000
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
||||
31
src/ctx.rs
31
src/ctx.rs
@@ -154,7 +154,6 @@ pub async fn handle_query(
|
||||
&qname,
|
||||
qtype,
|
||||
&ctx.cache,
|
||||
ctx.timeout,
|
||||
&query,
|
||||
&ctx.root_hints,
|
||||
)
|
||||
@@ -162,28 +161,14 @@ pub async fn handle_query(
|
||||
{
|
||||
Ok(resp) => (resp, QueryPath::Recursive),
|
||||
Err(e) => {
|
||||
// Auto-fallback: retry via forward upstream if configured
|
||||
let upstream = ctx.upstream.lock().unwrap().clone();
|
||||
match forward_query(&query, &upstream, ctx.timeout).await {
|
||||
Ok(resp) => {
|
||||
debug!(
|
||||
"{} | {:?} {} | RECURSIVE FALLBACK → FORWARD | {}",
|
||||
src_addr, qtype, qname, e
|
||||
);
|
||||
ctx.cache.write().unwrap().insert(&qname, qtype, &resp);
|
||||
(resp, QueryPath::Forwarded)
|
||||
}
|
||||
Err(e2) => {
|
||||
error!(
|
||||
"{} | {:?} {} | RECURSIVE+FORWARD FAILED | recursive: {} | forward: {}",
|
||||
src_addr, qtype, qname, e, e2
|
||||
);
|
||||
(
|
||||
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
|
||||
QueryPath::UpstreamError,
|
||||
)
|
||||
}
|
||||
}
|
||||
error!(
|
||||
"{} | {:?} {} | RECURSIVE ERROR | {}",
|
||||
src_addr, qtype, qname, e
|
||||
);
|
||||
(
|
||||
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
|
||||
QueryPath::UpstreamError,
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -74,6 +74,37 @@ pub(crate) async fn forward_udp(
|
||||
DnsPacket::from_buffer(&mut recv_buffer)
|
||||
}
|
||||
|
||||
/// DNS over TCP (RFC 1035 §4.2.2): 2-byte length prefix, then the DNS message.
|
||||
pub(crate) async fn forward_tcp(
|
||||
query: &DnsPacket,
|
||||
upstream: SocketAddr,
|
||||
timeout_duration: Duration,
|
||||
) -> Result<DnsPacket> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
let mut send_buffer = BytePacketBuffer::new();
|
||||
query.write(&mut send_buffer)?;
|
||||
let msg = send_buffer.filled();
|
||||
|
||||
let mut stream = timeout(timeout_duration, TcpStream::connect(upstream)).await??;
|
||||
|
||||
// Write length-prefixed message
|
||||
stream.write_all(&(msg.len() as u16).to_be_bytes()).await?;
|
||||
stream.write_all(msg).await?;
|
||||
|
||||
// Read length-prefixed response
|
||||
let mut len_buf = [0u8; 2];
|
||||
timeout(timeout_duration, stream.read_exact(&mut len_buf)).await??;
|
||||
let resp_len = u16::from_be_bytes(len_buf) as usize;
|
||||
|
||||
let mut data = vec![0u8; resp_len];
|
||||
timeout(timeout_duration, stream.read_exact(&mut data)).await??;
|
||||
|
||||
let mut recv_buffer = BytePacketBuffer::from_bytes(&data);
|
||||
DnsPacket::from_buffer(&mut recv_buffer)
|
||||
}
|
||||
|
||||
async fn forward_doh(
|
||||
query: &DnsPacket,
|
||||
url: &str,
|
||||
|
||||
@@ -447,6 +447,7 @@ async fn network_watch_loop(ctx: Arc<numa::ctx::ServerCtx>) {
|
||||
info!("LAN IP changed: {} → {}", current_ip, new_ip);
|
||||
*current_ip = new_ip;
|
||||
changed = true;
|
||||
numa::recursive::reset_udp_state();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
422
src/recursive.rs
422
src/recursive.rs
@@ -4,7 +4,6 @@ use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use log::{debug, info};
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::cache::DnsCache;
|
||||
use crate::forward::forward_udp;
|
||||
@@ -15,9 +14,13 @@ use crate::record::DnsRecord;
|
||||
|
||||
const MAX_REFERRAL_DEPTH: u8 = 10;
|
||||
const MAX_CNAME_DEPTH: u8 = 8;
|
||||
const NS_QUERY_TIMEOUT: Duration = Duration::from_secs(2);
|
||||
const NS_QUERY_TIMEOUT: Duration = Duration::from_millis(800);
|
||||
const TCP_TIMEOUT: Duration = Duration::from_millis(1500);
|
||||
const UDP_FAIL_THRESHOLD: u8 = 3;
|
||||
|
||||
static QUERY_ID: AtomicU16 = AtomicU16::new(1);
|
||||
static UDP_FAILURES: std::sync::atomic::AtomicU8 = std::sync::atomic::AtomicU8::new(0);
|
||||
static UDP_DISABLED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
fn next_id() -> u16 {
|
||||
QUERY_ID.fetch_add(1, Ordering::Relaxed)
|
||||
@@ -27,17 +30,31 @@ fn dns_addr(ip: impl Into<IpAddr>) -> SocketAddr {
|
||||
SocketAddr::new(ip.into(), 53)
|
||||
}
|
||||
|
||||
/// Query root servers for common TLDs and cache NS + glue + DNSKEY + DS records.
|
||||
/// Pre-warms the DNSSEC trust chain so first queries skip chain-walking I/O.
|
||||
pub fn reset_udp_state() {
|
||||
UDP_DISABLED.store(false, Ordering::Relaxed);
|
||||
UDP_FAILURES.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn prime_tld_cache(cache: &RwLock<DnsCache>, root_hints: &[SocketAddr], tlds: &[String]) {
|
||||
let root_addr = match root_hints.first() {
|
||||
Some(addr) => *addr,
|
||||
None => return,
|
||||
};
|
||||
if tlds.is_empty() {
|
||||
if root_hints.is_empty() || tlds.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut root_addr = root_hints[0];
|
||||
for hint in root_hints {
|
||||
info!("prime: probing root {}", hint);
|
||||
match send_query(".", QueryType::NS, *hint).await {
|
||||
Ok(_) => {
|
||||
info!("prime: root {} reachable", hint);
|
||||
root_addr = *hint;
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
info!("prime: root {} failed: {}, trying next", hint, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch root DNSKEY (needed for DNSSEC chain-of-trust terminus)
|
||||
if let Ok(root_dnskey) = send_query(".", QueryType::DNSKEY, root_addr).await {
|
||||
cache
|
||||
@@ -98,19 +115,12 @@ pub async fn resolve_recursive(
|
||||
qname: &str,
|
||||
qtype: QueryType,
|
||||
cache: &RwLock<DnsCache>,
|
||||
overall_timeout: Duration,
|
||||
original_query: &DnsPacket,
|
||||
root_hints: &[SocketAddr],
|
||||
) -> crate::Result<DnsPacket> {
|
||||
let mut resp = match timeout(
|
||||
overall_timeout,
|
||||
resolve_iterative(qname, qtype, cache, root_hints, 0, 0),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result?,
|
||||
Err(_) => return Err(format!("recursive resolution timed out for {}", qname).into()),
|
||||
};
|
||||
// No overall timeout — each hop is bounded by NS_QUERY_TIMEOUT (UDP + TCP fallback),
|
||||
// and MAX_REFERRAL_DEPTH caps the chain length.
|
||||
let mut resp = resolve_iterative(qname, qtype, cache, root_hints, 0, 0).await?;
|
||||
|
||||
resp.header.id = original_query.header.id;
|
||||
resp.header.recursion_available = true;
|
||||
@@ -136,7 +146,7 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
let mut ns_addrs = find_starting_ns(qname, cache, root_hints);
|
||||
let (mut current_zone, mut ns_addrs) = find_closest_ns(qname, cache, root_hints);
|
||||
let mut ns_idx = 0;
|
||||
|
||||
for _ in 0..MAX_REFERRAL_DEPTH {
|
||||
@@ -145,12 +155,14 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
None => return Err("no nameserver available".into()),
|
||||
};
|
||||
|
||||
let (q_name, q_type) = minimize_query(qname, qtype, ¤t_zone);
|
||||
|
||||
debug!(
|
||||
"recursive: querying {} for {:?} {} (depth {})",
|
||||
ns_addr, qtype, qname, referral_depth
|
||||
"recursive: querying {} for {:?} {} (zone: {}, depth {})",
|
||||
ns_addr, q_type, q_name, current_zone, referral_depth
|
||||
);
|
||||
|
||||
let response = match send_query(qname, qtype, ns_addr).await {
|
||||
let response = match send_query(q_name, q_type, ns_addr).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
debug!("recursive: NS {} failed: {}", ns_addr, e);
|
||||
@@ -159,6 +171,27 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
}
|
||||
};
|
||||
|
||||
// Minimized query response — treat as referral, not final answer
|
||||
if (q_type != qtype || !q_name.eq_ignore_ascii_case(qname))
|
||||
&& (!response.authorities.is_empty() || !response.answers.is_empty())
|
||||
{
|
||||
if let Some(zone) = referral_zone(&response) {
|
||||
current_zone = zone;
|
||||
}
|
||||
let mut all_ns = extract_ns_from_records(&response.answers);
|
||||
if all_ns.is_empty() {
|
||||
all_ns = extract_ns_names(&response);
|
||||
}
|
||||
let new_addrs = resolve_ns_addrs_from_glue(&response, &all_ns, cache);
|
||||
if !new_addrs.is_empty() {
|
||||
ns_addrs = new_addrs;
|
||||
ns_idx = 0;
|
||||
continue;
|
||||
}
|
||||
ns_idx += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if !response.answers.is_empty() {
|
||||
let has_target = response.answers.iter().any(|r| r.query_type() == qtype);
|
||||
|
||||
@@ -201,32 +234,24 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
}
|
||||
|
||||
// Referral — extract NS + glue, cache glue, resolve NS addresses
|
||||
// Update zone for query minimization
|
||||
if let Some(zone) = referral_zone(&response) {
|
||||
current_zone = zone;
|
||||
}
|
||||
let ns_names = extract_ns_names(&response);
|
||||
if ns_names.is_empty() {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
// Cache glue + DS from referral (avoids separate fetch during DNSSEC validation)
|
||||
let mut new_ns_addrs = Vec::new();
|
||||
{
|
||||
let mut cache_w = cache.write().unwrap();
|
||||
cache_glue(&mut cache_w, &response, &ns_names);
|
||||
cache_ds_from_authority(&mut cache_w, &response);
|
||||
}
|
||||
for ns_name in &ns_names {
|
||||
let glue = glue_addrs_for(&response, ns_name);
|
||||
if !glue.is_empty() {
|
||||
new_ns_addrs.extend_from_slice(&glue);
|
||||
break;
|
||||
}
|
||||
}
|
||||
let mut new_ns_addrs = resolve_ns_addrs_from_glue(&response, &ns_names, cache);
|
||||
|
||||
// If no glue, try cache (A then AAAA) then recursive resolve
|
||||
if new_ns_addrs.is_empty() {
|
||||
for ns_name in &ns_names {
|
||||
new_ns_addrs.extend(addrs_from_cache(cache, ns_name));
|
||||
|
||||
if new_ns_addrs.is_empty() && referral_depth < MAX_REFERRAL_DEPTH {
|
||||
if referral_depth < MAX_REFERRAL_DEPTH {
|
||||
debug!("recursive: resolving glue-less NS {}", ns_name);
|
||||
// Try A first, then AAAA
|
||||
for qt in [QueryType::A, QueryType::AAAA] {
|
||||
@@ -276,11 +301,13 @@ pub(crate) fn resolve_iterative<'a>(
|
||||
})
|
||||
}
|
||||
|
||||
fn find_starting_ns(
|
||||
/// Find the closest cached NS zone and its resolved addresses.
|
||||
/// Returns (zone_name, ns_addresses). Falls back to (".", root_hints).
|
||||
fn find_closest_ns(
|
||||
qname: &str,
|
||||
cache: &RwLock<DnsCache>,
|
||||
root_hints: &[SocketAddr],
|
||||
) -> Vec<SocketAddr> {
|
||||
) -> (String, Vec<SocketAddr>) {
|
||||
let guard = cache.read().unwrap();
|
||||
|
||||
let mut pos = 0;
|
||||
@@ -294,12 +321,8 @@ fn find_starting_ns(
|
||||
if let Some(resp) = guard.lookup(host, qt) {
|
||||
for rec in &resp.answers {
|
||||
match rec {
|
||||
DnsRecord::A { addr, .. } => {
|
||||
addrs.push(dns_addr(*addr));
|
||||
}
|
||||
DnsRecord::AAAA { addr, .. } => {
|
||||
addrs.push(dns_addr(*addr));
|
||||
}
|
||||
DnsRecord::A { addr, .. } => addrs.push(dns_addr(*addr)),
|
||||
DnsRecord::AAAA { addr, .. } => addrs.push(dns_addr(*addr)),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@@ -309,7 +332,7 @@ fn find_starting_ns(
|
||||
}
|
||||
if !addrs.is_empty() {
|
||||
debug!("recursive: starting from cached NS for zone '{}'", zone);
|
||||
return addrs;
|
||||
return (zone.to_string(), addrs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,7 +347,70 @@ fn find_starting_ns(
|
||||
"recursive: starting from root hints ({} servers)",
|
||||
root_hints.len()
|
||||
);
|
||||
root_hints.to_vec()
|
||||
(".".to_string(), root_hints.to_vec())
|
||||
}
|
||||
|
||||
/// Extract NS hostnames from any record section (answers or authorities).
|
||||
fn extract_ns_from_records(records: &[DnsRecord]) -> Vec<String> {
|
||||
records
|
||||
.iter()
|
||||
.filter_map(|r| match r {
|
||||
DnsRecord::NS { host, .. } => Some(host.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Resolve NS addresses from glue records, then cache fallback.
|
||||
fn resolve_ns_addrs_from_glue(
|
||||
response: &DnsPacket,
|
||||
ns_names: &[String],
|
||||
cache: &RwLock<DnsCache>,
|
||||
) -> Vec<SocketAddr> {
|
||||
let mut addrs = Vec::new();
|
||||
{
|
||||
let mut cache_w = cache.write().unwrap();
|
||||
cache_glue(&mut cache_w, response, ns_names);
|
||||
}
|
||||
for ns_name in ns_names {
|
||||
let glue = glue_addrs_for(response, ns_name);
|
||||
if !glue.is_empty() {
|
||||
addrs.extend_from_slice(&glue);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if addrs.is_empty() {
|
||||
for ns_name in ns_names {
|
||||
addrs.extend(addrs_from_cache(cache, ns_name));
|
||||
if !addrs.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
addrs
|
||||
}
|
||||
|
||||
fn referral_zone(response: &DnsPacket) -> Option<String> {
|
||||
response.authorities.iter().find_map(|r| match r {
|
||||
DnsRecord::NS { domain, .. } => Some(domain.clone()),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
/// RFC 7816 query minimization (conservative): only minimize at root.
|
||||
fn minimize_query<'a>(
|
||||
qname: &'a str,
|
||||
qtype: QueryType,
|
||||
current_zone: &str,
|
||||
) -> (&'a str, QueryType) {
|
||||
if current_zone != "." {
|
||||
return (qname, qtype);
|
||||
}
|
||||
// At root: extract TLD (last label)
|
||||
match qname.rfind('.') {
|
||||
Some(dot) if dot > 0 => (&qname[dot + 1..], QueryType::NS),
|
||||
_ => (qname, qtype),
|
||||
}
|
||||
}
|
||||
|
||||
fn addrs_from_cache(cache: &RwLock<DnsCache>, name: &str) -> Vec<SocketAddr> {
|
||||
@@ -461,7 +547,40 @@ async fn send_query(qname: &str, qtype: QueryType, server: SocketAddr) -> crate:
|
||||
do_bit: true,
|
||||
..Default::default()
|
||||
});
|
||||
forward_udp(&query, server, NS_QUERY_TIMEOUT).await
|
||||
|
||||
// Skip IPv6 if the socket can't handle it (bound to 0.0.0.0)
|
||||
if server.is_ipv6() {
|
||||
return crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await;
|
||||
}
|
||||
|
||||
// If UDP has been detected as blocked, go TCP-first
|
||||
if UDP_DISABLED.load(Ordering::Relaxed) {
|
||||
return crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await;
|
||||
}
|
||||
|
||||
match forward_udp(&query, server, NS_QUERY_TIMEOUT).await {
|
||||
Ok(resp) if resp.header.truncated_message => {
|
||||
debug!("send_query: truncated from {}, retrying TCP", server);
|
||||
crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await
|
||||
}
|
||||
Ok(resp) => {
|
||||
// UDP works — reset failure counter
|
||||
UDP_FAILURES.store(0, Ordering::Relaxed);
|
||||
Ok(resp)
|
||||
}
|
||||
Err(e) => {
|
||||
let fails = UDP_FAILURES.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
if fails >= UDP_FAIL_THRESHOLD && !UDP_DISABLED.load(Ordering::Relaxed) {
|
||||
UDP_DISABLED.store(true, Ordering::Relaxed);
|
||||
info!(
|
||||
"send_query: {} consecutive UDP failures — switching to TCP-first",
|
||||
fails
|
||||
);
|
||||
}
|
||||
debug!("send_query: UDP failed for {}: {}, trying TCP", server, e);
|
||||
crate::forward::forward_tcp(&query, server, TCP_TIMEOUT).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_cname_target(response: &DnsPacket, qname: &str) -> Option<String> {
|
||||
@@ -589,13 +708,216 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_starting_ns_falls_back_to_hints() {
|
||||
fn find_closest_ns_falls_back_to_hints() {
|
||||
let cache = RwLock::new(DnsCache::new(100, 60, 86400));
|
||||
let hints = vec![
|
||||
dns_addr(Ipv4Addr::new(198, 41, 0, 4)),
|
||||
dns_addr(Ipv4Addr::new(199, 9, 14, 201)),
|
||||
];
|
||||
let addrs = find_starting_ns("example.com", &cache, &hints);
|
||||
let (zone, addrs) = find_closest_ns("example.com", &cache, &hints);
|
||||
assert_eq!(zone, ".");
|
||||
assert_eq!(addrs, hints);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn minimize_query_from_root() {
|
||||
// At root, only reveal TLD
|
||||
let (name, qt) = minimize_query("www.example.com", QueryType::A, ".");
|
||||
assert_eq!(name, "com");
|
||||
assert_eq!(qt, QueryType::NS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn minimize_query_beyond_root_sends_full() {
|
||||
// Beyond root, send full query (conservative minimization)
|
||||
let (name, qt) = minimize_query("www.example.com", QueryType::A, "com");
|
||||
assert_eq!(name, "www.example.com");
|
||||
assert_eq!(qt, QueryType::A);
|
||||
|
||||
let (name, qt) = minimize_query("www.example.com", QueryType::A, "example.com");
|
||||
assert_eq!(name, "www.example.com");
|
||||
assert_eq!(qt, QueryType::A);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn minimize_query_single_label() {
|
||||
// Single label (e.g., "com") from root — send as-is
|
||||
let (name, qt) = minimize_query("com", QueryType::NS, ".");
|
||||
assert_eq!(name, "com");
|
||||
assert_eq!(qt, QueryType::NS);
|
||||
}
|
||||
|
||||
// ---- Mock DNS server (TCP-only) for fallback tests ----
|
||||
|
||||
use crate::buffer::BytePacketBuffer;
|
||||
use crate::header::ResultCode;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Spawn a TCP-only DNS server on localhost. Returns the address.
|
||||
/// The handler receives each query and returns a response packet.
|
||||
async fn spawn_tcp_dns_server(
|
||||
handler: impl Fn(&DnsPacket) -> DnsPacket + Send + Sync + 'static,
|
||||
) -> SocketAddr {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let handler = std::sync::Arc::new(handler);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(c) => c,
|
||||
Err(_) => break,
|
||||
};
|
||||
let handler = handler.clone();
|
||||
tokio::spawn(async move {
|
||||
// Read length-prefixed DNS query
|
||||
let mut len_buf = [0u8; 2];
|
||||
if stream.read_exact(&mut len_buf).await.is_err() {
|
||||
return;
|
||||
}
|
||||
let len = u16::from_be_bytes(len_buf) as usize;
|
||||
let mut data = vec![0u8; len];
|
||||
if stream.read_exact(&mut data).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut buf = BytePacketBuffer::from_bytes(&data);
|
||||
let query = match DnsPacket::from_buffer(&mut buf) {
|
||||
Ok(q) => q,
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
let response = handler(&query);
|
||||
|
||||
let mut resp_buf = BytePacketBuffer::new();
|
||||
if response.write(&mut resp_buf).is_err() {
|
||||
return;
|
||||
}
|
||||
let resp_bytes = resp_buf.filled();
|
||||
let _ = stream
|
||||
.write_all(&(resp_bytes.len() as u16).to_be_bytes())
|
||||
.await;
|
||||
let _ = stream.write_all(resp_bytes).await;
|
||||
});
|
||||
}
|
||||
});
|
||||
addr
|
||||
}
|
||||
|
||||
/// TCP-only server returns authoritative answer directly.
|
||||
/// Verifies: UDP fails → TCP fallback → resolves.
|
||||
#[tokio::test]
|
||||
async fn tcp_fallback_resolves_when_udp_blocked() {
|
||||
UDP_DISABLED.store(false, Ordering::Relaxed);
|
||||
UDP_FAILURES.store(0, Ordering::Relaxed);
|
||||
|
||||
let server_addr = spawn_tcp_dns_server(|query| {
|
||||
let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR);
|
||||
resp.header.authoritative_answer = true;
|
||||
if let Some(q) = query.questions.first() {
|
||||
if q.qtype == QueryType::A || q.qtype == QueryType::NS {
|
||||
resp.answers.push(DnsRecord::A {
|
||||
domain: q.name.clone(),
|
||||
addr: Ipv4Addr::new(10, 0, 0, 1),
|
||||
ttl: 300,
|
||||
});
|
||||
}
|
||||
}
|
||||
resp
|
||||
})
|
||||
.await;
|
||||
|
||||
let result = send_query("test.example.com", QueryType::A, server_addr).await;
|
||||
|
||||
let resp = result.expect("should resolve via TCP fallback");
|
||||
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
|
||||
assert!(!resp.answers.is_empty());
|
||||
match &resp.answers[0] {
|
||||
DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::new(10, 0, 0, 1)),
|
||||
other => panic!("expected A record, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
/// Full iterative resolution through TCP-only mock: root referral → authoritative answer.
|
||||
/// The mock plays both roles (returns referral for NS queries, answer for A queries).
|
||||
#[tokio::test]
|
||||
async fn tcp_only_iterative_resolution() {
|
||||
UDP_DISABLED.store(true, Ordering::Relaxed); // Skip UDP entirely for speed
|
||||
|
||||
let server_addr = spawn_tcp_dns_server(|query| {
|
||||
let q = match query.questions.first() {
|
||||
Some(q) => q,
|
||||
None => return DnsPacket::response_from(query, ResultCode::SERVFAIL),
|
||||
};
|
||||
|
||||
if q.qtype == QueryType::NS || q.name == "com" {
|
||||
// Return referral — NS points back to ourselves (same IP, port 53 in glue
|
||||
// won't work, but cache will have our address from root_hints)
|
||||
let mut resp = DnsPacket::new();
|
||||
resp.header.id = query.header.id;
|
||||
resp.header.response = true;
|
||||
resp.header.rescode = ResultCode::NOERROR;
|
||||
resp.questions = query.questions.clone();
|
||||
resp.authorities.push(DnsRecord::NS {
|
||||
domain: "com".into(),
|
||||
host: "ns1.com".into(),
|
||||
ttl: 3600,
|
||||
});
|
||||
resp
|
||||
} else {
|
||||
// Return authoritative answer
|
||||
let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR);
|
||||
resp.header.authoritative_answer = true;
|
||||
resp.answers.push(DnsRecord::A {
|
||||
domain: q.name.clone(),
|
||||
addr: Ipv4Addr::new(10, 0, 0, 42),
|
||||
ttl: 300,
|
||||
});
|
||||
resp
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
let result = send_query("hello.example.com", QueryType::A, server_addr).await;
|
||||
let resp = result.expect("TCP-only send_query should work");
|
||||
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
|
||||
match &resp.answers[0] {
|
||||
DnsRecord::A { addr, .. } => assert_eq!(*addr, Ipv4Addr::new(10, 0, 0, 42)),
|
||||
other => panic!("expected A, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_fallback_handles_nxdomain() {
|
||||
UDP_DISABLED.store(false, Ordering::Relaxed);
|
||||
UDP_FAILURES.store(0, Ordering::Relaxed);
|
||||
|
||||
let server_addr = spawn_tcp_dns_server(|query| {
|
||||
let mut resp = DnsPacket::response_from(query, ResultCode::NXDOMAIN);
|
||||
resp.header.authoritative_answer = true;
|
||||
resp
|
||||
})
|
||||
.await;
|
||||
|
||||
let cache = RwLock::new(DnsCache::new(100, 60, 86400));
|
||||
let root_hints = vec![server_addr];
|
||||
|
||||
let result =
|
||||
resolve_iterative("nonexistent.test", QueryType::A, &cache, &root_hints, 0, 0).await;
|
||||
|
||||
let resp = result.expect("NXDOMAIN should still return a response");
|
||||
assert_eq!(resp.header.rescode, ResultCode::NXDOMAIN);
|
||||
assert!(resp.answers.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_auto_disable_resets() {
|
||||
UDP_DISABLED.store(true, Ordering::Relaxed);
|
||||
UDP_FAILURES.store(5, Ordering::Relaxed);
|
||||
|
||||
reset_udp_state();
|
||||
|
||||
assert!(!UDP_DISABLED.load(Ordering::Relaxed));
|
||||
assert_eq!(UDP_FAILURES.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user