diff --git a/src/system_dns.rs b/src/system_dns.rs index fc02393..643b9d0 100644 --- a/src/system_dns.rs +++ b/src/system_dns.rs @@ -214,7 +214,18 @@ fn discover_linux() -> SystemDnsInfo { } } -/// Parse resolv.conf in a single pass, extracting both the first non-loopback +/// Yield each `nameserver` address from resolv.conf content. No filtering — +/// callers decide what counts as a real upstream. +#[cfg(any(target_os = "linux", test))] +fn iter_nameservers(content: &str) -> impl Iterator { + content.lines().filter_map(|line| { + let mut parts = line.split_whitespace(); + (parts.next() == Some("nameserver")).then_some(())?; + parts.next() + }) +} + +/// Parse resolv.conf in a single pass, extracting the first non-loopback /// nameserver and all search domains. #[cfg(target_os = "linux")] fn parse_resolv_conf(path: &str) -> (Option, Vec) { @@ -222,19 +233,13 @@ fn parse_resolv_conf(path: &str) -> (Option, Vec) { Ok(t) => t, Err(_) => return (None, Vec::new()), }; - let mut upstream = None; + let upstream = iter_nameservers(&text) + .find(|ns| !is_loopback_or_stub(ns)) + .map(str::to_string); let mut search_domains = Vec::new(); for line in text.lines() { let line = line.trim(); - if line.starts_with("nameserver") { - if upstream.is_none() { - if let Some(ns) = line.split_whitespace().nth(1) { - if !is_loopback_or_stub(ns) { - upstream = Some(ns.to_string()); - } - } - } - } else if line.starts_with("search") || line.starts_with("domain") { + if line.starts_with("search") || line.starts_with("domain") { for domain in line.split_whitespace().skip(1) { search_domains.push(domain.to_string()); } @@ -243,6 +248,21 @@ fn parse_resolv_conf(path: &str) -> (Option, Vec) { (upstream, search_domains) } +/// True if the resolv.conf *content* appears to be written by numa itself, +/// or has no real upstream — either way, it's not a safe source of truth +/// for a backup. +#[cfg(any(target_os = "linux", test))] +fn resolv_conf_is_numa_managed(content: &str) -> bool { + content.contains("Generated by Numa") || !resolv_conf_has_real_upstream(content) +} + +/// True if the resolv.conf content has at least one non-loopback, non-stub +/// nameserver. An all-loopback resolv.conf is self-referential. +#[cfg(any(target_os = "linux", test))] +fn resolv_conf_has_real_upstream(content: &str) -> bool { + iter_nameservers(content).any(|ns| !is_loopback_or_stub(ns)) +} + /// Query resolvectl for the real upstream DNS server (e.g. VPC resolver on AWS). #[cfg(target_os = "linux")] fn resolvectl_dns_server() -> Option { @@ -526,9 +546,19 @@ fn enable_dnscache() { .status(); } +/// True if the backup map has at least one real upstream (non-loopback, non-stub). +#[cfg(any(windows, test))] +fn backup_has_real_upstream_windows( + interfaces: &std::collections::HashMap, +) -> bool { + interfaces + .values() + .any(|iface| iface.servers.iter().any(|s| !is_loopback_or_stub(s))) +} + #[cfg(windows)] fn install_windows() -> Result<(), String> { - let interfaces = get_windows_interfaces()?; + let mut interfaces = get_windows_interfaces()?; if interfaces.is_empty() { return Err("no active network interfaces found".to_string()); } @@ -538,9 +568,30 @@ fn install_windows() -> Result<(), String> { std::fs::create_dir_all(parent) .map_err(|e| format!("failed to create {}: {}", parent.display(), e))?; } - let json = serde_json::to_string_pretty(&interfaces) - .map_err(|e| format!("failed to serialize backup: {}", e))?; - std::fs::write(&path, json).map_err(|e| format!("failed to write backup: {}", e))?; + + // Preserve an existing useful backup rather than overwriting it with + // numa-managed state (which would be self-referential after uninstall). + let existing: Option> = + std::fs::read_to_string(&path) + .ok() + .and_then(|json| serde_json::from_str(&json).ok()); + let has_useful_existing = existing + .as_ref() + .map(backup_has_real_upstream_windows) + .unwrap_or(false); + + if has_useful_existing { + eprintln!(" Existing DNS backup preserved at {}", path.display()); + } else { + // Filter loopback/stub addresses before saving so a fresh backup + // captured from already-numa-managed state isn't self-referential. + for iface in interfaces.values_mut() { + iface.servers.retain(|s| !is_loopback_or_stub(s)); + } + let json = serde_json::to_string_pretty(&interfaces) + .map_err(|e| format!("failed to serialize backup: {}", e))?; + std::fs::write(&path, json).map_err(|e| format!("failed to write backup: {}", e))?; + } for name in interfaces.keys() { let status = std::process::Command::new("netsh") @@ -570,7 +621,10 @@ fn install_windows() -> Result<(), String> { let needs_reboot = disable_dnscache()?; register_autostart(); - eprintln!("\n Original DNS saved to {}", path.display()); + eprintln!(); + if !has_useful_existing { + eprintln!(" Original DNS saved to {}", path.display()); + } eprintln!(" Run 'numa uninstall' to restore.\n"); if needs_reboot { eprintln!(" *** Reboot required. Numa will start automatically. ***\n"); @@ -754,27 +808,60 @@ fn get_dns_servers(service: &str) -> Result, String> { } } +/// True if the backup map has at least one real upstream (non-loopback, non-stub). +/// An all-loopback backup is self-referential — restoring it is a no-op. +#[cfg(any(target_os = "macos", test))] +fn backup_has_real_upstream_macos( + servers: &std::collections::HashMap>, +) -> bool { + servers + .values() + .any(|list| list.iter().any(|s| !is_loopback_or_stub(s))) +} + #[cfg(target_os = "macos")] fn install_macos() -> Result<(), String> { use std::collections::HashMap; let services = get_network_services()?; - let mut original: HashMap> = HashMap::new(); - - // Save current DNS for each service - for service in &services { - let servers = get_dns_servers(service)?; - original.insert(service.clone(), servers); - } - - // Save backup let dir = numa_data_dir(); std::fs::create_dir_all(&dir) .map_err(|e| format!("failed to create {}: {}", dir.display(), e))?; - let json = serde_json::to_string_pretty(&original) - .map_err(|e| format!("failed to serialize backup: {}", e))?; - std::fs::write(backup_path(), json).map_err(|e| format!("failed to write backup: {}", e))?; + // If a useful backup already exists (at least one non-loopback upstream), + // preserve it — overwriting would destroy the original DNS state when + // re-installing on top of a numa-managed configuration. + let existing_backup: Option>> = + std::fs::read_to_string(backup_path()) + .ok() + .and_then(|json| serde_json::from_str(&json).ok()); + let has_useful_existing = existing_backup + .as_ref() + .map(backup_has_real_upstream_macos) + .unwrap_or(false); + + if has_useful_existing { + eprintln!( + " Existing DNS backup preserved at {}", + backup_path().display() + ); + } else { + // Capture fresh, filtering out loopback and stub addresses so we + // never record a self-referential backup. + let mut original: HashMap> = HashMap::new(); + for service in &services { + let servers: Vec = get_dns_servers(service)? + .into_iter() + .filter(|s| !is_loopback_or_stub(s)) + .collect(); + original.insert(service.clone(), servers); + } + + let json = serde_json::to_string_pretty(&original) + .map_err(|e| format!("failed to serialize backup: {}", e))?; + std::fs::write(backup_path(), json) + .map_err(|e| format!("failed to write backup: {}", e))?; + } // Set DNS to 127.0.0.1 and add "numa" search domain for each service for service in &services { @@ -795,7 +882,10 @@ fn install_macos() -> Result<(), String> { .status(); } - eprintln!("\n Original DNS saved to {}", backup_path().display()); + eprintln!(); + if !has_useful_existing { + eprintln!(" Original DNS saved to {}", backup_path().display()); + } eprintln!(" Run 'sudo numa uninstall' to restore.\n"); Ok(()) @@ -1132,11 +1222,31 @@ fn install_linux() -> Result<(), String> { .map_err(|e| format!("failed to create {}: {}", parent.display(), e))?; } - // Back up current resolv.conf (ignore NotFound) - match std::fs::copy(resolv, &backup) { - Ok(_) => eprintln!(" Saved /etc/resolv.conf to {}", backup.display()), - Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} - Err(e) => return Err(format!("failed to backup /etc/resolv.conf: {}", e)), + // Back up current resolv.conf, but never overwrite a useful existing + // backup with a numa-managed file — that would leave uninstall with + // nothing to restore to. + let current = std::fs::read_to_string(resolv).ok(); + let current_is_numa_managed = current + .as_deref() + .map(resolv_conf_is_numa_managed) + .unwrap_or(false); + let existing_backup_is_useful = std::fs::read_to_string(&backup) + .ok() + .as_deref() + .map(resolv_conf_has_real_upstream) + .unwrap_or(false); + + if existing_backup_is_useful { + eprintln!( + " Existing resolv.conf backup preserved at {}", + backup.display() + ); + } else if current_is_numa_managed { + eprintln!(" warning: /etc/resolv.conf is already numa-managed; no fresh backup written"); + } else if let Some(content) = current.as_deref() { + std::fs::write(&backup, content) + .map_err(|e| format!("failed to backup /etc/resolv.conf: {}", e))?; + eprintln!(" Saved /etc/resolv.conf to {}", backup.display()); } if resolv @@ -1539,6 +1649,82 @@ Wireless LAN adapter Wi-Fi: assert!(!result.contains("{{exe_path}}")); } + #[test] + fn macos_backup_real_upstream_detection() { + use std::collections::HashMap; + let mut map: HashMap> = HashMap::new(); + + // Empty backup → no real upstream + assert!(!backup_has_real_upstream_macos(&map)); + + // All-loopback backup → still no real upstream (the bug case) + map.insert("Wi-Fi".into(), vec!["127.0.0.1".into()]); + map.insert("Ethernet".into(), vec!["::1".into()]); + assert!(!backup_has_real_upstream_macos(&map)); + + // One real entry → useful + map.insert("Tailscale".into(), vec!["192.168.1.1".into()]); + assert!(backup_has_real_upstream_macos(&map)); + } + + #[test] + fn windows_backup_filters_loopback() { + use std::collections::HashMap; + let mut map: HashMap = HashMap::new(); + + // Empty backup → no real upstream + assert!(!backup_has_real_upstream_windows(&map)); + + // All-loopback backup → still no real upstream (the bug case) + map.insert( + "Wi-Fi".into(), + WindowsInterfaceDns { + dhcp: false, + servers: vec!["127.0.0.1".into()], + }, + ); + map.insert( + "Ethernet".into(), + WindowsInterfaceDns { + dhcp: false, + servers: vec!["::1".into(), "0.0.0.0".into()], + }, + ); + assert!(!backup_has_real_upstream_windows(&map)); + + // One real entry alongside loopback → useful + map.insert( + "Ethernet 2".into(), + WindowsInterfaceDns { + dhcp: false, + servers: vec!["192.168.1.1".into()], + }, + ); + assert!(backup_has_real_upstream_windows(&map)); + } + + #[test] + fn resolv_conf_real_upstream_detection() { + let real = "nameserver 192.168.1.1\nsearch lan\n"; + assert!(resolv_conf_has_real_upstream(real)); + assert!(!resolv_conf_is_numa_managed(real)); + + let self_ref = "nameserver 127.0.0.1\nsearch numa\n"; + assert!(!resolv_conf_has_real_upstream(self_ref)); + assert!(resolv_conf_is_numa_managed(self_ref)); + + let numa_marker = + "# Generated by Numa — run 'sudo numa uninstall' to restore\nnameserver 127.0.0.1\nsearch numa\n"; + assert!(resolv_conf_is_numa_managed(numa_marker)); + + let systemd_stub = "nameserver 127.0.0.53\noptions edns0\n"; + assert!(!resolv_conf_has_real_upstream(systemd_stub)); + + let mixed = "nameserver 127.0.0.1\nnameserver 1.1.1.1\n"; + assert!(resolv_conf_has_real_upstream(mixed)); + assert!(!resolv_conf_is_numa_managed(mixed)); + } + #[test] fn parse_ipconfig_skips_disconnected() { let sample = "\