diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0ad7e45..e116744 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,6 +56,8 @@ jobs: runs-on: windows-latest steps: - uses: actions/checkout@v6 + with: + fetch-depth: 0 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - name: build @@ -69,3 +71,62 @@ jobs: with: name: numa-windows-x86_64 path: target/debug/numa.exe + + integration-linux: + needs: [check] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - name: build + run: cargo build --release + - name: install / verify / re-install / uninstall + run: | + sudo ./target/release/numa install + sleep 2 + curl -sf http://127.0.0.1:5380/health + dig @127.0.0.1 example.com +short +timeout=5 | grep -q '.' + sudo ./target/release/numa install + sleep 2 + curl -sf http://127.0.0.1:5380/health + sudo ./target/release/numa uninstall + sleep 1 + ! curl -sf http://127.0.0.1:5380/health 2>/dev/null + - name: cleanup + if: always() + run: | + sudo ./target/release/numa uninstall 2>/dev/null || true + # systemd-resolved has a ~40s DNS reconfiguration stall after + # restart (systemd issue #22521) that breaks the runner agent's + # connection to GitHub. Bridge it by replacing the stub-resolv + # symlink with a direct upstream — DNS works instantly and the + # runner can phone home for post-job steps. + sudo rm -f /etc/resolv.conf + echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf > /dev/null + getent hosts github.com >/dev/null + + integration-macos: + needs: [check-macos] + runs-on: macos-latest + steps: + - uses: actions/checkout@v6 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - name: build + run: cargo build --release + - name: install / verify / re-install / uninstall + run: | + sudo ./target/release/numa install + sleep 2 + curl -sf http://127.0.0.1:5380/health + dig @127.0.0.1 example.com +short +timeout=5 | grep -q '.' + sudo ./target/release/numa install + sleep 2 + curl -sf http://127.0.0.1:5380/health + sudo ./target/release/numa uninstall + sleep 1 + ! curl -sf http://127.0.0.1:5380/health 2>/dev/null + - name: cleanup + if: always() + run: sudo ./target/release/numa uninstall 2>/dev/null || true diff --git a/Cargo.lock b/Cargo.lock index 9cd1b7d..cf25b3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1359,6 +1359,7 @@ dependencies = [ "toml", "tower", "webpki-roots 1.0.6", + "windows-service", "x509-parser", ] @@ -2583,6 +2584,17 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-service" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d24d6bcc7f734a4091ecf8d7a64c5f7d7066f45585c1861eba06449909609c8a" +dependencies = [ + "bitflags", + "widestring", + "windows-sys 0.52.0", +] + [[package]] name = "windows-strings" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 0b13af2..3b3234f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,9 @@ rustls-pemfile = "2.2.0" qrcode = { version = "0.14", default-features = false, features = ["svg"] } webpki-roots = "1" +[target.'cfg(windows)'.dependencies] +windows-service = "0.7" + [dev-dependencies] criterion = { version = "0.8", features = ["html_reports"] } tower = { version = "0.5", features = ["util"] } diff --git a/site/dashboard.html b/site/dashboard.html index d3b1820..fa2d965 100644 --- a/site/dashboard.html +++ b/site/dashboard.html @@ -1150,9 +1150,12 @@ async function refresh() { document.getElementById('footerSrtt').textContent = stats.srtt ? 'on' : 'off'; document.getElementById('footerSrtt').style.color = stats.srtt ? 'var(--emerald)' : 'var(--text-dim)'; if (!document.getElementById('footerLogs').textContent) { + const isWin = stats.data_dir && stats.data_dir.includes(':\\'); const isMac = stats.data_dir && stats.data_dir.includes('/usr/local/'); - document.getElementById('footerLogs').textContent = isMac - ? '/usr/local/var/log/numa.log' + const logsEl = document.getElementById('footerLogs'); + logsEl.textContent = isWin + ? stats.data_dir + '\\numa.log' + : isMac ? '/usr/local/var/log/numa.log' : 'journalctl -u numa -f'; } diff --git a/src/lib.rs b/src/lib.rs index a9d38fc..a16568b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ pub mod query_log; pub mod question; pub mod record; pub mod recursive; +pub mod serve; pub mod service_store; pub mod setup_phone; pub mod srtt; @@ -28,6 +29,9 @@ pub mod system_dns; pub mod tls; pub mod wire; +#[cfg(windows)] +pub mod windows_service; + #[cfg(test)] pub(crate) mod testutil; @@ -97,14 +101,11 @@ where /// Linux root daemon: /var/lib/numa (FHS) — falls back to /usr/local/var/numa /// if a pre-v0.10.1 install already lives there. /// macOS root daemon: /usr/local/var/numa (Homebrew prefix) -/// Windows: %APPDATA%\numa +/// Windows: %PROGRAMDATA%\numa (same as data_dir — no per-user config on Windows) pub fn config_dir() -> std::path::PathBuf { #[cfg(windows)] { - std::path::PathBuf::from( - std::env::var("APPDATA").unwrap_or_else(|_| "C:\\ProgramData".into()), - ) - .join("numa") + data_dir() } #[cfg(not(windows))] { diff --git a/src/main.rs b/src/main.rs index faf2e22..34bf747 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,36 +1,34 @@ -use std::net::SocketAddr; -use std::sync::{Arc, Mutex, RwLock}; -use std::time::Duration; - -use arc_swap::ArcSwap; -use log::{error, info}; -use tokio::net::UdpSocket; - -use numa::blocklist::{download_blocklists, parse_blocklist, BlocklistStore}; -use numa::buffer::BytePacketBuffer; -use numa::cache::DnsCache; -use numa::config::{build_zone_map, load_config, ConfigLoad}; -use numa::ctx::{handle_query, ServerCtx}; -use numa::forward::{parse_upstream, Upstream, UpstreamPool}; -use numa::override_store::OverrideStore; -use numa::query_log::QueryLog; -use numa::service_store::ServiceStore; -use numa::stats::{ServerStats, Transport}; use numa::system_dns::{ - discover_system_dns, install_service, restart_service, service_status, uninstall_service, + install_service, restart_service, service_status, start_service, stop_service, + uninstall_service, }; -const QUAD9_IP: &str = "9.9.9.9"; -const DOH_FALLBACK: &str = "https://9.9.9.9/dns-query"; +fn main() -> numa::Result<()> { + // Handle CLI subcommands + let arg1 = std::env::args().nth(1).unwrap_or_default(); + + #[cfg(windows)] + if arg1 == "--service" { + // Running under SCM — stderr goes nowhere. Redirect logs to a file. + let log_path = numa::data_dir().join("numa.log"); + let log_file = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&log_path) + .expect("failed to open log file"); + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) + .format_timestamp_millis() + .target(env_logger::Target::Pipe(Box::new(log_file))) + .init(); + numa::windows_service::run_as_service() + .map_err(|e| format!("windows service dispatcher failed: {}", e))?; + return Ok(()); + } -#[tokio::main] -async fn main() -> numa::Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) .format_timestamp_millis() .init(); - // Handle CLI subcommands - let arg1 = std::env::args().nth(1).unwrap_or_default(); match arg1.as_str() { "install" => { eprintln!("\x1b[1;38;2;192;98;58mNuma\x1b[0m — installing\n"); @@ -44,8 +42,8 @@ async fn main() -> numa::Result<()> { let sub = std::env::args().nth(2).unwrap_or_default(); eprintln!("\x1b[1;38;2;192;98;58mNuma\x1b[0m — service management\n"); return match sub.as_str() { - "start" => install_service().map_err(|e| e.into()), - "stop" => uninstall_service().map_err(|e| e.into()), + "start" => start_service().map_err(|e| e.into()), + "stop" => stop_service().map_err(|e| e.into()), "restart" => restart_service().map_err(|e| e.into()), "status" => service_status().map_err(|e| e.into()), _ => { @@ -55,7 +53,12 @@ async fn main() -> numa::Result<()> { }; } "setup-phone" => { - return numa::setup_phone::run().await.map_err(|e| e.into()); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + return runtime + .block_on(numa::setup_phone::run()) + .map_err(|e| e.into()); } "lan" => { let sub = std::env::args().nth(2).unwrap_or_default(); @@ -72,7 +75,7 @@ async fn main() -> numa::Result<()> { }; } "version" | "--version" | "-V" => { - eprintln!("numa {}", numa::version()); + eprintln!("numa {}", env!("CARGO_PKG_VERSION")); return Ok(()); } "help" | "--help" | "-h" => { @@ -118,550 +121,11 @@ async fn main() -> numa::Result<()> { } else { arg1 // treat as config path for backwards compatibility }; - let ConfigLoad { - config, - path: resolved_config_path, - found: config_found, - } = load_config(&config_path)?; - // Discover system DNS in a single pass (upstream + forwarding rules) - let system_dns = discover_system_dns(); - - let root_hints = numa::recursive::parse_root_hints(&config.upstream.root_hints); - - let recursive_pool = || { - let dummy = UpstreamPool::new(vec![Upstream::Udp("0.0.0.0:0".parse().unwrap())], vec![]); - (dummy, "recursive (root hints)".to_string()) - }; - - let (resolved_mode, upstream_auto, pool, upstream_label) = match config.upstream.mode { - numa::config::UpstreamMode::Auto => { - info!("auto mode: probing recursive resolution..."); - if numa::recursive::probe_recursive(&root_hints).await { - info!("recursive probe succeeded — self-sovereign mode"); - let (pool, label) = recursive_pool(); - (numa::config::UpstreamMode::Recursive, false, pool, label) - } else { - log::warn!("recursive probe failed — falling back to Quad9 DoH"); - let client = reqwest::Client::builder() - .use_rustls_tls() - .build() - .unwrap_or_default(); - let url = DOH_FALLBACK.to_string(); - let label = url.clone(); - let pool = UpstreamPool::new(vec![Upstream::Doh { url, client }], vec![]); - (numa::config::UpstreamMode::Forward, false, pool, label) - } - } - numa::config::UpstreamMode::Recursive => { - let (pool, label) = recursive_pool(); - (numa::config::UpstreamMode::Recursive, false, pool, label) - } - numa::config::UpstreamMode::Forward => { - let addrs = if config.upstream.address.is_empty() { - let detected = system_dns - .default_upstream - .or_else(numa::system_dns::detect_dhcp_dns) - .unwrap_or_else(|| { - info!("could not detect system DNS, falling back to Quad9 DoH"); - DOH_FALLBACK.to_string() - }); - vec![detected] - } else { - config.upstream.address.clone() - }; - - let primary: Vec = addrs - .iter() - .map(|s| parse_upstream(s, config.upstream.port)) - .collect::>>()?; - let fallback: Vec = config - .upstream - .fallback - .iter() - .map(|s| parse_upstream(s, config.upstream.port)) - .collect::>>()?; - - let pool = UpstreamPool::new(primary, fallback); - let label = pool.label(); - ( - numa::config::UpstreamMode::Forward, - config.upstream.address.is_empty(), - pool, - label, - ) - } - }; - let api_port = config.server.api_port; - - let mut blocklist = BlocklistStore::new(); - for domain in &config.blocking.allowlist { - blocklist.add_to_allowlist(domain); - } - if !config.blocking.enabled { - blocklist.set_enabled(false); - } - - // Build service store: config services + persisted user services - let mut service_store = ServiceStore::new(); - service_store.insert_from_config("numa", config.server.api_port, Vec::new()); - for svc in &config.services { - service_store.insert_from_config(&svc.name, svc.target_port, svc.routes.clone()); - } - service_store.load_persisted(); - - for fwd in &config.forwarding { - for suffix in &fwd.suffix { - info!("forwarding .{} to {} (config rule)", suffix, fwd.upstream); - } - } - let forwarding_rules = - numa::config::merge_forwarding_rules(&config.forwarding, system_dns.forwarding_rules)?; - - // Resolve data_dir from config, falling back to the platform default. - // Used for TLS CA storage below and stored on ServerCtx for runtime use. - let resolved_data_dir = config - .server - .data_dir - .clone() - .unwrap_or_else(numa::data_dir); - - // Build initial TLS config before ServerCtx (so ArcSwap is ready at construction) - let initial_tls = if config.proxy.enabled && config.proxy.tls_port > 0 { - let service_names = service_store.names(); - match numa::tls::build_tls_config( - &config.proxy.tld, - &service_names, - Vec::new(), - &resolved_data_dir, - ) { - Ok(tls_config) => Some(ArcSwap::from(tls_config)), - Err(e) => { - if let Some(advisory) = numa::tls::try_data_dir_advisory(&e, &resolved_data_dir) { - eprint!("{}", advisory); - } else { - log::warn!("TLS setup failed, HTTPS proxy disabled: {}", e); - } - None - } - } - } else { - None - }; - - let doh_enabled = initial_tls.is_some(); - let health_meta = numa::health::HealthMeta::build( - &resolved_data_dir, - config.dot.enabled, - config.dot.port, - config.mobile.port, - config.dnssec.enabled, - resolved_mode == numa::config::UpstreamMode::Recursive, - config.lan.enabled, - config.blocking.enabled, - doh_enabled, - ); - - let ca_pem = std::fs::read_to_string(resolved_data_dir.join("ca.pem")).ok(); - - let socket = match UdpSocket::bind(&config.server.bind_addr).await { - Ok(s) => s, - Err(e) => { - if let Some(advisory) = - numa::system_dns::try_port53_advisory(&config.server.bind_addr, &e) - { - eprint!("{}", advisory); - std::process::exit(1); - } - return Err(e.into()); - } - }; - - let ctx = Arc::new(ServerCtx { - socket, - zone_map: build_zone_map(&config.zones)?, - cache: RwLock::new(DnsCache::new( - config.cache.max_entries, - config.cache.min_ttl, - config.cache.max_ttl, - )), - refreshing: Mutex::new(std::collections::HashSet::new()), - stats: Mutex::new(ServerStats::new()), - overrides: RwLock::new(OverrideStore::new()), - blocklist: RwLock::new(blocklist), - query_log: Mutex::new(QueryLog::new(1000)), - services: Mutex::new(service_store), - lan_peers: Mutex::new(numa::lan::PeerStore::new(config.lan.peer_timeout_secs)), - forwarding_rules, - upstream_pool: Mutex::new(pool), - upstream_auto, - upstream_port: config.upstream.port, - lan_ip: Mutex::new(numa::lan::detect_lan_ip().unwrap_or(std::net::Ipv4Addr::LOCALHOST)), - timeout: Duration::from_millis(config.upstream.timeout_ms), - hedge_delay: Duration::from_millis(config.upstream.hedge_ms), - proxy_tld_suffix: if config.proxy.tld.is_empty() { - String::new() - } else { - format!(".{}", config.proxy.tld) - }, - proxy_tld: config.proxy.tld.clone(), - lan_enabled: config.lan.enabled, - config_path: resolved_config_path, - config_found, - config_dir: numa::config_dir(), - data_dir: resolved_data_dir, - tls_config: initial_tls, - upstream_mode: resolved_mode, - root_hints, - srtt: std::sync::RwLock::new(numa::srtt::SrttCache::new(config.upstream.srtt)), - inflight: std::sync::Mutex::new(std::collections::HashMap::new()), - dnssec_enabled: config.dnssec.enabled, - dnssec_strict: config.dnssec.strict, - health_meta, - ca_pem, - mobile_enabled: config.mobile.enabled, - mobile_port: config.mobile.port, - }); - - let zone_count: usize = ctx.zone_map.values().map(|m| m.len()).sum(); - // Build banner rows, then size the box to fit the longest value - let api_url = format!("http://localhost:{}", api_port); - let proxy_label = if config.proxy.enabled { - if config.proxy.tls_port > 0 { - Some(format!( - "http://:{} https://:{}", - config.proxy.port, config.proxy.tls_port - )) - } else { - Some(format!( - "http://*.{} on :{}", - config.proxy.tld, config.proxy.port - )) - } - } else { - None - }; - let config_label = if ctx.config_found { - ctx.config_path.clone() - } else { - format!("{} (defaults)", ctx.config_path) - }; - let data_label = ctx.data_dir.display().to_string(); - let services_label = ctx.config_dir.join("services.json").display().to_string(); - - // label (10) + value + padding (2) = inner width; minimum 40 for the title row - let val_w = [ - config.server.bind_addr.len(), - api_url.len(), - upstream_label.len(), - config_label.len(), - data_label.len(), - services_label.len(), - ] - .into_iter() - .chain(proxy_label.as_ref().map(|s| s.len())) - .max() - .unwrap_or(30); - let w = (val_w + 12).max(42); // 10 label + 2 padding, min 42 for title - - let o = "\x1b[38;2;192;98;58m"; // orange - let g = "\x1b[38;2;107;124;78m"; // green - let d = "\x1b[38;2;163;152;136m"; // dim - let r = "\x1b[0m"; // reset - let b = "\x1b[1;38;2;192;98;58m"; // bold orange - let it = "\x1b[3;38;2;163;152;136m"; // italic dim - - let bar_top = "═".repeat(w); - let bar_mid = "─".repeat(w); - let row = |label: &str, color: &str, value: &str| { - eprintln!( - "{o} ║{r} {color}{:<9}{r} {: 0 && ctx.tls_config.is_some() { - let proxy_ctx = Arc::clone(&ctx); - let tls_port = config.proxy.tls_port; - tokio::spawn(async move { - numa::proxy::start_proxy_tls(proxy_ctx, tls_port, proxy_bind).await; - }); - } - - // Spawn network change watcher (upstream re-detection, LAN IP update, peer flush) - { - let watch_ctx = Arc::clone(&ctx); - tokio::spawn(async move { - network_watch_loop(watch_ctx).await; - }); - } - - // Spawn LAN service discovery - if config.lan.enabled { - let lan_ctx = Arc::clone(&ctx); - let lan_config = config.lan.clone(); - tokio::spawn(async move { - numa::lan::start_lan_discovery(lan_ctx, &lan_config).await; - }); - } - - // Spawn DNS-over-TLS listener (RFC 7858) - if config.dot.enabled { - let dot_ctx = Arc::clone(&ctx); - let dot_config = config.dot.clone(); - tokio::spawn(async move { - numa::dot::start_dot(dot_ctx, &dot_config).await; - }); - } - - // UDP DNS listener - #[allow(clippy::infinite_loop)] - loop { - let mut buffer = BytePacketBuffer::new(); - let (len, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await { - Ok(r) => r, - Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => { - // Windows delivers ICMP port-unreachable as ConnectionReset on UDP sockets - continue; - } - Err(e) => return Err(e.into()), - }; - let ctx = Arc::clone(&ctx); - tokio::spawn(async move { - if let Err(e) = handle_query(buffer, len, src_addr, &ctx, Transport::Udp).await { - error!("{} | HANDLER ERROR | {}", src_addr, e); - } - }); - } -} - -async fn network_watch_loop(ctx: Arc) { - let mut tick: u64 = 0; - - let mut interval = tokio::time::interval(Duration::from_secs(5)); - interval.tick().await; // skip immediate tick - - loop { - interval.tick().await; - tick += 1; - let mut changed = false; - - // Check LAN IP change (every 5s — cheap, one UDP socket call) - if let Some(new_ip) = numa::lan::detect_lan_ip() { - let mut current_ip = ctx.lan_ip.lock().unwrap(); - if new_ip != *current_ip { - info!("LAN IP changed: {} → {}", current_ip, new_ip); - *current_ip = new_ip; - changed = true; - numa::recursive::reset_udp_state(); - } - } - - // Re-detect upstream every 30s or on LAN IP change (auto-detect only) - if ctx.upstream_auto && (changed || tick.is_multiple_of(6)) { - let dns_info = numa::system_dns::discover_system_dns(); - let new_addr = dns_info - .default_upstream - .or_else(numa::system_dns::detect_dhcp_dns) - .unwrap_or_else(|| QUAD9_IP.to_string()); - let mut pool = ctx.upstream_pool.lock().unwrap(); - if pool.maybe_update_primary(&new_addr, ctx.upstream_port) { - info!("upstream changed → {}", pool.label()); - changed = true; - } - } - - // Flush stale LAN peers on any network change - if changed { - ctx.lan_peers.lock().unwrap().clear(); - info!("flushed LAN peers after network change"); - } - - // Re-probe UDP every 5 minutes when disabled - if tick.is_multiple_of(60) { - numa::recursive::probe_udp(&ctx.root_hints).await; - } - } + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()?; + runtime.block_on(numa::serve::run(config_path)) } fn set_lan_enabled(enabled: bool, path: &str) -> numa::Result<()> { @@ -728,71 +192,3 @@ fn print_lan_status(enabled: bool) { eprintln!(" Restart Numa to start mDNS discovery"); } } - -async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) { - let downloaded = download_blocklists(lists).await; - - // Parse outside the lock to avoid blocking DNS queries during parse (~100ms) - let mut all_domains = std::collections::HashSet::new(); - let mut sources = Vec::new(); - for (source, text) in &downloaded { - let domains = parse_blocklist(text); - info!("blocklist: {} domains from {}", domains.len(), source); - all_domains.extend(domains); - sources.push(source.clone()); - } - let total = all_domains.len(); - - // Swap under lock — sub-microsecond - ctx.blocklist - .write() - .unwrap() - .swap_domains(all_domains, sources); - info!( - "blocking enabled: {} unique domains from {} lists", - total, - downloaded.len() - ); -} - -async fn warm_domain(ctx: &ServerCtx, domain: &str) { - for qtype in [ - numa::question::QueryType::A, - numa::question::QueryType::AAAA, - ] { - numa::ctx::refresh_entry(ctx, domain, qtype).await; - } -} - -async fn doh_keepalive_loop(ctx: Arc) { - let mut interval = tokio::time::interval(Duration::from_secs(25)); - interval.tick().await; // skip first immediate tick - loop { - interval.tick().await; - let pool = ctx.upstream_pool.lock().unwrap().clone(); - if let Some(upstream) = pool.preferred() { - numa::forward::keepalive_doh(upstream).await; - } - } -} - -async fn cache_warm_loop(ctx: Arc, domains: Vec) { - tokio::time::sleep(Duration::from_secs(2)).await; - - for domain in &domains { - warm_domain(&ctx, domain).await; - } - info!("cache warm: {} domains resolved at startup", domains.len()); - - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.tick().await; - loop { - interval.tick().await; - for domain in &domains { - let refresh = ctx.cache.read().unwrap().needs_warm(domain); - if refresh { - warm_domain(&ctx, domain).await; - } - } - } -} diff --git a/src/serve.rs b/src/serve.rs new file mode 100644 index 0000000..db0465b --- /dev/null +++ b/src/serve.rs @@ -0,0 +1,646 @@ +//! The main DNS-server runtime. +//! +//! Extracted from `main.rs` so both the interactive CLI entry and the +//! Windows service dispatcher (`windows_service` module) can drive the +//! same startup/serve loop. + +use std::net::SocketAddr; +use std::sync::{Arc, Mutex, RwLock}; +use std::time::Duration; + +use arc_swap::ArcSwap; +use log::{error, info}; +use tokio::net::UdpSocket; + +use crate::blocklist::{download_blocklists, parse_blocklist, BlocklistStore}; +use crate::buffer::BytePacketBuffer; +use crate::cache::DnsCache; +use crate::config::{build_zone_map, load_config, ConfigLoad}; +use crate::ctx::{handle_query, ServerCtx}; +use crate::forward::{parse_upstream, Upstream, UpstreamPool}; +use crate::override_store::OverrideStore; +use crate::query_log::QueryLog; +use crate::service_store::ServiceStore; +use crate::stats::{ServerStats, Transport}; +use crate::system_dns::discover_system_dns; + +const QUAD9_IP: &str = "9.9.9.9"; +const DOH_FALLBACK: &str = "https://9.9.9.9/dns-query"; + +/// Boot the DNS server and run until the UDP listener errors out. +pub async fn run(config_path: String) -> crate::Result<()> { + let ConfigLoad { + config, + path: resolved_config_path, + found: config_found, + } = load_config(&config_path)?; + + // Discover system DNS in a single pass (upstream + forwarding rules) + let system_dns = discover_system_dns(); + + let root_hints = crate::recursive::parse_root_hints(&config.upstream.root_hints); + + let recursive_pool = || { + let dummy = UpstreamPool::new(vec![Upstream::Udp("0.0.0.0:0".parse().unwrap())], vec![]); + (dummy, "recursive (root hints)".to_string()) + }; + + let (resolved_mode, upstream_auto, pool, upstream_label) = match config.upstream.mode { + crate::config::UpstreamMode::Auto => { + info!("auto mode: probing recursive resolution..."); + if crate::recursive::probe_recursive(&root_hints).await { + info!("recursive probe succeeded — self-sovereign mode"); + let (pool, label) = recursive_pool(); + (crate::config::UpstreamMode::Recursive, false, pool, label) + } else { + log::warn!("recursive probe failed — falling back to Quad9 DoH"); + let client = reqwest::Client::builder() + .use_rustls_tls() + .build() + .unwrap_or_default(); + let url = DOH_FALLBACK.to_string(); + let label = url.clone(); + let pool = UpstreamPool::new(vec![Upstream::Doh { url, client }], vec![]); + (crate::config::UpstreamMode::Forward, false, pool, label) + } + } + crate::config::UpstreamMode::Recursive => { + let (pool, label) = recursive_pool(); + (crate::config::UpstreamMode::Recursive, false, pool, label) + } + crate::config::UpstreamMode::Forward => { + let addrs = if config.upstream.address.is_empty() { + let detected = system_dns + .default_upstream + .or_else(crate::system_dns::detect_dhcp_dns) + .unwrap_or_else(|| { + info!("could not detect system DNS, falling back to Quad9 DoH"); + DOH_FALLBACK.to_string() + }); + vec![detected] + } else { + config.upstream.address.clone() + }; + + let primary: Vec = addrs + .iter() + .map(|s| parse_upstream(s, config.upstream.port)) + .collect::>>()?; + let fallback: Vec = config + .upstream + .fallback + .iter() + .map(|s| parse_upstream(s, config.upstream.port)) + .collect::>>()?; + + let pool = UpstreamPool::new(primary, fallback); + let label = pool.label(); + ( + crate::config::UpstreamMode::Forward, + config.upstream.address.is_empty(), + pool, + label, + ) + } + }; + let api_port = config.server.api_port; + + let mut blocklist = BlocklistStore::new(); + for domain in &config.blocking.allowlist { + blocklist.add_to_allowlist(domain); + } + if !config.blocking.enabled { + blocklist.set_enabled(false); + } + + // Build service store: config services + persisted user services + let mut service_store = ServiceStore::new(); + service_store.insert_from_config("numa", config.server.api_port, Vec::new()); + for svc in &config.services { + service_store.insert_from_config(&svc.name, svc.target_port, svc.routes.clone()); + } + service_store.load_persisted(); + + for fwd in &config.forwarding { + for suffix in &fwd.suffix { + info!("forwarding .{} to {} (config rule)", suffix, fwd.upstream); + } + } + let forwarding_rules = + crate::config::merge_forwarding_rules(&config.forwarding, system_dns.forwarding_rules)?; + + // Resolve data_dir from config, falling back to the platform default. + // Used for TLS CA storage below and stored on ServerCtx for runtime use. + let resolved_data_dir = config + .server + .data_dir + .clone() + .unwrap_or_else(crate::data_dir); + + // Build initial TLS config before ServerCtx (so ArcSwap is ready at construction) + let initial_tls = if config.proxy.enabled && config.proxy.tls_port > 0 { + let service_names = service_store.names(); + match crate::tls::build_tls_config( + &config.proxy.tld, + &service_names, + Vec::new(), + &resolved_data_dir, + ) { + Ok(tls_config) => Some(ArcSwap::from(tls_config)), + Err(e) => { + if let Some(advisory) = crate::tls::try_data_dir_advisory(&e, &resolved_data_dir) { + eprint!("{}", advisory); + } else { + log::warn!("TLS setup failed, HTTPS proxy disabled: {}", e); + } + None + } + } + } else { + None + }; + + let doh_enabled = initial_tls.is_some(); + let health_meta = crate::health::HealthMeta::build( + &resolved_data_dir, + config.dot.enabled, + config.dot.port, + config.mobile.port, + config.dnssec.enabled, + resolved_mode == crate::config::UpstreamMode::Recursive, + config.lan.enabled, + config.blocking.enabled, + doh_enabled, + ); + + let ca_pem = std::fs::read_to_string(resolved_data_dir.join("ca.pem")).ok(); + + let socket = match UdpSocket::bind(&config.server.bind_addr).await { + Ok(s) => s, + Err(e) => { + if let Some(advisory) = + crate::system_dns::try_port53_advisory(&config.server.bind_addr, &e) + { + eprint!("{}", advisory); + std::process::exit(1); + } + return Err(e.into()); + } + }; + + let ctx = Arc::new(ServerCtx { + socket, + zone_map: build_zone_map(&config.zones)?, + cache: RwLock::new(DnsCache::new( + config.cache.max_entries, + config.cache.min_ttl, + config.cache.max_ttl, + )), + refreshing: Mutex::new(std::collections::HashSet::new()), + stats: Mutex::new(ServerStats::new()), + overrides: RwLock::new(OverrideStore::new()), + blocklist: RwLock::new(blocklist), + query_log: Mutex::new(QueryLog::new(1000)), + services: Mutex::new(service_store), + lan_peers: Mutex::new(crate::lan::PeerStore::new(config.lan.peer_timeout_secs)), + forwarding_rules, + upstream_pool: Mutex::new(pool), + upstream_auto, + upstream_port: config.upstream.port, + lan_ip: Mutex::new(crate::lan::detect_lan_ip().unwrap_or(std::net::Ipv4Addr::LOCALHOST)), + timeout: Duration::from_millis(config.upstream.timeout_ms), + hedge_delay: Duration::from_millis(config.upstream.hedge_ms), + proxy_tld_suffix: if config.proxy.tld.is_empty() { + String::new() + } else { + format!(".{}", config.proxy.tld) + }, + proxy_tld: config.proxy.tld.clone(), + lan_enabled: config.lan.enabled, + config_path: resolved_config_path, + config_found, + config_dir: crate::config_dir(), + data_dir: resolved_data_dir, + tls_config: initial_tls, + upstream_mode: resolved_mode, + root_hints, + srtt: std::sync::RwLock::new(crate::srtt::SrttCache::new(config.upstream.srtt)), + inflight: std::sync::Mutex::new(std::collections::HashMap::new()), + dnssec_enabled: config.dnssec.enabled, + dnssec_strict: config.dnssec.strict, + health_meta, + ca_pem, + mobile_enabled: config.mobile.enabled, + mobile_port: config.mobile.port, + }); + + let zone_count: usize = ctx.zone_map.values().map(|m| m.len()).sum(); + // Build banner rows, then size the box to fit the longest value + let api_url = format!("http://localhost:{}", api_port); + let proxy_label = if config.proxy.enabled { + if config.proxy.tls_port > 0 { + Some(format!( + "http://:{} https://:{}", + config.proxy.port, config.proxy.tls_port + )) + } else { + Some(format!( + "http://*.{} on :{}", + config.proxy.tld, config.proxy.port + )) + } + } else { + None + }; + let config_label = if ctx.config_found { + ctx.config_path.clone() + } else { + format!("{} (defaults)", ctx.config_path) + }; + let data_label = ctx.data_dir.display().to_string(); + let services_label = ctx.config_dir.join("services.json").display().to_string(); + + // label (10) + value + padding (2) = inner width; minimum 40 for the title row + let val_w = [ + config.server.bind_addr.len(), + api_url.len(), + upstream_label.len(), + config_label.len(), + data_label.len(), + services_label.len(), + ] + .into_iter() + .chain(proxy_label.as_ref().map(|s| s.len())) + .max() + .unwrap_or(30); + let w = (val_w + 12).max(42); // 10 label + 2 padding, min 42 for title + + let o = "\x1b[38;2;192;98;58m"; // orange + let g = "\x1b[38;2;107;124;78m"; // green + let d = "\x1b[38;2;163;152;136m"; // dim + let r = "\x1b[0m"; // reset + let b = "\x1b[1;38;2;192;98;58m"; // bold orange + let it = "\x1b[3;38;2;163;152;136m"; // italic dim + + let bar_top = "═".repeat(w); + let bar_mid = "─".repeat(w); + let row = |label: &str, color: &str, value: &str| { + eprintln!( + "{o} ║{r} {color}{:<9}{r} {: 0 && ctx.tls_config.is_some() { + let proxy_ctx = Arc::clone(&ctx); + let tls_port = config.proxy.tls_port; + tokio::spawn(async move { + crate::proxy::start_proxy_tls(proxy_ctx, tls_port, proxy_bind).await; + }); + } + + // Spawn network change watcher (upstream re-detection, LAN IP update, peer flush) + { + let watch_ctx = Arc::clone(&ctx); + tokio::spawn(async move { + network_watch_loop(watch_ctx).await; + }); + } + + // Spawn LAN service discovery + if config.lan.enabled { + let lan_ctx = Arc::clone(&ctx); + let lan_config = config.lan.clone(); + tokio::spawn(async move { + crate::lan::start_lan_discovery(lan_ctx, &lan_config).await; + }); + } + + // Spawn DNS-over-TLS listener (RFC 7858) + if config.dot.enabled { + let dot_ctx = Arc::clone(&ctx); + let dot_config = config.dot.clone(); + tokio::spawn(async move { + crate::dot::start_dot(dot_ctx, &dot_config).await; + }); + } + + // UDP DNS listener + #[allow(clippy::infinite_loop)] + loop { + let mut buffer = BytePacketBuffer::new(); + let (len, src_addr) = match ctx.socket.recv_from(&mut buffer.buf).await { + Ok(r) => r, + Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => { + // Windows delivers ICMP port-unreachable as ConnectionReset on UDP sockets + continue; + } + Err(e) => return Err(e.into()), + }; + let ctx = Arc::clone(&ctx); + tokio::spawn(async move { + if let Err(e) = handle_query(buffer, len, src_addr, &ctx, Transport::Udp).await { + error!("{} | HANDLER ERROR | {}", src_addr, e); + } + }); + } +} + +async fn network_watch_loop(ctx: Arc) { + let mut tick: u64 = 0; + + let mut interval = tokio::time::interval(Duration::from_secs(5)); + interval.tick().await; // skip immediate tick + + loop { + interval.tick().await; + tick += 1; + let mut changed = false; + + // Check LAN IP change (every 5s — cheap, one UDP socket call) + if let Some(new_ip) = crate::lan::detect_lan_ip() { + let mut current_ip = ctx.lan_ip.lock().unwrap(); + if new_ip != *current_ip { + info!("LAN IP changed: {} → {}", current_ip, new_ip); + *current_ip = new_ip; + changed = true; + crate::recursive::reset_udp_state(); + } + } + + // Re-detect upstream every 30s or on LAN IP change (auto-detect only) + if ctx.upstream_auto && (changed || tick.is_multiple_of(6)) { + let dns_info = crate::system_dns::discover_system_dns(); + let new_addr = dns_info + .default_upstream + .or_else(crate::system_dns::detect_dhcp_dns) + .unwrap_or_else(|| QUAD9_IP.to_string()); + let mut pool = ctx.upstream_pool.lock().unwrap(); + if pool.maybe_update_primary(&new_addr, ctx.upstream_port) { + info!("upstream changed → {}", pool.label()); + changed = true; + } + } + + // Flush stale LAN peers on any network change + if changed { + ctx.lan_peers.lock().unwrap().clear(); + info!("flushed LAN peers after network change"); + } + + // Re-probe UDP every 5 minutes when disabled + if tick.is_multiple_of(60) { + crate::recursive::probe_udp(&ctx.root_hints).await; + } + } +} + +async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) { + let downloaded = download_blocklists(lists).await; + + // Parse outside the lock to avoid blocking DNS queries during parse (~100ms) + let mut all_domains = std::collections::HashSet::new(); + let mut sources = Vec::new(); + for (source, text) in &downloaded { + let domains = parse_blocklist(text); + info!("blocklist: {} domains from {}", domains.len(), source); + all_domains.extend(domains); + sources.push(source.clone()); + } + let total = all_domains.len(); + + // Swap under lock — sub-microsecond + ctx.blocklist + .write() + .unwrap() + .swap_domains(all_domains, sources); + info!( + "blocking enabled: {} unique domains from {} lists", + total, + downloaded.len() + ); +} + +async fn warm_domain(ctx: &ServerCtx, domain: &str) { + for qtype in [ + crate::question::QueryType::A, + crate::question::QueryType::AAAA, + ] { + crate::ctx::refresh_entry(ctx, domain, qtype).await; + } +} + +async fn doh_keepalive_loop(ctx: Arc) { + let mut interval = tokio::time::interval(Duration::from_secs(25)); + interval.tick().await; // skip first immediate tick + loop { + interval.tick().await; + let pool = ctx.upstream_pool.lock().unwrap().clone(); + if let Some(upstream) = pool.preferred() { + crate::forward::keepalive_doh(upstream).await; + } + } +} + +async fn cache_warm_loop(ctx: Arc, domains: Vec) { + tokio::time::sleep(Duration::from_secs(2)).await; + + for domain in &domains { + warm_domain(&ctx, domain).await; + } + info!("cache warm: {} domains resolved at startup", domains.len()); + + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.tick().await; + loop { + interval.tick().await; + for domain in &domains { + let refresh = ctx.cache.read().unwrap().needs_warm(domain); + if refresh { + warm_domain(&ctx, domain).await; + } + } + } +} diff --git a/src/system_dns.rs b/src/system_dns.rs index 96ae372..941c053 100644 --- a/src/system_dns.rs +++ b/src/system_dns.rs @@ -211,7 +211,7 @@ fn discover_macos() -> SystemDnsInfo { } // Sort longest suffix first for most-specific matching - rules.sort_by(|a, b| b.suffix.len().cmp(&a.suffix.len())); + rules.sort_by_key(|r| std::cmp::Reverse(r.suffix.len())); for rule in &rules { info!( @@ -572,7 +572,7 @@ fn windows_backup_path() -> std::path::PathBuf { #[cfg(windows)] fn disable_dnscache() -> Result { - // Check if Dnscache is running (it holds port 53 at kernel level) + // Check if Dnscache is running (it can hold port 53) let output = std::process::Command::new("sc") .args(["query", "Dnscache"]) .output() @@ -603,8 +603,16 @@ fn disable_dnscache() -> Result { return Err("failed to disable Dnscache via registry (run as Administrator?)".into()); } - eprintln!(" Dnscache disabled. A reboot is required to free port 53."); - Ok(true) + // Dnscache is disabled for next boot. Check whether port 53 is + // actually blocked right now — on many Windows configurations + // Dnscache doesn't bind port 53 even while running. + let port_blocked = std::net::UdpSocket::bind("127.0.0.1:53").is_err(); + if port_blocked { + eprintln!(" Dnscache disabled. A reboot is required to free port 53."); + } else { + eprintln!(" Dnscache disabled. Port 53 is free."); + } + Ok(port_blocked) } #[cfg(windows)] @@ -671,6 +679,83 @@ fn install_windows() -> Result<(), String> { std::fs::write(&path, json).map_err(|e| format!("failed to write backup: {}", e))?; } + // On re-install, stop the running service first so the binary can be + // overwritten and port 53 is released for the Dnscache probe. + if is_service_registered() { + eprintln!(" Stopping existing service..."); + stop_service_scm(); + } + + let needs_reboot = disable_dnscache()?; + + // Copy the binary to a stable path under ProgramData and register it + // as a real Windows service (SCM-managed, boot-time, auto-restart). + let service_exe = install_service_binary()?; + register_service_scm(&service_exe)?; + + if needs_reboot { + // Dnscache still holds port 53 until reboot. Do NOT redirect DNS + // yet — nothing is listening on 127.0.0.1:53, so redirecting now + // would kill DNS. The service will call redirect_dns_to_localhost() + // on its first startup after reboot. + } else { + redirect_dns_with_interfaces(&interfaces)?; + + match start_service_scm() { + Ok(_) => eprintln!(" Service started."), + Err(e) => eprintln!( + " warning: service registered but could not start now: {}", + e + ), + } + } + + eprintln!(); + if !has_useful_existing { + eprintln!(" Original DNS saved to {}", path.display()); + } + eprintln!(" Run 'numa uninstall' to restore.\n"); + if needs_reboot { + eprintln!(" *** Reboot required. Numa will start automatically. ***\n"); + } else { + eprintln!(" Numa is running.\n"); + } + print_recursive_hint(); + Ok(()) +} + +/// Stable install location for the service binary. SCM keeps a handle to +/// this path; the user's Downloads folder (where `current_exe()` points at +/// install time) is not durable. +#[cfg(windows)] +fn windows_service_exe_path() -> std::path::PathBuf { + crate::data_dir().join("bin").join("numa.exe") +} + +/// Run `sc.exe` with the given args and return its merged stdout/stderr on +/// failure. `sc` emits errors on stdout (not stderr) on Windows, so the +/// caller reads stdout to format a useful error. +#[cfg(windows)] +fn run_sc(args: &[&str]) -> Result { + let out = std::process::Command::new("sc") + .args(args) + .output() + .map_err(|e| format!("failed to run sc {}: {}", args.first().unwrap_or(&""), e))?; + Ok(out) +} + +/// Point all active network interfaces at 127.0.0.1 so Numa handles DNS. +/// Called from the service on first boot after a reboot that freed Dnscache. +#[cfg(windows)] +pub fn redirect_dns_to_localhost() -> Result<(), String> { + let interfaces = get_windows_interfaces()?; + redirect_dns_with_interfaces(&interfaces) +} + +#[cfg(windows)] +fn redirect_dns_with_interfaces( + interfaces: &std::collections::HashMap, +) -> Result<(), String> { for name in interfaces.keys() { let status = std::process::Command::new("netsh") .args([ @@ -695,63 +780,184 @@ fn install_windows() -> Result<(), String> { ); } } - - let needs_reboot = disable_dnscache()?; - register_autostart(); - - eprintln!(); - if !has_useful_existing { - eprintln!(" Original DNS saved to {}", path.display()); - } - eprintln!(" Run 'numa uninstall' to restore.\n"); - if needs_reboot { - eprintln!(" *** Reboot required. Numa will start automatically. ***\n"); - } else { - eprintln!(" Numa will start automatically on next boot.\n"); - } - print_recursive_hint(); Ok(()) } -/// Register numa to auto-start on boot via registry Run key. +/// Copy the currently-running binary to the service install location. SCM +/// keeps a handle to this path, so it must be stable across user sessions. #[cfg(windows)] -fn register_autostart() { - let exe = std::env::current_exe() - .map(|p| p.to_string_lossy().to_string()) - .unwrap_or_else(|_| "numa".into()); - let _ = std::process::Command::new("reg") - .args([ - "add", - "HKLM\\SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Run", - "/v", - "Numa", - "/t", - "REG_SZ", - "/d", - &exe, - "/f", - ]) - .status(); - eprintln!(" Registered auto-start on boot."); +fn install_service_binary() -> Result { + let src = std::env::current_exe().map_err(|e| format!("current_exe(): {}", e))?; + let dst = windows_service_exe_path(); + if let Some(parent) = dst.parent() { + std::fs::create_dir_all(parent) + .map_err(|e| format!("failed to create {}: {}", parent.display(), e))?; + } + // Copy only if source and destination differ; running the binary from + // its install location is a supported (re-install) case. + if src != dst { + std::fs::copy(&src, &dst).map_err(|e| { + format!( + "failed to copy {} -> {}: {}", + src.display(), + dst.display(), + e + ) + })?; + } + Ok(dst) } -/// Remove numa auto-start registry key. +/// Remove the service binary on uninstall. Ignore failures — the service +/// is already deleted; a leftover file in ProgramData is not a hard error. #[cfg(windows)] -fn remove_autostart() { - let _ = std::process::Command::new("reg") - .args([ - "delete", - "HKLM\\SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Run", - "/v", - "Numa", - "/f", - ]) - .status(); +fn remove_service_binary() { + let _ = std::fs::remove_file(windows_service_exe_path()); +} + +/// Register numa with the Service Control Manager, boot-time auto-start, +/// LocalSystem context, with a failure policy of restart-after-5s. +#[cfg(windows)] +fn register_service_scm(exe: &std::path::Path) -> Result<(), String> { + let bin_path = format!("\"{}\" --service", exe.display()); + let name = crate::windows_service::SERVICE_NAME; + + // sc.exe uses a leading space as its `name= value` delimiter; the space + // after `=` is mandatory. + let create = run_sc(&[ + "create", + name, + "binPath=", + &bin_path, + "DisplayName=", + "Numa DNS", + "start=", + "auto", + "obj=", + "LocalSystem", + ])?; + if !create.status.success() { + let out = String::from_utf8_lossy(&create.stdout); + // "service already exists" is 1073 — treat as idempotent success. + if !out.contains("1073") { + return Err(format!("sc create failed: {}", out.trim())); + } + } + + let _ = run_sc(&[ + "description", + name, + "Self-sovereign DNS resolver (ad blocking, DoH/DoT, local zones).", + ]); + + // Restart on crash: 5s, 5s, 10s; reset failure counter after 60s. + let _ = run_sc(&[ + "failure", + name, + "reset=", + "60", + "actions=", + "restart/5000/restart/5000/restart/10000", + ]); + + eprintln!(" Registered service '{}' (boot-time).", name); + Ok(()) +} + +/// Start the service. Safe to call on a freshly-registered service — SCM +/// will fail with 1056 ("already running") or 1058 ("disabled") and we +/// return the underlying error string rather than masking it. +#[cfg(windows)] +fn start_service_scm() -> Result<(), String> { + let out = run_sc(&["start", crate::windows_service::SERVICE_NAME])?; + if !out.status.success() { + let text = String::from_utf8_lossy(&out.stdout); + if text.contains("1056") { + return Ok(()); // already running + } + return Err(format!("sc start failed: {}", text.trim())); + } + Ok(()) +} + +/// Stop the service and wait for it to fully exit. Idempotent — +/// already-stopped or missing service is not an error. +#[cfg(windows)] +fn stop_service_scm() { + let name = crate::windows_service::SERVICE_NAME; + let _ = run_sc(&["stop", name]); + // Wait up to 10s for the service to reach STOPPED state so the + // binary file handle is released before we try to overwrite it. + for _ in 0..20 { + if let Ok(out) = run_sc(&["query", name]) { + let text = String::from_utf8_lossy(&out.stdout); + if text.contains("STOPPED") || text.contains("1060") { + return; + } + } + std::thread::sleep(std::time::Duration::from_millis(500)); + } + eprintln!(" warning: service did not stop within 10s"); +} + +/// Remove the service from SCM. Idempotent — see `stop_service_scm`. +#[cfg(windows)] +fn delete_service_scm() { + if let Err(e) = run_sc(&["delete", crate::windows_service::SERVICE_NAME]) { + log::warn!("sc delete failed: {}", e); + } +} + +/// Check whether the service is registered with SCM (regardless of state). +#[cfg(windows)] +fn is_service_registered() -> bool { + run_sc(&["query", crate::windows_service::SERVICE_NAME]) + .map(|o| parse_sc_registered(o.status.success(), &String::from_utf8_lossy(&o.stdout))) + .unwrap_or(false) +} + +/// Parse `sc query` output to determine if a service is registered. +/// Extracted for testability — the actual `sc` call is in `is_service_registered`. +#[cfg(any(windows, test))] +fn parse_sc_registered(exit_success: bool, stdout: &str) -> bool { + if exit_success { + return true; + } + // Error 1060 = "The specified service does not exist as an installed service." + !stdout.contains("1060") +} + +/// Print service state from SCM. +#[cfg(windows)] +fn service_status_windows() -> Result<(), String> { + let out = run_sc(&["query", crate::windows_service::SERVICE_NAME])?; + let text = String::from_utf8_lossy(&out.stdout); + let display = parse_sc_state(&text); + eprintln!(" {}\n", display); + Ok(()) +} + +/// Parse the STATE line from `sc query` output. Returns a human-readable +/// string like "STATE : 4 RUNNING" or "Service is not installed." +#[cfg(any(windows, test))] +fn parse_sc_state(sc_output: &str) -> String { + if sc_output.contains("1060") { + return "Service is not installed.".to_string(); + } + sc_output + .lines() + .find(|l| l.contains("STATE")) + .map(|l| l.trim().to_string()) + .unwrap_or_else(|| "unknown".to_string()) } #[cfg(windows)] fn uninstall_windows() -> Result<(), String> { - remove_autostart(); + // Stop + remove the service before touching DNS, so port 53 is released + // cleanly and the failure-restart policy doesn't resurrect it. + stop_service_scm(); + delete_service_scm(); + remove_service_binary(); let path = windows_backup_path(); let json = std::fs::read_to_string(&path) .map_err(|e| format!("no backup found at {}: {}", path.display(), e))?; @@ -1048,6 +1254,62 @@ pub fn install_service() -> Result<(), String> { result } +/// Start the service. If already installed, just starts it via the platform +/// service manager. If not installed, falls through to a full install. +pub fn start_service() -> Result<(), String> { + #[cfg(target_os = "macos")] + { + install_service() + } + #[cfg(target_os = "linux")] + { + install_service() + } + #[cfg(windows)] + { + if is_service_registered() { + start_service_scm()?; + eprintln!(" Service started.\n"); + Ok(()) + } else { + install_service() + } + } + #[cfg(not(any(target_os = "macos", target_os = "linux", windows)))] + { + Err("service start not supported on this OS".to_string()) + } +} + +/// Stop the service without uninstalling it. +pub fn stop_service() -> Result<(), String> { + #[cfg(target_os = "macos")] + { + uninstall_service() + } + #[cfg(target_os = "linux")] + { + uninstall_service() + } + #[cfg(windows)] + { + let out = run_sc(&["stop", crate::windows_service::SERVICE_NAME])?; + if !out.status.success() { + let text = String::from_utf8_lossy(&out.stdout); + // 1062 = not started, 1060 = does not exist + if !text.contains("1062") && !text.contains("1060") { + return Err(format!("sc stop failed: {}", text.trim())); + } + } + eprintln!(" Service stopped.\n"); + Ok(()) + } + #[cfg(not(any(target_os = "macos", target_os = "linux", windows)))] + { + Err("service stop not supported on this OS".to_string()) + } +} + /// Uninstall the Numa system service. pub fn uninstall_service() -> Result<(), String> { let _ = untrust_ca(); @@ -1117,7 +1379,14 @@ pub fn restart_service() -> Result<(), String> { eprintln!(" Service restarted → {}\n", version); Ok(()) } - #[cfg(not(any(target_os = "macos", target_os = "linux")))] + #[cfg(windows)] + { + stop_service_scm(); + start_service_scm()?; + eprintln!(" Service restarted.\n"); + Ok(()) + } + #[cfg(not(any(target_os = "macos", target_os = "linux", windows)))] { Err("service restart not supported on this OS".to_string()) } @@ -1133,7 +1402,11 @@ pub fn service_status() -> Result<(), String> { { service_status_linux() } - #[cfg(not(any(target_os = "macos", target_os = "linux")))] + #[cfg(windows)] + { + service_status_windows() + } + #[cfg(not(any(target_os = "macos", target_os = "linux", windows)))] { Err("service status not supported on this OS".to_string()) } @@ -1867,4 +2140,57 @@ Wireless LAN adapter Wi-Fi: let err = std::io::Error::from(std::io::ErrorKind::AddrInUse); assert!(try_port53_advisory("not-an-address", &err).is_none()); } + + #[test] + fn sc_query_running_service_is_registered() { + assert!(parse_sc_registered(true, "")); + } + + #[test] + fn sc_query_stopped_service_is_registered() { + let output = "SERVICE_NAME: Numa\n TYPE: 10 WIN32_OWN\n STATE: 1 STOPPED\n"; + assert!(parse_sc_registered(true, output)); + } + + #[test] + fn sc_query_missing_service_not_registered() { + let output = "[SC] EnumQueryServicesStatus:OpenService FAILED 1060:\n\nThe specified service does not exist as an installed service.\n"; + assert!(!parse_sc_registered(false, output)); + } + + #[test] + fn sc_query_other_error_assumes_registered() { + // Permission denied or other errors — don't assume unregistered. + let output = "[SC] OpenService FAILED 5:\n\nAccess is denied.\n"; + assert!(parse_sc_registered(false, output)); + } + + #[test] + fn parse_sc_state_running() { + let output = "SERVICE_NAME: Numa\n TYPE : 10 WIN32_OWN_PROCESS\n STATE : 4 RUNNING\n WIN32_EXIT_CODE : 0\n"; + assert!(parse_sc_state(output).contains("RUNNING")); + } + + #[test] + fn parse_sc_state_stopped() { + let output = "SERVICE_NAME: Numa\n TYPE : 10 WIN32_OWN_PROCESS\n STATE : 1 STOPPED\n"; + assert!(parse_sc_state(output).contains("STOPPED")); + } + + #[test] + fn parse_sc_state_not_installed() { + let output = "[SC] EnumQueryServicesStatus:OpenService FAILED 1060:\n\n"; + assert_eq!(parse_sc_state(output), "Service is not installed."); + } + + #[test] + fn parse_sc_state_empty_output() { + assert_eq!(parse_sc_state(""), "unknown"); + } + + #[cfg(windows)] + #[test] + fn windows_config_dir_equals_data_dir() { + assert_eq!(crate::config_dir(), crate::data_dir()); + } } diff --git a/src/windows_service.rs b/src/windows_service.rs new file mode 100644 index 0000000..a363359 --- /dev/null +++ b/src/windows_service.rs @@ -0,0 +1,147 @@ +//! Windows service wrapper. +//! +//! Lets the `numa.exe` binary act as a real Windows service registered with +//! the Service Control Manager (SCM). Invoked via `numa.exe --service` (the +//! form that `sc create … binPath=` uses). +//! +//! Interactive runs (`numa.exe`, `numa.exe run`, `numa.exe install`) do not +//! go through this module — they keep their existing console-attached +//! behaviour. + +use std::ffi::OsString; +use std::sync::mpsc; +use std::time::Duration; + +use windows_service::service::{ + ServiceControl, ServiceControlAccept, ServiceExitCode, ServiceState, ServiceStatus, ServiceType, +}; +use windows_service::service_control_handler::{self, ServiceControlHandlerResult}; +use windows_service::{define_windows_service, service_dispatcher}; + +pub const SERVICE_NAME: &str = "Numa"; + +define_windows_service!(ffi_service_main, service_main); + +/// Entry point the SCM hands control to after `StartServiceCtrlDispatcherW`. +/// Any panic here vanishes silently into the service host — log instead of +/// unwrapping. +fn service_main(_arguments: Vec) { + if let Err(e) = run_service() { + log::error!("numa service exited with error: {:?}", e); + } +} + +fn run_service() -> windows_service::Result<()> { + let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(); + + let event_handler = move |control_event| -> ServiceControlHandlerResult { + match control_event { + ServiceControl::Stop | ServiceControl::Shutdown => { + let _ = shutdown_tx.send(()); + ServiceControlHandlerResult::NoError + } + ServiceControl::Interrogate => ServiceControlHandlerResult::NoError, + _ => ServiceControlHandlerResult::NotImplemented, + } + }; + + let status_handle = service_control_handler::register(SERVICE_NAME, event_handler)?; + + status_handle.set_service_status(ServiceStatus { + service_type: ServiceType::OWN_PROCESS, + current_state: ServiceState::Running, + controls_accepted: ServiceControlAccept::STOP | ServiceControlAccept::SHUTDOWN, + exit_code: ServiceExitCode::Win32(0), + checkpoint: 0, + wait_hint: Duration::default(), + process_id: None, + })?; + + // Spin up a multi-threaded tokio runtime and run the server on it. A + // dedicated thread runs the runtime so this function can return cleanly + // once the SCM tells us to stop — we can't block the dispatcher thread + // forever without preventing graceful shutdown. + let config_path = service_config_path(); + let (server_done_tx, server_done_rx) = mpsc::channel::<()>(); + + let server_thread = std::thread::spawn(move || { + let runtime = match tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + { + Ok(rt) => rt, + Err(e) => { + log::error!("failed to build tokio runtime: {}", e); + let _ = server_done_tx.send(()); + return; + } + }; + + if let Err(e) = runtime.block_on(crate::serve::run(config_path)) { + log::error!("numa serve exited with error: {}", e); + } + let _ = server_done_tx.send(()); + }); + + // Wait for the API to be ready, then ensure DNS points at localhost. + // On first boot after install (Dnscache was disabled, reboot freed + // port 53), the installer deferred the DNS redirect — do it now. + let api_up = (0..20).any(|i| { + if i > 0 { + std::thread::sleep(Duration::from_millis(500)); + } + std::net::TcpStream::connect(("127.0.0.1", crate::config::DEFAULT_API_PORT)).is_ok() + }); + if api_up { + if let Err(e) = crate::system_dns::redirect_dns_to_localhost() { + log::warn!("could not redirect DNS to localhost: {}", e); + } + } else { + log::error!("numa API did not start within 10s — DNS not redirected"); + } + + // Wait for either SCM stop or server termination. + loop { + if shutdown_rx.recv_timeout(Duration::from_millis(500)).is_ok() { + break; + } + if server_done_rx.try_recv().is_ok() { + break; + } + } + + // The server's tokio runtime runs detached inside server_thread. Abandon + // it — the process is about to report Stopped and the SCM will terminate + // us if we linger. Future work: plumb a cancellation signal into + // serve::run() for a clean teardown of listeners and in-flight queries. + drop(server_thread); + + status_handle.set_service_status(ServiceStatus { + service_type: ServiceType::OWN_PROCESS, + current_state: ServiceState::Stopped, + controls_accepted: ServiceControlAccept::empty(), + exit_code: ServiceExitCode::Win32(0), + checkpoint: 0, + wait_hint: Duration::default(), + process_id: None, + })?; + + Ok(()) +} + +/// Hand control to the SCM dispatcher. Blocks until the service stops. +/// Call only from the `--service` command path — interactive invocations +/// will hang here waiting for an SCM that isn't talking to them. +pub fn run_as_service() -> windows_service::Result<()> { + service_dispatcher::start(SERVICE_NAME, ffi_service_main) +} + +/// Path to the config file used when running under SCM. SCM launches the +/// service with SYSTEM's working directory (usually `C:\Windows\System32`), +/// so a relative `numa.toml` lookup won't find anything meaningful. +fn service_config_path() -> String { + crate::data_dir() + .join("numa.toml") + .to_string_lossy() + .into_owned() +}