From ee310fe7cb415d62bcf4f8d5cd9c7e4711b25a7a Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Mon, 30 Mar 2026 00:50:04 +0300 Subject: [PATCH] fix: send SERVFAIL on DoT resolve errors, extract shared connection handler - Send SERVFAIL response (with correct query ID) when resolve_query fails, preventing DoT clients from hanging until idle timeout - Extract handle_dot_connection() so tests use the same logic as production, eliminating duplicated accept/read/resolve loop - Replace magic 4096 with named MAX_MSG_LEN constant tied to BUF_SIZE - Add flush() after each TLS write to prevent buffered responses - Extract fallback_tls() helper, handle partial cert/key config, support IPv6 bind address, remove redundant crypto provider init Co-Authored-By: Claude Opus 4.6 (1M context) --- src/dot.rs | 197 ++++++++++++++++++++++++++++------------------------- 1 file changed, 103 insertions(+), 94 deletions(-) diff --git a/src/dot.rs b/src/dot.rs index 4d86176..e10e7b7 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::path::Path; use std::sync::Arc; use std::time::Duration; @@ -13,9 +13,14 @@ use tokio_rustls::TlsAcceptor; use crate::buffer::BytePacketBuffer; use crate::config::DotConfig; use crate::ctx::{resolve_query, ServerCtx}; +use crate::header::ResultCode; +use crate::packet::DnsPacket; const MAX_CONNECTIONS: usize = 512; const IDLE_TIMEOUT: Duration = Duration::from_secs(30); +// Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our +// buffer would silently truncate anything larger. +const MAX_MSG_LEN: usize = 4096; /// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files. fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result> { @@ -26,8 +31,6 @@ fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result crate::Result Option> { + if let Some(arc_swap) = ctx.tls_config.as_ref() { + return Some(Arc::clone(&*arc_swap.load())); + } + match crate::tls::build_tls_config(&ctx.proxy_tld, &[]) { + Ok(cfg) => Some(cfg), + Err(e) => { + warn!( + "DoT: failed to generate self-signed TLS: {} — DoT disabled", + e + ); + None + } + } +} + /// Start the DNS-over-TLS listener (RFC 7858). pub async fn start_dot(ctx: Arc, config: &DotConfig) { let tls_config = match (&config.cert_path, &config.key_path) { @@ -45,26 +64,24 @@ pub async fn start_dot(ctx: Arc, config: &DotConfig) { return; } }, - _ => match ctx.tls_config.as_ref() { - Some(arc_swap) => Arc::clone(&*arc_swap.load()), - None => match crate::tls::build_tls_config(&ctx.proxy_tld, &[]) { - Ok(cfg) => cfg, - Err(e) => { - warn!( - "DoT: failed to generate self-signed TLS: {} — DoT disabled", - e - ); - return; - } - }, + (Some(_), None) | (None, Some(_)) => { + warn!("DoT: both cert_path and key_path must be set — ignoring partial config, using self-signed"); + match fallback_tls(&ctx) { + Some(cfg) => cfg, + None => return, + } + } + (None, None) => match fallback_tls(&ctx) { + Some(cfg) => cfg, + None => return, }, }; - let bind_addr: std::net::Ipv4Addr = config + let bind_addr: IpAddr = config .bind_addr .parse() - .unwrap_or(std::net::Ipv4Addr::UNSPECIFIED); - let addr: SocketAddr = (bind_addr, config.port).into(); + .unwrap_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)); + let addr = SocketAddr::new(bind_addr, config.port); let listener = match TcpListener::bind(addr).await { Ok(l) => l, Err(e) => { @@ -99,7 +116,7 @@ pub async fn start_dot(ctx: Arc, config: &DotConfig) { tokio::spawn(async move { let _permit = permit; // held until task exits - let mut tls_stream = match acceptor.accept(tcp_stream).await { + let tls_stream = match acceptor.accept(tcp_stream).await { Ok(s) => s, Err(e) => { debug!("DoT: TLS handshake failed from {}: {}", remote_addr, e); @@ -107,51 +124,75 @@ pub async fn start_dot(ctx: Arc, config: &DotConfig) { } }; - // RFC 7858: connection is persistent — read queries until EOF or idle timeout - loop { - // Read 2-byte length prefix (RFC 1035 §4.2.2) with idle timeout - let mut len_buf = [0u8; 2]; - match tokio::time::timeout(IDLE_TIMEOUT, tls_stream.read_exact(&mut len_buf)).await - { - Ok(Ok(_)) => {} - Ok(Err(_)) => break, // read error or EOF - Err(_) => break, // idle timeout - } - let msg_len = u16::from_be_bytes(len_buf) as usize; - if msg_len == 0 || msg_len > 4096 { - debug!( - "DoT: invalid message length {} from {}", - msg_len, remote_addr - ); - break; - } - - let mut data = vec![0u8; msg_len]; - if tls_stream.read_exact(&mut data).await.is_err() { - break; - } - - let buffer = BytePacketBuffer::from_bytes(&data); - match resolve_query(buffer, remote_addr, &ctx).await { - Ok(resp_buffer) => { - let resp = resp_buffer.filled(); - // Coalesce length prefix + response into a single TLS write - let mut out = Vec::with_capacity(2 + resp.len()); - out.extend_from_slice(&(resp.len() as u16).to_be_bytes()); - out.extend_from_slice(resp); - if tls_stream.write_all(&out).await.is_err() { - break; - } - } - Err(e) => { - debug!("DoT: resolve error from {}: {}", remote_addr, e); - } - } - } + handle_dot_connection(tls_stream, remote_addr, &ctx).await; }); } } +/// Handle a single persistent DoT connection (RFC 7858). +/// Reads length-prefixed DNS queries until EOF, idle timeout, or error. +async fn handle_dot_connection(mut stream: S, remote_addr: SocketAddr, ctx: &ServerCtx) +where + S: AsyncReadExt + AsyncWriteExt + Unpin, +{ + loop { + // Read 2-byte length prefix (RFC 1035 §4.2.2) with idle timeout + let mut len_buf = [0u8; 2]; + match tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut len_buf)).await { + Ok(Ok(_)) => {} + Ok(Err(_)) => break, // read error or EOF + Err(_) => break, // idle timeout + } + let msg_len = u16::from_be_bytes(len_buf) as usize; + if msg_len == 0 || msg_len > MAX_MSG_LEN { + debug!( + "DoT: invalid message length {} from {}", + msg_len, remote_addr + ); + break; + } + + let mut data = vec![0u8; msg_len]; + if stream.read_exact(&mut data).await.is_err() { + break; + } + + // Extract query ID before resolve_query consumes the buffer + let query_id = data + .get(..2) + .map(|b| u16::from_be_bytes([b[0], b[1]])) + .unwrap_or(0); + + let buffer = BytePacketBuffer::from_bytes(&data); + let resp_buffer = match resolve_query(buffer, remote_addr, ctx).await { + Ok(buf) => buf, + Err(e) => { + debug!("DoT: resolve error from {}: {}", remote_addr, e); + // Send SERVFAIL so the client doesn't hang + let mut resp = DnsPacket::new(); + resp.header.id = query_id; + resp.header.response = true; + resp.header.rescode = ResultCode::SERVFAIL; + let mut buf = BytePacketBuffer::new(); + if resp.write(&mut buf).is_err() { + break; + } + buf + } + }; + let resp = resp_buffer.filled(); + let mut out = Vec::with_capacity(2 + resp.len()); + out.extend_from_slice(&(resp.len() as u16).to_be_bytes()); + out.extend_from_slice(resp); + if stream.write_all(&out).await.is_err() { + break; + } + if stream.flush().await.is_err() { + break; + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -270,43 +311,11 @@ mod tests { let ctx = Arc::clone(&ctx); tokio::spawn(async move { let _permit = permit; - let mut tls_stream = match acceptor.accept(tcp_stream).await { + let tls_stream = match acceptor.accept(tcp_stream).await { Ok(s) => s, Err(_) => return, }; - loop { - let mut len_buf = [0u8; 2]; - match tokio::time::timeout( - IDLE_TIMEOUT, - tls_stream.read_exact(&mut len_buf), - ) - .await - { - Ok(Ok(_)) => {} - _ => break, - } - let msg_len = u16::from_be_bytes(len_buf) as usize; - if msg_len == 0 || msg_len > 4096 { - break; - } - let mut data = vec![0u8; msg_len]; - if tls_stream.read_exact(&mut data).await.is_err() { - break; - } - let buffer = BytePacketBuffer::from_bytes(&data); - match resolve_query(buffer, remote_addr, &ctx).await { - Ok(resp_buffer) => { - let resp = resp_buffer.filled(); - let mut out = Vec::with_capacity(2 + resp.len()); - out.extend_from_slice(&(resp.len() as u16).to_be_bytes()); - out.extend_from_slice(resp); - if tls_stream.write_all(&out).await.is_err() { - break; - } - } - Err(_) => {} - } - } + handle_dot_connection(tls_stream, remote_addr, &ctx).await; }); } });