fix: parse DoT queries up-front and echo question in SERVFAIL
Address review findings on PR #25: - Refactor resolve_query to take a pre-parsed DnsPacket. Parse-error handling moves to the UDP caller, eliminating the double warn! line on malformed UDP queries. - Enforce MIN_MSG_LEN=12 (DNS header) in handle_dot_connection so query_id extraction is always reading client-sent bytes, not the zeroed buffer tail. - Parse the DoT query before calling resolve_query and retain it, so SERVFAIL responses can echo the original question section via response_from(). Parse failures send FORMERR with the client id. - Extract write_framed() helper for length-prefix + flush, reused by success, SERVFAIL, and FORMERR paths. - Back off 100ms on listener.accept() errors to avoid tight-looping on fd exhaustion. - Replace the hardcoded 127.0.0.1:53 upstream in dot_nxdomain_for_unknown with a bound-but-unresponsive UDP socket owned by the test, making it independent of the host's local resolver. Test now runs in ~220ms (timeout lowered to 200ms) instead of 3s and asserts the question is echoed in the SERVFAIL response. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
23
src/ctx.rs
23
src/ctx.rs
@@ -65,21 +65,15 @@ pub struct ServerCtx {
|
||||
/// Transport-agnostic DNS resolution. Runs the full pipeline (overrides, blocklist,
|
||||
/// cache, upstream, DNSSEC) and returns the serialized response in a buffer.
|
||||
/// Callers use `.filled()` to get the response bytes without heap allocation.
|
||||
/// Callers are responsible for parsing the incoming buffer into a `DnsPacket`
|
||||
/// (and logging parse errors) before calling this function.
|
||||
pub async fn resolve_query(
|
||||
mut buffer: BytePacketBuffer,
|
||||
query: DnsPacket,
|
||||
src_addr: SocketAddr,
|
||||
ctx: &ServerCtx,
|
||||
) -> crate::Result<BytePacketBuffer> {
|
||||
let start = Instant::now();
|
||||
|
||||
let query = match DnsPacket::from_buffer(&mut buffer) {
|
||||
Ok(packet) => packet,
|
||||
Err(e) => {
|
||||
warn!("{} | PARSE ERROR | {}", src_addr, e);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let (qname, qtype) = match query.questions.first() {
|
||||
Some(q) => (q.name.clone(), q.qtype),
|
||||
None => return Err("empty question section".into()),
|
||||
@@ -347,11 +341,18 @@ pub async fn resolve_query(
|
||||
|
||||
/// Handle a DNS query received over UDP. Thin wrapper around resolve_query.
|
||||
pub async fn handle_query(
|
||||
buffer: BytePacketBuffer,
|
||||
mut buffer: BytePacketBuffer,
|
||||
src_addr: SocketAddr,
|
||||
ctx: &ServerCtx,
|
||||
) -> crate::Result<()> {
|
||||
match resolve_query(buffer, src_addr, ctx).await {
|
||||
let query = match DnsPacket::from_buffer(&mut buffer) {
|
||||
Ok(packet) => packet,
|
||||
Err(e) => {
|
||||
warn!("{} | PARSE ERROR | {}", src_addr, e);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
match resolve_query(query, src_addr, ctx).await {
|
||||
Ok(resp_buffer) => {
|
||||
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
|
||||
}
|
||||
|
||||
87
src/dot.rs
87
src/dot.rs
@@ -19,9 +19,12 @@ use crate::packet::DnsPacket;
|
||||
const MAX_CONNECTIONS: usize = 512;
|
||||
const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const ACCEPT_ERROR_BACKOFF: Duration = Duration::from_millis(100);
|
||||
// Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our
|
||||
// buffer would silently truncate anything larger.
|
||||
const MAX_MSG_LEN: usize = 4096;
|
||||
// DNS header is 12 bytes; anything shorter cannot be a valid query.
|
||||
const MIN_MSG_LEN: usize = 12;
|
||||
|
||||
/// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files.
|
||||
fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> {
|
||||
@@ -103,6 +106,8 @@ async fn accept_loop(listener: TcpListener, acceptor: TlsAcceptor, ctx: Arc<Serv
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
error!("DoT: TCP accept error: {}", e);
|
||||
// Back off to avoid tight-looping on persistent failures (e.g. fd exhaustion).
|
||||
tokio::time::sleep(ACCEPT_ERROR_BACKOFF).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -153,7 +158,7 @@ where
|
||||
Err(_) => break, // idle timeout
|
||||
}
|
||||
let msg_len = u16::from_be_bytes(len_buf) as usize;
|
||||
if msg_len == 0 || msg_len > MAX_MSG_LEN {
|
||||
if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&msg_len) {
|
||||
debug!(
|
||||
"DoT: invalid message length {} from {}",
|
||||
msg_len, remote_addr
|
||||
@@ -173,37 +178,66 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]);
|
||||
let resp_buffer = match resolve_query(buffer, remote_addr, ctx).await {
|
||||
Ok(buf) => buf,
|
||||
// Parse query up-front so we can echo its question section in SERVFAIL
|
||||
// responses when resolve_query fails.
|
||||
let query = match DnsPacket::from_buffer(&mut buffer) {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
debug!("DoT: resolve error from {}: {}", remote_addr, e);
|
||||
// Send SERVFAIL so the client doesn't hang
|
||||
warn!("{} | PARSE ERROR | {}", remote_addr, e);
|
||||
// msg_len >= MIN_MSG_LEN guarantees buf[0..2] is the client's query id.
|
||||
let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]);
|
||||
let mut resp = DnsPacket::new();
|
||||
resp.header.id = query_id;
|
||||
resp.header.response = true;
|
||||
resp.header.rescode = ResultCode::SERVFAIL;
|
||||
let mut buf = BytePacketBuffer::new();
|
||||
if resp.write(&mut buf).is_err() {
|
||||
resp.header.rescode = ResultCode::FORMERR;
|
||||
let mut out_buf = BytePacketBuffer::new();
|
||||
if resp.write(&mut out_buf).is_err() {
|
||||
debug!("DoT: failed to serialize FORMERR for {}", remote_addr);
|
||||
break;
|
||||
}
|
||||
if write_framed(&mut stream, out_buf.filled()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let resp_buffer = match resolve_query(query.clone(), remote_addr, ctx).await {
|
||||
Ok(buf) => buf,
|
||||
Err(e) => {
|
||||
warn!("{} | RESOLVE ERROR | {}", remote_addr, e);
|
||||
// Build SERVFAIL that echoes the original question section.
|
||||
let resp = DnsPacket::response_from(&query, ResultCode::SERVFAIL);
|
||||
let mut out_buf = BytePacketBuffer::new();
|
||||
if resp.write(&mut out_buf).is_err() {
|
||||
debug!("DoT: failed to serialize SERVFAIL for {}", remote_addr);
|
||||
break;
|
||||
}
|
||||
buf
|
||||
out_buf
|
||||
}
|
||||
};
|
||||
let resp = resp_buffer.filled();
|
||||
let mut out = Vec::with_capacity(2 + resp.len());
|
||||
out.extend_from_slice(&(resp.len() as u16).to_be_bytes());
|
||||
out.extend_from_slice(resp);
|
||||
if stream.write_all(&out).await.is_err() {
|
||||
break;
|
||||
}
|
||||
if stream.flush().await.is_err() {
|
||||
if write_framed(&mut stream, resp_buffer.filled())
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a DNS message with its 2-byte length prefix, coalesced into one syscall.
|
||||
async fn write_framed<S>(stream: &mut S, msg: &[u8]) -> std::io::Result<()>
|
||||
where
|
||||
S: AsyncWriteExt + Unpin,
|
||||
{
|
||||
let mut out = Vec::with_capacity(2 + msg.len());
|
||||
out.extend_from_slice(&(msg.len() as u16).to_be_bytes());
|
||||
out.extend_from_slice(msg);
|
||||
stream.write_all(&out).await?;
|
||||
stream.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -250,10 +284,16 @@ mod tests {
|
||||
}
|
||||
|
||||
/// Spin up a DoT listener with a test TLS config. Returns (addr, client_config).
|
||||
/// The upstream is pointed at a bound-but-unresponsive UDP socket we own, so
|
||||
/// any query that escapes to the upstream path times out deterministically
|
||||
/// (SERVFAIL) regardless of what the host has running on port 53.
|
||||
async fn spawn_dot_server() -> (SocketAddr, Arc<rustls::ClientConfig>) {
|
||||
let (server_tls, client_tls) = test_tls_configs();
|
||||
|
||||
let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
// Bind an unresponsive upstream and leak it so it lives for the test duration.
|
||||
let blackhole = Box::leak(Box::new(std::net::UdpSocket::bind("127.0.0.1:0").unwrap()));
|
||||
let upstream_addr = blackhole.local_addr().unwrap();
|
||||
let ctx = Arc::new(ServerCtx {
|
||||
socket,
|
||||
zone_map: {
|
||||
@@ -278,13 +318,11 @@ mod tests {
|
||||
services: Mutex::new(crate::service_store::ServiceStore::new()),
|
||||
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
|
||||
forwarding_rules: Vec::new(),
|
||||
upstream: Mutex::new(crate::forward::Upstream::Udp(
|
||||
"127.0.0.1:53".parse().unwrap(),
|
||||
)),
|
||||
upstream: Mutex::new(crate::forward::Upstream::Udp(upstream_addr)),
|
||||
upstream_auto: false,
|
||||
upstream_port: 53,
|
||||
lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST),
|
||||
timeout: Duration::from_secs(3),
|
||||
timeout: Duration::from_millis(200),
|
||||
proxy_tld: "numa".to_string(),
|
||||
proxy_tld_suffix: ".numa".to_string(),
|
||||
lan_enabled: false,
|
||||
@@ -397,8 +435,11 @@ mod tests {
|
||||
|
||||
assert_eq!(resp.header.id, 0xBEEF);
|
||||
assert!(resp.header.response);
|
||||
// Query goes to upstream (127.0.0.1:53), which will fail — expect SERVFAIL
|
||||
// Query goes to the blackhole upstream which never replies → SERVFAIL.
|
||||
// The SERVFAIL response echoes the question section.
|
||||
assert_eq!(resp.header.rescode, ResultCode::SERVFAIL);
|
||||
assert_eq!(resp.questions.len(), 1);
|
||||
assert_eq!(resp.questions[0].name, "nonexistent.test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Reference in New Issue
Block a user