//! PolicyKernel — Thompson Sampling two-signal model for skip-mode selection. //! //! Faithful WASM port of the PolicyKernel from the benchmarks crate. //! Implements: //! - Two-signal Thompson Sampling (safety Beta + cost EMA) //! - 18 context buckets (3 range x 3 distractor x 2 noise) //! - Speculative dual-path for Mode C //! - KnowledgeCompiler with signature cache extern crate alloc; use alloc::collections::BTreeMap; use alloc::format; use alloc::string::String; use alloc::vec::Vec; use libm::{cos, log, pow, sqrt}; use serde::{Deserialize, Serialize}; use crate::types::{Constraint, Puzzle, constraint_type_name}; // ═════════════════════════════════════════════════════════════════════ // Skip / Prepass modes // ═════════════════════════════════════════════════════════════════════ #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum SkipMode { None, Weekday, Hybrid, } impl SkipMode { pub fn name(&self) -> &'static str { match self { SkipMode::None => "none", SkipMode::Weekday => "weekday", SkipMode::Hybrid => "hybrid", } } } #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] pub enum PrepassMode { #[default] Off, Light, Full, } // ═════════════════════════════════════════════════════════════════════ // Policy context + outcome // ═════════════════════════════════════════════════════════════════════ #[derive(Clone, Debug, Serialize, Deserialize)] pub struct PolicyContext { pub posterior_range: usize, pub distractor_count: usize, pub has_day_of_week: bool, pub noisy: bool, } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SkipOutcome { pub mode: SkipMode, pub correct: bool, pub steps: usize, pub early_commit_wrong: bool, pub initial_candidates: usize, pub remaining_at_commit: usize, } // ═════════════════════════════════════════════════════════════════════ // Per-arm stats (two-signal Thompson Sampling) // ═════════════════════════════════════════════════════════════════════ const THOMPSON_LAMBDA: f64 = 0.3; const COST_EMA_ALPHA: f64 = 0.1; #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct SkipModeStats { pub attempts: usize, pub successes: usize, pub total_steps: usize, pub alpha_safety: f64, pub beta_safety: f64, pub cost_ema: f64, pub early_commit_wrongs: usize, pub early_commit_penalty_sum: f64, } impl SkipModeStats { pub fn safety_beta(&self) -> (f64, f64) { (self.alpha_safety + 1.0, self.beta_safety + 1.0) } pub fn safety_variance(&self) -> f64 { let (a, b) = self.safety_beta(); let s = a + b; (a * b) / (s * s * (s + 1.0)) } pub fn update_safety(&mut self, correct: bool, early_wrong: bool) { if correct && !early_wrong { self.alpha_safety += 1.0; } else { self.beta_safety += 1.0; if early_wrong { self.beta_safety += 0.5; } } } pub fn update_cost(&mut self, normalized_steps: f64) { if self.attempts <= 1 { self.cost_ema = normalized_steps; } else { self.cost_ema = COST_EMA_ALPHA * normalized_steps + (1.0 - COST_EMA_ALPHA) * self.cost_ema; } } pub fn reward(&self) -> f64 { if self.attempts == 0 { return 0.5; } let acc = self.successes as f64 / self.attempts as f64; let cost = 0.3 * (1.0 - (self.total_steps as f64 / self.attempts as f64) / 200.0).max(0.0); let penalty = 0.2 * (self.early_commit_penalty_sum / self.attempts as f64).min(1.0); (acc * 0.5 + cost - penalty).max(0.0) } } // ═════════════════════════════════════════════════════════════════════ // PolicyKernel // ═════════════════════════════════════════════════════════════════════ #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct PolicyKernel { pub context_stats: BTreeMap>, pub early_commit_penalties: f64, pub early_commits_total: usize, pub early_commits_wrong: usize, pub prepass: PrepassMode, pub speculative_attempts: usize, pub speculative_arm2_wins: usize, rng_state: u64, } impl PolicyKernel { pub fn new() -> Self { Self { rng_state: 42, ..Default::default() } } /// Mode A: fixed heuristic policy. /// risk_score = R - 30*D, threshold 140. const K: usize = 30; const T: usize = 140; pub fn fixed_policy(ctx: &PolicyContext) -> SkipMode { if !ctx.has_day_of_week { return SkipMode::None; } let eff = ctx.posterior_range.saturating_sub(Self::K * ctx.distractor_count); if eff >= Self::T { SkipMode::Weekday } else { SkipMode::None } } /// Mode B: compiler-suggested policy. pub fn compiled_policy(ctx: &PolicyContext, compiled: Option) -> SkipMode { compiled.unwrap_or_else(|| Self::fixed_policy(ctx)) } /// Mode C: learned two-signal Thompson Sampling. /// When no training data exists for a context bucket, defaults to /// SkipMode::None (conservative linear scan). After training, the /// learned policy discovers better skip modes — showing real cost /// improvement between early and later cycles. pub fn learned_policy(&mut self, ctx: &PolicyContext) -> SkipMode { if !ctx.has_day_of_week { return SkipMode::None; } let bucket = Self::context_bucket(ctx); let modes = ["none", "weekday", "hybrid"]; // Conservative default: use None (linear scan) until the solver has // accumulated enough training data. This ensures a meaningful baseline // in early cycles that training can measurably improve upon. { let total_observations: usize = self.context_stats.values() .flat_map(|m| m.values()) .map(|s| s.attempts) .sum(); if total_observations < 100 { return SkipMode::None; } } let params: Vec<(SkipMode, f64, f64, f64)> = { let map = self.context_stats.entry(bucket).or_default(); modes .iter() .map(|name| { let s = map.get(*name).cloned().unwrap_or_default(); let (a, b) = s.safety_beta(); let mode = match *name { "weekday" => SkipMode::Weekday, "hybrid" => SkipMode::Hybrid, _ => SkipMode::None, }; (mode, a, b, s.cost_ema) }) .collect() }; let mut scored: Vec<(SkipMode, f64)> = params .into_iter() .map(|(mode, a, b, cost)| { let sample = self.sample_beta(a, b); (mode, sample - THOMPSON_LAMBDA * cost) }) .collect(); scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal)); scored.first().map(|(m, _)| m.clone()).unwrap_or(SkipMode::None) } /// Speculative dual-path check. pub fn should_speculate(&mut self, ctx: &PolicyContext) -> Option<(SkipMode, SkipMode)> { if !ctx.has_day_of_week || ctx.posterior_range < 61 { return None; } let bucket = Self::context_bucket(ctx); let modes = ["none", "weekday", "hybrid"]; let params: Vec<(SkipMode, f64, f64, f64, f64)> = { let map = self.context_stats.entry(bucket).or_default(); modes .iter() .map(|name| { let s = map.get(*name).cloned().unwrap_or_default(); let (a, b) = s.safety_beta(); let v = s.safety_variance(); let mode = match *name { "weekday" => SkipMode::Weekday, "hybrid" => SkipMode::Hybrid, _ => SkipMode::None, }; (mode, a, b, s.cost_ema, v) }) .collect() }; let mut scored: Vec<(SkipMode, f64, f64)> = params .into_iter() .map(|(mode, a, b, cost, var)| { let sample = self.sample_beta(a, b); (mode, sample - THOMPSON_LAMBDA * cost, var) }) .collect(); scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal)); if scored.len() >= 2 { let (ref a1, s1, v1) = scored[0]; let (ref a2, s2, _) = scored[1]; if (s1 - s2).abs() < 0.15 && v1 > 0.02 { return Some((a1.clone(), a2.clone())); } } None } pub fn record_outcome(&mut self, ctx: &PolicyContext, outcome: &SkipOutcome) { let bucket = Self::context_bucket(ctx); let mode_name = outcome.mode.name(); let map = self.context_stats.entry(bucket).or_default(); let stats = map.entry(String::from(mode_name)).or_default(); stats.attempts += 1; stats.total_steps += outcome.steps; if outcome.correct { stats.successes += 1; } stats.update_safety(outcome.correct, outcome.early_commit_wrong); stats.update_cost((outcome.steps as f64 / 200.0).min(1.0)); if outcome.early_commit_wrong { stats.early_commit_wrongs += 1; self.early_commits_wrong += 1; let penalty = if outcome.initial_candidates > 0 { outcome.remaining_at_commit as f64 / outcome.initial_candidates as f64 } else { 1.0 }; self.early_commit_penalties += penalty; stats.early_commit_penalty_sum += penalty; } self.early_commits_total += 1; } pub fn early_commit_rate(&self) -> f64 { if self.early_commits_total == 0 { 0.0 } else { self.early_commits_wrong as f64 / self.early_commits_total as f64 } } /// 3 range x 3 distractor x 2 noise = 18 buckets. pub fn context_bucket(ctx: &PolicyContext) -> String { let r = match ctx.posterior_range { 0..=60 => "small", 61..=180 => "medium", _ => "large", }; let d = match ctx.distractor_count { 0 => "clean", 1 => "some", _ => "heavy", }; let n = if ctx.noisy { "noisy" } else { "clean" }; format!("{}:{}:{}", r, d, n) } fn sample_beta(&mut self, alpha: f64, beta: f64) -> f64 { let x = self.sample_gamma(alpha); let y = self.sample_gamma(beta); if x + y == 0.0 { 0.5 } else { x / (x + y) } } fn sample_gamma(&mut self, shape: f64) -> f64 { if shape < 1.0 { let u = self.next_f64().max(1e-10); return self.sample_gamma(shape + 1.0) * pow(u, 1.0 / shape); } let d = shape - 1.0 / 3.0; let c = 1.0 / sqrt(9.0 * d); loop { let x = self.next_normal(); let t = 1.0 + c * x; let v = t * t * t; if v <= 0.0 { continue; } let u = self.next_f64().max(1e-10); if u < 1.0 - 0.0331 * x * x * x * x { return d * v; } if log(u) < 0.5 * x * x + d * (1.0 - v + log(v)) { return d * v; } } } fn next_normal(&mut self) -> f64 { let u1 = self.next_f64().max(1e-10); let u2 = self.next_f64(); sqrt(-2.0 * log(u1)) * cos(2.0 * core::f64::consts::PI * u2) } fn next_f64(&mut self) -> f64 { let mut x = self.rng_state.max(1); x ^= x << 13; x ^= x >> 7; x ^= x << 17; self.rng_state = x; x as f64 / u64::MAX as f64 } } // ═════════════════════════════════════════════════════════════════════ // KnowledgeCompiler — constraint signature → compiled config // ═════════════════════════════════════════════════════════════════════ #[derive(Clone, Debug, Serialize, Deserialize)] pub struct CompiledConfig { pub max_steps: usize, pub avg_steps: f64, pub observations: usize, pub expected_correct: bool, pub stop_after_first: bool, pub hit_count: usize, pub counterexample_count: usize, pub compiled_skip: SkipMode, } impl CompiledConfig { pub fn confidence(&self) -> f64 { let total = self.hit_count + self.counterexample_count; if total == 0 { 0.5 } else { (self.hit_count as f64 + 1.0) / (total as f64 + 2.0) } } pub fn trial_budget(&self, external_limit: usize) -> usize { let budget = if self.observations > 2 && self.avg_steps > 1.0 { (self.avg_steps * 2.0) as usize } else { self.max_steps.max(10) }; budget.max(10).min(external_limit / 4) } } #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct KnowledgeCompiler { pub cache: BTreeMap, pub hits: usize, pub misses: usize, pub false_hits: usize, pub steps_saved: i64, pub confidence_threshold: f64, } impl KnowledgeCompiler { pub fn new() -> Self { Self { confidence_threshold: 0.7, ..Default::default() } } pub fn signature(puzzle: &Puzzle) -> String { let mut parts: Vec<&str> = puzzle.constraints.iter().map(constraint_type_name).collect(); parts.sort(); format!("v1:{}:{}", puzzle.difficulty, parts.join(",")) } pub fn lookup(&mut self, puzzle: &Puzzle) -> Option<&CompiledConfig> { let sig = Self::signature(puzzle); if self.cache.contains_key(&sig) { self.hits += 1; self.cache.get(&sig) } else { self.misses += 1; None } } pub fn record_success(&mut self, puzzle: &Puzzle, actual_steps: usize) { let sig = Self::signature(puzzle); if let Some(cfg) = self.cache.get_mut(&sig) { cfg.hit_count += 1; let est = if cfg.avg_steps > 0.0 { (cfg.avg_steps * 2.0) as i64 } else { cfg.max_steps as i64 }; self.steps_saved += est - actual_steps as i64; } } pub fn record_failure(&mut self, puzzle: &Puzzle) { self.false_hits += 1; let sig = Self::signature(puzzle); if let Some(cfg) = self.cache.get_mut(&sig) { cfg.counterexample_count += 1; if cfg.counterexample_count >= 2 { cfg.expected_correct = false; } } } /// Compile knowledge from trajectories (simplified ReasoningBank integration). pub fn compile_from_trajectories(&mut self, trajectories: &[(String, u8, Vec<&str>, usize, bool)]) { for (_, difficulty, ctypes, steps, correct) in trajectories { if !correct { continue; } let mut parts = ctypes.clone(); parts.sort(); let sig = format!("v1:{}:{}", difficulty, parts.join(",")); let has_dow = parts.iter().any(|c| *c == "DayOfWeek"); let compiled_skip = if has_dow { SkipMode::Weekday } else { SkipMode::None }; let entry = self.cache.entry(sig).or_insert(CompiledConfig { max_steps: *steps, avg_steps: 0.0, observations: 0, expected_correct: true, stop_after_first: true, hit_count: 0, counterexample_count: 0, compiled_skip, }); entry.max_steps = entry.max_steps.min(*steps); let n = entry.observations as f64; entry.avg_steps = (entry.avg_steps * n + *steps as f64) / (n + 1.0); entry.observations += 1; entry.hit_count = entry.observations; } } } /// Count distractor constraints in a puzzle. pub fn count_distractors(puzzle: &Puzzle) -> usize { let mut count = 0; let (mut sb, mut sy, mut sd) = (false, false, false); for c in &puzzle.constraints { match c { Constraint::Between(_, _) => { if sb { count += 1; } sb = true; } Constraint::InYear(_) => { if sy { count += 1; } sy = true; } Constraint::DayOfWeek(_) => { if sd { count += 1; } sd = true; } _ => {} } } count }