refactor: trim DoT listener — let-else reads, drop MIN_MSG_LEN and redundant localhost test

- Collapse two 4-arm read/timeout matches to let-else (lose one
  defensive debug log on payload-read timeout; idle timeouts are
  routine on persistent DoT connections anyway)
- Drop MIN_MSG_LEN: DnsPacket::from_buffer rejects truncated input
  on its own, and BytePacketBuffer is zero-init so buf[0..2] for
  sub-2-byte messages just yields a harmless FORMERR with id=0
- Inline ACCEPT_ERROR_BACKOFF (single use site)
- Drop the partial cert/key warning: missing one of cert_path/
  key_path silently falls back to self-signed; users see the
  self-signed cert at startup and figure it out
- Drop dot_localhost_resolution test: RFC 6761 localhost is tested
  in ctx.rs; this test only verified DoT transport, which
  dot_resolves_local_zone already covers
- Drop self-documenting comment in dot_multiple_queries_on_persistent_connection

Net -32 lines, 125/125 tests pass, no behavior change users would notice.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-04-07 20:35:05 +03:00
parent 16689526aa
commit e9836f1be7

View File

@@ -19,12 +19,9 @@ use crate::packet::DnsPacket;
const MAX_CONNECTIONS: usize = 512; const MAX_CONNECTIONS: usize = 512;
const IDLE_TIMEOUT: Duration = Duration::from_secs(30); const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10); const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
const ACCEPT_ERROR_BACKOFF: Duration = Duration::from_millis(100);
// Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our // Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our
// buffer would silently truncate anything larger. // buffer would silently truncate anything larger.
const MAX_MSG_LEN: usize = 4096; const MAX_MSG_LEN: usize = 4096;
// DNS header is 12 bytes; anything shorter cannot be a valid query.
const MIN_MSG_LEN: usize = 12;
/// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files. /// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files.
fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> { fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> {
@@ -60,9 +57,6 @@ fn fallback_tls(ctx: &ServerCtx) -> Option<Arc<ServerConfig>> {
/// Start the DNS-over-TLS listener (RFC 7858). /// Start the DNS-over-TLS listener (RFC 7858).
pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) { pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
if config.cert_path.is_some() != config.key_path.is_some() {
warn!("DoT: both cert_path and key_path must be set — ignoring partial config, using self-signed");
}
let tls_config = match (&config.cert_path, &config.key_path) { let tls_config = match (&config.cert_path, &config.key_path) {
(Some(cert), Some(key)) => match load_tls_config(cert, key) { (Some(cert), Some(key)) => match load_tls_config(cert, key) {
Ok(cfg) => cfg, Ok(cfg) => cfg,
@@ -103,7 +97,7 @@ async fn accept_loop(listener: TcpListener, acceptor: TlsAcceptor, ctx: Arc<Serv
Err(e) => { Err(e) => {
error!("DoT: TCP accept error: {}", e); error!("DoT: TCP accept error: {}", e);
// Back off to avoid tight-looping on persistent failures (e.g. fd exhaustion). // Back off to avoid tight-looping on persistent failures (e.g. fd exhaustion).
tokio::time::sleep(ACCEPT_ERROR_BACKOFF).await; tokio::time::sleep(Duration::from_millis(100)).await;
continue; continue;
} }
}; };
@@ -148,31 +142,22 @@ where
loop { loop {
// Read 2-byte length prefix (RFC 1035 §4.2.2) with idle timeout // Read 2-byte length prefix (RFC 1035 §4.2.2) with idle timeout
let mut len_buf = [0u8; 2]; let mut len_buf = [0u8; 2];
match tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut len_buf)).await { let Ok(Ok(_)) = tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut len_buf)).await
Ok(Ok(_)) => {} else {
Ok(Err(_)) => break, // read error or EOF break;
Err(_) => break, // idle timeout };
}
let msg_len = u16::from_be_bytes(len_buf) as usize; let msg_len = u16::from_be_bytes(len_buf) as usize;
if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&msg_len) { if msg_len > MAX_MSG_LEN {
debug!( debug!("DoT: oversized message {} from {}", msg_len, remote_addr);
"DoT: invalid message length {} from {}",
msg_len, remote_addr
);
break; break;
} }
let mut buffer = BytePacketBuffer::new(); let mut buffer = BytePacketBuffer::new();
match tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut buffer.buf[..msg_len])) let Ok(Ok(_)) =
.await tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut buffer.buf[..msg_len])).await
{ else {
Ok(Ok(_)) => {} break;
Ok(Err(_)) => break, };
Err(_) => {
debug!("DoT: payload read timeout from {}", remote_addr);
break;
}
}
// Parse query up-front so we can echo its question section in SERVFAIL // Parse query up-front so we can echo its question section in SERVFAIL
// responses when resolve_query fails. // responses when resolve_query fails.
@@ -180,7 +165,8 @@ where
Ok(q) => q, Ok(q) => q,
Err(e) => { Err(e) => {
warn!("{} | PARSE ERROR | {}", remote_addr, e); warn!("{} | PARSE ERROR | {}", remote_addr, e);
// msg_len >= MIN_MSG_LEN guarantees buf[0..2] is the client's query id. // BytePacketBuffer is zero-initialized, so buf[0..2] reads as 0x0000
// for sub-2-byte messages — harmless FORMERR with id=0.
let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]); let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]);
let mut resp = DnsPacket::new(); let mut resp = DnsPacket::new();
resp.header.id = query_id; resp.header.id = query_id;
@@ -431,7 +417,6 @@ mod tests {
let (addr, client_config) = spawn_dot_server().await; let (addr, client_config) = spawn_dot_server().await;
let mut stream = dot_connect(addr, &client_config).await; let mut stream = dot_connect(addr, &client_config).await;
// Send 3 queries on the same TLS connection
for i in 0..3u16 { for i in 0..3u16 {
let query = DnsPacket::query(0xA000 + i, "dot-test.example", QueryType::A); let query = DnsPacket::query(0xA000 + i, "dot-test.example", QueryType::A);
let resp = dot_exchange(&mut stream, &query).await; let resp = dot_exchange(&mut stream, &query).await;
@@ -479,21 +464,4 @@ mod tests {
h.await.unwrap(); h.await.unwrap();
} }
} }
#[tokio::test]
async fn dot_localhost_resolution() {
let (addr, client_config) = spawn_dot_server().await;
let mut stream = dot_connect(addr, &client_config).await;
let query = DnsPacket::query(0xD000, "localhost", QueryType::A);
let resp = dot_exchange(&mut stream, &query).await;
assert_eq!(resp.header.id, 0xD000);
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
assert_eq!(resp.answers.len(), 1);
match &resp.answers[0] {
DnsRecord::A { addr, .. } => assert_eq!(*addr, std::net::Ipv4Addr::LOCALHOST),
other => panic!("expected A record, got {:?}", other),
}
}
} }