feat: add DNS-over-TLS (DoT) listener #25
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -1159,6 +1159,7 @@ dependencies = [
|
|||||||
"reqwest",
|
"reqwest",
|
||||||
"ring",
|
"ring",
|
||||||
"rustls",
|
"rustls",
|
||||||
|
"rustls-pemfile",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"socket2 0.5.10",
|
"socket2 0.5.10",
|
||||||
@@ -1546,6 +1547,15 @@ dependencies = [
|
|||||||
"zeroize",
|
"zeroize",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustls-pemfile"
|
||||||
|
version = "2.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50"
|
||||||
|
dependencies = [
|
||||||
|
"rustls-pki-types",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls-pki-types"
|
name = "rustls-pki-types"
|
||||||
version = "1.14.0"
|
version = "1.14.0"
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ rustls = "0.23"
|
|||||||
tokio-rustls = "0.26"
|
tokio-rustls = "0.26"
|
||||||
arc-swap = "1"
|
arc-swap = "1"
|
||||||
ring = "0.17"
|
ring = "0.17"
|
||||||
|
rustls-pemfile = "2.2.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = { version = "0.5", features = ["html_reports"] }
|
criterion = { version = "0.5", features = ["html_reports"] }
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::Ipv4Addr;
|
use std::net::Ipv4Addr;
|
||||||
use std::net::Ipv6Addr;
|
use std::net::Ipv6Addr;
|
||||||
use std::path::Path;
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
@@ -29,6 +29,8 @@ pub struct Config {
|
|||||||
pub lan: LanConfig,
|
pub lan: LanConfig,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub dnssec: DnssecConfig,
|
pub dnssec: DnssecConfig,
|
||||||
|
#[serde(default)]
|
||||||
|
pub dot: DotConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@@ -370,6 +372,41 @@ pub struct DnssecConfig {
|
|||||||
pub strict: bool,
|
pub strict: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Clone)]
|
||||||
|
pub struct DotConfig {
|
||||||
|
#[serde(default)]
|
||||||
|
pub enabled: bool,
|
||||||
|
#[serde(default = "default_dot_port")]
|
||||||
|
pub port: u16,
|
||||||
|
#[serde(default = "default_dot_bind_addr")]
|
||||||
|
pub bind_addr: String,
|
||||||
|
/// Path to TLS certificate (PEM). If None, uses self-signed CA.
|
||||||
|
#[serde(default)]
|
||||||
|
pub cert_path: Option<PathBuf>,
|
||||||
|
/// Path to TLS private key (PEM). If None, uses self-signed CA.
|
||||||
|
#[serde(default)]
|
||||||
|
pub key_path: Option<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DotConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
DotConfig {
|
||||||
|
enabled: false,
|
||||||
|
port: default_dot_port(),
|
||||||
|
bind_addr: default_dot_bind_addr(),
|
||||||
|
cert_path: None,
|
||||||
|
key_path: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_dot_port() -> u16 {
|
||||||
|
853
|
||||||
|
}
|
||||||
|
fn default_dot_bind_addr() -> String {
|
||||||
|
"0.0.0.0".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
38
src/ctx.rs
38
src/ctx.rs
@@ -62,24 +62,27 @@ pub struct ServerCtx {
|
|||||||
pub dnssec_strict: bool,
|
pub dnssec_strict: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_query(
|
/// Transport-agnostic DNS resolution. Runs the full pipeline (overrides, blocklist,
|
||||||
|
/// cache, upstream, DNSSEC) and returns the serialized response in a buffer.
|
||||||
|
/// Callers use `.filled()` to get the response bytes without heap allocation.
|
||||||
|
pub async fn resolve_query(
|
||||||
mut buffer: BytePacketBuffer,
|
mut buffer: BytePacketBuffer,
|
||||||
src_addr: SocketAddr,
|
src_addr: SocketAddr,
|
||||||
ctx: &ServerCtx,
|
ctx: &ServerCtx,
|
||||||
) -> crate::Result<()> {
|
) -> crate::Result<BytePacketBuffer> {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let query = match DnsPacket::from_buffer(&mut buffer) {
|
let query = match DnsPacket::from_buffer(&mut buffer) {
|
||||||
Ok(packet) => packet,
|
Ok(packet) => packet,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("{} | PARSE ERROR | {}", src_addr, e);
|
warn!("{} | PARSE ERROR | {}", src_addr, e);
|
||||||
return Ok(());
|
return Err(e);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let (qname, qtype) = match query.questions.first() {
|
let (qname, qtype) = match query.questions.first() {
|
||||||
Some(q) => (q.name.clone(), q.qtype),
|
Some(q) => (q.name.clone(), q.qtype),
|
||||||
None => return Ok(()),
|
None => return Err("empty question section".into()),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Pipeline: overrides -> .tld interception -> blocklist -> local zones -> cache -> upstream
|
// Pipeline: overrides -> .tld interception -> blocklist -> local zones -> cache -> upstream
|
||||||
@@ -306,17 +309,15 @@ pub async fn handle_query(
|
|||||||
response.resources.len(),
|
response.resources.len(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Serialize response
|
||||||
let mut resp_buffer = BytePacketBuffer::new();
|
let mut resp_buffer = BytePacketBuffer::new();
|
||||||
if response.write(&mut resp_buffer).is_err() {
|
if response.write(&mut resp_buffer).is_err() {
|
||||||
// Response too large for UDP — set TC bit and send header + question only
|
// Response too large — set TC bit and send header + question only
|
||||||
debug!("response too large, setting TC bit for {}", qname);
|
debug!("response too large, setting TC bit for {}", qname);
|
||||||
let mut tc_response = DnsPacket::response_from(&query, response.header.rescode);
|
let mut tc_response = DnsPacket::response_from(&query, response.header.rescode);
|
||||||
tc_response.header.truncated_message = true;
|
tc_response.header.truncated_message = true;
|
||||||
let mut tc_buffer = BytePacketBuffer::new();
|
resp_buffer = BytePacketBuffer::new();
|
||||||
tc_response.write(&mut tc_buffer)?;
|
tc_response.write(&mut resp_buffer)?;
|
||||||
ctx.socket.send_to(tc_buffer.filled(), src_addr).await?;
|
|
||||||
} else {
|
|
||||||
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record stats and query log
|
// Record stats and query log
|
||||||
@@ -339,6 +340,23 @@ pub async fn handle_query(
|
|||||||
dnssec,
|
dnssec,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
Ok(resp_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a DNS query received over UDP. Thin wrapper around resolve_query.
|
||||||
|
pub async fn handle_query(
|
||||||
|
buffer: BytePacketBuffer,
|
||||||
|
src_addr: SocketAddr,
|
||||||
|
ctx: &ServerCtx,
|
||||||
|
) -> crate::Result<()> {
|
||||||
|
match resolve_query(buffer, src_addr, ctx).await {
|
||||||
|
Ok(resp_buffer) => {
|
||||||
|
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("{} | RESOLVE ERROR | {}", src_addr, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
444
src/dot.rs
Normal file
444
src/dot.rs
Normal file
@@ -0,0 +1,444 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use log::{debug, error, info, warn};
|
||||||
|
use rustls::ServerConfig;
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
|
use tokio_rustls::TlsAcceptor;
|
||||||
|
|
||||||
|
use crate::buffer::BytePacketBuffer;
|
||||||
|
use crate::config::DotConfig;
|
||||||
|
use crate::ctx::{resolve_query, ServerCtx};
|
||||||
|
|
||||||
|
const MAX_CONNECTIONS: usize = 512;
|
||||||
|
const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
|
||||||
|
|
||||||
|
/// Build a TLS ServerConfig for DoT from user-provided cert/key PEM files.
|
||||||
|
fn load_tls_config(cert_path: &Path, key_path: &Path) -> crate::Result<Arc<ServerConfig>> {
|
||||||
|
let cert_pem = std::fs::read(cert_path)?;
|
||||||
|
let key_pem = std::fs::read(key_path)?;
|
||||||
|
|
||||||
|
let certs: Vec<_> = rustls_pemfile::certs(&mut &cert_pem[..]).collect::<Result<_, _>>()?;
|
||||||
|
let key = rustls_pemfile::private_key(&mut &key_pem[..])?
|
||||||
|
.ok_or("no private key found in key file")?;
|
||||||
|
|
||||||
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||||
|
|
||||||
|
let config = ServerConfig::builder()
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_single_cert(certs, key)?;
|
||||||
|
|
||||||
|
Ok(Arc::new(config))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the DNS-over-TLS listener (RFC 7858).
|
||||||
|
pub async fn start_dot(ctx: Arc<ServerCtx>, config: &DotConfig) {
|
||||||
|
let tls_config = match (&config.cert_path, &config.key_path) {
|
||||||
|
(Some(cert), Some(key)) => match load_tls_config(cert, key) {
|
||||||
|
Ok(cfg) => cfg,
|
||||||
|
Err(e) => {
|
||||||
|
warn!("DoT: failed to load TLS cert/key: {} — DoT disabled", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => match ctx.tls_config.as_ref() {
|
||||||
|
Some(arc_swap) => Arc::clone(&*arc_swap.load()),
|
||||||
|
None => match crate::tls::build_tls_config(&ctx.proxy_tld, &[]) {
|
||||||
|
Ok(cfg) => cfg,
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"DoT: failed to generate self-signed TLS: {} — DoT disabled",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let bind_addr: std::net::Ipv4Addr = config
|
||||||
|
.bind_addr
|
||||||
|
.parse()
|
||||||
|
.unwrap_or(std::net::Ipv4Addr::UNSPECIFIED);
|
||||||
|
let addr: SocketAddr = (bind_addr, config.port).into();
|
||||||
|
let listener = match TcpListener::bind(addr).await {
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) => {
|
||||||
|
warn!("DoT: could not bind {} ({}) — DoT disabled", addr, e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
info!("DoT listening on {}", addr);
|
||||||
|
|
||||||
|
let acceptor = TlsAcceptor::from(tls_config);
|
||||||
|
let semaphore = Arc::new(Semaphore::new(MAX_CONNECTIONS));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let (tcp_stream, remote_addr) = match listener.accept().await {
|
||||||
|
Ok(conn) => conn,
|
||||||
|
Err(e) => {
|
||||||
|
error!("DoT: TCP accept error: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let permit = match semaphore.clone().try_acquire_owned() {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => {
|
||||||
|
debug!("DoT: connection limit reached, rejecting {}", remote_addr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let acceptor = acceptor.clone();
|
||||||
|
let ctx = Arc::clone(&ctx);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _permit = permit; // held until task exits
|
||||||
|
|
||||||
|
let mut tls_stream = match acceptor.accept(tcp_stream).await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
debug!("DoT: TLS handshake failed from {}: {}", remote_addr, e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// RFC 7858: connection is persistent — read queries until EOF or idle timeout
|
||||||
|
loop {
|
||||||
|
// Read 2-byte length prefix (RFC 1035 §4.2.2) with idle timeout
|
||||||
|
let mut len_buf = [0u8; 2];
|
||||||
|
match tokio::time::timeout(IDLE_TIMEOUT, tls_stream.read_exact(&mut len_buf)).await
|
||||||
|
{
|
||||||
|
Ok(Ok(_)) => {}
|
||||||
|
Ok(Err(_)) => break, // read error or EOF
|
||||||
|
Err(_) => break, // idle timeout
|
||||||
|
}
|
||||||
|
let msg_len = u16::from_be_bytes(len_buf) as usize;
|
||||||
|
if msg_len == 0 || msg_len > 4096 {
|
||||||
|
debug!(
|
||||||
|
"DoT: invalid message length {} from {}",
|
||||||
|
msg_len, remote_addr
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut data = vec![0u8; msg_len];
|
||||||
|
if tls_stream.read_exact(&mut data).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let buffer = BytePacketBuffer::from_bytes(&data);
|
||||||
|
match resolve_query(buffer, remote_addr, &ctx).await {
|
||||||
|
Ok(resp_buffer) => {
|
||||||
|
let resp = resp_buffer.filled();
|
||||||
|
// Coalesce length prefix + response into a single TLS write
|
||||||
|
let mut out = Vec::with_capacity(2 + resp.len());
|
||||||
|
out.extend_from_slice(&(resp.len() as u16).to_be_bytes());
|
||||||
|
out.extend_from_slice(resp);
|
||||||
|
if tls_stream.write_all(&out).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!("DoT: resolve error from {}: {}", remote_addr, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Mutex, RwLock};
|
||||||
|
|
||||||
|
use rcgen::{CertificateParams, DnType, KeyPair};
|
||||||
|
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName};
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
|
||||||
|
use crate::buffer::BytePacketBuffer;
|
||||||
|
use crate::header::ResultCode;
|
||||||
|
use crate::packet::DnsPacket;
|
||||||
|
use crate::question::QueryType;
|
||||||
|
use crate::record::DnsRecord;
|
||||||
|
|
||||||
|
/// Generate a self-signed cert + key in memory, return (ServerConfig, ClientConfig).
|
||||||
|
fn test_tls_configs() -> (Arc<ServerConfig>, Arc<rustls::ClientConfig>) {
|
||||||
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||||
|
|
||||||
|
let key_pair = KeyPair::generate().unwrap();
|
||||||
|
let mut params = CertificateParams::default();
|
||||||
|
params
|
||||||
|
.distinguished_name
|
||||||
|
.push(DnType::CommonName, "localhost");
|
||||||
|
params.subject_alt_names = vec![rcgen::SanType::DnsName("localhost".try_into().unwrap())];
|
||||||
|
let cert = params.self_signed(&key_pair).unwrap();
|
||||||
|
|
||||||
|
let cert_der = CertificateDer::from(cert.der().to_vec());
|
||||||
|
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
|
||||||
|
|
||||||
|
let server_config = ServerConfig::builder()
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_single_cert(vec![cert_der.clone()], key_der)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut root_store = rustls::RootCertStore::empty();
|
||||||
|
root_store.add(cert_der).unwrap();
|
||||||
|
let client_config = rustls::ClientConfig::builder()
|
||||||
|
.with_root_certificates(root_store)
|
||||||
|
.with_no_client_auth();
|
||||||
|
|
||||||
|
(Arc::new(server_config), Arc::new(client_config))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spin up a DoT listener with a test TLS config. Returns (addr, client_config).
|
||||||
|
async fn spawn_dot_server() -> (SocketAddr, Arc<rustls::ClientConfig>) {
|
||||||
|
let (server_tls, client_tls) = test_tls_configs();
|
||||||
|
|
||||||
|
let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let ctx = Arc::new(ServerCtx {
|
||||||
|
socket,
|
||||||
|
zone_map: {
|
||||||
|
let mut m = HashMap::new();
|
||||||
|
let mut inner = HashMap::new();
|
||||||
|
inner.insert(
|
||||||
|
QueryType::A,
|
||||||
|
vec![DnsRecord::A {
|
||||||
|
domain: "dot-test.example".to_string(),
|
||||||
|
addr: std::net::Ipv4Addr::new(10, 0, 0, 1),
|
||||||
|
ttl: 300,
|
||||||
|
}],
|
||||||
|
);
|
||||||
|
m.insert("dot-test.example".to_string(), inner);
|
||||||
|
m
|
||||||
|
},
|
||||||
|
cache: RwLock::new(crate::cache::DnsCache::new(100, 60, 86400)),
|
||||||
|
stats: Mutex::new(crate::stats::ServerStats::new()),
|
||||||
|
overrides: RwLock::new(crate::override_store::OverrideStore::new()),
|
||||||
|
blocklist: RwLock::new(crate::blocklist::BlocklistStore::new()),
|
||||||
|
query_log: Mutex::new(crate::query_log::QueryLog::new(100)),
|
||||||
|
services: Mutex::new(crate::service_store::ServiceStore::new()),
|
||||||
|
lan_peers: Mutex::new(crate::lan::PeerStore::new(90)),
|
||||||
|
forwarding_rules: Vec::new(),
|
||||||
|
upstream: Mutex::new(crate::forward::Upstream::Udp(
|
||||||
|
"127.0.0.1:53".parse().unwrap(),
|
||||||
|
)),
|
||||||
|
upstream_auto: false,
|
||||||
|
upstream_port: 53,
|
||||||
|
lan_ip: Mutex::new(std::net::Ipv4Addr::LOCALHOST),
|
||||||
|
timeout: Duration::from_secs(3),
|
||||||
|
proxy_tld: "numa".to_string(),
|
||||||
|
proxy_tld_suffix: ".numa".to_string(),
|
||||||
|
lan_enabled: false,
|
||||||
|
config_path: String::new(),
|
||||||
|
config_found: false,
|
||||||
|
config_dir: std::path::PathBuf::from("/tmp"),
|
||||||
|
data_dir: std::path::PathBuf::from("/tmp"),
|
||||||
|
tls_config: Some(arc_swap::ArcSwap::from(server_tls)),
|
||||||
|
upstream_mode: crate::config::UpstreamMode::Forward,
|
||||||
|
root_hints: Vec::new(),
|
||||||
|
srtt: RwLock::new(crate::srtt::SrttCache::new(true)),
|
||||||
|
inflight: Mutex::new(HashMap::new()),
|
||||||
|
dnssec_enabled: false,
|
||||||
|
dnssec_strict: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let tls_config = Arc::clone(&*ctx.tls_config.as_ref().unwrap().load());
|
||||||
|
let acceptor = TlsAcceptor::from(tls_config);
|
||||||
|
let semaphore = Arc::new(Semaphore::new(MAX_CONNECTIONS));
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
let (tcp_stream, remote_addr) = match listener.accept().await {
|
||||||
|
Ok(conn) => conn,
|
||||||
|
Err(_) => return,
|
||||||
|
};
|
||||||
|
let permit = match semaphore.clone().try_acquire_owned() {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
let acceptor = acceptor.clone();
|
||||||
|
let ctx = Arc::clone(&ctx);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _permit = permit;
|
||||||
|
let mut tls_stream = match acceptor.accept(tcp_stream).await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(_) => return,
|
||||||
|
};
|
||||||
|
loop {
|
||||||
|
let mut len_buf = [0u8; 2];
|
||||||
|
match tokio::time::timeout(
|
||||||
|
IDLE_TIMEOUT,
|
||||||
|
tls_stream.read_exact(&mut len_buf),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Ok(_)) => {}
|
||||||
|
_ => break,
|
||||||
|
}
|
||||||
|
let msg_len = u16::from_be_bytes(len_buf) as usize;
|
||||||
|
if msg_len == 0 || msg_len > 4096 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let mut data = vec![0u8; msg_len];
|
||||||
|
if tls_stream.read_exact(&mut data).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let buffer = BytePacketBuffer::from_bytes(&data);
|
||||||
|
match resolve_query(buffer, remote_addr, &ctx).await {
|
||||||
|
Ok(resp_buffer) => {
|
||||||
|
let resp = resp_buffer.filled();
|
||||||
|
let mut out = Vec::with_capacity(2 + resp.len());
|
||||||
|
out.extend_from_slice(&(resp.len() as u16).to_be_bytes());
|
||||||
|
out.extend_from_slice(resp);
|
||||||
|
if tls_stream.write_all(&out).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
(addr, client_tls)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Open a TLS connection to the DoT server and return the stream.
|
||||||
|
async fn dot_connect(
|
||||||
|
addr: SocketAddr,
|
||||||
|
client_config: &Arc<rustls::ClientConfig>,
|
||||||
|
) -> tokio_rustls::client::TlsStream<tokio::net::TcpStream> {
|
||||||
|
let connector = tokio_rustls::TlsConnector::from(Arc::clone(client_config));
|
||||||
|
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
|
||||||
|
connector
|
||||||
|
.connect(ServerName::try_from("localhost").unwrap(), tcp)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a DNS query over a DoT stream and read the response.
|
||||||
|
async fn dot_exchange(
|
||||||
|
stream: &mut tokio_rustls::client::TlsStream<tokio::net::TcpStream>,
|
||||||
|
query: &DnsPacket,
|
||||||
|
) -> DnsPacket {
|
||||||
|
let mut buf = BytePacketBuffer::new();
|
||||||
|
query.write(&mut buf).unwrap();
|
||||||
|
let msg = buf.filled();
|
||||||
|
|
||||||
|
let mut out = Vec::with_capacity(2 + msg.len());
|
||||||
|
out.extend_from_slice(&(msg.len() as u16).to_be_bytes());
|
||||||
|
out.extend_from_slice(msg);
|
||||||
|
stream.write_all(&out).await.unwrap();
|
||||||
|
|
||||||
|
let mut len_buf = [0u8; 2];
|
||||||
|
stream.read_exact(&mut len_buf).await.unwrap();
|
||||||
|
let resp_len = u16::from_be_bytes(len_buf) as usize;
|
||||||
|
|
||||||
|
let mut data = vec![0u8; resp_len];
|
||||||
|
stream.read_exact(&mut data).await.unwrap();
|
||||||
|
|
||||||
|
let mut resp_buf = BytePacketBuffer::from_bytes(&data);
|
||||||
|
DnsPacket::from_buffer(&mut resp_buf).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn dot_resolves_local_zone() {
|
||||||
|
let (addr, client_config) = spawn_dot_server().await;
|
||||||
|
let mut stream = dot_connect(addr, &client_config).await;
|
||||||
|
|
||||||
|
let query = DnsPacket::query(0x1234, "dot-test.example", QueryType::A);
|
||||||
|
let resp = dot_exchange(&mut stream, &query).await;
|
||||||
|
|
||||||
|
assert_eq!(resp.header.id, 0x1234);
|
||||||
|
assert!(resp.header.response);
|
||||||
|
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
|
||||||
|
assert_eq!(resp.answers.len(), 1);
|
||||||
|
match &resp.answers[0] {
|
||||||
|
DnsRecord::A { domain, addr, ttl } => {
|
||||||
|
assert_eq!(domain, "dot-test.example");
|
||||||
|
assert_eq!(*addr, std::net::Ipv4Addr::new(10, 0, 0, 1));
|
||||||
|
assert_eq!(*ttl, 300);
|
||||||
|
}
|
||||||
|
other => panic!("expected A record, got {:?}", other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn dot_multiple_queries_on_persistent_connection() {
|
||||||
|
let (addr, client_config) = spawn_dot_server().await;
|
||||||
|
let mut stream = dot_connect(addr, &client_config).await;
|
||||||
|
|
||||||
|
// Send 3 queries on the same TLS connection
|
||||||
|
for i in 0..3u16 {
|
||||||
|
let query = DnsPacket::query(0xA000 + i, "dot-test.example", QueryType::A);
|
||||||
|
let resp = dot_exchange(&mut stream, &query).await;
|
||||||
|
assert_eq!(resp.header.id, 0xA000 + i);
|
||||||
|
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
|
||||||
|
assert_eq!(resp.answers.len(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn dot_nxdomain_for_unknown() {
|
||||||
|
let (addr, client_config) = spawn_dot_server().await;
|
||||||
|
let mut stream = dot_connect(addr, &client_config).await;
|
||||||
|
|
||||||
|
let query = DnsPacket::query(0xBEEF, "nonexistent.test", QueryType::A);
|
||||||
|
let resp = dot_exchange(&mut stream, &query).await;
|
||||||
|
|
||||||
|
assert_eq!(resp.header.id, 0xBEEF);
|
||||||
|
assert!(resp.header.response);
|
||||||
|
// Query goes to upstream (127.0.0.1:53), which will fail — expect SERVFAIL
|
||||||
|
assert_eq!(resp.header.rescode, ResultCode::SERVFAIL);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn dot_concurrent_connections() {
|
||||||
|
let (addr, client_config) = spawn_dot_server().await;
|
||||||
|
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
for i in 0..5u16 {
|
||||||
|
let cfg = Arc::clone(&client_config);
|
||||||
|
handles.push(tokio::spawn(async move {
|
||||||
|
let mut stream = dot_connect(addr, &cfg).await;
|
||||||
|
let query = DnsPacket::query(0xC000 + i, "dot-test.example", QueryType::A);
|
||||||
|
let resp = dot_exchange(&mut stream, &query).await;
|
||||||
|
assert_eq!(resp.header.id, 0xC000 + i);
|
||||||
|
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
|
||||||
|
assert_eq!(resp.answers.len(), 1);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for h in handles {
|
||||||
|
h.await.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn dot_localhost_resolution() {
|
||||||
|
let (addr, client_config) = spawn_dot_server().await;
|
||||||
|
let mut stream = dot_connect(addr, &client_config).await;
|
||||||
|
|
||||||
|
let query = DnsPacket::query(0xD000, "localhost", QueryType::A);
|
||||||
|
let resp = dot_exchange(&mut stream, &query).await;
|
||||||
|
|
||||||
|
assert_eq!(resp.header.id, 0xD000);
|
||||||
|
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
|
||||||
|
assert_eq!(resp.answers.len(), 1);
|
||||||
|
match &resp.answers[0] {
|
||||||
|
DnsRecord::A { addr, .. } => assert_eq!(*addr, std::net::Ipv4Addr::LOCALHOST),
|
||||||
|
other => panic!("expected A record, got {:?}", other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ pub mod cache;
|
|||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod ctx;
|
pub mod ctx;
|
||||||
pub mod dnssec;
|
pub mod dnssec;
|
||||||
|
pub mod dot;
|
||||||
pub mod forward;
|
pub mod forward;
|
||||||
pub mod header;
|
pub mod header;
|
||||||
pub mod lan;
|
pub mod lan;
|
||||||
|
|||||||
12
src/main.rs
12
src/main.rs
@@ -370,6 +370,9 @@ async fn main() -> numa::Result<()> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if config.dot.enabled {
|
||||||
|
row("DoT", g, &format!("tls://:{}", config.dot.port));
|
||||||
|
}
|
||||||
if config.lan.enabled {
|
if config.lan.enabled {
|
||||||
row("LAN", g, "mDNS (_numa._tcp.local)");
|
row("LAN", g, "mDNS (_numa._tcp.local)");
|
||||||
}
|
}
|
||||||
@@ -477,6 +480,15 @@ async fn main() -> numa::Result<()> {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Spawn DNS-over-TLS listener (RFC 7858)
|
||||||
|
if config.dot.enabled {
|
||||||
|
let dot_ctx = Arc::clone(&ctx);
|
||||||
|
let dot_config = config.dot.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
numa::dot::start_dot(dot_ctx, &dot_config).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// UDP DNS listener
|
// UDP DNS listener
|
||||||
#[allow(clippy::infinite_loop)]
|
#[allow(clippy::infinite_loop)]
|
||||||
loop {
|
loop {
|
||||||
|
|||||||
Reference in New Issue
Block a user