feat: add DNS-over-TLS (DoT) listener #25
60
src/dot.rs
60
src/dot.rs
@@ -19,12 +19,9 @@ use crate::packet::DnsPacket;
|
||||
const MAX_CONNECTIONS: usize = 512;
|
||||
const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const ACCEPT_ERROR_BACKOFF: Duration = Duration::from_millis(100);
|
||||
// Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our
|
||||
// buffer would silently truncate anything larger.
|
||||
const MAX_MSG_LEN: usize = 4096;
|
||||
// DNS header is 12 bytes; anything shorter cannot be a valid query.
|
||||
const MIN_MSG_LEN: usize = 12;
|
||||
|
||||
/// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files.
|
||||
fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> {
|
||||
@@ -60,9 +57,6 @@ fn fallback_tls(ctx: &ServerCtx) -> Option<Arc<ServerConfig>> {
|
||||
|
||||
/// Start the DNS-over-TLS listener (RFC 7858).
|
||||
pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
|
||||
if config.cert_path.is_some() != config.key_path.is_some() {
|
||||
warn!("DoT: both cert_path and key_path must be set — ignoring partial config, using self-signed");
|
||||
}
|
||||
let tls_config = match (&config.cert_path, &config.key_path) {
|
||||
(Some(cert), Some(key)) => match load_tls_config(cert, key) {
|
||||
Ok(cfg) => cfg,
|
||||
@@ -103,7 +97,7 @@ async fn accept_loop(listener: TcpListener, acceptor: TlsAcceptor, ctx: Arc<Serv
|
||||
Err(e) => {
|
||||
error!("DoT: TCP accept error: {}", e);
|
||||
// Back off to avoid tight-looping on persistent failures (e.g. fd exhaustion).
|
||||
tokio::time::sleep(ACCEPT_ERROR_BACKOFF).await;
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -148,31 +142,22 @@ where
|
||||
loop {
|
||||
// Read 2-byte length prefix (RFC 1035 §4.2.2) with idle timeout
|
||||
let mut len_buf = [0u8; 2];
|
||||
match tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut len_buf)).await {
|
||||
Ok(Ok(_)) => {}
|
||||
Ok(Err(_)) => break, // read error or EOF
|
||||
Err(_) => break, // idle timeout
|
||||
}
|
||||
let Ok(Ok(_)) = tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut len_buf)).await
|
||||
else {
|
||||
break;
|
||||
};
|
||||
let msg_len = u16::from_be_bytes(len_buf) as usize;
|
||||
if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&msg_len) {
|
||||
debug!(
|
||||
"DoT: invalid message length {} from {}",
|
||||
msg_len, remote_addr
|
||||
);
|
||||
if msg_len > MAX_MSG_LEN {
|
||||
debug!("DoT: oversized message {} from {}", msg_len, remote_addr);
|
||||
break;
|
||||
}
|
||||
|
||||
let mut buffer = BytePacketBuffer::new();
|
||||
match tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut buffer.buf[..msg_len]))
|
||||
.await
|
||||
{
|
||||
Ok(Ok(_)) => {}
|
||||
Ok(Err(_)) => break,
|
||||
Err(_) => {
|
||||
debug!("DoT: payload read timeout from {}", remote_addr);
|
||||
break;
|
||||
}
|
||||
}
|
||||
let Ok(Ok(_)) =
|
||||
tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut buffer.buf[..msg_len])).await
|
||||
else {
|
||||
break;
|
||||
};
|
||||
|
||||
// Parse query up-front so we can echo its question section in SERVFAIL
|
||||
// responses when resolve_query fails.
|
||||
@@ -180,7 +165,8 @@ where
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
warn!("{} | PARSE ERROR | {}", remote_addr, e);
|
||||
// msg_len >= MIN_MSG_LEN guarantees buf[0..2] is the client's query id.
|
||||
// BytePacketBuffer is zero-initialized, so buf[0..2] reads as 0x0000
|
||||
// for sub-2-byte messages — harmless FORMERR with id=0.
|
||||
let query_id = u16::from_be_bytes([buffer.buf[0], buffer.buf[1]]);
|
||||
let mut resp = DnsPacket::new();
|
||||
resp.header.id = query_id;
|
||||
@@ -431,7 +417,6 @@ mod tests {
|
||||
let (addr, client_config) = spawn_dot_server().await;
|
||||
let mut stream = dot_connect(addr, &client_config).await;
|
||||
|
||||
// Send 3 queries on the same TLS connection
|
||||
for i in 0..3u16 {
|
||||
let query = DnsPacket::query(0xA000 + i, "dot-test.example", QueryType::A);
|
||||
let resp = dot_exchange(&mut stream, &query).await;
|
||||
@@ -479,21 +464,4 @@ mod tests {
|
||||
h.await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dot_localhost_resolution() {
|
||||
let (addr, client_config) = spawn_dot_server().await;
|
||||
let mut stream = dot_connect(addr, &client_config).await;
|
||||
|
||||
let query = DnsPacket::query(0xD000, "localhost", QueryType::A);
|
||||
let resp = dot_exchange(&mut stream, &query).await;
|
||||
|
||||
assert_eq!(resp.header.id, 0xD000);
|
||||
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
|
||||
assert_eq!(resp.answers.len(), 1);
|
||||
match &resp.answers[0] {
|
||||
DnsRecord::A { addr, .. } => assert_eq!(*addr, std::net::Ipv4Addr::LOCALHOST),
|
||||
other => panic!("expected A record, got {:?}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user