Files
wifi-densepose/vendor/ruvector/crates/rvf/rvf-solver-wasm/src/policy.rs

524 lines
18 KiB
Rust

//! PolicyKernel — Thompson Sampling two-signal model for skip-mode selection.
//!
//! Faithful WASM port of the PolicyKernel from the benchmarks crate.
//! Implements:
//! - Two-signal Thompson Sampling (safety Beta + cost EMA)
//! - 18 context buckets (3 range x 3 distractor x 2 noise)
//! - Speculative dual-path for Mode C
//! - KnowledgeCompiler with signature cache
extern crate alloc;
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use libm::{cos, log, pow, sqrt};
use serde::{Deserialize, Serialize};
use crate::types::{Constraint, Puzzle, constraint_type_name};
// ═════════════════════════════════════════════════════════════════════
// Skip / Prepass modes
// ═════════════════════════════════════════════════════════════════════
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum SkipMode {
None,
Weekday,
Hybrid,
}
impl SkipMode {
pub fn name(&self) -> &'static str {
match self {
SkipMode::None => "none",
SkipMode::Weekday => "weekday",
SkipMode::Hybrid => "hybrid",
}
}
}
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub enum PrepassMode {
#[default]
Off,
Light,
Full,
}
// ═════════════════════════════════════════════════════════════════════
// Policy context + outcome
// ═════════════════════════════════════════════════════════════════════
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PolicyContext {
pub posterior_range: usize,
pub distractor_count: usize,
pub has_day_of_week: bool,
pub noisy: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SkipOutcome {
pub mode: SkipMode,
pub correct: bool,
pub steps: usize,
pub early_commit_wrong: bool,
pub initial_candidates: usize,
pub remaining_at_commit: usize,
}
// ═════════════════════════════════════════════════════════════════════
// Per-arm stats (two-signal Thompson Sampling)
// ═════════════════════════════════════════════════════════════════════
const THOMPSON_LAMBDA: f64 = 0.3;
const COST_EMA_ALPHA: f64 = 0.1;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct SkipModeStats {
pub attempts: usize,
pub successes: usize,
pub total_steps: usize,
pub alpha_safety: f64,
pub beta_safety: f64,
pub cost_ema: f64,
pub early_commit_wrongs: usize,
pub early_commit_penalty_sum: f64,
}
impl SkipModeStats {
pub fn safety_beta(&self) -> (f64, f64) {
(self.alpha_safety + 1.0, self.beta_safety + 1.0)
}
pub fn safety_variance(&self) -> f64 {
let (a, b) = self.safety_beta();
let s = a + b;
(a * b) / (s * s * (s + 1.0))
}
pub fn update_safety(&mut self, correct: bool, early_wrong: bool) {
if correct && !early_wrong {
self.alpha_safety += 1.0;
} else {
self.beta_safety += 1.0;
if early_wrong {
self.beta_safety += 0.5;
}
}
}
pub fn update_cost(&mut self, normalized_steps: f64) {
if self.attempts <= 1 {
self.cost_ema = normalized_steps;
} else {
self.cost_ema = COST_EMA_ALPHA * normalized_steps
+ (1.0 - COST_EMA_ALPHA) * self.cost_ema;
}
}
pub fn reward(&self) -> f64 {
if self.attempts == 0 {
return 0.5;
}
let acc = self.successes as f64 / self.attempts as f64;
let cost = 0.3 * (1.0 - (self.total_steps as f64 / self.attempts as f64) / 200.0).max(0.0);
let penalty = 0.2 * (self.early_commit_penalty_sum / self.attempts as f64).min(1.0);
(acc * 0.5 + cost - penalty).max(0.0)
}
}
// ═════════════════════════════════════════════════════════════════════
// PolicyKernel
// ═════════════════════════════════════════════════════════════════════
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct PolicyKernel {
pub context_stats: BTreeMap<String, BTreeMap<String, SkipModeStats>>,
pub early_commit_penalties: f64,
pub early_commits_total: usize,
pub early_commits_wrong: usize,
pub prepass: PrepassMode,
pub speculative_attempts: usize,
pub speculative_arm2_wins: usize,
rng_state: u64,
}
impl PolicyKernel {
pub fn new() -> Self {
Self {
rng_state: 42,
..Default::default()
}
}
/// Mode A: fixed heuristic policy.
/// risk_score = R - 30*D, threshold 140.
const K: usize = 30;
const T: usize = 140;
pub fn fixed_policy(ctx: &PolicyContext) -> SkipMode {
if !ctx.has_day_of_week {
return SkipMode::None;
}
let eff = ctx.posterior_range.saturating_sub(Self::K * ctx.distractor_count);
if eff >= Self::T {
SkipMode::Weekday
} else {
SkipMode::None
}
}
/// Mode B: compiler-suggested policy.
pub fn compiled_policy(ctx: &PolicyContext, compiled: Option<SkipMode>) -> SkipMode {
compiled.unwrap_or_else(|| Self::fixed_policy(ctx))
}
/// Mode C: learned two-signal Thompson Sampling.
/// When no training data exists for a context bucket, defaults to
/// SkipMode::None (conservative linear scan). After training, the
/// learned policy discovers better skip modes — showing real cost
/// improvement between early and later cycles.
pub fn learned_policy(&mut self, ctx: &PolicyContext) -> SkipMode {
if !ctx.has_day_of_week {
return SkipMode::None;
}
let bucket = Self::context_bucket(ctx);
let modes = ["none", "weekday", "hybrid"];
// Conservative default: use None (linear scan) until the solver has
// accumulated enough training data. This ensures a meaningful baseline
// in early cycles that training can measurably improve upon.
{
let total_observations: usize = self.context_stats.values()
.flat_map(|m| m.values())
.map(|s| s.attempts)
.sum();
if total_observations < 100 {
return SkipMode::None;
}
}
let params: Vec<(SkipMode, f64, f64, f64)> = {
let map = self.context_stats.entry(bucket).or_default();
modes
.iter()
.map(|name| {
let s = map.get(*name).cloned().unwrap_or_default();
let (a, b) = s.safety_beta();
let mode = match *name {
"weekday" => SkipMode::Weekday,
"hybrid" => SkipMode::Hybrid,
_ => SkipMode::None,
};
(mode, a, b, s.cost_ema)
})
.collect()
};
let mut scored: Vec<(SkipMode, f64)> = params
.into_iter()
.map(|(mode, a, b, cost)| {
let sample = self.sample_beta(a, b);
(mode, sample - THOMPSON_LAMBDA * cost)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
scored.first().map(|(m, _)| m.clone()).unwrap_or(SkipMode::None)
}
/// Speculative dual-path check.
pub fn should_speculate(&mut self, ctx: &PolicyContext) -> Option<(SkipMode, SkipMode)> {
if !ctx.has_day_of_week || ctx.posterior_range < 61 {
return None;
}
let bucket = Self::context_bucket(ctx);
let modes = ["none", "weekday", "hybrid"];
let params: Vec<(SkipMode, f64, f64, f64, f64)> = {
let map = self.context_stats.entry(bucket).or_default();
modes
.iter()
.map(|name| {
let s = map.get(*name).cloned().unwrap_or_default();
let (a, b) = s.safety_beta();
let v = s.safety_variance();
let mode = match *name {
"weekday" => SkipMode::Weekday,
"hybrid" => SkipMode::Hybrid,
_ => SkipMode::None,
};
(mode, a, b, s.cost_ema, v)
})
.collect()
};
let mut scored: Vec<(SkipMode, f64, f64)> = params
.into_iter()
.map(|(mode, a, b, cost, var)| {
let sample = self.sample_beta(a, b);
(mode, sample - THOMPSON_LAMBDA * cost, var)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
if scored.len() >= 2 {
let (ref a1, s1, v1) = scored[0];
let (ref a2, s2, _) = scored[1];
if (s1 - s2).abs() < 0.15 && v1 > 0.02 {
return Some((a1.clone(), a2.clone()));
}
}
None
}
pub fn record_outcome(&mut self, ctx: &PolicyContext, outcome: &SkipOutcome) {
let bucket = Self::context_bucket(ctx);
let mode_name = outcome.mode.name();
let map = self.context_stats.entry(bucket).or_default();
let stats = map.entry(String::from(mode_name)).or_default();
stats.attempts += 1;
stats.total_steps += outcome.steps;
if outcome.correct {
stats.successes += 1;
}
stats.update_safety(outcome.correct, outcome.early_commit_wrong);
stats.update_cost((outcome.steps as f64 / 200.0).min(1.0));
if outcome.early_commit_wrong {
stats.early_commit_wrongs += 1;
self.early_commits_wrong += 1;
let penalty = if outcome.initial_candidates > 0 {
outcome.remaining_at_commit as f64 / outcome.initial_candidates as f64
} else {
1.0
};
self.early_commit_penalties += penalty;
stats.early_commit_penalty_sum += penalty;
}
self.early_commits_total += 1;
}
pub fn early_commit_rate(&self) -> f64 {
if self.early_commits_total == 0 {
0.0
} else {
self.early_commits_wrong as f64 / self.early_commits_total as f64
}
}
/// 3 range x 3 distractor x 2 noise = 18 buckets.
pub fn context_bucket(ctx: &PolicyContext) -> String {
let r = match ctx.posterior_range {
0..=60 => "small",
61..=180 => "medium",
_ => "large",
};
let d = match ctx.distractor_count {
0 => "clean",
1 => "some",
_ => "heavy",
};
let n = if ctx.noisy { "noisy" } else { "clean" };
format!("{}:{}:{}", r, d, n)
}
fn sample_beta(&mut self, alpha: f64, beta: f64) -> f64 {
let x = self.sample_gamma(alpha);
let y = self.sample_gamma(beta);
if x + y == 0.0 {
0.5
} else {
x / (x + y)
}
}
fn sample_gamma(&mut self, shape: f64) -> f64 {
if shape < 1.0 {
let u = self.next_f64().max(1e-10);
return self.sample_gamma(shape + 1.0) * pow(u, 1.0 / shape);
}
let d = shape - 1.0 / 3.0;
let c = 1.0 / sqrt(9.0 * d);
loop {
let x = self.next_normal();
let t = 1.0 + c * x;
let v = t * t * t;
if v <= 0.0 {
continue;
}
let u = self.next_f64().max(1e-10);
if u < 1.0 - 0.0331 * x * x * x * x {
return d * v;
}
if log(u) < 0.5 * x * x + d * (1.0 - v + log(v)) {
return d * v;
}
}
}
fn next_normal(&mut self) -> f64 {
let u1 = self.next_f64().max(1e-10);
let u2 = self.next_f64();
sqrt(-2.0 * log(u1)) * cos(2.0 * core::f64::consts::PI * u2)
}
fn next_f64(&mut self) -> f64 {
let mut x = self.rng_state.max(1);
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.rng_state = x;
x as f64 / u64::MAX as f64
}
}
// ═════════════════════════════════════════════════════════════════════
// KnowledgeCompiler — constraint signature → compiled config
// ═════════════════════════════════════════════════════════════════════
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CompiledConfig {
pub max_steps: usize,
pub avg_steps: f64,
pub observations: usize,
pub expected_correct: bool,
pub stop_after_first: bool,
pub hit_count: usize,
pub counterexample_count: usize,
pub compiled_skip: SkipMode,
}
impl CompiledConfig {
pub fn confidence(&self) -> f64 {
let total = self.hit_count + self.counterexample_count;
if total == 0 {
0.5
} else {
(self.hit_count as f64 + 1.0) / (total as f64 + 2.0)
}
}
pub fn trial_budget(&self, external_limit: usize) -> usize {
let budget = if self.observations > 2 && self.avg_steps > 1.0 {
(self.avg_steps * 2.0) as usize
} else {
self.max_steps.max(10)
};
budget.max(10).min(external_limit / 4)
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct KnowledgeCompiler {
pub cache: BTreeMap<String, CompiledConfig>,
pub hits: usize,
pub misses: usize,
pub false_hits: usize,
pub steps_saved: i64,
pub confidence_threshold: f64,
}
impl KnowledgeCompiler {
pub fn new() -> Self {
Self {
confidence_threshold: 0.7,
..Default::default()
}
}
pub fn signature(puzzle: &Puzzle) -> String {
let mut parts: Vec<&str> = puzzle.constraints.iter().map(constraint_type_name).collect();
parts.sort();
format!("v1:{}:{}", puzzle.difficulty, parts.join(","))
}
pub fn lookup(&mut self, puzzle: &Puzzle) -> Option<&CompiledConfig> {
let sig = Self::signature(puzzle);
if self.cache.contains_key(&sig) {
self.hits += 1;
self.cache.get(&sig)
} else {
self.misses += 1;
None
}
}
pub fn record_success(&mut self, puzzle: &Puzzle, actual_steps: usize) {
let sig = Self::signature(puzzle);
if let Some(cfg) = self.cache.get_mut(&sig) {
cfg.hit_count += 1;
let est = if cfg.avg_steps > 0.0 {
(cfg.avg_steps * 2.0) as i64
} else {
cfg.max_steps as i64
};
self.steps_saved += est - actual_steps as i64;
}
}
pub fn record_failure(&mut self, puzzle: &Puzzle) {
self.false_hits += 1;
let sig = Self::signature(puzzle);
if let Some(cfg) = self.cache.get_mut(&sig) {
cfg.counterexample_count += 1;
if cfg.counterexample_count >= 2 {
cfg.expected_correct = false;
}
}
}
/// Compile knowledge from trajectories (simplified ReasoningBank integration).
pub fn compile_from_trajectories(&mut self, trajectories: &[(String, u8, Vec<&str>, usize, bool)]) {
for (_, difficulty, ctypes, steps, correct) in trajectories {
if !correct {
continue;
}
let mut parts = ctypes.clone();
parts.sort();
let sig = format!("v1:{}:{}", difficulty, parts.join(","));
let has_dow = parts.iter().any(|c| *c == "DayOfWeek");
let compiled_skip = if has_dow {
SkipMode::Weekday
} else {
SkipMode::None
};
let entry = self.cache.entry(sig).or_insert(CompiledConfig {
max_steps: *steps,
avg_steps: 0.0,
observations: 0,
expected_correct: true,
stop_after_first: true,
hit_count: 0,
counterexample_count: 0,
compiled_skip,
});
entry.max_steps = entry.max_steps.min(*steps);
let n = entry.observations as f64;
entry.avg_steps = (entry.avg_steps * n + *steps as f64) / (n + 1.0);
entry.observations += 1;
entry.hit_count = entry.observations;
}
}
}
/// Count distractor constraints in a puzzle.
pub fn count_distractors(puzzle: &Puzzle) -> usize {
let mut count = 0;
let (mut sb, mut sy, mut sd) = (false, false, false);
for c in &puzzle.constraints {
match c {
Constraint::Between(_, _) => {
if sb { count += 1; }
sb = true;
}
Constraint::InYear(_) => {
if sy { count += 1; }
sy = true;
}
Constraint::DayOfWeek(_) => {
if sd { count += 1; }
sd = true;
}
_ => {}
}
}
count
}