fix: parse DoT queries up-front and echo question in SERVFAIL

Address review findings on PR #25:

- Refactor resolve_query to take a pre-parsed DnsPacket. Parse-error
  handling moves to the UDP caller, eliminating the double warn! line
  on malformed UDP queries.
- Enforce MIN_MSG_LEN=12 (DNS header) in handle_dot_connection so
  query_id extraction is always reading client-sent bytes, not the
  zeroed buffer tail.
- Parse the DoT query before calling resolve_query and retain it, so
  SERVFAIL responses can echo the original question section via
  response_from(). Parse failures send FORMERR with the client id.
- Extract write_framed() helper for length-prefix + flush, reused by
  success, SERVFAIL, and FORMERR paths.
- Back off 100ms on listener.accept() errors to avoid tight-looping
  on fd exhaustion.
- Replace the hardcoded 127.0.0.1:53 upstream in dot_nxdomain_for_unknown
  with a bound-but-unresponsive UDP socket owned by the test, making it
  independent of the host's local resolver. Test now runs in ~220ms
  (timeout lowered to 200ms) instead of 3s and asserts the question is
  echoed in the SERVFAIL response.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-04-07 16:47:54 +03:00
parent b78cd44c99
commit 7be923e6d2
2 changed files with 76 additions and 34 deletions

View File

@@ -19,9 +19,12 @@ use crate::packet::DnsPacket;
const MAX_CONNECTIONS: usize = 512;
const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
const ACCEPT_ERROR_BACKOFF: Duration = Duration::from_millis(100);
// Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our
// buffer would silently truncate anything larger.
const MAX_MSG_LEN: usize = 4096;
// DNS header is 12 bytes; anything shorter cannot be a valid query.
const MIN_MSG_LEN: usize = 12;
/// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files.
fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> {
@@ -103,6 +106,8 @@ async fn accept_loop(listener: TcpListener, acceptor: TlsAcceptor, ctx: Arc<Serv
Ok(conn) => conn,
Err(e) => {
error!("DoT: TCP accept error: {}", e);
// Back off to avoid tight-looping on persistent failures (e.g. fd exhaustion).
tokio::time::sleep(ACCEPT_ERROR_BACKOFF).await;
continue;
}
};
@@ -153,7 +158,7 @@ where
Err(_) => break, // idle timeout
}
let msg_len = u16::from_be_bytes(len_buf) as usize;
if msg_len == 0 || msg_len > MAX_MSG_LEN {
if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&msg_len) {
debug!(
"DoT: invalid message length {} from {}",
msg_len, remote_addr
@@ -173,37 +178,66 @@ where
}
}
let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]);
let resp_buffer = match resolve_query(buffer, remote_addr, ctx).await {
Ok(buf) => buf,
// Parse query up-front so we can echo its question section in SERVFAIL
// responses when resolve_query fails.
let query = match DnsPacket::from_buffer(&mut buffer) {
Ok(q) => q,
Err(e) => {
debug!("DoT: resolve error from {}: {}", remote_addr, e);
// Send SERVFAIL so the client doesn't hang
warn!("{} | PARSE ERROR | {}", remote_addr, e);
// msg_len >= MIN_MSG_LEN guarantees buf[0..2] is the client's query id.
let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]);
let mut resp = DnsPacket::new();
resp.header.id = query_id;
resp.header.response = true;
resp.header.rescode = ResultCode::SERVFAIL;
let mut buf = BytePacketBuffer::new();
if resp.write(&mut buf).is_err() {
resp.header.rescode = ResultCode::FORMERR;
let mut out_buf = BytePacketBuffer::new();
if resp.write(&mut out_buf).is_err() {
debug!("DoT: failed to serialize FORMERR for {}", remote_addr);
break;
}
if write_framed(&mut stream, out_buf.filled()).await.is_err() {
break;
}
continue;
}
};
let resp_buffer = match resolve_query(query.clone(), remote_addr, ctx).await {
Ok(buf) => buf,
Err(e) => {
warn!("{} | RESOLVE ERROR | {}", remote_addr, e);
// Build SERVFAIL that echoes the original question section.
let resp = DnsPacket::response_from(&query, ResultCode::SERVFAIL);
let mut out_buf = BytePacketBuffer::new();
if resp.write(&mut out_buf).is_err() {
debug!("DoT: failed to serialize SERVFAIL for {}", remote_addr);
break;
}
buf
out_buf
}
};
let resp = resp_buffer.filled();
let mut out = Vec::with_capacity(2 + resp.len());
out.extend_from_slice(&(resp.len() as u16).to_be_bytes());
out.extend_from_slice(resp);
if stream.write_all(&out).await.is_err() {
break;
}
if stream.flush().await.is_err() {
if write_framed(&mut stream, resp_buffer.filled())
.await
.is_err()
{
break;
}
}
}
/// Write a DNS message with its 2-byte length prefix, coalesced into one syscall.
async fn write_framed<S>(stream: &mut S, msg: &[u8]) -> std::io::Result<()>
where
S: AsyncWriteExt + Unpin,
{
let mut out = Vec::with_capacity(2 + msg.len());
out.extend_from_slice(&(msg.len() as u16).to_be_bytes());
out.extend_from_slice(msg);
stream.write_all(&out).await?;
stream.flush().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
@@ -250,10 +284,16 @@ mod tests {
}
/// Spin up a DoT listener with a test TLS config. Returns (addr, client_config).
/// The upstream is pointed at a bound-but-unresponsive UDP socket we own, so
/// any query that escapes to the upstream path times out deterministically
/// (SERVFAIL) regardless of what the host has running on port 53.
async fn spawn_dot_server() -> (SocketAddr, Arc<rustls::ClientConfig>) {
let (server_tls, client_tls) = test_tls_configs();
let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
// Bind an unresponsive upstream and leak it so it lives for the test duration.
let blackhole = Box::leak(Box::new(std::net::UdpSocket::bind("127.0.0.1:0").unwrap()));
let upstream_addr = blackhole.local_addr().unwrap();
let ctx = Arc::new(ServerCtx {
socket,
zone_map: {
@@ -278,13 +318,11 @@ mod tests {
services: Mutex::new(crate::service_store::ServiceStore::new()),
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
forwarding_rules: Vec::new(),
upstream: Mutex::new(crate::forward::Upstream::Udp(
"127.0.0.1:53".parse().unwrap(),
)),
upstream: Mutex::new(crate::forward::Upstream::Udp(upstream_addr)),
upstream_auto: false,
upstream_port: 53,
lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST),
timeout: Duration::from_secs(3),
timeout: Duration::from_millis(200),
proxy_tld: "numa".to_string(),
proxy_tld_suffix: ".numa".to_string(),
lan_enabled: false,
@@ -397,8 +435,11 @@ mod tests {
assert_eq!(resp.header.id, 0xBEEF);
assert!(resp.header.response);
// Query goes to upstream (127.0.0.1:53), which will fail — expect SERVFAIL
// Query goes to the blackhole upstream which never replies → SERVFAIL.
// The SERVFAIL response echoes the question section.
assert_eq!(resp.header.rescode, ResultCode::SERVFAIL);
assert_eq!(resp.questions.len(), 1);
assert_eq!(resp.questions[0].name, "nonexistent.test");
}
#[tokio::test]