diff --git a/src/dot.rs b/src/dot.rs index de2f9a8..360bf4a 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -19,12 +19,9 @@ use crate::packet::DnsPacket; const MAX_CONNECTIONS: usize = 512; const IDLE_TIMEOUT: Duration = Duration::from_secs(30); const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10); -const ACCEPT_ERROR_BACKOFF: Duration = Duration::from_millis(100); // Matches BytePacketBuffer::BUF_SIZE โ€” RFC 7858 allows up to 65535 but our // buffer would silently truncate anything larger. const MAX_MSG_LEN: usize = 4096; -// DNS header is 12 bytes; anything shorter cannot be a valid query. -const MIN_MSG_LEN: usize = 12; /// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files. fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result> { @@ -60,9 +57,6 @@ fn fallback_tls(ctx: &ServerCtx) -> Option> { /// Start the DNS-over-TLS listener (RFC 7858). pub async fn start_dot(ctx: Arc, config: &DotConfig) { - if config.cert_path.is_some() != config.key_path.is_some() { - warn!("DoT: both cert_path and key_path must be set โ€” ignoring partial config, using self-signed"); - } let tls_config = match (&config.cert_path, &config.key_path) { (Some(cert), Some(key)) => match load_tls_config(cert, key) { Ok(cfg) => cfg, @@ -103,7 +97,7 @@ async fn accept_loop(listener: TcpListener, acceptor: TlsAcceptor, ctx: Arc { error!("DoT: TCP accept error: {}", e); // Back off to avoid tight-looping on persistent failures (e.g. fd exhaustion). - tokio::time::sleep(ACCEPT_ERROR_BACKOFF).await; + tokio::time::sleep(Duration::from_millis(100)).await; continue; } }; @@ -148,31 +142,22 @@ where loop { // Read 2-byte length prefix (RFC 1035 ยง4.2.2) with idle timeout let mut len_buf = [0u8; 2]; - match tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut len_buf)).await { - Ok(Ok(_)) => {} - Ok(Err(_)) => break, // read error or EOF - Err(_) => break, // idle timeout - } + let Ok(Ok(_)) = tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut len_buf)).await + else { + break; + }; let msg_len = u16::from_be_bytes(len_buf) as usize; - if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&msg_len) { - debug!( - "DoT: invalid message length {} from {}", - msg_len, remote_addr - ); + if msg_len > MAX_MSG_LEN { + debug!("DoT: oversized message {} from {}", msg_len, remote_addr); break; } let mut buffer = BytePacketBuffer::new(); - match tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut buffer.buf[..msg_len])) - .await - { - Ok(Ok(_)) => {} - Ok(Err(_)) => break, - Err(_) => { - debug!("DoT: payload read timeout from {}", remote_addr); - break; - } - } + let Ok(Ok(_)) = + tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut buffer.buf[..msg_len])).await + else { + break; + }; // Parse query up-front so we can echo its question section in SERVFAIL // responses when resolve_query fails. @@ -180,7 +165,8 @@ where Ok(q) => q, Err(e) => { warn!("{} | PARSE ERROR | {}", remote_addr, e); - // msg_len >= MIN_MSG_LEN guarantees buf[0..2] is the client's query id. + // BytePacketBuffer is zero-initialized, so buf[0..2] reads as 0x0000 + // for sub-2-byte messages โ€” harmless FORMERR with id=0. let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]); let mut resp = DnsPacket::new(); resp.header.id = query_id; @@ -431,7 +417,6 @@ mod tests { let (addr, client_config) = spawn_dot_server().await; let mut stream = dot_connect(addr, &client_config).await; - // Send 3 queries on the same TLS connection for i in 0..3u16 { let query = DnsPacket::query(0xA000 + i, "dot-test.example", QueryType::A); let resp = dot_exchange(&mut stream, &query).await; @@ -479,21 +464,4 @@ mod tests { h.await.unwrap(); } } - - #[tokio::test] - async fn dot_localhost_resolution() { - let (addr, client_config) = spawn_dot_server().await; - let mut stream = dot_connect(addr, &client_config).await; - - let query = DnsPacket::query(0xD000, "localhost", QueryType::A); - let resp = dot_exchange(&mut stream, &query).await; - - assert_eq!(resp.header.id, 0xD000); - assert_eq!(resp.header.rescode, ResultCode::NOERROR); - assert_eq!(resp.answers.len(), 1); - match &resp.answers[0] { - DnsRecord::A { addr, .. } => assert_eq!(*addr, std::net::Ipv4Addr::LOCALHOST), - other => panic!("expected A record, got {:?}", other), - } - } }