From 4020776b8e30ec628ed8725de6ad727eaa432efc Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Mon, 23 Mar 2026 10:59:35 +0200 Subject: [PATCH] simplify set_lan_enabled: fix config path, TOCTOU, double iteration - Accept config path parameter (consistent with main's resolution) - Read first, match on NotFound (eliminates TOCTOU race) - Single position() call replaces any() + position() - Precise key matching via split_once('=') - Preserve original indentation on replacement - Extract print_lan_status helper Co-Authored-By: Claude Opus 4.6 (1M context) --- src/main.rs | 94 +++++++++++++++++++++++++++++------------------------ 1 file changed, 52 insertions(+), 42 deletions(-) diff --git a/src/main.rs b/src/main.rs index 678a8bf..7f60419 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,11 +52,14 @@ async fn main() -> numa::Result<()> { } "lan" => { let sub = std::env::args().nth(2).unwrap_or_default(); + let config_path = std::env::args() + .nth(3) + .unwrap_or_else(|| "numa.toml".to_string()); return match sub.as_str() { - "on" => set_lan_enabled(true), - "off" => set_lan_enabled(false), + "on" => set_lan_enabled(true, &config_path), + "off" => set_lan_enabled(false, &config_path), _ => { - eprintln!("Usage: numa lan "); + eprintln!("Usage: numa lan [config-path]"); Ok(()) } }; @@ -351,52 +354,60 @@ async fn network_watch_loop(ctx: Arc) { } } -fn set_lan_enabled(enabled: bool) -> numa::Result<()> { - let path = "numa.toml"; +fn set_lan_enabled(enabled: bool, path: &str) -> numa::Result<()> { + let contents = match std::fs::read_to_string(path) { + Ok(c) => c, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + std::fs::write(path, format!("[lan]\nenabled = {}\n", enabled))?; + print_lan_status(enabled); + return Ok(()); + } + Err(e) => return Err(e.into()), + }; - if std::path::Path::new(path).exists() { - let contents = std::fs::read_to_string(path)?; - - // Track current TOML section while scanning lines - let mut in_lan = false; - let mut found = false; - let mut lines: Vec = contents - .lines() - .map(|line| { - let trimmed = line.trim(); - if trimmed.starts_with('[') { - in_lan = trimmed == "[lan]"; - } - if in_lan && !found && trimmed.starts_with("enabled") && trimmed.contains('=') { - found = true; - return format!("enabled = {}", enabled); - } - line.to_string() - }) - .collect(); - - let has_lan_section = lines.iter().any(|l| l.trim() == "[lan]"); - if !found && has_lan_section { - // [lan] exists but no enabled line — insert after it - if let Some(i) = lines.iter().position(|l| l.trim() == "[lan]") { - lines.insert(i + 1, format!("enabled = {}", enabled)); + // Track current TOML section while scanning lines + let mut in_lan = false; + let mut found = false; + let mut lines: Vec = contents + .lines() + .map(|line| { + let trimmed = line.trim(); + if trimmed.starts_with('[') { + in_lan = trimmed == "[lan]"; } - } else if !found { - // No [lan] section — append + if in_lan && !found { + if let Some((key, _)) = trimmed.split_once('=') { + if key.trim() == "enabled" { + found = true; + let indent = &line[..line.len() - trimmed.len()]; + return format!("{}enabled = {}", indent, enabled); + } + } + } + line.to_string() + }) + .collect(); + + if !found { + if let Some(i) = lines.iter().position(|l| l.trim() == "[lan]") { + lines.insert(i + 1, format!("enabled = {}", enabled)); + } else { lines.push(String::new()); lines.push("[lan]".to_string()); lines.push(format!("enabled = {}", enabled)); } - - let mut result = lines.join("\n"); - if contents.ends_with('\n') && !result.ends_with('\n') { - result.push('\n'); - } - std::fs::write(path, result)?; - } else { - std::fs::write(path, format!("[lan]\nenabled = {}\n", enabled))?; } + let mut result = lines.join("\n"); + if !result.ends_with('\n') { + result.push('\n'); + } + std::fs::write(path, result)?; + print_lan_status(enabled); + Ok(()) +} + +fn print_lan_status(enabled: bool) { let label = if enabled { "enabled" } else { "disabled" }; let color = if enabled { "32" } else { "33" }; eprintln!( @@ -406,7 +417,6 @@ fn set_lan_enabled(enabled: bool) -> numa::Result<()> { if enabled { eprintln!(" Restart Numa to start mDNS discovery"); } - Ok(()) } async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) {