fix: parse DoT queries up-front and echo question in SERVFAIL
Address review findings on PR #25: - Refactor resolve_query to take a pre-parsed DnsPacket. Parse-error handling moves to the UDP caller, eliminating the double warn! line on malformed UDP queries. - Enforce MIN_MSG_LEN=12 (DNS header) in handle_dot_connection so query_id extraction is always reading client-sent bytes, not the zeroed buffer tail. - Parse the DoT query before calling resolve_query and retain it, so SERVFAIL responses can echo the original question section via response_from(). Parse failures send FORMERR with the client id. - Extract write_framed() helper for length-prefix + flush, reused by success, SERVFAIL, and FORMERR paths. - Back off 100ms on listener.accept() errors to avoid tight-looping on fd exhaustion. - Replace the hardcoded 127.0.0.1:53 upstream in dot_nxdomain_for_unknown with a bound-but-unresponsive UDP socket owned by the test, making it independent of the host's local resolver. Test now runs in ~220ms (timeout lowered to 200ms) instead of 3s and asserts the question is echoed in the SERVFAIL response. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
23
src/ctx.rs
23
src/ctx.rs
@@ -65,21 +65,15 @@ pub struct ServerCtx {
|
|||||||
/// Transport-agnostic DNS resolution. Runs the full pipeline (overrides, blocklist,
|
/// Transport-agnostic DNS resolution. Runs the full pipeline (overrides, blocklist,
|
||||||
/// cache, upstream, DNSSEC) and returns the serialized response in a buffer.
|
/// cache, upstream, DNSSEC) and returns the serialized response in a buffer.
|
||||||
/// Callers use `.filled()` to get the response bytes without heap allocation.
|
/// Callers use `.filled()` to get the response bytes without heap allocation.
|
||||||
|
/// Callers are responsible for parsing the incoming buffer into a `DnsPacket`
|
||||||
|
/// (and logging parse errors) before calling this function.
|
||||||
pub async fn resolve_query(
|
pub async fn resolve_query(
|
||||||
mut buffer: BytePacketBuffer,
|
query: DnsPacket,
|
||||||
src_addr: SocketAddr,
|
src_addr: SocketAddr,
|
||||||
ctx: &ServerCtx,
|
ctx: &ServerCtx,
|
||||||
) -> crate::Result<BytePacketBuffer> {
|
) -> crate::Result<BytePacketBuffer> {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let query = match DnsPacket::from_buffer(&mut buffer) {
|
|
||||||
Ok(packet) => packet,
|
|
||||||
Err(e) => {
|
|
||||||
warn!("{} | PARSE ERROR | {}", src_addr, e);
|
|
||||||
return Err(e);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let (qname, qtype) = match query.questions.first() {
|
let (qname, qtype) = match query.questions.first() {
|
||||||
Some(q) => (q.name.clone(), q.qtype),
|
Some(q) => (q.name.clone(), q.qtype),
|
||||||
None => return Err("empty question section".into()),
|
None => return Err("empty question section".into()),
|
||||||
@@ -347,11 +341,18 @@ pub async fn resolve_query(
|
|||||||
|
|
||||||
/// Handle a DNS query received over UDP. Thin wrapper around resolve_query.
|
/// Handle a DNS query received over UDP. Thin wrapper around resolve_query.
|
||||||
pub async fn handle_query(
|
pub async fn handle_query(
|
||||||
buffer: BytePacketBuffer,
|
mut buffer: BytePacketBuffer,
|
||||||
src_addr: SocketAddr,
|
src_addr: SocketAddr,
|
||||||
ctx: &ServerCtx,
|
ctx: &ServerCtx,
|
||||||
) -> crate::Result<()> {
|
) -> crate::Result<()> {
|
||||||
match resolve_query(buffer, src_addr, ctx).await {
|
let query = match DnsPacket::from_buffer(&mut buffer) {
|
||||||
|
Ok(packet) => packet,
|
||||||
|
Err(e) => {
|
||||||
|
warn!("{} | PARSE ERROR | {}", src_addr, e);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match resolve_query(query, src_addr, ctx).await {
|
||||||
Ok(resp_buffer) => {
|
Ok(resp_buffer) => {
|
||||||
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
|
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
|
||||||
}
|
}
|
||||||
|
|||||||
87
src/dot.rs
87
src/dot.rs
@@ -19,9 +19,12 @@ use crate::packet::DnsPacket;
|
|||||||
const MAX_CONNECTIONS: usize = 512;
|
const MAX_CONNECTIONS: usize = 512;
|
||||||
const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
|
const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
|
||||||
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||||
|
const ACCEPT_ERROR_BACKOFF: Duration = Duration::from_millis(100);
|
||||||
// Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our
|
// Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our
|
||||||
// buffer would silently truncate anything larger.
|
// buffer would silently truncate anything larger.
|
||||||
const MAX_MSG_LEN: usize = 4096;
|
const MAX_MSG_LEN: usize = 4096;
|
||||||
|
// DNS header is 12 bytes; anything shorter cannot be a valid query.
|
||||||
|
const MIN_MSG_LEN: usize = 12;
|
||||||
|
|
||||||
/// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files.
|
/// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files.
|
||||||
fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> {
|
fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> {
|
||||||
@@ -103,6 +106,8 @@ async fn accept_loop(listener: TcpListener, acceptor: TlsAcceptor, ctx: Arc<Serv
|
|||||||
Ok(conn) => conn,
|
Ok(conn) => conn,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("DoT: TCP accept error: {}", e);
|
error!("DoT: TCP accept error: {}", e);
|
||||||
|
// Back off to avoid tight-looping on persistent failures (e.g. fd exhaustion).
|
||||||
|
tokio::time::sleep(ACCEPT_ERROR_BACKOFF).await;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -153,7 +158,7 @@ where
|
|||||||
Err(_) => break, // idle timeout
|
Err(_) => break, // idle timeout
|
||||||
}
|
}
|
||||||
let msg_len = u16::from_be_bytes(len_buf) as usize;
|
let msg_len = u16::from_be_bytes(len_buf) as usize;
|
||||||
if msg_len == 0 || msg_len > MAX_MSG_LEN {
|
if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&msg_len) {
|
||||||
debug!(
|
debug!(
|
||||||
"DoT: invalid message length {} from {}",
|
"DoT: invalid message length {} from {}",
|
||||||
msg_len, remote_addr
|
msg_len, remote_addr
|
||||||
@@ -173,37 +178,66 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]);
|
// Parse query up-front so we can echo its question section in SERVFAIL
|
||||||
let resp_buffer = match resolve_query(buffer, remote_addr, ctx).await {
|
// responses when resolve_query fails.
|
||||||
Ok(buf) => buf,
|
let query = match DnsPacket::from_buffer(&mut buffer) {
|
||||||
|
Ok(q) => q,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("DoT: resolve error from {}: {}", remote_addr, e);
|
warn!("{} | PARSE ERROR | {}", remote_addr, e);
|
||||||
// Send SERVFAIL so the client doesn't hang
|
// msg_len >= MIN_MSG_LEN guarantees buf[0..2] is the client's query id.
|
||||||
|
let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]);
|
||||||
let mut resp = DnsPacket::new();
|
let mut resp = DnsPacket::new();
|
||||||
resp.header.id = query_id;
|
resp.header.id = query_id;
|
||||||
resp.header.response = true;
|
resp.header.response = true;
|
||||||
resp.header.rescode = ResultCode::SERVFAIL;
|
resp.header.rescode = ResultCode::FORMERR;
|
||||||
let mut buf = BytePacketBuffer::new();
|
let mut out_buf = BytePacketBuffer::new();
|
||||||
if resp.write(&mut buf).is_err() {
|
if resp.write(&mut out_buf).is_err() {
|
||||||
|
debug!("DoT: failed to serialize FORMERR for {}", remote_addr);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if write_framed(&mut stream, out_buf.filled()).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp_buffer = match resolve_query(query.clone(), remote_addr, ctx).await {
|
||||||
|
Ok(buf) => buf,
|
||||||
|
Err(e) => {
|
||||||
|
warn!("{} | RESOLVE ERROR | {}", remote_addr, e);
|
||||||
|
// Build SERVFAIL that echoes the original question section.
|
||||||
|
let resp = DnsPacket::response_from(&query, ResultCode::SERVFAIL);
|
||||||
|
let mut out_buf = BytePacketBuffer::new();
|
||||||
|
if resp.write(&mut out_buf).is_err() {
|
||||||
debug!("DoT: failed to serialize SERVFAIL for {}", remote_addr);
|
debug!("DoT: failed to serialize SERVFAIL for {}", remote_addr);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
buf
|
out_buf
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let resp = resp_buffer.filled();
|
if write_framed(&mut stream, resp_buffer.filled())
|
||||||
let mut out = Vec::with_capacity(2 + resp.len());
|
.await
|
||||||
out.extend_from_slice(&(resp.len() as u16).to_be_bytes());
|
.is_err()
|
||||||
out.extend_from_slice(resp);
|
{
|
||||||
if stream.write_all(&out).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if stream.flush().await.is_err() {
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Write a DNS message with its 2-byte length prefix, coalesced into one syscall.
|
||||||
|
async fn write_framed<S>(stream: &mut S, msg: &[u8]) -> std::io::Result<()>
|
||||||
|
where
|
||||||
|
S: AsyncWriteExt + Unpin,
|
||||||
|
{
|
||||||
|
let mut out = Vec::with_capacity(2 + msg.len());
|
||||||
|
out.extend_from_slice(&(msg.len() as u16).to_be_bytes());
|
||||||
|
out.extend_from_slice(msg);
|
||||||
|
stream.write_all(&out).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -250,10 +284,16 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Spin up a DoT listener with a test TLS config. Returns (addr, client_config).
|
/// Spin up a DoT listener with a test TLS config. Returns (addr, client_config).
|
||||||
|
/// The upstream is pointed at a bound-but-unresponsive UDP socket we own, so
|
||||||
|
/// any query that escapes to the upstream path times out deterministically
|
||||||
|
/// (SERVFAIL) regardless of what the host has running on port 53.
|
||||||
async fn spawn_dot_server() -> (SocketAddr, Arc<rustls::ClientConfig>) {
|
async fn spawn_dot_server() -> (SocketAddr, Arc<rustls::ClientConfig>) {
|
||||||
let (server_tls, client_tls) = test_tls_configs();
|
let (server_tls, client_tls) = test_tls_configs();
|
||||||
|
|
||||||
let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
// Bind an unresponsive upstream and leak it so it lives for the test duration.
|
||||||
|
let blackhole = Box::leak(Box::new(std::net::UdpSocket::bind("127.0.0.1:0").unwrap()));
|
||||||
|
let upstream_addr = blackhole.local_addr().unwrap();
|
||||||
let ctx = Arc::new(ServerCtx {
|
let ctx = Arc::new(ServerCtx {
|
||||||
socket,
|
socket,
|
||||||
zone_map: {
|
zone_map: {
|
||||||
@@ -278,13 +318,11 @@ mod tests {
|
|||||||
services: Mutex::new(crate::service_store::ServiceStore::new()),
|
services: Mutex::new(crate::service_store::ServiceStore::new()),
|
||||||
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
|
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
|
||||||
forwarding_rules: Vec::new(),
|
forwarding_rules: Vec::new(),
|
||||||
upstream: Mutex::new(crate::forward::Upstream::Udp(
|
upstream: Mutex::new(crate::forward::Upstream::Udp(upstream_addr)),
|
||||||
"127.0.0.1:53".parse().unwrap(),
|
|
||||||
)),
|
|
||||||
upstream_auto: false,
|
upstream_auto: false,
|
||||||
upstream_port: 53,
|
upstream_port: 53,
|
||||||
lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST),
|
lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST),
|
||||||
timeout: Duration::from_secs(3),
|
timeout: Duration::from_millis(200),
|
||||||
proxy_tld: "numa".to_string(),
|
proxy_tld: "numa".to_string(),
|
||||||
proxy_tld_suffix: ".numa".to_string(),
|
proxy_tld_suffix: ".numa".to_string(),
|
||||||
lan_enabled: false,
|
lan_enabled: false,
|
||||||
@@ -397,8 +435,11 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(resp.header.id, 0xBEEF);
|
assert_eq!(resp.header.id, 0xBEEF);
|
||||||
assert!(resp.header.response);
|
assert!(resp.header.response);
|
||||||
// Query goes to upstream (127.0.0.1:53), which will fail — expect SERVFAIL
|
// Query goes to the blackhole upstream which never replies → SERVFAIL.
|
||||||
|
// The SERVFAIL response echoes the question section.
|
||||||
assert_eq!(resp.header.rescode, ResultCode::SERVFAIL);
|
assert_eq!(resp.header.rescode, ResultCode::SERVFAIL);
|
||||||
|
assert_eq!(resp.questions.len(), 1);
|
||||||
|
assert_eq!(resp.questions[0].name, "nonexistent.test");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
Reference in New Issue
Block a user