diff --git a/src/lib.rs b/src/lib.rs index 346c739..0370c37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ pub mod query_log; pub mod question; pub mod record; pub mod recursive; +pub mod serve; pub mod service_store; pub mod setup_phone; pub mod srtt; diff --git a/src/main.rs b/src/main.rs index 0459005..88f2128 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,30 +1,6 @@ -use std::net::SocketAddr; -use std::sync::{Arc, Mutex, RwLock}; -use std::time::Duration; +use numa::system_dns::{install_service, restart_service, service_status, uninstall_service}; -use arc_swap::ArcSwap; -use log::{error, info}; -use tokio::net::UdpSocket; - -use numa::blocklist::{download_blocklists, parse_blocklist, BlocklistStore}; -use numa::buffer::BytePacketBuffer; -use numa::cache::DnsCache; -use numa::config::{build_zone_map, load_config, ConfigLoad}; -use numa::ctx::{handle_query, ServerCtx}; -use numa::forward::{parse_upstream, Upstream, UpstreamPool}; -use numa::override_store::OverrideStore; -use numa::query_log::QueryLog; -use numa::service_store::ServiceStore; -use numa::stats::{ServerStats, Transport}; -use numa::system_dns::{ - discover_system_dns, install_service, restart_service, service_status, uninstall_service, -}; - -const QUAD9_IP: &str = "9.9.9.9"; -const DOH_FALLBACK: &str = "https://9.9.9.9/dns-query"; - -#[tokio::main] -async fn main() -> numa::Result<()> { +fn main() -> numa::Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) .format_timestamp_millis() .init(); @@ -35,7 +11,7 @@ async fn main() -> numa::Result<()> { #[cfg(windows)] "--service" => { // Entry point used by Windows SCM (`sc create … binPath="numa.exe --service"`). - // Hands control to the service dispatcher and blocks until Stop. + // Blocks until SCM sends Stop; never returns normally. numa::windows_service::run_as_service() .map_err(|e| format!("windows service dispatcher failed: {}", e))?; return Ok(()); @@ -63,7 +39,12 @@ async fn main() -> numa::Result<()> { }; } "setup-phone" => { - return numa::setup_phone::run().await.map_err(|e| e.into()); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + return runtime + .block_on(numa::setup_phone::run()) + .map_err(|e| e.into()); } "lan" => { let sub = std::env::args().nth(2).unwrap_or_default(); @@ -126,552 +107,11 @@ async fn main() -> numa::Result<()> { } else { arg1 // treat as config path for backwards compatibility }; - let ConfigLoad { - config, - path: resolved_config_path, - found: config_found, - } = load_config(&config_path)?; - // Discover system DNS in a single pass (upstream + forwarding rules) - let system_dns = discover_system_dns(); - - let root_hints = numa::recursive::parse_root_hints(&config.upstream.root_hints); - - let recursive_pool = || { - let dummy = UpstreamPool::new(vec![Upstream::Udp("0.0.0.0:0".parse().unwrap())], vec![]); - (dummy, "recursive (root hints)".to_string()) - }; - - let (resolved_mode, upstream_auto, pool, upstream_label) = match config.upstream.mode { - numa::config::UpstreamMode::Auto => { - info!("auto mode: probing recursive resolution..."); - if numa::recursive::probe_recursive(&root_hints).await { - info!("recursive probe succeeded — self-sovereign mode"); - let (pool, label) = recursive_pool(); - (numa::config::UpstreamMode::Recursive, false, pool, label) - } else { - log::warn!("recursive probe failed — falling back to Quad9 DoH"); - let client = reqwest::Client::builder() - .use_rustls_tls() - .build() - .unwrap_or_default(); - let url = DOH_FALLBACK.to_string(); - let label = url.clone(); - let pool = UpstreamPool::new(vec![Upstream::Doh { url, client }], vec![]); - (numa::config::UpstreamMode::Forward, false, pool, label) - } - } - numa::config::UpstreamMode::Recursive => { - let (pool, label) = recursive_pool(); - (numa::config::UpstreamMode::Recursive, false, pool, label) - } - numa::config::UpstreamMode::Forward => { - let addrs = if config.upstream.address.is_empty() { - let detected = system_dns - .default_upstream - .or_else(numa::system_dns::detect_dhcp_dns) - .unwrap_or_else(|| { - info!("could not detect system DNS, falling back to Quad9 DoH"); - DOH_FALLBACK.to_string() - }); - vec![detected] - } else { - config.upstream.address.clone() - }; - - let primary: Vec = addrs - .iter() - .map(|s| parse_upstream(s, config.upstream.port)) - .collect::>>()?; - let fallback: Vec = config - .upstream - .fallback - .iter() - .map(|s| parse_upstream(s, config.upstream.port)) - .collect::>>()?; - - let pool = UpstreamPool::new(primary, fallback); - let label = pool.label(); - ( - numa::config::UpstreamMode::Forward, - config.upstream.address.is_empty(), - pool, - label, - ) - } - }; - let api_port = config.server.api_port; - - let mut blocklist = BlocklistStore::new(); - for domain in &config.blocking.allowlist { - blocklist.add_to_allowlist(domain); - } - if !config.blocking.enabled { - blocklist.set_enabled(false); - } - - // Build service store: config services + persisted user services - let mut service_store = ServiceStore::new(); - service_store.insert_from_config("numa", config.server.api_port, Vec::new()); - for svc in &config.services { - service_store.insert_from_config(&svc.name, svc.target_port, svc.routes.clone()); - } - service_store.load_persisted(); - - for fwd in &config.forwarding { - for suffix in &fwd.suffix { - info!("forwarding .{} to {} (config rule)", suffix, fwd.upstream); - } - } - let forwarding_rules = - numa::config::merge_forwarding_rules(&config.forwarding, system_dns.forwarding_rules)?; - - // Resolve data_dir from config, falling back to the platform default. - // Used for TLS CA storage below and stored on ServerCtx for runtime use. - let resolved_data_dir = config - .server - .data_dir - .clone() - .unwrap_or_else(numa::data_dir); - - // Build initial TLS config before ServerCtx (so ArcSwap is ready at construction) - let initial_tls = if config.proxy.enabled && config.proxy.tls_port > 0 { - let service_names = service_store.names(); - match numa::tls::build_tls_config( - &config.proxy.tld, - &service_names, - Vec::new(), - &resolved_data_dir, - ) { - Ok(tls_config) => Some(ArcSwap::from(tls_config)), - Err(e) => { - if let Some(advisory) = numa::tls::try_data_dir_advisory(&e, &resolved_data_dir) { - eprint!("{}", advisory); - } else { - log::warn!("TLS setup failed, HTTPS proxy disabled: {}", e); - } - None - } - } - } else { - None - }; - - let doh_enabled = initial_tls.is_some(); - let health_meta = numa::health::HealthMeta::build( - &resolved_data_dir, - config.dot.enabled, - config.dot.port, - config.mobile.port, - config.dnssec.enabled, - resolved_mode == numa::config::UpstreamMode::Recursive, - config.lan.enabled, - config.blocking.enabled, - doh_enabled, - ); - - let ca_pem = std::fs::read_to_string(resolved_data_dir.join("ca.pem")).ok(); - - let socket = match UdpSocket::bind(&config.server.bind_addr).await { - Ok(s) => s, - Err(e) => { - if let Some(advisory) = - numa::system_dns::try_port53_advisory(&config.server.bind_addr, &e) - { - eprint!("{}", advisory); - std::process::exit(1); - } - return Err(e.into()); - } - }; - - let ctx = Arc::new(ServerCtx { - socket, - zone_map: build_zone_map(&config.zones)?, - cache: RwLock::new(DnsCache::new( - config.cache.max_entries, - config.cache.min_ttl, - config.cache.max_ttl, - )), - refreshing: Mutex::new(std::collections::HashSet::new()), - stats: Mutex::new(ServerStats::new()), - overrides: RwLock::new(OverrideStore::new()), - blocklist: RwLock::new(blocklist), - query_log: Mutex::new(QueryLog::new(1000)), - services: Mutex::new(service_store), - lan_peers: Mutex::new(numa::lan::PeerStore::new(config.lan.peer_timeout_secs)), - forwarding_rules, - upstream_pool: Mutex::new(pool), - upstream_auto, - upstream_port: config.upstream.port, - lan_ip: Mutex::new(numa::lan::detect_lan_ip().unwrap_or(std::net::Ipv4Addr::LOCALHOST)), - timeout: Duration::from_millis(config.upstream.timeout_ms), - hedge_delay: Duration::from_millis(config.upstream.hedge_ms), - proxy_tld_suffix: if config.proxy.tld.is_empty() { - String::new() - } else { - format!(".{}", config.proxy.tld) - }, - proxy_tld: config.proxy.tld.clone(), - lan_enabled: config.lan.enabled, - config_path: resolved_config_path, - config_found, - config_dir: numa::config_dir(), - data_dir: resolved_data_dir, - tls_config: initial_tls, - upstream_mode: resolved_mode, - root_hints, - srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)), - inflight: std::sync::Mutex::new(std::collections::HashMap::new()), - dnssec_enabled: config.dnssec.enabled, - dnssec_strict: config.dnssec.strict, - health_meta, - ca_pem, - mobile_enabled: config.mobile.enabled, - mobile_port: config.mobile.port, - }); - - let zone_count: usize = ctx.zone_map.values().map(|m| m.len()).sum(); - // Build banner rows, then size the box to fit the longest value - let api_url = format!("http://localhost:{}", api_port); - let proxy_label = if config.proxy.enabled { - if config.proxy.tls_port > 0 { - Some(format!( - "http://:{} https://:{}", - config.proxy.port, config.proxy.tls_port - )) - } else { - Some(format!( - "http://*.{} on :{}", - config.proxy.tld, config.proxy.port - )) - } - } else { - None - }; - let config_label = if ctx.config_found { - ctx.config_path.clone() - } else { - format!("{} (defaults)", ctx.config_path) - }; - let data_label = ctx.data_dir.display().to_string(); - let services_label = ctx.config_dir.join("services.json").display().to_string(); - - // label (10) + value + padding (2) = inner width; minimum 40 for the title row - let val_w = [ - config.server.bind_addr.len(), - api_url.len(), - upstream_label.len(), - config_label.len(), - data_label.len(), - services_label.len(), - ] - .into_iter() - .chain(proxy_label.as_ref().map(|s| s.len())) - .max() - .unwrap_or(30); - let w = (val_w + 12).max(42); // 10 label + 2 padding, min 42 for title - - let o = "\x1b[38;2;192;98;58m"; // orange - let g = "\x1b[38;2;107;124;78m"; // green - let d = "\x1b[38;2;163;152;136m"; // dim - let r = "\x1b[0m"; // reset - let b = "\x1b[1;38;2;192;98;58m"; // bold orange - let it = "\x1b[3;38;2;163;152;136m"; // italic dim - - let bar_top = "═".repeat(w); - let bar_mid = "─".repeat(w); - let row = |label: &str, color: &str, value: &str| { - eprintln!( - "{o} ║{r} {color}{:<9}{r} {: 0 && ctx.tls_config.is_some() { - let proxy_ctx = Arc::clone(&ctx); - let tls_port = config.proxy.tls_port; - tokio::spawn(async move { - numa::proxy::start_proxy_tls(proxy_ctx, tls_port, proxy_bind).await; - }); - } - - // Spawn network change watcher (upstream re-detection, LAN IP update, peer flush) - { - let watch_ctx = Arc::clone(&ctx); - tokio::spawn(async move { - network_watch_loop(watch_ctx).await; - }); - } - - // Spawn LAN service discovery - if config.lan.enabled { - let lan_ctx = Arc::clone(&ctx); - let lan_config = config.lan.clone(); - tokio::spawn(async move { - numa::lan::start_lan_discovery(lan_ctx, &lan_config).await; - }); - } - - // Spawn DNS-over-TLS listener (RFC 7858) - if config.dot.enabled { - let dot_ctx = Arc::clone(&ctx); - let dot_config = config.dot.clone(); - tokio::spawn(async move { - numa::dot::start_dot(dot_ctx, &dot_config).await; - }); - } - - // UDP DNS listener - #[allow(clippy::infinite_loop)] - loop { - let mut buffer = BytePacketBuffer::new(); - let (len, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await { - Ok(r) => r, - Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => { - // Windows delivers ICMP port-unreachable as ConnectionReset on UDP sockets - continue; - } - Err(e) => return Err(e.into()), - }; - let ctx = Arc::clone(&ctx); - tokio::spawn(async move { - if let Err(e) = handle_query(buffer, len, src_addr, &ctx, Transport::Udp).await { - error!("{} | HANDLER ERROR | {}", src_addr, e); - } - }); - } -} - -async fn network_watch_loop(ctx: Arc) { - let mut tick: u64 = 0; - - let mut interval = tokio::time::interval(Duration::from_secs(5)); - interval.tick().await; // skip immediate tick - - loop { - interval.tick().await; - tick += 1; - let mut changed = false; - - // Check LAN IP change (every 5s — cheap, one UDP socket call) - if let Some(new_ip) = numa::lan::detect_lan_ip() { - let mut current_ip = ctx.lan_ip.lock().unwrap(); - if new_ip != *current_ip { - info!("LAN IP changed: {} → {}", current_ip, new_ip); - *current_ip = new_ip; - changed = true; - numa::recursive::reset_udp_state(); - } - } - - // Re-detect upstream every 30s or on LAN IP change (auto-detect only) - if ctx.upstream_auto && (changed || tick.is_multiple_of(6)) { - let dns_info = numa::system_dns::discover_system_dns(); - let new_addr = dns_info - .default_upstream - .or_else(numa::system_dns::detect_dhcp_dns) - .unwrap_or_else(|| QUAD9_IP.to_string()); - let mut pool = ctx.upstream_pool.lock().unwrap(); - if pool.maybe_update_primary(&new_addr, ctx.upstream_port) { - info!("upstream changed → {}", pool.label()); - changed = true; - } - } - - // Flush stale LAN peers on any network change - if changed { - ctx.lan_peers.lock().unwrap().clear(); - info!("flushed LAN peers after network change"); - } - - // Re-probe UDP every 5 minutes when disabled - if tick.is_multiple_of(60) { - numa::recursive::probe_udp(&ctx.root_hints).await; - } - } + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()?; + runtime.block_on(numa::serve::run(config_path)) } fn set_lan_enabled(enabled: bool, path: &str) -> numa::Result<()> { @@ -738,71 +178,3 @@ fn print_lan_status(enabled: bool) { eprintln!(" Restart Numa to start mDNS discovery"); } } - -async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) { - let downloaded = download_blocklists(lists).await; - - // Parse outside the lock to avoid blocking DNS queries during parse (~100ms) - let mut all_domains = std::collections::HashSet::new(); - let mut sources = Vec::new(); - for (source, text) in &downloaded { - let domains = parse_blocklist(text); - info!("blocklist: {} domains from {}", domains.len(), source); - all_domains.extend(domains); - sources.push(source.clone()); - } - let total = all_domains.len(); - - // Swap under lock — sub-microsecond - ctx.blocklist - .write() - .unwrap() - .swap_domains(all_domains, sources); - info!( - "blocking enabled: {} unique domains from {} lists", - total, - downloaded.len() - ); -} - -async fn warm_domain(ctx: &ServerCtx, domain: &str) { - for qtype in [ - numa::question::QueryType::A, - numa::question::QueryType::AAAA, - ] { - numa::ctx::refresh_entry(ctx, domain, qtype).await; - } -} - -async fn doh_keepalive_loop(ctx: Arc) { - let mut interval = tokio::time::interval(Duration::from_secs(25)); - interval.tick().await; // skip first immediate tick - loop { - interval.tick().await; - let pool = ctx.upstream_pool.lock().unwrap().clone(); - if let Some(upstream) = pool.preferred() { - numa::forward::keepalive_doh(upstream).await; - } - } -} - -async fn cache_warm_loop(ctx: Arc, domains: Vec) { - tokio::time::sleep(Duration::from_secs(2)).await; - - for domain in &domains { - warm_domain(&ctx, domain).await; - } - info!("cache warm: {} domains resolved at startup", domains.len()); - - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.tick().await; - loop { - interval.tick().await; - for domain in &domains { - let refresh = ctx.cache.read().unwrap().needs_warm(domain); - if refresh { - warm_domain(&ctx, domain).await; - } - } - } -} diff --git a/src/serve.rs b/src/serve.rs new file mode 100644 index 0000000..db0465b --- /dev/null +++ b/src/serve.rs @@ -0,0 +1,646 @@ +//! The main DNS-server runtime. +//! +//! Extracted from `main.rs` so both the interactive CLI entry and the +//! Windows service dispatcher (`windows_service` module) can drive the +//! same startup/serve loop. + +use std::net::SocketAddr; +use std::sync::{Arc, Mutex, RwLock}; +use std::time::Duration; + +use arc_swap::ArcSwap; +use log::{error, info}; +use tokio::net::UdpSocket; + +use crate::blocklist::{download_blocklists, parse_blocklist, BlocklistStore}; +use crate::buffer::BytePacketBuffer; +use crate::cache::DnsCache; +use crate::config::{build_zone_map, load_config, ConfigLoad}; +use crate::ctx::{handle_query, ServerCtx}; +use crate::forward::{parse_upstream, Upstream, UpstreamPool}; +use crate::override_store::OverrideStore; +use crate::query_log::QueryLog; +use crate::service_store::ServiceStore; +use crate::stats::{ServerStats, Transport}; +use crate::system_dns::discover_system_dns; + +const QUAD9_IP: &str = "9.9.9.9"; +const DOH_FALLBACK: &str = "https://9.9.9.9/dns-query"; + +/// Boot the DNS server and run until the UDP listener errors out. +pub async fn run(config_path: String) -> crate::Result<()> { + let ConfigLoad { + config, + path: resolved_config_path, + found: config_found, + } = load_config(&config_path)?; + + // Discover system DNS in a single pass (upstream + forwarding rules) + let system_dns = discover_system_dns(); + + let root_hints = crate::recursive::parse_root_hints(&config.upstream.root_hints); + + let recursive_pool = || { + let dummy = UpstreamPool::new(vec![Upstream::Udp("0.0.0.0:0".parse().unwrap())], vec![]); + (dummy, "recursive (root hints)".to_string()) + }; + + let (resolved_mode, upstream_auto, pool, upstream_label) = match config.upstream.mode { + crate::config::UpstreamMode::Auto => { + info!("auto mode: probing recursive resolution..."); + if crate::recursive::probe_recursive(&root_hints).await { + info!("recursive probe succeeded — self-sovereign mode"); + let (pool, label) = recursive_pool(); + (crate::config::UpstreamMode::Recursive, false, pool, label) + } else { + log::warn!("recursive probe failed — falling back to Quad9 DoH"); + let client = reqwest::Client::builder() + .use_rustls_tls() + .build() + .unwrap_or_default(); + let url = DOH_FALLBACK.to_string(); + let label = url.clone(); + let pool = UpstreamPool::new(vec![Upstream::Doh { url, client }], vec![]); + (crate::config::UpstreamMode::Forward, false, pool, label) + } + } + crate::config::UpstreamMode::Recursive => { + let (pool, label) = recursive_pool(); + (crate::config::UpstreamMode::Recursive, false, pool, label) + } + crate::config::UpstreamMode::Forward => { + let addrs = if config.upstream.address.is_empty() { + let detected = system_dns + .default_upstream + .or_else(crate::system_dns::detect_dhcp_dns) + .unwrap_or_else(|| { + info!("could not detect system DNS, falling back to Quad9 DoH"); + DOH_FALLBACK.to_string() + }); + vec![detected] + } else { + config.upstream.address.clone() + }; + + let primary: Vec = addrs + .iter() + .map(|s| parse_upstream(s, config.upstream.port)) + .collect::>>()?; + let fallback: Vec = config + .upstream + .fallback + .iter() + .map(|s| parse_upstream(s, config.upstream.port)) + .collect::>>()?; + + let pool = UpstreamPool::new(primary, fallback); + let label = pool.label(); + ( + crate::config::UpstreamMode::Forward, + config.upstream.address.is_empty(), + pool, + label, + ) + } + }; + let api_port = config.server.api_port; + + let mut blocklist = BlocklistStore::new(); + for domain in &config.blocking.allowlist { + blocklist.add_to_allowlist(domain); + } + if !config.blocking.enabled { + blocklist.set_enabled(false); + } + + // Build service store: config services + persisted user services + let mut service_store = ServiceStore::new(); + service_store.insert_from_config("numa", config.server.api_port, Vec::new()); + for svc in &config.services { + service_store.insert_from_config(&svc.name, svc.target_port, svc.routes.clone()); + } + service_store.load_persisted(); + + for fwd in &config.forwarding { + for suffix in &fwd.suffix { + info!("forwarding .{} to {} (config rule)", suffix, fwd.upstream); + } + } + let forwarding_rules = + crate::config::merge_forwarding_rules(&config.forwarding, system_dns.forwarding_rules)?; + + // Resolve data_dir from config, falling back to the platform default. + // Used for TLS CA storage below and stored on ServerCtx for runtime use. + let resolved_data_dir = config + .server + .data_dir + .clone() + .unwrap_or_else(crate::data_dir); + + // Build initial TLS config before ServerCtx (so ArcSwap is ready at construction) + let initial_tls = if config.proxy.enabled && config.proxy.tls_port > 0 { + let service_names = service_store.names(); + match crate::tls::build_tls_config( + &config.proxy.tld, + &service_names, + Vec::new(), + &resolved_data_dir, + ) { + Ok(tls_config) => Some(ArcSwap::from(tls_config)), + Err(e) => { + if let Some(advisory) = crate::tls::try_data_dir_advisory(&e, &resolved_data_dir) { + eprint!("{}", advisory); + } else { + log::warn!("TLS setup failed, HTTPS proxy disabled: {}", e); + } + None + } + } + } else { + None + }; + + let doh_enabled = initial_tls.is_some(); + let health_meta = crate::health::HealthMeta::build( + &resolved_data_dir, + config.dot.enabled, + config.dot.port, + config.mobile.port, + config.dnssec.enabled, + resolved_mode == crate::config::UpstreamMode::Recursive, + config.lan.enabled, + config.blocking.enabled, + doh_enabled, + ); + + let ca_pem = std::fs::read_to_string(resolved_data_dir.join("ca.pem")).ok(); + + let socket = match UdpSocket::bind(&config.server.bind_addr).await { + Ok(s) => s, + Err(e) => { + if let Some(advisory) = + crate::system_dns::try_port53_advisory(&config.server.bind_addr, &e) + { + eprint!("{}", advisory); + std::process::exit(1); + } + return Err(e.into()); + } + }; + + let ctx = Arc::new(ServerCtx { + socket, + zone_map: build_zone_map(&config.zones)?, + cache: RwLock::new(DnsCache::new( + config.cache.max_entries, + config.cache.min_ttl, + config.cache.max_ttl, + )), + refreshing: Mutex::new(std::collections::HashSet::new()), + stats: Mutex::new(ServerStats::new()), + overrides: RwLock::new(OverrideStore::new()), + blocklist: RwLock::new(blocklist), + query_log: Mutex::new(QueryLog::new(1000)), + services: Mutex::new(service_store), + lan_peers: Mutex::new(crate::lan::PeerStore::new(config.lan.peer_timeout_secs)), + forwarding_rules, + upstream_pool: Mutex::new(pool), + upstream_auto, + upstream_port: config.upstream.port, + lan_ip: Mutex::new(crate::lan::detect_lan_ip().unwrap_or(std::net::Ipv4Addr::LOCALHOST)), + timeout: Duration::from_millis(config.upstream.timeout_ms), + hedge_delay: Duration::from_millis(config.upstream.hedge_ms), + proxy_tld_suffix: if config.proxy.tld.is_empty() { + String::new() + } else { + format!(".{}", config.proxy.tld) + }, + proxy_tld: config.proxy.tld.clone(), + lan_enabled: config.lan.enabled, + config_path: resolved_config_path, + config_found, + config_dir: crate::config_dir(), + data_dir: resolved_data_dir, + tls_config: initial_tls, + upstream_mode: resolved_mode, + root_hints, + srtt: std::sync::RwLock::new(crate::srtt::SrttCache::new(config.upstream.srtt)), + inflight: std::sync::Mutex::new(std::collections::HashMap::new()), + dnssec_enabled: config.dnssec.enabled, + dnssec_strict: config.dnssec.strict, + health_meta, + ca_pem, + mobile_enabled: config.mobile.enabled, + mobile_port: config.mobile.port, + }); + + let zone_count: usize = ctx.zone_map.values().map(|m| m.len()).sum(); + // Build banner rows, then size the box to fit the longest value + let api_url = format!("http://localhost:{}", api_port); + let proxy_label = if config.proxy.enabled { + if config.proxy.tls_port > 0 { + Some(format!( + "http://:{} https://:{}", + config.proxy.port, config.proxy.tls_port + )) + } else { + Some(format!( + "http://*.{} on :{}", + config.proxy.tld, config.proxy.port + )) + } + } else { + None + }; + let config_label = if ctx.config_found { + ctx.config_path.clone() + } else { + format!("{} (defaults)", ctx.config_path) + }; + let data_label = ctx.data_dir.display().to_string(); + let services_label = ctx.config_dir.join("services.json").display().to_string(); + + // label (10) + value + padding (2) = inner width; minimum 40 for the title row + let val_w = [ + config.server.bind_addr.len(), + api_url.len(), + upstream_label.len(), + config_label.len(), + data_label.len(), + services_label.len(), + ] + .into_iter() + .chain(proxy_label.as_ref().map(|s| s.len())) + .max() + .unwrap_or(30); + let w = (val_w + 12).max(42); // 10 label + 2 padding, min 42 for title + + let o = "\x1b[38;2;192;98;58m"; // orange + let g = "\x1b[38;2;107;124;78m"; // green + let d = "\x1b[38;2;163;152;136m"; // dim + let r = "\x1b[0m"; // reset + let b = "\x1b[1;38;2;192;98;58m"; // bold orange + let it = "\x1b[3;38;2;163;152;136m"; // italic dim + + let bar_top = "═".repeat(w); + let bar_mid = "─".repeat(w); + let row = |label: &str, color: &str, value: &str| { + eprintln!( + "{o} ║{r} {color}{:<9}{r} {: 0 && ctx.tls_config.is_some() { + let proxy_ctx = Arc::clone(&ctx); + let tls_port = config.proxy.tls_port; + tokio::spawn(async move { + crate::proxy::start_proxy_tls(proxy_ctx, tls_port, proxy_bind).await; + }); + } + + // Spawn network change watcher (upstream re-detection, LAN IP update, peer flush) + { + let watch_ctx = Arc::clone(&ctx); + tokio::spawn(async move { + network_watch_loop(watch_ctx).await; + }); + } + + // Spawn LAN service discovery + if config.lan.enabled { + let lan_ctx = Arc::clone(&ctx); + let lan_config = config.lan.clone(); + tokio::spawn(async move { + crate::lan::start_lan_discovery(lan_ctx, &lan_config).await; + }); + } + + // Spawn DNS-over-TLS listener (RFC 7858) + if config.dot.enabled { + let dot_ctx = Arc::clone(&ctx); + let dot_config = config.dot.clone(); + tokio::spawn(async move { + crate::dot::start_dot(dot_ctx, &dot_config).await; + }); + } + + // UDP DNS listener + #[allow(clippy::infinite_loop)] + loop { + let mut buffer = BytePacketBuffer::new(); + let (len, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await { + Ok(r) => r, + Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => { + // Windows delivers ICMP port-unreachable as ConnectionReset on UDP sockets + continue; + } + Err(e) => return Err(e.into()), + }; + let ctx = Arc::clone(&ctx); + tokio::spawn(async move { + if let Err(e) = handle_query(buffer, len, src_addr, &ctx, Transport::Udp).await { + error!("{} | HANDLER ERROR | {}", src_addr, e); + } + }); + } +} + +async fn network_watch_loop(ctx: Arc) { + let mut tick: u64 = 0; + + let mut interval = tokio::time::interval(Duration::from_secs(5)); + interval.tick().await; // skip immediate tick + + loop { + interval.tick().await; + tick += 1; + let mut changed = false; + + // Check LAN IP change (every 5s — cheap, one UDP socket call) + if let Some(new_ip) = crate::lan::detect_lan_ip() { + let mut current_ip = ctx.lan_ip.lock().unwrap(); + if new_ip != *current_ip { + info!("LAN IP changed: {} → {}", current_ip, new_ip); + *current_ip = new_ip; + changed = true; + crate::recursive::reset_udp_state(); + } + } + + // Re-detect upstream every 30s or on LAN IP change (auto-detect only) + if ctx.upstream_auto && (changed || tick.is_multiple_of(6)) { + let dns_info = crate::system_dns::discover_system_dns(); + let new_addr = dns_info + .default_upstream + .or_else(crate::system_dns::detect_dhcp_dns) + .unwrap_or_else(|| QUAD9_IP.to_string()); + let mut pool = ctx.upstream_pool.lock().unwrap(); + if pool.maybe_update_primary(&new_addr, ctx.upstream_port) { + info!("upstream changed → {}", pool.label()); + changed = true; + } + } + + // Flush stale LAN peers on any network change + if changed { + ctx.lan_peers.lock().unwrap().clear(); + info!("flushed LAN peers after network change"); + } + + // Re-probe UDP every 5 minutes when disabled + if tick.is_multiple_of(60) { + crate::recursive::probe_udp(&ctx.root_hints).await; + } + } +} + +async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) { + let downloaded = download_blocklists(lists).await; + + // Parse outside the lock to avoid blocking DNS queries during parse (~100ms) + let mut all_domains = std::collections::HashSet::new(); + let mut sources = Vec::new(); + for (source, text) in &downloaded { + let domains = parse_blocklist(text); + info!("blocklist: {} domains from {}", domains.len(), source); + all_domains.extend(domains); + sources.push(source.clone()); + } + let total = all_domains.len(); + + // Swap under lock — sub-microsecond + ctx.blocklist + .write() + .unwrap() + .swap_domains(all_domains, sources); + info!( + "blocking enabled: {} unique domains from {} lists", + total, + downloaded.len() + ); +} + +async fn warm_domain(ctx: &ServerCtx, domain: &str) { + for qtype in [ + crate::question::QueryType::A, + crate::question::QueryType::AAAA, + ] { + crate::ctx::refresh_entry(ctx, domain, qtype).await; + } +} + +async fn doh_keepalive_loop(ctx: Arc) { + let mut interval = tokio::time::interval(Duration::from_secs(25)); + interval.tick().await; // skip first immediate tick + loop { + interval.tick().await; + let pool = ctx.upstream_pool.lock().unwrap().clone(); + if let Some(upstream) = pool.preferred() { + crate::forward::keepalive_doh(upstream).await; + } + } +} + +async fn cache_warm_loop(ctx: Arc, domains: Vec) { + tokio::time::sleep(Duration::from_secs(2)).await; + + for domain in &domains { + warm_domain(&ctx, domain).await; + } + info!("cache warm: {} domains resolved at startup", domains.len()); + + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.tick().await; + loop { + interval.tick().await; + for domain in &domains { + let refresh = ctx.cache.read().unwrap().needs_warm(domain); + if refresh { + warm_domain(&ctx, domain).await; + } + } + } +} diff --git a/src/system_dns.rs b/src/system_dns.rs index 96ae372..b39f661 100644 --- a/src/system_dns.rs +++ b/src/system_dns.rs @@ -697,7 +697,23 @@ fn install_windows() -> Result<(), String> { } let needs_reboot = disable_dnscache()?; - register_autostart(); + + // Copy the binary to a stable path under ProgramData and register it + // as a real Windows service (SCM-managed, boot-time, auto-restart). + let service_exe = install_service_binary()?; + register_service_scm(&service_exe)?; + + // If no reboot is pending (Dnscache wasn't running, port 53 free), + // start the service immediately. Otherwise it'll launch on next boot. + if !needs_reboot { + match start_service_scm() { + Ok(_) => eprintln!(" Service started."), + Err(e) => eprintln!( + " warning: service registered but could not start now: {}", + e + ), + } + } eprintln!(); if !has_useful_existing { @@ -707,51 +723,160 @@ fn install_windows() -> Result<(), String> { if needs_reboot { eprintln!(" *** Reboot required. Numa will start automatically. ***\n"); } else { - eprintln!(" Numa will start automatically on next boot.\n"); + eprintln!(" Numa is running.\n"); } print_recursive_hint(); Ok(()) } -/// Register numa to auto-start on boot via registry Run key. #[cfg(windows)] -fn register_autostart() { - let exe = std::env::current_exe() - .map(|p| p.to_string_lossy().to_string()) - .unwrap_or_else(|_| "numa".into()); - let _ = std::process::Command::new("reg") - .args([ - "add", - "HKLM\\SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Run", - "/v", - "Numa", - "/t", - "REG_SZ", - "/d", - &exe, - "/f", - ]) - .status(); - eprintln!(" Registered auto-start on boot."); +const WINDOWS_SERVICE_NAME: &str = "Numa"; + +/// Stable install location for the service binary. SCM keeps a handle to +/// this path; the user's Downloads folder (where `current_exe()` points at +/// install time) is not durable. +#[cfg(windows)] +fn windows_service_exe_path() -> std::path::PathBuf { + std::path::PathBuf::from( + std::env::var("PROGRAMDATA").unwrap_or_else(|_| "C:\\ProgramData".into()), + ) + .join("numa") + .join("bin") + .join("numa.exe") } -/// Remove numa auto-start registry key. +/// Copy the currently-running binary to the service install location. SCM +/// keeps a handle to this path, so it must be stable across user sessions. #[cfg(windows)] -fn remove_autostart() { - let _ = std::process::Command::new("reg") +fn install_service_binary() -> Result { + let src = std::env::current_exe().map_err(|e| format!("current_exe(): {}", e))?; + let dst = windows_service_exe_path(); + if let Some(parent) = dst.parent() { + std::fs::create_dir_all(parent) + .map_err(|e| format!("failed to create {}: {}", parent.display(), e))?; + } + // Copy only if source and destination differ; running the binary from + // its install location is a supported (re-install) case. + if src != dst { + std::fs::copy(&src, &dst).map_err(|e| { + format!( + "failed to copy {} -> {}: {}", + src.display(), + dst.display(), + e + ) + })?; + } + Ok(dst) +} + +/// Remove the service binary on uninstall. Ignore failures — the service +/// is already deleted; a leftover file in ProgramData is not a hard error. +#[cfg(windows)] +fn remove_service_binary() { + let _ = std::fs::remove_file(windows_service_exe_path()); +} + +/// Register numa with the Service Control Manager, boot-time auto-start, +/// LocalSystem context, with a failure policy of restart-after-5s. +#[cfg(windows)] +fn register_service_scm(exe: &std::path::Path) -> Result<(), String> { + let bin_path = format!("\"{}\" --service", exe.display()); + + // sc.exe uses a leading space as its `name= value` delimiter; the space + // after `=` is mandatory. + let create = std::process::Command::new("sc") .args([ - "delete", - "HKLM\\SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Run", - "/v", - "Numa", - "/f", + "create", + WINDOWS_SERVICE_NAME, + "binPath=", + &bin_path, + "DisplayName=", + "Numa DNS", + "start=", + "auto", + "obj=", + "LocalSystem", ]) + .output() + .map_err(|e| format!("failed to run sc create: {}", e))?; + if !create.status.success() { + let out = String::from_utf8_lossy(&create.stdout); + // "service already exists" is 1073 — treat as idempotent success. + if !out.contains("1073") { + return Err(format!("sc create failed: {}", out.trim())); + } + } + + let _ = std::process::Command::new("sc") + .args([ + "description", + WINDOWS_SERVICE_NAME, + "Self-sovereign DNS resolver (ad blocking, DoH/DoT, local zones).", + ]) + .status(); + + // Restart on crash: 5s, 5s, 10s; reset failure counter after 60s. + let _ = std::process::Command::new("sc") + .args([ + "failure", + WINDOWS_SERVICE_NAME, + "reset=", + "60", + "actions=", + "restart/5000/restart/5000/restart/10000", + ]) + .status(); + + eprintln!( + " Registered service '{}' (boot-time).", + WINDOWS_SERVICE_NAME + ); + Ok(()) +} + +/// Start the service. Safe to call on a freshly-registered service — SCM +/// will fail with 1056 ("already running") or 1058 ("disabled") and we +/// return the underlying error string rather than masking it. +#[cfg(windows)] +fn start_service_scm() -> Result<(), String> { + let out = std::process::Command::new("sc") + .args(["start", WINDOWS_SERVICE_NAME]) + .output() + .map_err(|e| format!("failed to run sc start: {}", e))?; + if !out.status.success() { + let text = String::from_utf8_lossy(&out.stdout); + if text.contains("1056") { + return Ok(()); // already running + } + return Err(format!("sc start failed: {}", text.trim())); + } + Ok(()) +} + +/// Stop the service. Returns Ok if already stopped — idempotent. +#[cfg(windows)] +fn stop_service_scm() { + let _ = std::process::Command::new("sc") + .args(["stop", WINDOWS_SERVICE_NAME]) + .status(); +} + +/// Remove the service from SCM. Safe if already absent. +#[cfg(windows)] +fn delete_service_scm() { + let _ = std::process::Command::new("sc") + .args(["delete", WINDOWS_SERVICE_NAME]) .status(); } #[cfg(windows)] fn uninstall_windows() -> Result<(), String> { - remove_autostart(); + // Stop + remove the service before touching DNS, so port 53 is released + // cleanly and the failure-restart policy doesn't resurrect it. + stop_service_scm(); + delete_service_scm(); + remove_service_binary(); let path = windows_backup_path(); let json = std::fs::read_to_string(&path) .map_err(|e| format!("no backup found at {}: {}", path.display(), e))?; diff --git a/src/windows_service.rs b/src/windows_service.rs index 8751f23..c51339c 100644 --- a/src/windows_service.rs +++ b/src/windows_service.rs @@ -57,12 +57,50 @@ fn run_service() -> windows_service::Result<()> { process_id: None, })?; - // TODO(windows-service): call numa's async serve loop here once main.rs's - // server body is extracted into `numa::serve(config_path)`. For now the - // service registers, reports Running, and blocks until SCM sends Stop — - // useful for verifying the SCM plumbing end to end with `sc start Numa` - // and `sc stop Numa`. - let _ = shutdown_rx.recv(); + // Spin up a multi-threaded tokio runtime and run the server on it. A + // dedicated thread runs the runtime so this function can return cleanly + // once the SCM tells us to stop — we can't block the dispatcher thread + // forever without preventing graceful shutdown. + let config_path = service_config_path(); + let (runtime_stop_tx, runtime_stop_rx) = mpsc::channel::<()>(); + + let server_thread = std::thread::spawn(move || { + let runtime = match tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + { + Ok(rt) => rt, + Err(e) => { + log::error!("failed to build tokio runtime: {}", e); + let _ = runtime_stop_tx.send(()); + return; + } + }; + + // block_on returns when serve::run's UDP loop errors out OR when the + // runtime is dropped from another thread. Either signals exit. + if let Err(e) = runtime.block_on(crate::serve::run(config_path)) { + log::error!("numa serve exited with error: {}", e); + } + let _ = runtime_stop_tx.send(()); + }); + + // Wait for either SCM stop or server termination. + loop { + if shutdown_rx.try_recv().is_ok() { + break; + } + if runtime_stop_rx.try_recv().is_ok() { + break; + } + std::thread::sleep(Duration::from_millis(200)); + } + + // The server's tokio runtime runs detached inside server_thread. Abandon + // it — the process is about to report Stopped and the SCM will terminate + // us if we linger. Future work: plumb a cancellation signal into + // serve::run() for a clean teardown of listeners and in-flight queries. + drop(server_thread); status_handle.set_service_status(ServiceStatus { service_type: ServiceType::OWN_PROCESS, @@ -83,3 +121,12 @@ fn run_service() -> windows_service::Result<()> { pub fn run_as_service() -> windows_service::Result<()> { service_dispatcher::start(SERVICE_NAME, ffi_service_main) } + +/// Path to the config file used when running under SCM. SCM launches the +/// service with SYSTEM's working directory (usually `C:\Windows\System32`), +/// so a relative `numa.toml` lookup won't find anything meaningful — use an +/// absolute path under `%PROGRAMDATA%` instead. +fn service_config_path() -> String { + let base = std::env::var("PROGRAMDATA").unwrap_or_else(|_| "C:\\ProgramData".into()); + format!("{}\\numa\\numa.toml", base) +}