simplify set_lan_enabled: fix config path, TOCTOU, double iteration

- Accept config path parameter (consistent with main's resolution)
- Read first, match on NotFound (eliminates TOCTOU race)
- Single position() call replaces any() + position()
- Precise key matching via split_once('=')
- Preserve original indentation on replacement
- Extract print_lan_status helper

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-03-23 10:59:35 +02:00
parent dddc10336c
commit 0dd7700665

View File

@@ -52,11 +52,14 @@ async fn main() -> numa::Result<()> {
}
"lan" => {
let sub = std::env::args().nth(2).unwrap_or_default();
let config_path = std::env::args()
.nth(3)
.unwrap_or_else(|| "numa.toml".to_string());
return match sub.as_str() {
"on" => set_lan_enabled(true),
"off" => set_lan_enabled(false),
"on" => set_lan_enabled(true, &config_path),
"off" => set_lan_enabled(false, &config_path),
_ => {
eprintln!("Usage: numa lan <on|off>");
eprintln!("Usage: numa lan <on|off> [config-path]");
Ok(())
}
};
@@ -351,11 +354,16 @@ async fn network_watch_loop(ctx: Arc<numa::ctx::ServerCtx>) {
}
}
fn set_lan_enabled(enabled: bool) -> numa::Result<()> {
let path = "numa.toml";
if std::path::Path::new(path).exists() {
let contents = std::fs::read_to_string(path)?;
fn set_lan_enabled(enabled: bool, path: &str) -> numa::Result<()> {
let contents = match std::fs::read_to_string(path) {
Ok(c) => c,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
std::fs::write(path, format!("[lan]\nenabled = {}\n", enabled))?;
print_lan_status(enabled);
return Ok(());
}
Err(e) => return Err(e.into()),
};
// Track current TOML section while scanning lines
let mut in_lan = false;
@@ -367,36 +375,39 @@ fn set_lan_enabled(enabled: bool) -> numa::Result<()> {
if trimmed.starts_with('[') {
in_lan = trimmed == "[lan]";
}
if in_lan && !found && trimmed.starts_with("enabled") && trimmed.contains('=') {
if in_lan && !found {
if let Some((key, _)) = trimmed.split_once('=') {
if key.trim() == "enabled" {
found = true;
return format!("enabled = {}", enabled);
let indent = &line[..line.len() - trimmed.len()];
return format!("{}enabled = {}", indent, enabled);
}
}
}
line.to_string()
})
.collect();
let has_lan_section = lines.iter().any(|l| l.trim() == "[lan]");
if !found && has_lan_section {
// [lan] exists but no enabled line — insert after it
if !found {
if let Some(i) = lines.iter().position(|l| l.trim() == "[lan]") {
lines.insert(i + 1, format!("enabled = {}", enabled));
}
} else if !found {
// No [lan] section — append
} else {
lines.push(String::new());
lines.push("[lan]".to_string());
lines.push(format!("enabled = {}", enabled));
}
}
let mut result = lines.join("\n");
if contents.ends_with('\n') && !result.ends_with('\n') {
if !result.ends_with('\n') {
result.push('\n');
}
std::fs::write(path, result)?;
} else {
std::fs::write(path, format!("[lan]\nenabled = {}\n", enabled))?;
print_lan_status(enabled);
Ok(())
}
fn print_lan_status(enabled: bool) {
let label = if enabled { "enabled" } else { "disabled" };
let color = if enabled { "32" } else { "33" };
eprintln!(
@@ -406,7 +417,6 @@ fn set_lan_enabled(enabled: bool) -> numa::Result<()> {
if enabled {
eprintln!(" Restart Numa to start mDNS discovery");
}
Ok(())
}
async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) {