diff --git a/src/ctx.rs b/src/ctx.rs index d3d4eb0..17a4979 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -65,21 +65,15 @@ pub struct ServerCtx { /// Transport-agnostic DNS resolution. Runs the full pipeline (overrides, blocklist, /// cache, upstream, DNSSEC) and returns the serialized response in a buffer. /// Callers use `.filled()` to get the response bytes without heap allocation. +/// Callers are responsible for parsing the incoming buffer into a `DnsPacket` +/// (and logging parse errors) before calling this function. pub async fn resolve_query( - mut buffer: BytePacketBuffer, + query: DnsPacket, src_addr: SocketAddr, ctx: &ServerCtx, ) -> crate::Result { let start = Instant::now(); - let query = match DnsPacket::from_buffer(&mut buffer) { - Ok(packet) => packet, - Err(e) => { - warn!("{} | PARSE ERROR | {}", src_addr, e); - return Err(e); - } - }; - let (qname, qtype) = match query.questions.first() { Some(q) => (q.name.clone(), q.qtype), None => return Err("empty question section".into()), @@ -347,11 +341,18 @@ pub async fn resolve_query( /// Handle a DNS query received over UDP. Thin wrapper around resolve_query. pub async fn handle_query( - buffer: BytePacketBuffer, + mut buffer: BytePacketBuffer, src_addr: SocketAddr, ctx: &ServerCtx, ) -> crate::Result<()> { - match resolve_query(buffer, src_addr, ctx).await { + let query = match DnsPacket::from_buffer(&mut buffer) { + Ok(packet) => packet, + Err(e) => { + warn!("{} | PARSE ERROR | {}", src_addr, e); + return Ok(()); + } + }; + match resolve_query(query, src_addr, ctx).await { Ok(resp_buffer) => { ctx.socket.send_to(resp_buffer.filled(), src_addr).await?; } diff --git a/src/dot.rs b/src/dot.rs index d9c1180..2178c26 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -19,9 +19,12 @@ use crate::packet::DnsPacket; const MAX_CONNECTIONS: usize = 512; const IDLE_TIMEOUT: Duration = Duration::from_secs(30); const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10); +const ACCEPT_ERROR_BACKOFF: Duration = Duration::from_millis(100); // Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our // buffer would silently truncate anything larger. const MAX_MSG_LEN: usize = 4096; +// DNS header is 12 bytes; anything shorter cannot be a valid query. +const MIN_MSG_LEN: usize = 12; /// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files. fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result> { @@ -103,6 +106,8 @@ async fn accept_loop(listener: TcpListener, acceptor: TlsAcceptor, ctx: Arc conn, Err(e) => { error!("DoT: TCP accept error: {}", e); + // Back off to avoid tight-looping on persistent failures (e.g. fd exhaustion). + tokio::time::sleep(ACCEPT_ERROR_BACKOFF).await; continue; } }; @@ -153,7 +158,7 @@ where Err(_) => break, // idle timeout } let msg_len = u16::from_be_bytes(len_buf) as usize; - if msg_len == 0 || msg_len > MAX_MSG_LEN { + if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&msg_len) { debug!( "DoT: invalid message length {} from {}", msg_len, remote_addr @@ -173,37 +178,66 @@ where } } - let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]); - let resp_buffer = match resolve_query(buffer, remote_addr, ctx).await { - Ok(buf) => buf, + // Parse query up-front so we can echo its question section in SERVFAIL + // responses when resolve_query fails. + let query = match DnsPacket::from_buffer(&mut buffer) { + Ok(q) => q, Err(e) => { - debug!("DoT: resolve error from {}: {}", remote_addr, e); - // Send SERVFAIL so the client doesn't hang + warn!("{} | PARSE ERROR | {}", remote_addr, e); + // msg_len >= MIN_MSG_LEN guarantees buf[0..2] is the client's query id. + let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]); let mut resp = DnsPacket::new(); resp.header.id = query_id; resp.header.response = true; - resp.header.rescode = ResultCode::SERVFAIL; - let mut buf = BytePacketBuffer::new(); - if resp.write(&mut buf).is_err() { + resp.header.rescode = ResultCode::FORMERR; + let mut out_buf = BytePacketBuffer::new(); + if resp.write(&mut out_buf).is_err() { + debug!("DoT: failed to serialize FORMERR for {}", remote_addr); + break; + } + if write_framed(&mut stream, out_buf.filled()).await.is_err() { + break; + } + continue; + } + }; + + let resp_buffer = match resolve_query(query.clone(), remote_addr, ctx).await { + Ok(buf) => buf, + Err(e) => { + warn!("{} | RESOLVE ERROR | {}", remote_addr, e); + // Build SERVFAIL that echoes the original question section. + let resp = DnsPacket::response_from(&query, ResultCode::SERVFAIL); + let mut out_buf = BytePacketBuffer::new(); + if resp.write(&mut out_buf).is_err() { debug!("DoT: failed to serialize SERVFAIL for {}", remote_addr); break; } - buf + out_buf } }; - let resp = resp_buffer.filled(); - let mut out = Vec::with_capacity(2 + resp.len()); - out.extend_from_slice(&(resp.len() as u16).to_be_bytes()); - out.extend_from_slice(resp); - if stream.write_all(&out).await.is_err() { - break; - } - if stream.flush().await.is_err() { + if write_framed(&mut stream, resp_buffer.filled()) + .await + .is_err() + { break; } } } +/// Write a DNS message with its 2-byte length prefix, coalesced into one syscall. +async fn write_framed(stream: &mut S, msg: &[u8]) -> std::io::Result<()> +where + S: AsyncWriteExt + Unpin, +{ + let mut out = Vec::with_capacity(2 + msg.len()); + out.extend_from_slice(&(msg.len() as u16).to_be_bytes()); + out.extend_from_slice(msg); + stream.write_all(&out).await?; + stream.flush().await?; + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -250,10 +284,16 @@ mod tests { } /// Spin up a DoT listener with a test TLS config. Returns (addr, client_config). + /// The upstream is pointed at a bound-but-unresponsive UDP socket we own, so + /// any query that escapes to the upstream path times out deterministically + /// (SERVFAIL) regardless of what the host has running on port 53. async fn spawn_dot_server() -> (SocketAddr, Arc) { let (server_tls, client_tls) = test_tls_configs(); let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap(); + // Bind an unresponsive upstream and leak it so it lives for the test duration. + let blackhole = Box::leak(Box::new(std::net::UdpSocket::bind("127.0.0.1:0").unwrap())); + let upstream_addr = blackhole.local_addr().unwrap(); let ctx = Arc::new(ServerCtx { socket, zone_map: { @@ -278,13 +318,11 @@ mod tests { services: Mutex::new(crate::service_store::ServiceStore::new()), lan_peers: Mutex::new(crate::lan::PeerStore::new(90)), forwarding_rules: Vec::new(), - upstream: Mutex::new(crate::forward::Upstream::Udp( - "127.0.0.1:53".parse().unwrap(), - )), + upstream: Mutex::new(crate::forward::Upstream::Udp(upstream_addr)), upstream_auto: false, upstream_port: 53, lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST), - timeout: Duration::from_secs(3), + timeout: Duration::from_millis(200), proxy_tld: "numa".to_string(), proxy_tld_suffix: ".numa".to_string(), lan_enabled: false, @@ -397,8 +435,11 @@ mod tests { assert_eq!(resp.header.id, 0xBEEF); assert!(resp.header.response); - // Query goes to upstream (127.0.0.1:53), which will fail — expect SERVFAIL + // Query goes to the blackhole upstream which never replies → SERVFAIL. + // The SERVFAIL response echoes the question section. assert_eq!(resp.header.rescode, ResultCode::SERVFAIL); + assert_eq!(resp.questions.len(), 1); + assert_eq!(resp.questions[0].name, "nonexistent.test"); } #[tokio::test]