simplify set_lan_enabled: fix config path, TOCTOU, double iteration

- Accept config path parameter (consistent with main's resolution)
- Read first, match on NotFound (eliminates TOCTOU race)
- Single position() call replaces any() + position()
- Precise key matching via split_once('=')
- Preserve original indentation on replacement
- Extract print_lan_status helper

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-03-23 10:59:35 +02:00
parent 763ba1de91
commit 4020776b8e

View File

@@ -52,11 +52,14 @@ async fn main() -> numa::Result<()> {
} }
"lan" => { "lan" => {
let sub = std::env::args().nth(2).unwrap_or_default(); let sub = std::env::args().nth(2).unwrap_or_default();
let config_path = std::env::args()
.nth(3)
.unwrap_or_else(|| "numa.toml".to_string());
return match sub.as_str() { return match sub.as_str() {
"on" => set_lan_enabled(true), "on" => set_lan_enabled(true, &config_path),
"off" => set_lan_enabled(false), "off" => set_lan_enabled(false, &config_path),
_ => { _ => {
eprintln!("Usage: numa lan <on|off>"); eprintln!("Usage: numa lan <on|off> [config-path]");
Ok(()) Ok(())
} }
}; };
@@ -351,11 +354,16 @@ async fn network_watch_loop(ctx: Arc<numa::ctx::ServerCtx>) {
} }
} }
fn set_lan_enabled(enabled: bool) -> numa::Result<()> { fn set_lan_enabled(enabled: bool, path: &str) -> numa::Result<()> {
let path = "numa.toml"; let contents = match std::fs::read_to_string(path) {
Ok(c) => c,
if std::path::Path::new(path).exists() { Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
let contents = std::fs::read_to_string(path)?; std::fs::write(path, format!("[lan]\nenabled = {}\n", enabled))?;
print_lan_status(enabled);
return Ok(());
}
Err(e) => return Err(e.into()),
};
// Track current TOML section while scanning lines // Track current TOML section while scanning lines
let mut in_lan = false; let mut in_lan = false;
@@ -367,36 +375,39 @@ fn set_lan_enabled(enabled: bool) -> numa::Result<()> {
if trimmed.starts_with('[') { if trimmed.starts_with('[') {
in_lan = trimmed == "[lan]"; in_lan = trimmed == "[lan]";
} }
if in_lan && !found && trimmed.starts_with("enabled") && trimmed.contains('=') { if in_lan && !found {
if let Some((key, _)) = trimmed.split_once('=') {
if key.trim() == "enabled" {
found = true; found = true;
return format!("enabled = {}", enabled); let indent = &line[..line.len() - trimmed.len()];
return format!("{}enabled = {}", indent, enabled);
}
}
} }
line.to_string() line.to_string()
}) })
.collect(); .collect();
let has_lan_section = lines.iter().any(|l| l.trim() == "[lan]"); if !found {
if !found && has_lan_section {
// [lan] exists but no enabled line — insert after it
if let Some(i) = lines.iter().position(|l| l.trim() == "[lan]") { if let Some(i) = lines.iter().position(|l| l.trim() == "[lan]") {
lines.insert(i + 1, format!("enabled = {}", enabled)); lines.insert(i + 1, format!("enabled = {}", enabled));
} } else {
} else if !found {
// No [lan] section — append
lines.push(String::new()); lines.push(String::new());
lines.push("[lan]".to_string()); lines.push("[lan]".to_string());
lines.push(format!("enabled = {}", enabled)); lines.push(format!("enabled = {}", enabled));
} }
}
let mut result = lines.join("\n"); let mut result = lines.join("\n");
if contents.ends_with('\n') && !result.ends_with('\n') { if !result.ends_with('\n') {
result.push('\n'); result.push('\n');
} }
std::fs::write(path, result)?; std::fs::write(path, result)?;
} else { print_lan_status(enabled);
std::fs::write(path, format!("[lan]\nenabled = {}\n", enabled))?; Ok(())
} }
fn print_lan_status(enabled: bool) {
let label = if enabled { "enabled" } else { "disabled" }; let label = if enabled { "enabled" } else { "disabled" };
let color = if enabled { "32" } else { "33" }; let color = if enabled { "32" } else { "33" };
eprintln!( eprintln!(
@@ -406,7 +417,6 @@ fn set_lan_enabled(enabled: bool) -> numa::Result<()> {
if enabled { if enabled {
eprintln!(" Restart Numa to start mDNS discovery"); eprintln!(" Restart Numa to start mDNS discovery");
} }
Ok(())
} }
async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) { async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) {