diff --git a/Cargo.lock b/Cargo.lock index fb82483..3d45147 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -393,6 +393,12 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -523,6 +529,25 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "h2" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.16.1" @@ -584,6 +609,7 @@ dependencies = [ "bytes", "futures-channel", "futures-core", + "h2", "http", "http-body", "httparse", @@ -1211,6 +1237,7 @@ dependencies = [ "base64", "bytes", "futures-core", + "h2", "http", "http-body", "http-body-util", diff --git a/Cargo.toml b/Cargo.toml index f9382cc..303d6ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ serde_json = "1" toml = "0.8" log = "0.4" env_logger = "0.11" -reqwest = { version = "0.12", features = ["rustls-tls", "gzip"], default-features = false } +reqwest = { version = "0.12", features = ["rustls-tls", "gzip", "http2"], default-features = false } hyper = { version = "1", features = ["client", "http1", "server"] } hyper-util = { version = "0.1", features = ["client-legacy", "http1", "tokio"] } http-body-util = "0.1" diff --git a/numa.toml b/numa.toml index 1e0f851..09e8523 100644 --- a/numa.toml +++ b/numa.toml @@ -4,9 +4,11 @@ api_port = 5380 # api_bind_addr = "127.0.0.1" # default; set to "0.0.0.0" for LAN dashboard access # [upstream] -# address = "" # auto-detect from system resolver (default) -# address = "9.9.9.9" # or set explicitly -# port = 53 +# address = "" # auto-detect from system resolver (default) +# address = "https://dns.quad9.net/dns-query" # DNS-over-HTTPS (encrypted) +# address = "https://cloudflare-dns.com/dns-query" # Cloudflare DoH +# address = "9.9.9.9" # plain UDP +# port = 53 # only used for plain UDP # timeout_ms = 3000 # [blocking] diff --git a/src/api.rs b/src/api.rs index 8dde3a3..2df9d73 100644 --- a/src/api.rs +++ b/src/api.rs @@ -9,7 +9,7 @@ use axum::{Json, Router}; use serde::{Deserialize, Serialize}; use crate::ctx::ServerCtx; -use crate::forward::forward_query; +use crate::forward::{forward_query, Upstream}; use crate::query_log::QueryLogFilter; use crate::question::QueryType; use crate::stats::QueryPath; @@ -355,9 +355,9 @@ async fn diagnose( } // Check upstream (async, no locks held) - let upstream = *ctx.upstream.lock().unwrap(); + let upstream = ctx.upstream.lock().unwrap().clone(); let (upstream_matched, upstream_detail) = - forward_query_for_diagnose(&domain_lower, upstream, ctx.timeout).await; + forward_query_for_diagnose(&domain_lower, &upstream, ctx.timeout).await; steps.push(DiagnoseStep { source: "upstream".to_string(), matched: upstream_matched, @@ -373,7 +373,7 @@ async fn diagnose( async fn forward_query_for_diagnose( domain: &str, - upstream: std::net::SocketAddr, + upstream: &Upstream, timeout: std::time::Duration, ) -> (bool, String) { use crate::packet::DnsPacket; diff --git a/src/buffer.rs b/src/buffer.rs index 0c358e7..212bf92 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -21,6 +21,13 @@ impl BytePacketBuffer { } } + pub fn from_bytes(data: &[u8]) -> Self { + let mut buf = Self::new(); + let len = data.len().min(BUF_SIZE); + buf.buf[..len].copy_from_slice(&data[..len]); + buf + } + pub fn pos(&self) -> usize { self.pos } diff --git a/src/ctx.rs b/src/ctx.rs index a017eaf..925ab4a 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -12,7 +12,7 @@ use crate::blocklist::BlocklistStore; use crate::buffer::BytePacketBuffer; use crate::cache::DnsCache; use crate::config::ZoneMap; -use crate::forward::forward_query; +use crate::forward::{forward_query, Upstream}; use crate::header::ResultCode; use crate::lan::PeerStore; use crate::override_store::OverrideStore; @@ -35,7 +35,7 @@ pub struct ServerCtx { pub services: Mutex, pub lan_peers: Mutex, pub forwarding_rules: Vec, - pub upstream: Mutex, + pub upstream: Mutex, pub upstream_auto: bool, pub upstream_port: u16, pub lan_ip: Mutex, @@ -143,9 +143,11 @@ pub async fn handle_query( (resp, QueryPath::Cached) } else { let upstream = - crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) - .unwrap_or_else(|| *ctx.upstream.lock().unwrap()); - match forward_query(&query, upstream, ctx.timeout).await { + match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) { + Some(addr) => Upstream::Udp(addr), + None => ctx.upstream.lock().unwrap().clone(), + }; + match forward_query(&query, &upstream, ctx.timeout).await { Ok(resp) => { ctx.cache.lock().unwrap().insert(&qname, qtype, &resp); (resp, QueryPath::Forwarded) diff --git a/src/forward.rs b/src/forward.rs index ff5c14f..b64c204 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -1,3 +1,4 @@ +use std::fmt; use std::net::SocketAddr; use std::time::Duration; @@ -8,7 +9,46 @@ use crate::buffer::BytePacketBuffer; use crate::packet::DnsPacket; use crate::Result; +#[derive(Clone)] +pub enum Upstream { + Udp(SocketAddr), + Doh { + url: String, + client: reqwest::Client, + }, +} + +impl PartialEq for Upstream { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Udp(a), Self::Udp(b)) => a == b, + (Self::Doh { url: a, .. }, Self::Doh { url: b, .. }) => a == b, + _ => false, + } + } +} + +impl fmt::Display for Upstream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Upstream::Udp(addr) => write!(f, "{}", addr), + Upstream::Doh { url, .. } => f.write_str(url), + } + } +} + pub async fn forward_query( + query: &DnsPacket, + upstream: &Upstream, + timeout_duration: Duration, +) -> Result { + match upstream { + Upstream::Udp(addr) => forward_udp(query, *addr, timeout_duration).await, + Upstream::Doh { url, client } => forward_doh(query, url, client, timeout_duration).await, + } +} + +async fn forward_udp( query: &DnsPacket, upstream: SocketAddr, timeout_duration: Duration, @@ -33,3 +73,175 @@ pub async fn forward_query( DnsPacket::from_buffer(&mut recv_buffer) } + +async fn forward_doh( + query: &DnsPacket, + url: &str, + client: &reqwest::Client, + timeout_duration: Duration, +) -> Result { + let mut send_buffer = BytePacketBuffer::new(); + query.write(&mut send_buffer)?; + + let resp = timeout( + timeout_duration, + client + .post(url) + .header("content-type", "application/dns-message") + .header("accept", "application/dns-message") + .body(send_buffer.filled().to_vec()) + .send(), + ) + .await?? + .error_for_status()?; + + let bytes = resp.bytes().await?; + log::debug!("DoH response: {} bytes", bytes.len()); + + let mut recv_buffer = BytePacketBuffer::from_bytes(&bytes); + DnsPacket::from_buffer(&mut recv_buffer) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::future::IntoFuture; + + use crate::header::ResultCode; + use crate::question::{DnsQuestion, QueryType}; + use crate::record::DnsRecord; + + #[test] + fn upstream_display_udp() { + let u = Upstream::Udp("9.9.9.9:53".parse().unwrap()); + assert_eq!(u.to_string(), "9.9.9.9:53"); + } + + #[test] + fn upstream_display_doh() { + let u = Upstream::Doh { + url: "https://dns.quad9.net/dns-query".to_string(), + client: reqwest::Client::new(), + }; + assert_eq!(u.to_string(), "https://dns.quad9.net/dns-query"); + } + + fn make_query() -> DnsPacket { + let mut q = DnsPacket::new(); + q.header.id = 0xABCD; + q.header.recursion_desired = true; + q.questions + .push(DnsQuestion::new("example.com".to_string(), QueryType::A)); + q + } + + fn make_response(query: &DnsPacket) -> DnsPacket { + let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR); + resp.answers.push(DnsRecord::A { + domain: "example.com".to_string(), + addr: "93.184.216.34".parse().unwrap(), + ttl: 300, + }); + resp + } + + fn to_wire(pkt: &DnsPacket) -> Vec { + let mut buf = BytePacketBuffer::new(); + pkt.write(&mut buf).unwrap(); + buf.filled().to_vec() + } + + #[tokio::test] + async fn doh_mock_server_resolves() { + let query = make_query(); + let response_bytes = to_wire(&make_response(&query)); + + let app = axum::Router::new().route( + "/dns-query", + axum::routing::post(move || { + let body = response_bytes.clone(); + async move { + ( + [( + axum::http::header::CONTENT_TYPE, + "application/dns-message", + )], + body, + ) + } + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(axum::serve(listener, app).into_future()); + + let upstream = Upstream::Doh { + url: format!("http://{}/dns-query", addr), + client: reqwest::Client::new(), + }; + + let result = forward_query(&query, &upstream, Duration::from_secs(2)) + .await + .expect("DoH forward should succeed"); + + assert_eq!(result.header.id, 0xABCD); + assert!(result.header.response); + assert_eq!(result.header.rescode, ResultCode::NOERROR); + assert_eq!(result.answers.len(), 1); + match &result.answers[0] { + DnsRecord::A { domain, addr, ttl } => { + assert_eq!(domain, "example.com"); + assert_eq!(*addr, "93.184.216.34".parse::().unwrap()); + assert_eq!(*ttl, 300); + } + other => panic!("expected A record, got {:?}", other), + } + } + + #[tokio::test] + async fn doh_http_error_propagates() { + let app = axum::Router::new().route( + "/dns-query", + axum::routing::post(|| async { + (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "bad") + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(axum::serve(listener, app).into_future()); + + let upstream = Upstream::Doh { + url: format!("http://{}/dns-query", addr), + client: reqwest::Client::new(), + }; + + let result = forward_query(&make_query(), &upstream, Duration::from_secs(2)).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn doh_timeout() { + let app = axum::Router::new().route( + "/dns-query", + axum::routing::post(|| async { + tokio::time::sleep(Duration::from_secs(10)).await; + "never" + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(axum::serve(listener, app).into_future()); + + let upstream = Upstream::Doh { + url: format!("http://{}/dns-query", addr), + client: reqwest::Client::new(), + }; + + let result = + forward_query(&make_query(), &upstream, Duration::from_millis(100)).await; + assert!(result.is_err()); + } +} diff --git a/src/main.rs b/src/main.rs index 60d3a95..d17a821 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ use log::{error, info}; use tokio::net::UdpSocket; use numa::blocklist::{download_blocklists, parse_blocklist, BlocklistStore}; +use numa::forward::Upstream; use numa::buffer::BytePacketBuffer; use numa::cache::DnsCache; use numa::config::{build_zone_map, load_config, ConfigLoad}; @@ -111,13 +112,27 @@ async fn main() -> numa::Result<()> { .default_upstream .or_else(numa::system_dns::detect_dhcp_dns) .unwrap_or_else(|| { - info!("could not detect system DNS, falling back to 9.9.9.9 (Quad9)"); - "9.9.9.9".to_string() + info!("could not detect system DNS, falling back to Quad9 DoH"); + "https://dns.quad9.net/dns-query".to_string() }) } else { config.upstream.address.clone() }; - let upstream: SocketAddr = format!("{}:{}", upstream_addr, config.upstream.port).parse()?; + + let upstream: Upstream = if upstream_addr.starts_with("https://") { + let client = reqwest::Client::builder() + .use_rustls_tls() + .build() + .unwrap_or_default(); + Upstream::Doh { + url: upstream_addr, + client, + } + } else { + let addr: SocketAddr = format!("{}:{}", upstream_addr, config.upstream.port).parse()?; + Upstream::Udp(addr) + }; + let upstream_label = upstream.to_string(); let api_port = config.server.api_port; let mut blocklist = BlocklistStore::new(); @@ -217,7 +232,7 @@ async fn main() -> numa::Result<()> { let val_w = [ config.server.bind_addr.len(), api_url.len(), - upstream.to_string().len(), + upstream_label.len(), config_label.len(), data_label.len(), services_label.len(), @@ -261,7 +276,7 @@ async fn main() -> numa::Result<()> { row("DNS", g, &config.server.bind_addr); row("API", g, &api_url); row("Dashboard", g, &api_url); - row("Upstream", g, &upstream.to_string()); + row("Upstream", g, &upstream_label); row("Zones", g, &format!("{} records", zone_count)); row( "Cache", @@ -298,7 +313,7 @@ async fn main() -> numa::Result<()> { info!( "numa listening on {}, upstream {}, {} zone records, cache max {}, API on port {}", - config.server.bind_addr, upstream, zone_count, config.cache.max_entries, api_port, + config.server.bind_addr, upstream_label, zone_count, config.cache.max_entries, api_port, ); // Download blocklists on startup @@ -413,20 +428,24 @@ async fn network_watch_loop(ctx: Arc) { } } - // Check upstream change every 30s or immediately on LAN IP change - // (heavier — spawns scutil/ipconfig, only when auto-detected) - if ctx.upstream_auto && (changed || tick.is_multiple_of(6)) { + // Re-detect upstream every 30s or on LAN IP change (UDP only — + // DoH upstreams are explicitly configured via URL, not auto-detected) + if ctx.upstream_auto + && matches!(*ctx.upstream.lock().unwrap(), Upstream::Udp(_)) + && (changed || tick.is_multiple_of(6)) + { let dns_info = numa::system_dns::discover_system_dns(); let new_addr = dns_info .default_upstream .or_else(numa::system_dns::detect_dhcp_dns) .unwrap_or_else(|| "9.9.9.9".to_string()); - if let Ok(new_upstream) = + if let Ok(new_sock) = format!("{}:{}", new_addr, ctx.upstream_port).parse::() { + let new_upstream = Upstream::Udp(new_sock); let mut upstream = ctx.upstream.lock().unwrap(); - if new_upstream != *upstream { - info!("upstream changed: {} → {}", *upstream, new_upstream); + if *upstream != new_upstream { + info!("upstream changed: {} → {}", upstream, new_upstream); *upstream = new_upstream; changed = true; }