feat: add DNS-over-TLS (DoT) listener #25

Merged
razvandimescu merged 19 commits from feat/dns-over-tls into main 2026-04-08 07:53:43 +08:00
Showing only changes of commit ee310fe7cb - Show all commits

View File

@@ -1,4 +1,4 @@
use std::net::SocketAddr; use std::net::{IpAddr, SocketAddr};
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@@ -13,9 +13,14 @@ use tokio_rustls::TlsAcceptor;
use crate::buffer::BytePacketBuffer; use crate::buffer::BytePacketBuffer;
use crate::config::DotConfig; use crate::config::DotConfig;
use crate::ctx::{resolve_query, ServerCtx}; use crate::ctx::{resolve_query, ServerCtx};
use crate::header::ResultCode;
use crate::packet::DnsPacket;
const MAX_CONNECTIONS: usize = 512; const MAX_CONNECTIONS: usize = 512;
const IDLE_TIMEOUT: Duration = Duration::from_secs(30); const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
// Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our
// buffer would silently truncate anything larger.
const MAX_MSG_LEN: usize = 4096;
/// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files. /// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files.
fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> { fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> {
@@ -26,8 +31,6 @@ fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<Serve
let key = rustls_pemfile::private_key(&mut &key_pem[..])? let key = rustls_pemfile::private_key(&mut &key_pem[..])?
.ok_or("no private key found in key file")?; .ok_or("no private key found in key file")?;
let _ = rustls::crypto::ring::default_provider().install_default();
let config = ServerConfig::builder() let config = ServerConfig::builder()
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(certs, key)?; .with_single_cert(certs, key)?;
@@ -35,6 +38,22 @@ fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<Serve
Ok(Arc::new(config)) Ok(Arc::new(config))
} }
fn fallback_tls(ctx: &ServerCtx) -> Option<Arc<ServerConfig>> {
if let Some(arc_swap) = ctx.tls_config.as_ref() {
return Some(Arc::clone(&*arc_swap.load()));
}
match crate::tls::build_tls_config(&ctx.proxy_tld, &[]) {
Ok(cfg) => Some(cfg),
Err(e) => {
warn!(
"DoT: failed to generate self-signed TLS: {} — DoT disabled",
e
);
None
}
}
}
/// Start the DNS-over-TLS listener (RFC 7858). /// Start the DNS-over-TLS listener (RFC 7858).
pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) { pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
let tls_config = match (&config.cert_path, &config.key_path) { let tls_config = match (&config.cert_path, &config.key_path) {
@@ -45,26 +64,24 @@ pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
return; return;
} }
}, },
_ => match ctx.tls_config.as_ref() { (Some(_), None) | (None, Some(_)) => {
Some(arc_swap) => Arc::clone(&*arc_swap.load()), warn!("DoT: both cert_path and key_path must be set — ignoring partial config, using self-signed");
None => match crate::tls::build_tls_config(&ctx.proxy_tld, &[]) { match fallback_tls(&ctx) {
Ok(cfg) => cfg, Some(cfg) => cfg,
Err(e) => { None => return,
warn!(
"DoT: failed to generate self-signed TLS: {} — DoT disabled",
e
);
return;
} }
}, }
(None, None) => match fallback_tls(&ctx) {
Some(cfg) => cfg,
None => return,
}, },
}; };
let bind_addr: std::net::Ipv4Addr = config let bind_addr: IpAddr = config
.bind_addr .bind_addr
.parse() .parse()
.unwrap_or(std::net::Ipv4Addr::UNSPECIFIED); .unwrap_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
let addr: SocketAddr = (bind_addr, config.port).into(); let addr = SocketAddr::new(bind_addr, config.port);
let listener = match TcpListener::bind(addr).await { let listener = match TcpListener::bind(addr).await {
Ok(l) => l, Ok(l) => l,
Err(e) => { Err(e) => {
@@ -99,7 +116,7 @@ pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
tokio::spawn(async move { tokio::spawn(async move {
let _permit = permit; // held until task exits let _permit = permit; // held until task exits
let mut tls_stream = match acceptor.accept(tcp_stream).await { let tls_stream = match acceptor.accept(tcp_stream).await {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
debug!("DoT: TLS handshake failed from {}: {}", remote_addr, e); debug!("DoT: TLS handshake failed from {}: {}", remote_addr, e);
@@ -107,18 +124,27 @@ pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
} }
}; };
// RFC 7858: connection is persistent — read queries until EOF or idle timeout handle_dot_connection(tls_stream, remote_addr, &ctx).await;
});
}
}
/// Handle a single persistent DoT connection (RFC 7858).
/// Reads length-prefixed DNS queries until EOF, idle timeout, or error.
async fn handle_dot_connection<S>(mut stream: S, remote_addr: SocketAddr, ctx: &ServerCtx)
where
S: AsyncReadExt + AsyncWriteExt + Unpin,
{
loop { loop {
// Read 2-byte length prefix (RFC 1035 §4.2.2) with idle timeout // Read 2-byte length prefix (RFC 1035 §4.2.2) with idle timeout
let mut len_buf = [0u8; 2]; let mut len_buf = [0u8; 2];
match tokio::time::timeout(IDLE_TIMEOUT, tls_stream.read_exact(&mut len_buf)).await match tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut len_buf)).await {
{
Ok(Ok(_)) => {} Ok(Ok(_)) => {}
Ok(Err(_)) => break, // read error or EOF Ok(Err(_)) => break, // read error or EOF
Err(_) => break, // idle timeout Err(_) => break, // idle timeout
} }
let msg_len = u16::from_be_bytes(len_buf) as usize; let msg_len = u16::from_be_bytes(len_buf) as usize;
if msg_len == 0 || msg_len > 4096 { if msg_len == 0 || msg_len > MAX_MSG_LEN {
debug!( debug!(
"DoT: invalid message length {} from {}", "DoT: invalid message length {} from {}",
msg_len, remote_addr msg_len, remote_addr
@@ -127,28 +153,43 @@ pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
} }
let mut data = vec![0u8; msg_len]; let mut data = vec![0u8; msg_len];
if tls_stream.read_exact(&mut data).await.is_err() { if stream.read_exact(&mut data).await.is_err() {
break; break;
} }
// Extract query ID before resolve_query consumes the buffer
let query_id = data
.get(..2)
.map(|b| u16::from_be_bytes([b[0], b[1]]))
.unwrap_or(0);
let buffer = BytePacketBuffer::from_bytes(&data); let buffer = BytePacketBuffer::from_bytes(&data);
match resolve_query(buffer, remote_addr, &ctx).await { let resp_buffer = match resolve_query(buffer, remote_addr, ctx).await {
Ok(resp_buffer) => { Ok(buf) => buf,
Err(e) => {
debug!("DoT: resolve error from {}: {}", remote_addr, e);
// Send SERVFAIL so the client doesn't hang
let mut resp = DnsPacket::new();
resp.header.id = query_id;
resp.header.response = true;
resp.header.rescode = ResultCode::SERVFAIL;
let mut buf = BytePacketBuffer::new();
if resp.write(&mut buf).is_err() {
break;
}
buf
}
};
let resp = resp_buffer.filled(); let resp = resp_buffer.filled();
// Coalesce length prefix + response into a single TLS write
let mut out = Vec::with_capacity(2 + resp.len()); let mut out = Vec::with_capacity(2 + resp.len());
out.extend_from_slice(&(resp.len() as u16).to_be_bytes()); out.extend_from_slice(&(resp.len() as u16).to_be_bytes());
out.extend_from_slice(resp); out.extend_from_slice(resp);
if tls_stream.write_all(&out).await.is_err() { if stream.write_all(&out).await.is_err() {
break; break;
} }
if stream.flush().await.is_err() {
break;
} }
Err(e) => {
debug!("DoT: resolve error from {}: {}", remote_addr, e);
}
}
}
});
} }
} }
@@ -270,43 +311,11 @@ mod tests {
let ctx = Arc::clone(&ctx); let ctx = Arc::clone(&ctx);
tokio::spawn(async move { tokio::spawn(async move {
let _permit = permit; let _permit = permit;
let mut tls_stream = match acceptor.accept(tcp_stream).await { let tls_stream = match acceptor.accept(tcp_stream).await {
Ok(s) => s, Ok(s) => s,
Err(_) => return, Err(_) => return,
}; };
loop { handle_dot_connection(tls_stream, remote_addr, &ctx).await;
let mut len_buf = [0u8; 2];
match tokio::time::timeout(
IDLE_TIMEOUT,
tls_stream.read_exact(&mut len_buf),
)
.await
{
Ok(Ok(_)) => {}
_ => break,
}
let msg_len = u16::from_be_bytes(len_buf) as usize;
if msg_len == 0 || msg_len > 4096 {
break;
}
let mut data = vec![0u8; msg_len];
if tls_stream.read_exact(&mut data).await.is_err() {
break;
}
let buffer = BytePacketBuffer::from_bytes(&data);
match resolve_query(buffer, remote_addr, &ctx).await {
Ok(resp_buffer) => {
let resp = resp_buffer.filled();
let mut out = Vec::with_capacity(2 + resp.len());
out.extend_from_slice(&(resp.len() as u16).to_be_bytes());
out.extend_from_slice(resp);
if tls_stream.write_all(&out).await.is_err() {
break;
}
}
Err(_) => {}
}
}
}); });
} }
}); });