feat: DoH server endpoint + DoT enabled by default #79

Merged
razvandimescu merged 5 commits from feat/cache-warming into main 2026-04-11 09:06:17 +08:00
7 changed files with 283 additions and 9 deletions
Showing only changes of commit d725091642 - Show all commits

View File

@@ -110,6 +110,10 @@ pub async fn resolve_query(
300, 300,
)); ));
(resp, QueryPath::Local, DnssecStatus::Indeterminate) (resp, QueryPath::Local, DnssecStatus::Indeterminate)
} else if let Some(records) = ctx.zone_map.get(qname.as_str()).and_then(|m| m.get(&qtype)) {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.answers = records.clone();
(resp, QueryPath::Local, DnssecStatus::Indeterminate)
} else if is_special_use_domain(&qname) { } else if is_special_use_domain(&qname) {
// RFC 6761/8880: private PTR, DDR, NAT64 — answer locally // RFC 6761/8880: private PTR, DDR, NAT64 — answer locally
let resp = special_use_response(&query, &qname, qtype); let resp = special_use_response(&query, &qname, qtype);
@@ -158,10 +162,6 @@ pub async fn resolve_query(
60, 60,
)); ));
(resp, QueryPath::Blocked, DnssecStatus::Indeterminate) (resp, QueryPath::Blocked, DnssecStatus::Indeterminate)
} else if let Some(records) = ctx.zone_map.get(qname.as_str()).and_then(|m| m.get(&qtype)) {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.answers = records.clone();
(resp, QueryPath::Local, DnssecStatus::Indeterminate)
} else { } else {
let cached = ctx.cache.read().unwrap().lookup_with_status(&qname, qtype); let cached = ctx.cache.read().unwrap().lookup_with_status(&qname, qtype);
if let Some((cached, cached_dnssec)) = cached { if let Some((cached, cached_dnssec)) = cached {

189
src/doh.rs Normal file
View File

@@ -0,0 +1,189 @@
use std::net::SocketAddr;
use axum::body::Bytes;
use axum::extract::{Request, State};
use axum::response::{IntoResponse, Response};
use hyper::StatusCode;
use log::warn;
use crate::buffer::BytePacketBuffer;
use crate::ctx::{resolve_query, ServerCtx};
use crate::header::ResultCode;
use crate::packet::DnsPacket;
const MAX_DNS_MSG: usize = 4096;
const DOH_CONTENT_TYPE: &str = "application/dns-message";
pub async fn doh_post(
State(state): State<super::proxy::DohState>,
req: Request,
) -> Response {
let host = super::proxy::extract_host(&req);
if !is_doh_host(host.as_deref(), &state.ctx.proxy_tld) {
return StatusCode::NOT_FOUND.into_response();
}
let content_type = req
.headers()
.get(hyper::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !content_type.starts_with(DOH_CONTENT_TYPE) {
return StatusCode::UNSUPPORTED_MEDIA_TYPE.into_response();
}
let body = match axum::body::to_bytes(req.into_body(), MAX_DNS_MSG).await {
Ok(b) => b,
Err(_) => return (StatusCode::PAYLOAD_TOO_LARGE, "body exceeds 4096 bytes").into_response(),
};
if body.is_empty() {
return (StatusCode::BAD_REQUEST, "empty body").into_response();
}
let src = state
.remote_addr
.unwrap_or_else(|| SocketAddr::from(([127, 0, 0, 1], 0)));
resolve_doh(&body, src, &state.ctx).await
}
fn is_doh_host(host: Option<&str>, tld: &str) -> bool {
match host {
Some(h) if h == tld => true,
Some(h) => h.len() == 2 * tld.len() + 1
&& h.starts_with(tld)
&& h.as_bytes().get(tld.len()) == Some(&b'.')
&& h.ends_with(tld),
None => false,
}
}
async fn resolve_doh(dns_bytes: &[u8], src: SocketAddr, ctx: &ServerCtx) -> Response {
let mut buffer = BytePacketBuffer::from_bytes(dns_bytes);
let query = match DnsPacket::from_buffer(&mut buffer) {
Ok(q) => q,
Err(e) => {
warn!("DoH: parse error from {}: {}", src, e);
let query_id = u16::from_be_bytes([
dns_bytes.first().copied().unwrap_or(0),
dns_bytes.get(1).copied().unwrap_or(0),
]);
let mut resp = DnsPacket::new();
resp.header.id = query_id;
resp.header.response = true;
resp.header.rescode = ResultCode::FORMERR;
return serialize_response(&resp);
}
};
let query_id = query.header.id;
let query_rd = query.header.recursion_desired;
let questions = query.questions.clone();
match resolve_query(query, src, ctx).await {
Ok(resp_buffer) => {
let min_ttl = extract_min_ttl(resp_buffer.filled());
dns_response(resp_buffer.filled(), min_ttl)
}
Err(e) => {
warn!("DoH: resolve error for {}: {}", src, e);
let mut resp = DnsPacket::new();
resp.header.id = query_id;
resp.header.response = true;
resp.header.recursion_desired = query_rd;
resp.header.recursion_available = true;
resp.header.rescode = ResultCode::SERVFAIL;
resp.questions = questions;
serialize_response(&resp)
}
}
}
fn extract_min_ttl(wire: &[u8]) -> u32 {
let mut buf = BytePacketBuffer::from_bytes(wire);
match DnsPacket::from_buffer(&mut buf) {
Ok(pkt) => pkt
.answers
.iter()
.map(|r| r.ttl())
.min()
.unwrap_or(0),
Err(_) => 0,
}
}
fn dns_response(wire: &[u8], min_ttl: u32) -> Response {
(
StatusCode::OK,
[
(hyper::header::CONTENT_TYPE, DOH_CONTENT_TYPE),
(hyper::header::CACHE_CONTROL, &format!("max-age={}", min_ttl)),
],
Bytes::copy_from_slice(wire),
)
.into_response()
}
fn serialize_response(pkt: &DnsPacket) -> Response {
let mut buf = BytePacketBuffer::new();
match pkt.write(&mut buf) {
Ok(_) => dns_response(buf.filled(), 0),
Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::buffer::BytePacketBuffer;
use crate::header::ResultCode;
use crate::packet::DnsPacket;
use crate::record::DnsRecord;
#[test]
fn is_doh_host_matches_tld() {
assert!(is_doh_host(Some("numa"), "numa"));
assert!(is_doh_host(Some("numa.numa"), "numa"));
assert!(!is_doh_host(Some("foo.numa"), "numa"));
assert!(!is_doh_host(None, "numa"));
}
#[test]
fn extract_min_ttl_from_response() {
let mut pkt = DnsPacket::new();
pkt.header.response = true;
pkt.answers.push(DnsRecord::A {
domain: "example.com".to_string(),
addr: std::net::Ipv4Addr::new(1, 2, 3, 4),
ttl: 300,
});
pkt.answers.push(DnsRecord::A {
domain: "example.com".to_string(),
addr: std::net::Ipv4Addr::new(5, 6, 7, 8),
ttl: 60,
});
let mut buf = BytePacketBuffer::new();
pkt.write(&mut buf).unwrap();
assert_eq!(extract_min_ttl(buf.filled()), 60);
}
#[test]
fn extract_min_ttl_no_answers() {
let mut pkt = DnsPacket::new();
pkt.header.response = true;
let mut buf = BytePacketBuffer::new();
pkt.write(&mut buf).unwrap();
assert_eq!(extract_min_ttl(buf.filled()), 0);
}
#[test]
fn serialize_formerr_response() {
let mut pkt = DnsPacket::new();
pkt.header.id = 0xABCD;
pkt.header.response = true;
pkt.header.rescode = ResultCode::FORMERR;
let resp = serialize_response(&pkt);
assert_eq!(resp.status(), StatusCode::OK);
}
}

View File

@@ -73,11 +73,15 @@ impl HealthMeta {
recursive_enabled: bool, recursive_enabled: bool,
mdns_enabled: bool, mdns_enabled: bool,
blocking_enabled: bool, blocking_enabled: bool,
doh_enabled: bool,
) -> Self { ) -> Self {
let ca_path = data_dir.join("ca.pem"); let ca_path = data_dir.join("ca.pem");
let ca_fingerprint_sha256 = compute_ca_fingerprint(&ca_path); let ca_fingerprint_sha256 = compute_ca_fingerprint(&ca_path);
let mut features = Vec::new(); let mut features = Vec::new();
if doh_enabled {
features.push("doh".to_string());
}
if dot_enabled { if dot_enabled {
features.push("dot".to_string()); features.push("dot".to_string());
} }

View File

@@ -5,6 +5,7 @@ pub mod cache;
pub mod config; pub mod config;
pub mod ctx; pub mod ctx;
pub mod dnssec; pub mod dnssec;
pub mod doh;
pub mod dot; pub mod dot;
pub mod forward; pub mod forward;
pub mod header; pub mod header;

View File

@@ -243,6 +243,7 @@ async fn main() -> numa::Result<()> {
None None
}; };
let doh_enabled = initial_tls.is_some();
let health_meta = numa::health::HealthMeta::build( let health_meta = numa::health::HealthMeta::build(
&resolved_data_dir, &resolved_data_dir,
config.dot.enabled, config.dot.enabled,
@@ -252,6 +253,7 @@ async fn main() -> numa::Result<()> {
resolved_mode == numa::config::UpstreamMode::Recursive, resolved_mode == numa::config::UpstreamMode::Recursive,
config.lan.enabled, config.lan.enabled,
config.blocking.enabled, config.blocking.enabled,
doh_enabled,
); );
let ca_pem = std::fs::read_to_string(resolved_data_dir.join("ca.pem")).ok(); let ca_pem = std::fs::read_to_string(resolved_data_dir.join("ca.pem")).ok();
@@ -431,6 +433,13 @@ async fn main() -> numa::Result<()> {
if config.dot.enabled { if config.dot.enabled {
row("DoT", g, &format!("tls://:{}", config.dot.port)); row("DoT", g, &format!("tls://:{}", config.dot.port));
} }
if doh_enabled {
row(
"DoH",
g,
&format!("https://:{}/dns-query", config.proxy.tls_port),
);
}
if config.lan.enabled { if config.lan.enabled {
row("LAN", g, "mDNS (_numa._tcp.local)"); row("LAN", g, "mDNS (_numa._tcp.local)");
} }

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use axum::body::Body; use axum::body::Body;
use axum::extract::{Request, State}; use axum::extract::{Request, State};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::any; use axum::routing::{any, post};
use axum::Router; use axum::Router;
use http_body_util::BodyExt; use http_body_util::BodyExt;
use hyper::StatusCode; use hyper::StatusCode;
@@ -18,6 +18,14 @@ use crate::ctx::ServerCtx;
type HttpClient = Client<hyper_util::client::legacy::connect::HttpConnector, Body>; type HttpClient = Client<hyper_util::client::legacy::connect::HttpConnector, Body>;
/// State passed to the DoH handler. Includes the remote address so
/// `resolve_query` can log the client IP.
#[derive(Clone)]
pub struct DohState {
pub ctx: Arc<ServerCtx>,
pub remote_addr: Option<std::net::SocketAddr>,
}
#[derive(Clone)] #[derive(Clone)]
struct ProxyState { struct ProxyState {
ctx: Arc<ServerCtx>, ctx: Arc<ServerCtx>,
@@ -74,9 +82,17 @@ pub async fn start_proxy_tls(ctx: Arc<ServerCtx>, port: u16, bind_addr: Ipv4Addr
// Hold a separate Arc so we can access tls_config after ctx moves into ProxyState // Hold a separate Arc so we can access tls_config after ctx moves into ProxyState
let tls_holder = Arc::clone(&ctx); let tls_holder = Arc::clone(&ctx);
let state = ProxyState { ctx, client }; let proxy_state = ProxyState {
ctx: Arc::clone(&ctx),
client,
};
let app = Router::new().fallback(any(proxy_handler)).with_state(state); // DoH route (RFC 8484) served only on the TLS listener.
// DohState.remote_addr is set per-connection below.
let doh_state = DohState {
ctx,
remote_addr: None,
};
loop { loop {
let (tcp_stream, remote_addr) = match listener.accept().await { let (tcp_stream, remote_addr) = match listener.accept().await {
@@ -91,7 +107,14 @@ pub async fn start_proxy_tls(ctx: Arc<ServerCtx>, port: u16, bind_addr: Ipv4Addr
// unwrap safe: guarded by is_none() check above // unwrap safe: guarded by is_none() check above
let acceptor = let acceptor =
TlsAcceptor::from(Arc::clone(&*tls_holder.tls_config.as_ref().unwrap().load())); TlsAcceptor::from(Arc::clone(&*tls_holder.tls_config.as_ref().unwrap().load()));
let app = app.clone();
let mut conn_doh_state = doh_state.clone();
conn_doh_state.remote_addr = Some(remote_addr);
let app = Router::new()
.route("/dns-query", post(crate::doh::doh_post).with_state(conn_doh_state))
.fallback(any(proxy_handler))
.with_state(proxy_state.clone());
tokio::spawn(async move { tokio::spawn(async move {
let tls_stream = match acceptor.accept(tcp_stream).await { let tls_stream = match acceptor.accept(tcp_stream).await {
@@ -232,7 +255,7 @@ pre .str {{ color: #d48a5a }}
) )
} }
fn extract_host(req: &Request) -> Option<String> { pub fn extract_host(req: &Request) -> Option<String> {
req.headers() req.headers()
.get(hyper::header::HOST) .get(hyper::header::HOST)
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())

View File

@@ -622,6 +622,54 @@ CONF
"10.0.0.1" \ "10.0.0.1" \
"$($KDIG +short dot-test.example A 2>/dev/null)" "$($KDIG +short dot-test.example A 2>/dev/null)"
echo ""
echo "=== DNS-over-HTTPS (RFC 8484) ==="
DOH_QUERY_FILE=/tmp/numa-doh-query.bin
DOH_RESP_FILE=/tmp/numa-doh-resp.bin
# Build DNS wire-format query for dot-test.example A
printf '\x00\x01\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x08dot-test\x07example\x00\x00\x01\x00\x01' > "$DOH_QUERY_FILE"
# POST valid DoH query
DOH_CODE=$(curl -sk -X POST \
--resolve "numa.numa:$PROXY_HTTPS_PORT:127.0.0.1" \
-H "Content-Type: application/dns-message" \
--data-binary @"$DOH_QUERY_FILE" \
--cacert "$CA" \
-o "$DOH_RESP_FILE" \
-w "%{http_code}" \
"https://numa.numa:$PROXY_HTTPS_PORT/dns-query")
check "DoH POST returns HTTP 200" "200" "$DOH_CODE"
# Check response contains IP 10.0.0.1 (hex: 0a000001)
DOH_HEX=$(xxd -p "$DOH_RESP_FILE" | tr -d '\n')
if echo "$DOH_HEX" | grep -q "0a000001"; then
check "DoH response resolves dot-test.example → 10.0.0.1" "found" "found"
else
check "DoH response resolves dot-test.example → 10.0.0.1" "0a000001" "$DOH_HEX"
fi
# Wrong Content-Type → 415
DOH_CT_CODE=$(curl -sk -X POST \
-H "Host: numa.numa" \
-H "Content-Type: text/plain" \
--data-binary @"$DOH_QUERY_FILE" \
-o /dev/null -w "%{http_code}" \
"https://127.0.0.1:$PROXY_HTTPS_PORT/dns-query")
check "DoH wrong Content-Type → 415" "415" "$DOH_CT_CODE"
# Wrong host → 404 (DoH only serves numa.numa)
DOH_HOST_CODE=$(curl -sk -X POST \
-H "Host: foo.numa" \
-H "Content-Type: application/dns-message" \
--data-binary @"$DOH_QUERY_FILE" \
-o /dev/null -w "%{http_code}" \
"https://127.0.0.1:$PROXY_HTTPS_PORT/dns-query")
check "DoH wrong host → 404" "404" "$DOH_HOST_CODE"
rm -f "$DOH_QUERY_FILE" "$DOH_RESP_FILE"
echo "" echo ""
echo "=== Proxy TLS works with DoT enabled ===" echo "=== Proxy TLS works with DoT enabled ==="