simplify set_lan_enabled: fix config path, TOCTOU, double iteration
- Accept config path parameter (consistent with main's resolution)
- Read first, match on NotFound (eliminates TOCTOU race)
- Single position() call replaces any() + position()
- Precise key matching via split_once('=')
- Preserve original indentation on replacement
- Extract print_lan_status helper
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
94
src/main.rs
94
src/main.rs
@@ -52,11 +52,14 @@ async fn main() -> numa::Result<()> {
|
||||
}
|
||||
"lan" => {
|
||||
let sub = std::env::args().nth(2).unwrap_or_default();
|
||||
let config_path = std::env::args()
|
||||
.nth(3)
|
||||
.unwrap_or_else(|| "numa.toml".to_string());
|
||||
return match sub.as_str() {
|
||||
"on" => set_lan_enabled(true),
|
||||
"off" => set_lan_enabled(false),
|
||||
"on" => set_lan_enabled(true, &config_path),
|
||||
"off" => set_lan_enabled(false, &config_path),
|
||||
_ => {
|
||||
eprintln!("Usage: numa lan <on|off>");
|
||||
eprintln!("Usage: numa lan <on|off> [config-path]");
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
@@ -351,52 +354,60 @@ async fn network_watch_loop(ctx: Arc<numa::ctx::ServerCtx>) {
|
||||
}
|
||||
}
|
||||
|
||||
fn set_lan_enabled(enabled: bool) -> numa::Result<()> {
|
||||
let path = "numa.toml";
|
||||
fn set_lan_enabled(enabled: bool, path: &str) -> numa::Result<()> {
|
||||
let contents = match std::fs::read_to_string(path) {
|
||||
Ok(c) => c,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
std::fs::write(path, format!("[lan]\nenabled = {}\n", enabled))?;
|
||||
print_lan_status(enabled);
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
if std::path::Path::new(path).exists() {
|
||||
let contents = std::fs::read_to_string(path)?;
|
||||
|
||||
// Track current TOML section while scanning lines
|
||||
let mut in_lan = false;
|
||||
let mut found = false;
|
||||
let mut lines: Vec<String> = contents
|
||||
.lines()
|
||||
.map(|line| {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.starts_with('[') {
|
||||
in_lan = trimmed == "[lan]";
|
||||
}
|
||||
if in_lan && !found && trimmed.starts_with("enabled") && trimmed.contains('=') {
|
||||
found = true;
|
||||
return format!("enabled = {}", enabled);
|
||||
}
|
||||
line.to_string()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let has_lan_section = lines.iter().any(|l| l.trim() == "[lan]");
|
||||
if !found && has_lan_section {
|
||||
// [lan] exists but no enabled line — insert after it
|
||||
if let Some(i) = lines.iter().position(|l| l.trim() == "[lan]") {
|
||||
lines.insert(i + 1, format!("enabled = {}", enabled));
|
||||
// Track current TOML section while scanning lines
|
||||
let mut in_lan = false;
|
||||
let mut found = false;
|
||||
let mut lines: Vec<String> = contents
|
||||
.lines()
|
||||
.map(|line| {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.starts_with('[') {
|
||||
in_lan = trimmed == "[lan]";
|
||||
}
|
||||
} else if !found {
|
||||
// No [lan] section — append
|
||||
if in_lan && !found {
|
||||
if let Some((key, _)) = trimmed.split_once('=') {
|
||||
if key.trim() == "enabled" {
|
||||
found = true;
|
||||
let indent = &line[..line.len() - trimmed.len()];
|
||||
return format!("{}enabled = {}", indent, enabled);
|
||||
}
|
||||
}
|
||||
}
|
||||
line.to_string()
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !found {
|
||||
if let Some(i) = lines.iter().position(|l| l.trim() == "[lan]") {
|
||||
lines.insert(i + 1, format!("enabled = {}", enabled));
|
||||
} else {
|
||||
lines.push(String::new());
|
||||
lines.push("[lan]".to_string());
|
||||
lines.push(format!("enabled = {}", enabled));
|
||||
}
|
||||
|
||||
let mut result = lines.join("\n");
|
||||
if contents.ends_with('\n') && !result.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
std::fs::write(path, result)?;
|
||||
} else {
|
||||
std::fs::write(path, format!("[lan]\nenabled = {}\n", enabled))?;
|
||||
}
|
||||
|
||||
let mut result = lines.join("\n");
|
||||
if !result.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
std::fs::write(path, result)?;
|
||||
print_lan_status(enabled);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_lan_status(enabled: bool) {
|
||||
let label = if enabled { "enabled" } else { "disabled" };
|
||||
let color = if enabled { "32" } else { "33" };
|
||||
eprintln!(
|
||||
@@ -406,7 +417,6 @@ fn set_lan_enabled(enabled: bool) -> numa::Result<()> {
|
||||
if enabled {
|
||||
eprintln!(" Restart Numa to start mDNS discovery");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) {
|
||||
|
||||
Reference in New Issue
Block a user