feat: add DNS-over-TLS (DoT) listener #25
197
src/dot.rs
197
src/dot.rs
@@ -1,4 +1,4 @@
|
|||||||
use std::net::SocketAddr;
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
@@ -13,9 +13,14 @@ use tokio_rustls::TlsAcceptor;
|
|||||||
use crate::buffer::BytePacketBuffer;
|
use crate::buffer::BytePacketBuffer;
|
||||||
use crate::config::DotConfig;
|
use crate::config::DotConfig;
|
||||||
use crate::ctx::{resolve_query, ServerCtx};
|
use crate::ctx::{resolve_query, ServerCtx};
|
||||||
|
use crate::header::ResultCode;
|
||||||
|
use crate::packet::DnsPacket;
|
||||||
|
|
||||||
const MAX_CONNECTIONS: usize = 512;
|
const MAX_CONNECTIONS: usize = 512;
|
||||||
const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
|
const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
|
||||||
|
// Matches BytePacketBuffer::BUF_SIZE — RFC 7858 allows up to 65535 but our
|
||||||
|
// buffer would silently truncate anything larger.
|
||||||
|
const MAX_MSG_LEN: usize = 4096;
|
||||||
|
|
||||||
/// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files.
|
/// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files.
|
||||||
fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> {
|
fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> {
|
||||||
@@ -26,8 +31,6 @@ fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<Serve
|
|||||||
let key = rustls_pemfile::private_key(&mut &key_pem[..])?
|
let key = rustls_pemfile::private_key(&mut &key_pem[..])?
|
||||||
.ok_or("no private key found in key file")?;
|
.ok_or("no private key found in key file")?;
|
||||||
|
|
||||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
|
||||||
|
|
||||||
let config = ServerConfig::builder()
|
let config = ServerConfig::builder()
|
||||||
.with_no_client_auth()
|
.with_no_client_auth()
|
||||||
.with_single_cert(certs, key)?;
|
.with_single_cert(certs, key)?;
|
||||||
@@ -35,6 +38,22 @@ fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<Serve
|
|||||||
Ok(Arc::new(config))
|
Ok(Arc::new(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fallback_tls(ctx: &ServerCtx) -> Option<Arc<ServerConfig>> {
|
||||||
|
if let Some(arc_swap) = ctx.tls_config.as_ref() {
|
||||||
|
return Some(Arc::clone(&*arc_swap.load()));
|
||||||
|
}
|
||||||
|
match crate::tls::build_tls_config(&ctx.proxy_tld, &[]) {
|
||||||
|
Ok(cfg) => Some(cfg),
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"DoT: failed to generate self-signed TLS: {} — DoT disabled",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Start the DNS-over-TLS listener (RFC 7858).
|
/// Start the DNS-over-TLS listener (RFC 7858).
|
||||||
pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
|
pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
|
||||||
let tls_config = match (&config.cert_path, &config.key_path) {
|
let tls_config = match (&config.cert_path, &config.key_path) {
|
||||||
@@ -45,26 +64,24 @@ pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
_ => match ctx.tls_config.as_ref() {
|
(Some(_), None) | (None, Some(_)) => {
|
||||||
Some(arc_swap) => Arc::clone(&*arc_swap.load()),
|
warn!("DoT: both cert_path and key_path must be set — ignoring partial config, using self-signed");
|
||||||
None => match crate::tls::build_tls_config(&ctx.proxy_tld, &[]) {
|
match fallback_tls(&ctx) {
|
||||||
Ok(cfg) => cfg,
|
Some(cfg) => cfg,
|
||||||
Err(e) => {
|
None => return,
|
||||||
warn!(
|
}
|
||||||
"DoT: failed to generate self-signed TLS: {} — DoT disabled",
|
}
|
||||||
e
|
(None, None) => match fallback_tls(&ctx) {
|
||||||
);
|
Some(cfg) => cfg,
|
||||||
return;
|
None => return,
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let bind_addr: std::net::Ipv4Addr = config
|
let bind_addr: IpAddr = config
|
||||||
.bind_addr
|
.bind_addr
|
||||||
.parse()
|
.parse()
|
||||||
.unwrap_or(std::net::Ipv4Addr::UNSPECIFIED);
|
.unwrap_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
|
||||||
let addr: SocketAddr = (bind_addr, config.port).into();
|
let addr = SocketAddr::new(bind_addr, config.port);
|
||||||
let listener = match TcpListener::bind(addr).await {
|
let listener = match TcpListener::bind(addr).await {
|
||||||
Ok(l) => l,
|
Ok(l) => l,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -99,7 +116,7 @@ pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
|
|||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let _permit = permit; // held until task exits
|
let _permit = permit; // held until task exits
|
||||||
|
|
||||||
let mut tls_stream = match acceptor.accept(tcp_stream).await {
|
let tls_stream = match acceptor.accept(tcp_stream).await {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("DoT: TLS handshake failed from {}: {}", remote_addr, e);
|
debug!("DoT: TLS handshake failed from {}: {}", remote_addr, e);
|
||||||
@@ -107,51 +124,75 @@ pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// RFC 7858: connection is persistent — read queries until EOF or idle timeout
|
handle_dot_connection(tls_stream, remote_addr, &ctx).await;
|
||||||
loop {
|
|
||||||
// Read 2-byte length prefix (RFC 1035 §4.2.2) with idle timeout
|
|
||||||
let mut len_buf = [0u8; 2];
|
|
||||||
match tokio::time::timeout(IDLE_TIMEOUT, tls_stream.read_exact(&mut len_buf)).await
|
|
||||||
{
|
|
||||||
Ok(Ok(_)) => {}
|
|
||||||
Ok(Err(_)) => break, // read error or EOF
|
|
||||||
Err(_) => break, // idle timeout
|
|
||||||
}
|
|
||||||
let msg_len = u16::from_be_bytes(len_buf) as usize;
|
|
||||||
if msg_len == 0 || msg_len > 4096 {
|
|
||||||
debug!(
|
|
||||||
"DoT: invalid message length {} from {}",
|
|
||||||
msg_len, remote_addr
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut data = vec![0u8; msg_len];
|
|
||||||
if tls_stream.read_exact(&mut data).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let buffer = BytePacketBuffer::from_bytes(&data);
|
|
||||||
match resolve_query(buffer, remote_addr, &ctx).await {
|
|
||||||
Ok(resp_buffer) => {
|
|
||||||
let resp = resp_buffer.filled();
|
|
||||||
// Coalesce length prefix + response into a single TLS write
|
|
||||||
let mut out = Vec::with_capacity(2 + resp.len());
|
|
||||||
out.extend_from_slice(&(resp.len() as u16).to_be_bytes());
|
|
||||||
out.extend_from_slice(resp);
|
|
||||||
if tls_stream.write_all(&out).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
debug!("DoT: resolve error from {}: {}", remote_addr, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Handle a single persistent DoT connection (RFC 7858).
|
||||||
|
/// Reads length-prefixed DNS queries until EOF, idle timeout, or error.
|
||||||
|
async fn handle_dot_connection<S>(mut stream: S, remote_addr: SocketAddr, ctx: &ServerCtx)
|
||||||
|
where
|
||||||
|
S: AsyncReadExt + AsyncWriteExt + Unpin,
|
||||||
|
{
|
||||||
|
loop {
|
||||||
|
// Read 2-byte length prefix (RFC 1035 §4.2.2) with idle timeout
|
||||||
|
let mut len_buf = [0u8; 2];
|
||||||
|
match tokio::time::timeout(IDLE_TIMEOUT, stream.read_exact(&mut len_buf)).await {
|
||||||
|
Ok(Ok(_)) => {}
|
||||||
|
Ok(Err(_)) => break, // read error or EOF
|
||||||
|
Err(_) => break, // idle timeout
|
||||||
|
}
|
||||||
|
let msg_len = u16::from_be_bytes(len_buf) as usize;
|
||||||
|
if msg_len == 0 || msg_len > MAX_MSG_LEN {
|
||||||
|
debug!(
|
||||||
|
"DoT: invalid message length {} from {}",
|
||||||
|
msg_len, remote_addr
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut data = vec![0u8; msg_len];
|
||||||
|
if stream.read_exact(&mut data).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract query ID before resolve_query consumes the buffer
|
||||||
|
let query_id = data
|
||||||
|
.get(..2)
|
||||||
|
.map(|b| u16::from_be_bytes([b[0], b[1]]))
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
let buffer = BytePacketBuffer::from_bytes(&data);
|
||||||
|
let resp_buffer = match resolve_query(buffer, remote_addr, ctx).await {
|
||||||
|
Ok(buf) => buf,
|
||||||
|
Err(e) => {
|
||||||
|
debug!("DoT: resolve error from {}: {}", remote_addr, e);
|
||||||
|
// Send SERVFAIL so the client doesn't hang
|
||||||
|
let mut resp = DnsPacket::new();
|
||||||
|
resp.header.id = query_id;
|
||||||
|
resp.header.response = true;
|
||||||
|
resp.header.rescode = ResultCode::SERVFAIL;
|
||||||
|
let mut buf = BytePacketBuffer::new();
|
||||||
|
if resp.write(&mut buf).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let resp = resp_buffer.filled();
|
||||||
|
let mut out = Vec::with_capacity(2 + resp.len());
|
||||||
|
out.extend_from_slice(&(resp.len() as u16).to_be_bytes());
|
||||||
|
out.extend_from_slice(resp);
|
||||||
|
if stream.write_all(&out).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if stream.flush().await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -270,43 +311,11 @@ mod tests {
|
|||||||
let ctx = Arc::clone(&ctx);
|
let ctx = Arc::clone(&ctx);
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let _permit = permit;
|
let _permit = permit;
|
||||||
let mut tls_stream = match acceptor.accept(tcp_stream).await {
|
let tls_stream = match acceptor.accept(tcp_stream).await {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(_) => return,
|
Err(_) => return,
|
||||||
};
|
};
|
||||||
loop {
|
handle_dot_connection(tls_stream, remote_addr, &ctx).await;
|
||||||
let mut len_buf = [0u8; 2];
|
|
||||||
match tokio::time::timeout(
|
|
||||||
IDLE_TIMEOUT,
|
|
||||||
tls_stream.read_exact(&mut len_buf),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(Ok(_)) => {}
|
|
||||||
_ => break,
|
|
||||||
}
|
|
||||||
let msg_len = u16::from_be_bytes(len_buf) as usize;
|
|
||||||
if msg_len == 0 || msg_len > 4096 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let mut data = vec![0u8; msg_len];
|
|
||||||
if tls_stream.read_exact(&mut data).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let buffer = BytePacketBuffer::from_bytes(&data);
|
|
||||||
match resolve_query(buffer, remote_addr, &ctx).await {
|
|
||||||
Ok(resp_buffer) => {
|
|
||||||
let resp = resp_buffer.filled();
|
|
||||||
let mut out = Vec::with_capacity(2 + resp.len());
|
|
||||||
out.extend_from_slice(&(resp.len() as u16).to_be_bytes());
|
|
||||||
out.extend_from_slice(resp);
|
|
||||||
if tls_stream.write_all(&out).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user