feat: add DNS-over-TLS (DoT) listener #25

Merged
razvandimescu merged 19 commits from feat/dns-over-tls into main 2026-04-08 07:53:43 +08:00
Showing only changes of commit e9836f1be7 - Show all commits

View File

@@ -19,12 +19,9 @@ use crate::packet::DnsPacket;
const MAX_CONNECTIONS: usize = 512;
const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
const ACCEPT_ERROR_BACKOFF: Duration = Duration::from_millis(100);
// Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our
// buffer would silently truncate anything larger.
const MAX_MSG_LEN: usize = 4096;
// DNS header is 12 bytes; anything shorter cannot be a valid query.
const MIN_MSG_LEN: usize = 12;
/// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files.
fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> {
@@ -60,9 +57,6 @@ fn fallback_tls(ctx: &ServerCtx) -> Option<Arc<ServerConfig>> {
/// Start the DNS-over-TLS listener (RFC 7858).
pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
if config.cert_path.is_some() != config.key_path.is_some() {
warn!("DoT: both cert_path and key_path must be set — ignoring partial config, using self-signed");
}
let tls_config = match (&config.cert_path, &config.key_path) {
(Some(cert), Some(key)) => match load_tls_config(cert, key) {
Ok(cfg) => cfg,
@@ -103,7 +97,7 @@ async fn accept_loop(listener: TcpListener, acceptor: TlsAcceptor, ctx: Arc<Serv
Err(e) => {
error!("DoT: TCP accept error: {}", e);
// Back off to avoid tight-looping on persistent failures (e.g. fd exhaustion).
tokio::time::sleep(ACCEPT_ERROR_BACKOFF).await;
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
}
};
@@ -148,31 +142,22 @@ where
loop {
// Read 2-byte length prefix (RFC 1035 §4.2.2) with idle timeout
let mut len_buf = [0u8; 2];
match tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut len_buf)).await {
Ok(Ok(_)) => {}
Ok(Err(_)) => break, // read error or EOF
Err(_) => break, // idle timeout
}
let Ok(Ok(_)) = tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut len_buf)).await
else {
break;
};
let msg_len = u16::from_be_bytes(len_buf) as usize;
if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&msg_len) {
debug!(
"DoT: invalid message length {} from {}",
msg_len, remote_addr
);
if msg_len > MAX_MSG_LEN {
debug!("DoT: oversized message {} from {}", msg_len, remote_addr);
break;
}
let mut buffer = BytePacketBuffer::new();
match tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut buffer.buf[..msg_len]))
.await
{
Ok(Ok(_)) => {}
Ok(Err(_)) => break,
Err(_) => {
debug!("DoT: payload read timeout from {}", remote_addr);
let Ok(Ok(_)) =
tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut buffer.buf[..msg_len])).await
else {
break;
}
}
};
// Parse query up-front so we can echo its question section in SERVFAIL
// responses when resolve_query fails.
@@ -180,7 +165,8 @@ where
Ok(q) => q,
Err(e) => {
warn!("{} | PARSE ERROR | {}", remote_addr, e);
// msg_len >= MIN_MSG_LEN guarantees buf[0..2] is the client's query id.
// BytePacketBuffer is zero-initialized, so buf[0..2] reads as 0x0000
// for sub-2-byte messages — harmless FORMERR with id=0.
let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]);
let mut resp = DnsPacket::new();
resp.header.id = query_id;
@@ -431,7 +417,6 @@ mod tests {
let (addr, client_config) = spawn_dot_server().await;
let mut stream = dot_connect(addr, &client_config).await;
// Send 3 queries on the same TLS connection
for i in 0..3u16 {
let query = DnsPacket::query(0xA000 + i, "dot-test.example", QueryType::A);
let resp = dot_exchange(&mut stream, &query).await;
@@ -479,21 +464,4 @@ mod tests {
h.await.unwrap();
}
}
#[tokio::test]
async fn dot_localhost_resolution() {
let (addr, client_config) = spawn_dot_server().await;
let mut stream = dot_connect(addr, &client_config).await;
let query = DnsPacket::query(0xD000, "localhost", QueryType::A);
let resp = dot_exchange(&mut stream, &query).await;
assert_eq!(resp.header.id, 0xD000);
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
assert_eq!(resp.answers.len(), 1);
match &resp.answers[0] {
DnsRecord::A { addr, .. } => assert_eq!(*addr, std::net::Ipv4Addr::LOCALHOST),
other => panic!("expected A record, got {:?}", other),
}
}
}