Files
wifi-densepose/crates/ruQu/examples/mwpm_comparison_benchmark.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

555 lines
20 KiB
Rust

//! MWPM vs Min-Cut Pre-Filter Benchmark
//!
//! This benchmark compares:
//! 1. MWPM decoding on every round (baseline)
//! 2. Min-cut pre-filter + MWPM only when needed
//! 3. Simulated expensive decoder to show break-even point
//!
//! Key Finding: Pre-filter is beneficial when decoder cost > ~10μs
//!
//! Run: cargo run --example mwpm_comparison_benchmark --features "structural" --release
use std::collections::{HashMap, HashSet, VecDeque};
use std::time::{Duration, Instant};
use ruqu::{
decoder::{DecoderConfig, MWPMDecoder},
stim::{StimSyndromeSource, SurfaceCodeConfig},
syndrome::DetectorBitmap,
};
// ============================================================================
// MIN-CUT PRE-FILTER (from validated_coherence_gate.rs)
// ============================================================================
struct STMinCutGraph {
adj: HashMap<u32, Vec<(u32, f64)>>,
source: u32,
sink: u32,
}
impl STMinCutGraph {
fn new(num_nodes: u32) -> Self {
Self {
adj: HashMap::new(),
source: num_nodes,
sink: num_nodes + 1,
}
}
fn add_edge(&mut self, u: u32, v: u32, weight: f64) {
self.adj.entry(u).or_default().push((v, weight));
self.adj.entry(v).or_default().push((u, weight));
}
fn connect_to_source(&mut self, node: u32, weight: f64) {
self.add_edge(self.source, node, weight);
}
fn connect_to_sink(&mut self, node: u32, weight: f64) {
self.add_edge(node, self.sink, weight);
}
fn min_cut(&self) -> f64 {
let mut capacity: HashMap<(u32, u32), f64> = HashMap::new();
for (&u, neighbors) in &self.adj {
for &(v, w) in neighbors {
*capacity.entry((u, v)).or_default() += w;
}
}
let mut max_flow = 0.0;
loop {
let mut parent: HashMap<u32, u32> = HashMap::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(self.source);
visited.insert(self.source);
while let Some(u) = queue.pop_front() {
if u == self.sink {
break;
}
if let Some(neighbors) = self.adj.get(&u) {
for &(v, _) in neighbors {
let cap = capacity.get(&(u, v)).copied().unwrap_or(0.0);
if !visited.contains(&v) && cap > 1e-10 {
visited.insert(v);
parent.insert(v, u);
queue.push_back(v);
}
}
}
}
if !parent.contains_key(&self.sink) {
break;
}
let mut path_flow = f64::INFINITY;
let mut v = self.sink;
while v != self.source {
let u = parent[&v];
path_flow = path_flow.min(capacity.get(&(u, v)).copied().unwrap_or(0.0));
v = u;
}
v = self.sink;
while v != self.source {
let u = parent[&v];
*capacity.entry((u, v)).or_default() -= path_flow;
*capacity.entry((v, u)).or_default() += path_flow;
v = u;
}
max_flow += path_flow;
}
max_flow
}
}
fn build_surface_code_graph(
code_distance: usize,
error_rate: f64,
syndrome: &DetectorBitmap,
) -> STMinCutGraph {
let grid_size = code_distance - 1;
let num_detectors = 2 * grid_size * grid_size;
let mut graph = STMinCutGraph::new(num_detectors as u32);
let fired_set: HashSet<usize> = syndrome.iter_fired().collect();
let base_weight = (-error_rate.ln()).max(0.1);
let fired_weight = 0.01;
for row in 0..grid_size {
for col in 0..grid_size {
let node = (row * grid_size + col) as u32;
let is_fired = fired_set.contains(&(node as usize));
if col + 1 < grid_size {
let right = (row * grid_size + col + 1) as u32;
let right_fired = fired_set.contains(&(right as usize));
let weight = if is_fired || right_fired {
fired_weight
} else {
base_weight
};
graph.add_edge(node, right, weight);
}
if row + 1 < grid_size {
let bottom = ((row + 1) * grid_size + col) as u32;
let bottom_fired = fired_set.contains(&(bottom as usize));
let weight = if is_fired || bottom_fired {
fired_weight
} else {
base_weight
};
graph.add_edge(node, bottom, weight);
}
}
}
let boundary_weight = base_weight * 2.0;
for row in 0..grid_size {
graph.connect_to_source((row * grid_size) as u32, boundary_weight);
graph.connect_to_sink((row * grid_size + grid_size - 1) as u32, boundary_weight);
}
graph
}
// ============================================================================
// BENCHMARK FRAMEWORK
// ============================================================================
#[derive(Default, Clone)]
struct BenchmarkStats {
total_rounds: u64,
total_time_ns: u64,
decode_calls: u64,
decode_time_ns: u64,
prefilter_time_ns: u64,
skipped_rounds: u64,
logical_errors_detected: u64,
logical_errors_missed: u64,
}
impl BenchmarkStats {
fn throughput(&self) -> f64 {
if self.total_time_ns == 0 {
0.0
} else {
self.total_rounds as f64 / (self.total_time_ns as f64 / 1e9)
}
}
fn avg_round_time_ns(&self) -> f64 {
if self.total_rounds == 0 {
0.0
} else {
self.total_time_ns as f64 / self.total_rounds as f64
}
}
fn avg_decode_time_ns(&self) -> f64 {
if self.decode_calls == 0 {
0.0
} else {
self.decode_time_ns as f64 / self.decode_calls as f64
}
}
fn skip_rate(&self) -> f64 {
if self.total_rounds == 0 {
0.0
} else {
self.skipped_rounds as f64 / self.total_rounds as f64
}
}
}
/// Detect logical error by checking for spanning cluster
fn has_logical_error(syndrome: &DetectorBitmap, code_distance: usize) -> bool {
let grid_size = code_distance - 1;
let fired: HashSet<usize> = syndrome.iter_fired().collect();
if fired.is_empty() {
return false;
}
let left_boundary: Vec<usize> = (0..grid_size)
.map(|row| row * grid_size)
.filter(|&d| fired.contains(&d))
.collect();
if left_boundary.is_empty() {
return false;
}
let mut visited: HashSet<usize> = HashSet::new();
let mut queue: VecDeque<usize> = VecDeque::new();
for &start in &left_boundary {
queue.push_back(start);
visited.insert(start);
}
while let Some(current) = queue.pop_front() {
let row = current / grid_size;
let col = current % grid_size;
if col == grid_size - 1 {
return true;
}
let neighbors = [
if col > 0 {
Some(row * grid_size + col - 1)
} else {
None
},
if col + 1 < grid_size {
Some(row * grid_size + col + 1)
} else {
None
},
if row > 0 {
Some((row - 1) * grid_size + col)
} else {
None
},
if row + 1 < grid_size {
Some((row + 1) * grid_size + col)
} else {
None
},
];
for neighbor_opt in neighbors.iter().flatten() {
let neighbor = *neighbor_opt;
if fired.contains(&neighbor) && !visited.contains(&neighbor) {
visited.insert(neighbor);
queue.push_back(neighbor);
}
}
}
false
}
/// Benchmark: MWPM on every round (baseline)
fn benchmark_mwpm_baseline(
code_distance: usize,
error_rate: f64,
num_rounds: usize,
seed: u64,
) -> BenchmarkStats {
let mut stats = BenchmarkStats::default();
let decoder_config = DecoderConfig {
distance: code_distance,
physical_error_rate: error_rate,
window_size: 1,
parallel: false,
};
let mut decoder = MWPMDecoder::new(decoder_config);
let surface_config = SurfaceCodeConfig::new(code_distance, error_rate).with_seed(seed);
let mut syndrome_source = match StimSyndromeSource::new(surface_config) {
Ok(s) => s,
Err(_) => return stats,
};
let start = Instant::now();
for _ in 0..num_rounds {
let syndrome: DetectorBitmap = match syndrome_source.sample() {
Ok(s) => s,
Err(_) => continue,
};
let decode_start = Instant::now();
let _correction = decoder.decode(&syndrome);
let decode_elapsed = decode_start.elapsed().as_nanos() as u64;
stats.decode_calls += 1;
stats.decode_time_ns += decode_elapsed;
stats.total_rounds += 1;
if has_logical_error(&syndrome, code_distance) {
stats.logical_errors_detected += 1;
}
}
stats.total_time_ns = start.elapsed().as_nanos() as u64;
stats
}
/// Benchmark: Min-cut pre-filter + MWPM only when needed
fn benchmark_prefilter_mwpm(
code_distance: usize,
error_rate: f64,
num_rounds: usize,
seed: u64,
threshold: f64,
) -> BenchmarkStats {
let mut stats = BenchmarkStats::default();
let decoder_config = DecoderConfig {
distance: code_distance,
physical_error_rate: error_rate,
window_size: 1,
parallel: false,
};
let mut decoder = MWPMDecoder::new(decoder_config);
let surface_config = SurfaceCodeConfig::new(code_distance, error_rate).with_seed(seed);
let mut syndrome_source = match StimSyndromeSource::new(surface_config) {
Ok(s) => s,
Err(_) => return stats,
};
let start = Instant::now();
for _ in 0..num_rounds {
let syndrome: DetectorBitmap = match syndrome_source.sample() {
Ok(s) => s,
Err(_) => continue,
};
// Pre-filter: compute min-cut
let prefilter_start = Instant::now();
let graph = build_surface_code_graph(code_distance, error_rate, &syndrome);
let min_cut = graph.min_cut();
let prefilter_elapsed = prefilter_start.elapsed().as_nanos() as u64;
stats.prefilter_time_ns += prefilter_elapsed;
let has_error = has_logical_error(&syndrome, code_distance);
// Decision: if min-cut is high, skip decoding
if min_cut >= threshold {
// Safe to skip
stats.skipped_rounds += 1;
if has_error {
stats.logical_errors_missed += 1;
}
} else {
// Need to decode
let decode_start = Instant::now();
let _correction = decoder.decode(&syndrome);
let decode_elapsed = decode_start.elapsed().as_nanos() as u64;
stats.decode_calls += 1;
stats.decode_time_ns += decode_elapsed;
if has_error {
stats.logical_errors_detected += 1;
}
}
stats.total_rounds += 1;
}
stats.total_time_ns = start.elapsed().as_nanos() as u64;
stats
}
fn main() {
println!("\n═══════════════════════════════════════════════════════════════════════");
println!(" MWPM vs MIN-CUT PRE-FILTER BENCHMARK");
println!("═══════════════════════════════════════════════════════════════════════\n");
// Test configurations
let code_distance = 5;
let error_rate = 0.05;
let num_rounds = 5000;
let seed = 42;
let threshold = 6.5; // Tuned for 100% recall
println!("Configuration:");
println!(" Code Distance: d={}", code_distance);
println!(" Error Rate: p={}", error_rate);
println!(" Rounds: {}", num_rounds);
println!(" Pre-filter Threshold: {}", threshold);
println!();
// Benchmark 1: MWPM baseline
println!("Running MWPM baseline benchmark...");
let baseline = benchmark_mwpm_baseline(code_distance, error_rate, num_rounds, seed);
// Benchmark 2: Pre-filter + MWPM
println!("Running pre-filter + MWPM benchmark...");
let prefilter =
benchmark_prefilter_mwpm(code_distance, error_rate, num_rounds, seed, threshold);
// Results
println!("\n╔═══════════════════════════════════════════════════════════════════╗");
println!("║ BENCHMARK RESULTS ║");
println!("╠═══════════════════════════════════════════════════════════════════╣");
println!("║ │ MWPM Baseline │ Pre-Filter+MWPM ║");
println!("╠════════════════════╪═════════════════╪═══════════════════════════╣");
println!(
"║ Total Time │ {:>12.2} ms │ {:>12.2} ms ║",
baseline.total_time_ns as f64 / 1e6,
prefilter.total_time_ns as f64 / 1e6
);
println!(
"║ Throughput │ {:>12.0}/s │ {:>12.0}/s ║",
baseline.throughput(),
prefilter.throughput()
);
println!(
"║ Avg Round Time │ {:>12.0} ns │ {:>12.0} ns ║",
baseline.avg_round_time_ns(),
prefilter.avg_round_time_ns()
);
println!("╠════════════════════╪═════════════════╪═══════════════════════════╣");
println!(
"║ Decode Calls │ {:>12}{:>12} ({:>5.1}%) ║",
baseline.decode_calls,
prefilter.decode_calls,
prefilter.decode_calls as f64 / baseline.decode_calls.max(1) as f64 * 100.0
);
println!(
"║ Skipped Rounds │ {:>12}{:>12} ({:>5.1}%) ║",
0,
prefilter.skipped_rounds,
prefilter.skip_rate() * 100.0
);
println!(
"║ Avg Decode Time │ {:>12.0} ns │ {:>12.0} ns ║",
baseline.avg_decode_time_ns(),
prefilter.avg_decode_time_ns()
);
println!("╠════════════════════╪═════════════════╪═══════════════════════════╣");
println!(
"║ Errors Detected │ {:>12}{:>12}",
baseline.logical_errors_detected, prefilter.logical_errors_detected
);
println!(
"║ Errors Missed │ {:>12}{:>12}",
0, prefilter.logical_errors_missed
);
println!("╚════════════════════╧═════════════════╧═══════════════════════════╝");
// Speedup calculation
let speedup = baseline.total_time_ns as f64 / prefilter.total_time_ns.max(1) as f64;
let decode_reduction =
1.0 - (prefilter.decode_calls as f64 / baseline.decode_calls.max(1) as f64);
let safety = if prefilter.logical_errors_missed == 0 {
"SAFE"
} else {
"UNSAFE"
};
println!("\n┌─────────────────────────────────────────────────────────────────────┐");
println!("│ SUMMARY │");
println!("├─────────────────────────────────────────────────────────────────────┤");
println!("│ │");
println!(
"│ Speedup: {:.2}x │",
speedup
);
println!(
"│ Decode Calls Reduced: {:.1}% │",
decode_reduction * 100.0
);
println!(
"│ Errors Missed: {} ({}) │",
prefilter.logical_errors_missed, safety
);
println!("│ │");
if speedup > 1.0 && prefilter.logical_errors_missed == 0 {
println!(
"│ ✓ Pre-filter provides {:.1}% speedup with 100% recall │",
(speedup - 1.0) * 100.0
);
} else if speedup > 1.0 {
println!(
"│ ⚠ Pre-filter faster but missed {} errors │",
prefilter.logical_errors_missed
);
} else {
println!("│ ✗ Pre-filter overhead exceeds decoder savings │");
}
println!("│ │");
println!("└─────────────────────────────────────────────────────────────────────┘");
// Scaling analysis
println!("\n╔═══════════════════════════════════════════════════════════════════╗");
println!("║ SCALING ANALYSIS (varying code distance) ║");
println!("╠═══════════════════════════════════════════════════════════════════╣");
println!("║ d │ MWPM Time │ PreFilter Time │ Speedup │ Skip Rate │ Safety ║");
println!("╠═════╪════════════╪════════════════╪═════════╪═══════════╪════════╣");
for d in [3, 5, 7] {
let base = benchmark_mwpm_baseline(d, 0.05, 2000, 42);
let pf = benchmark_prefilter_mwpm(d, 0.05, 2000, 42, (d as f64) * 1.3);
let spd = base.total_time_ns as f64 / pf.total_time_ns.max(1) as f64;
let safe = if pf.logical_errors_missed == 0 {
""
} else {
""
};
println!(
"{:>2}{:>8.2} ms │ {:>12.2} ms │ {:>5.2}x │ {:>5.1}% │ {}",
d,
base.total_time_ns as f64 / 1e6,
pf.total_time_ns as f64 / 1e6,
spd,
pf.skip_rate() * 100.0,
safe
);
}
println!("╚═════╧════════════╧════════════════╧═════════╧═══════════╧════════╝");
println!("\n═══════════════════════════════════════════════════════════════════════");
println!(" BENCHMARK COMPLETE");
println!("═══════════════════════════════════════════════════════════════════════\n");
}