Files
wifi-densepose/vendor/ruvector/examples/benchmarks/src/reasoning_bank.rs

1314 lines
46 KiB
Rust

//! ReasoningBank - Adaptive Learning for Temporal Reasoning
//!
//! Implements trajectory tracking, verdict judgment, and strategy optimization
//! based on the lean-agentic design pattern.
//!
//! Key components:
//! - Trajectory tracking for solution attempts
//! - Verdict judgment (success/failure/suboptimal)
//! - Pattern learning from successful solutions
//! - Strategy optimization based on historical performance
//! - Confidence calibration from feedback
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Verdict for a solution trajectory
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum Verdict {
/// Solution was correct
Success,
/// Solution was acceptable but not optimal
Acceptable,
/// Solution was wrong but close
Suboptimal { reason: String, delta: f64 },
/// Low confidence in solution
LowConfidence,
/// Complete failure
Failed,
}
impl Verdict {
pub fn is_success(&self) -> bool {
matches!(self, Verdict::Success | Verdict::Acceptable)
}
pub fn score(&self) -> f64 {
match self {
Verdict::Success => 1.0,
Verdict::Acceptable => 0.8,
Verdict::Suboptimal { delta, .. } => 0.5 - delta.min(0.3),
Verdict::LowConfidence => 0.3,
Verdict::Failed => 0.0,
}
}
}
/// A single solution attempt
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SolutionAttempt {
/// The proposed solution
pub solution: String,
/// Confidence in this solution
pub confidence: f64,
/// Steps taken to reach this solution
pub steps: usize,
/// Tool calls made
pub tool_calls: usize,
/// Strategy used
pub strategy: String,
/// Skip mode used (witness for policy audit: "none", "weekday", "hybrid")
pub skip_mode: String,
/// Context bucket key (witness for policy audit: "range:distractor")
pub context_bucket: String,
}
/// Trajectory tracking for a single puzzle
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Trajectory {
/// Puzzle identifier
pub puzzle_id: String,
/// Puzzle difficulty
pub difficulty: u8,
/// Constraint types encountered
pub constraint_types: Vec<String>,
/// Solution attempts made
pub attempts: Vec<SolutionAttempt>,
/// Final verdict
pub verdict: Option<Verdict>,
/// Correct solution (if known)
pub correct_solution: Option<String>,
/// Time taken in ms
pub latency_ms: u64,
}
impl Trajectory {
pub fn new(puzzle_id: &str, difficulty: u8) -> Self {
Self {
puzzle_id: puzzle_id.to_string(),
difficulty,
constraint_types: Vec::new(),
attempts: Vec::new(),
verdict: None,
correct_solution: None,
latency_ms: 0,
}
}
pub fn record_attempt(
&mut self,
solution: String,
confidence: f64,
steps: usize,
tool_calls: usize,
strategy: &str,
) {
self.attempts.push(SolutionAttempt {
solution,
confidence,
steps,
tool_calls,
strategy: strategy.to_string(),
skip_mode: String::new(),
context_bucket: String::new(),
});
}
/// Record attempt with full policy witness (skip_mode + context_bucket).
pub fn record_attempt_witnessed(
&mut self,
solution: String,
confidence: f64,
steps: usize,
tool_calls: usize,
strategy: &str,
skip_mode: &str,
context_bucket: &str,
) {
self.attempts.push(SolutionAttempt {
solution,
confidence,
steps,
tool_calls,
strategy: strategy.to_string(),
skip_mode: skip_mode.to_string(),
context_bucket: context_bucket.to_string(),
});
}
pub fn set_verdict(&mut self, verdict: Verdict, correct: Option<String>) {
self.verdict = Some(verdict);
self.correct_solution = correct;
}
}
/// Memory class for a learned pattern.
///
/// Patterns flow: Volatile → Trusted (requires counterexample) → Quarantined (on regression).
/// This three-class system prevents memory poisoning.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum MemoryClass {
/// New pattern, low confidence, no external authority.
/// Cannot influence strategy selection unless no Trusted alternative exists.
Volatile,
/// Survived promotion: has enough observations AND at least one counterexample.
/// Used as primary strategy source.
Trusted,
/// Implicated in regressions or contradictions.
/// Cannot be selected by default; only used in explicit recovery mode.
Quarantined,
}
/// Learned pattern from successful solutions
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LearnedPattern {
/// Constraint type this pattern applies to
pub constraint_type: String,
/// Difficulty range
pub difficulty_range: (u8, u8),
/// Best strategy for this pattern
pub best_strategy: String,
/// Success rate with this pattern
pub success_rate: f64,
/// Average steps needed
pub avg_steps: f64,
/// Number of observations
pub observations: usize,
/// Memory classification (Volatile → Trusted → Quarantined)
pub memory_class: MemoryClass,
/// Linked counterexample IDs (required for promotion to Trusted)
pub counterexample_ids: Vec<String>,
/// Number of holdout failures since promotion
pub holdout_failures: usize,
}
/// Strategy configuration
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Strategy {
/// Strategy name
pub name: String,
/// Whether to use calendar rewriting
pub use_rewriting: bool,
/// Maximum search steps
pub max_steps: usize,
/// Beam width for search
pub beam_width: usize,
/// Confidence threshold
pub confidence_threshold: f64,
}
impl Default for Strategy {
fn default() -> Self {
Self {
name: "default".to_string(),
use_rewriting: true,
max_steps: 50,
beam_width: 3,
confidence_threshold: 0.7,
}
}
}
/// Confidence calibration data
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct CalibrationData {
/// Reported confidence -> actual accuracy mapping
pub calibration_points: Vec<(f64, bool)>,
/// Calibrated thresholds by difficulty
pub thresholds: HashMap<u8, f64>,
}
impl CalibrationData {
/// Record a calibration point
pub fn record(&mut self, confidence: f64, correct: bool) {
self.calibration_points.push((confidence, correct));
}
/// Get calibrated confidence threshold for a difficulty
pub fn get_threshold(&self, difficulty: u8) -> f64 {
self.thresholds.get(&difficulty).copied().unwrap_or(0.7)
}
/// Recalibrate thresholds based on observed data
pub fn recalibrate(&mut self) {
if self.calibration_points.len() < 20 {
return;
}
// Group by confidence buckets
let mut buckets: HashMap<u8, (usize, usize)> = HashMap::new();
for &(conf, correct) in &self.calibration_points {
let bucket = (conf * 10.0) as u8;
let entry = buckets.entry(bucket).or_insert((0, 0));
entry.0 += 1;
if correct {
entry.1 += 1;
}
}
// Find threshold where precision >= 0.9
for bucket in (5..=10).rev() {
if let Some(&(total, correct)) = buckets.get(&bucket) {
if total >= 5 {
let precision = correct as f64 / total as f64;
if precision >= 0.9 {
// Use this bucket's lower bound as default threshold
let threshold = bucket as f64 / 10.0;
for diff in 1..=10 {
self.thresholds.insert(diff, threshold);
}
break;
}
}
}
}
}
}
/// Serializable memory checkpoint for rollback.
/// Captures the bank state at a point in time so bad learning can be undone.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MemoryCheckpoint {
/// Unique checkpoint ID
pub id: usize,
/// Number of trajectories when checkpoint was taken
pub trajectories_len: usize,
/// Snapshot of learned patterns
pub patterns: HashMap<String, Vec<LearnedPattern>>,
/// Snapshot of strategy stats
pub strategy_stats: HashMap<String, StrategyStats>,
/// Snapshot of calibration data
pub calibration: CalibrationData,
/// Snapshot of best strategies
pub best_strategies: HashMap<u8, String>,
}
/// A trajectory held in quarantine pending review.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct QuarantinedEntry {
/// The suspicious trajectory
pub trajectory: Trajectory,
/// Why it was quarantined
pub reason: String,
/// Counterexample IDs that contradict it
pub counterexample_ids: Vec<String>,
}
/// Counterexample: evidence of a pattern's failure mode.
/// Promotion to Trusted requires at least one of these linked to the pattern.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Counterexample {
/// Unique ID
pub id: String,
/// Constraint category signature (e.g. "d5_c3")
pub input_signature: String,
/// Type of failure
pub failure_type: String,
/// Seed that reproduces this failure
pub reproduction_seed: Option<u64>,
/// Expected outcome
pub expected: String,
/// Actual (wrong) outcome
pub observed: String,
/// Linked pattern ID
pub pattern_id: Option<String>,
/// Grader verdict
pub verdict: Option<Verdict>,
}
/// Witness of a rollback event — proves rollback happened and what was discarded.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RollbackWitness {
/// Checkpoint restored to
pub checkpoint_id: usize,
/// Number of trajectories discarded
pub trajectories_discarded: usize,
/// Patterns that changed
pub patterns_restored: usize,
/// Reason for rollback
pub reason: String,
/// Accuracy before rollback
pub pre_accuracy: f64,
/// Accuracy after rollback
pub post_accuracy: f64,
}
/// ReasoningBank - Central learning and adaptation system
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ReasoningBank {
/// All recorded trajectories
pub trajectories: Vec<Trajectory>,
/// Learned patterns by constraint type
pub patterns: HashMap<String, Vec<LearnedPattern>>,
/// Strategy performance by name
pub strategy_stats: HashMap<String, StrategyStats>,
/// Confidence calibration data
pub calibration: CalibrationData,
/// Current best strategy by difficulty
pub best_strategies: HashMap<u8, String>,
/// Checkpoint stack for rollback (newest last)
pub checkpoints: Vec<MemoryCheckpoint>,
/// Quarantine pool: trajectories under review
pub quarantine: Vec<QuarantinedEntry>,
/// Counterexamples per constraint type (legacy: raw trajectories)
pub counterexamples: HashMap<String, Vec<Trajectory>>,
/// Structured counterexamples with full evidence chain
pub structured_counterexamples: Vec<Counterexample>,
/// Rollback witness chain — proof that rollbacks occurred
pub rollback_witnesses: Vec<RollbackWitness>,
/// Minimum observations to promote a pattern (evidence threshold)
pub evidence_threshold: usize,
/// Next checkpoint ID
checkpoint_counter: usize,
/// Next counterexample ID
counterexample_counter: usize,
/// Pattern index for O(1) lookups: (constraint_type, difficulty) -> pattern_idx
#[serde(skip)]
pattern_index: HashMap<(String, u8), usize>,
/// Constraint type frequency for prioritization
#[serde(skip)]
constraint_frequency: HashMap<String, usize>,
}
impl Default for ReasoningBank {
fn default() -> Self {
Self {
trajectories: Vec::new(),
patterns: HashMap::new(),
strategy_stats: HashMap::new(),
calibration: CalibrationData::default(),
best_strategies: HashMap::new(),
checkpoints: Vec::new(),
quarantine: Vec::new(),
counterexamples: HashMap::new(),
structured_counterexamples: Vec::new(),
rollback_witnesses: Vec::new(),
evidence_threshold: 5,
checkpoint_counter: 0,
counterexample_counter: 0,
pattern_index: HashMap::new(),
constraint_frequency: HashMap::new(),
}
}
}
/// Statistics for a strategy
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct StrategyStats {
pub attempts: usize,
pub successes: usize,
pub total_steps: usize,
pub total_latency_ms: u64,
}
impl StrategyStats {
pub fn success_rate(&self) -> f64 {
if self.attempts == 0 {
return 0.5; // Prior
}
self.successes as f64 / self.attempts as f64
}
pub fn avg_steps(&self) -> f64 {
if self.attempts == 0 {
return 50.0;
}
self.total_steps as f64 / self.attempts as f64
}
}
impl ReasoningBank {
/// Create a new ReasoningBank
pub fn new() -> Self {
Self::default()
}
/// Record a completed trajectory
pub fn record_trajectory(&mut self, trajectory: Trajectory) {
// Update strategy stats
if let Some(attempt) = trajectory.attempts.first() {
let stats = self
.strategy_stats
.entry(attempt.strategy.clone())
.or_default();
stats.attempts += 1;
stats.total_steps += attempt.steps;
stats.total_latency_ms += trajectory.latency_ms;
if trajectory
.verdict
.as_ref()
.map(|v| v.is_success())
.unwrap_or(false)
{
stats.successes += 1;
}
}
// Update calibration
if let Some(attempt) = trajectory.attempts.first() {
let correct = trajectory
.verdict
.as_ref()
.map(|v| v.is_success())
.unwrap_or(false);
self.calibration.record(attempt.confidence, correct);
}
// Learn patterns from successful trajectories
if trajectory
.verdict
.as_ref()
.map(|v| v.is_success())
.unwrap_or(false)
{
self.learn_from_success(&trajectory);
}
// Update best strategies
self.update_best_strategies();
// Store trajectory
self.trajectories.push(trajectory);
// Recalibrate periodically
if self.trajectories.len() % 50 == 0 {
self.calibration.recalibrate();
}
}
/// Learn patterns from a successful trajectory
fn learn_from_success(&mut self, trajectory: &Trajectory) {
let attempt = match trajectory.attempts.first() {
Some(a) => a,
None => return,
};
for constraint_type in &trajectory.constraint_types {
// Update constraint frequency
*self
.constraint_frequency
.entry(constraint_type.clone())
.or_insert(0) += 1;
let patterns = self.patterns.entry(constraint_type.clone()).or_default();
// Find or create pattern
let pattern_idx = patterns.iter().position(|p| {
p.best_strategy == attempt.strategy
&& trajectory.difficulty >= p.difficulty_range.0
&& trajectory.difficulty <= p.difficulty_range.1
});
if let Some(idx) = pattern_idx {
// Update existing pattern
let p = &mut patterns[idx];
let n = p.observations as f64;
p.success_rate = (p.success_rate * n + 1.0) / (n + 1.0);
p.avg_steps = (p.avg_steps * n + attempt.steps as f64) / (n + 1.0);
p.observations += 1;
// Update pattern index for fast lookup
self.pattern_index
.insert((constraint_type.clone(), trajectory.difficulty), idx);
} else {
// Create new pattern
let new_idx = patterns.len();
patterns.push(LearnedPattern {
constraint_type: constraint_type.clone(),
difficulty_range: (
trajectory.difficulty.saturating_sub(2),
trajectory.difficulty.saturating_add(2),
),
best_strategy: attempt.strategy.clone(),
success_rate: 1.0,
avg_steps: attempt.steps as f64,
observations: 1,
memory_class: MemoryClass::Volatile,
counterexample_ids: Vec::new(),
holdout_failures: 0,
});
// Index the new pattern
for d in trajectory.difficulty.saturating_sub(2)
..=trajectory.difficulty.saturating_add(2)
{
self.pattern_index
.insert((constraint_type.clone(), d), new_idx);
}
}
}
}
/// Record multiple trajectories in batch (for parallel processing)
pub fn record_trajectories_batch(&mut self, trajectories: Vec<Trajectory>) {
for trajectory in trajectories {
self.record_trajectory(trajectory);
}
}
/// Update best strategies by difficulty.
/// Excludes strategies whose primary patterns are quarantined.
fn update_best_strategies(&mut self) {
// Collect quarantined strategy names
let quarantined_strategies: std::collections::HashSet<String> = self
.patterns
.values()
.flat_map(|ps| ps.iter())
.filter(|p| p.memory_class == MemoryClass::Quarantined)
.map(|p| p.best_strategy.clone())
.collect();
for difficulty in 1..=10 {
let mut best_strategy = "default".to_string();
let mut best_score = 0.0;
for (strategy, stats) in &self.strategy_stats {
// Skip quarantined strategies
if quarantined_strategies.contains(strategy) {
continue;
}
// Score = success_rate - penalty for steps
let score = stats.success_rate() - (stats.avg_steps() / 100.0);
if score > best_score {
best_score = score;
best_strategy = strategy.clone();
}
}
self.best_strategies.insert(difficulty, best_strategy);
}
}
/// Get recommended strategy for a puzzle (optimized with index).
/// Prefers Trusted patterns. Falls back to Volatile. Never uses Quarantined.
pub fn get_strategy(&self, difficulty: u8, constraint_types: &[String]) -> Strategy {
// Phase 1: Look for Trusted patterns (fast path via index)
for ct in constraint_types {
if let Some(&idx) = self.pattern_index.get(&(ct.clone(), difficulty)) {
if let Some(patterns) = self.patterns.get(ct) {
if let Some(pattern) = patterns.get(idx) {
if pattern.memory_class == MemoryClass::Trusted
&& pattern.success_rate > 0.7
&& pattern.observations >= 3
{
return self.strategy_from_name(&pattern.best_strategy, difficulty);
}
}
}
}
}
// Phase 2: Linear search for Trusted patterns
for ct in constraint_types {
if let Some(patterns) = self.patterns.get(ct) {
let best = patterns
.iter()
.filter(|p| {
p.memory_class == MemoryClass::Trusted
&& difficulty >= p.difficulty_range.0
&& difficulty <= p.difficulty_range.1
})
.max_by(|a, b| a.success_rate.partial_cmp(&b.success_rate).unwrap());
if let Some(pattern) = best {
if pattern.success_rate > 0.7 && pattern.observations >= 3 {
return self.strategy_from_name(&pattern.best_strategy, difficulty);
}
}
}
}
// Phase 3: Fall back to Volatile patterns (not yet promoted)
for ct in constraint_types {
if let Some(patterns) = self.patterns.get(ct) {
let best = patterns
.iter()
.filter(|p| {
p.memory_class != MemoryClass::Quarantined
&& difficulty >= p.difficulty_range.0
&& difficulty <= p.difficulty_range.1
&& p.success_rate > 0.7
&& p.observations >= 3
})
.max_by(|a, b| a.success_rate.partial_cmp(&b.success_rate).unwrap());
if let Some(pattern) = best {
return self.strategy_from_name(&pattern.best_strategy, difficulty);
}
}
}
// Phase 4: Fall back to global best strategy for difficulty
let strategy_name = self
.best_strategies
.get(&difficulty)
.cloned()
.unwrap_or_else(|| "default".to_string());
self.strategy_from_name(&strategy_name, difficulty)
}
pub fn strategy_from_name(&self, name: &str, difficulty: u8) -> Strategy {
match name {
"aggressive" => Strategy {
name: "aggressive".to_string(),
use_rewriting: true,
max_steps: 30,
beam_width: 5,
confidence_threshold: 0.6,
},
"conservative" => Strategy {
name: "conservative".to_string(),
use_rewriting: true,
max_steps: 100,
beam_width: 2,
confidence_threshold: 0.85,
},
"adaptive" => Strategy {
name: "adaptive".to_string(),
use_rewriting: true,
max_steps: 50 + (difficulty as usize * 5),
beam_width: 3,
confidence_threshold: self.calibration.get_threshold(difficulty),
},
_ => Strategy::default(),
}
}
/// Get hints for a puzzle based on learned patterns
pub fn get_hints(&self, constraint_types: &[String]) -> Vec<String> {
let mut hints = Vec::new();
for ct in constraint_types {
if let Some(patterns) = self.patterns.get(ct) {
for pattern in patterns.iter().filter(|p| p.observations >= 5) {
hints.push(format!(
"For {} constraints, {} strategy has {:.0}% success",
ct,
pattern.best_strategy,
pattern.success_rate * 100.0
));
}
}
}
hints
}
/// Calculate learning progress metrics
pub fn learning_progress(&self) -> LearningProgress {
let total = self.trajectories.len();
if total == 0 {
return LearningProgress::default();
}
let successes = self
.trajectories
.iter()
.filter(|t| t.verdict.as_ref().map(|v| v.is_success()).unwrap_or(false))
.count();
// Calculate improvement over time (compare first half vs second half)
let half = total / 2;
let first_half_success = self.trajectories[..half]
.iter()
.filter(|t| t.verdict.as_ref().map(|v| v.is_success()).unwrap_or(false))
.count() as f64
/ half as f64;
let second_half_success = self.trajectories[half..]
.iter()
.filter(|t| t.verdict.as_ref().map(|v| v.is_success()).unwrap_or(false))
.count() as f64
/ (total - half) as f64;
let improvement = second_half_success - first_half_success;
// Calculate pattern coverage
let unique_patterns: usize = self.patterns.values().map(|ps| ps.len()).sum();
LearningProgress {
total_trajectories: total,
success_rate: successes as f64 / total as f64,
improvement_rate: improvement,
patterns_learned: unique_patterns,
strategies_tried: self.strategy_stats.len(),
is_improving: improvement > 0.0,
}
}
// ═══════════════════════════════════════════════════════════════════
// Memory Checkpoint & Rollback (Caveat 2: memory poisoning defense)
// ═══════════════════════════════════════════════════════════════════
/// Create a checkpoint of the current bank state. Returns the checkpoint ID.
/// Use `rollback_to(id)` to restore this state if subsequent learning is bad.
pub fn checkpoint(&mut self) -> usize {
let id = self.checkpoint_counter;
self.checkpoint_counter += 1;
self.checkpoints.push(MemoryCheckpoint {
id,
trajectories_len: self.trajectories.len(),
patterns: self.patterns.clone(),
strategy_stats: self.strategy_stats.clone(),
calibration: self.calibration.clone(),
best_strategies: self.best_strategies.clone(),
});
id
}
/// Roll back to a previous checkpoint. Returns true if the checkpoint was found.
/// Discards all learning after the checkpoint: trajectories are truncated,
/// patterns/strategies/calibration are restored from the snapshot.
pub fn rollback_to(&mut self, checkpoint_id: usize) -> bool {
let pos = self.checkpoints.iter().position(|c| c.id == checkpoint_id);
if let Some(idx) = pos {
let cp = self.checkpoints[idx].clone();
// Truncate trajectories to checkpoint length
self.trajectories.truncate(cp.trajectories_len);
// Restore learned state
self.patterns = cp.patterns;
self.strategy_stats = cp.strategy_stats;
self.calibration = cp.calibration;
self.best_strategies = cp.best_strategies;
// Discard checkpoints after this one
self.checkpoints.truncate(idx + 1);
// Rebuild index
self.rebuild_pattern_index();
true
} else {
false
}
}
/// Number of checkpoints currently stored.
pub fn checkpoint_count(&self) -> usize {
self.checkpoints.len()
}
/// Rebuild the pattern index from current patterns (called after rollback).
fn rebuild_pattern_index(&mut self) {
self.pattern_index.clear();
for (ct, patterns) in &self.patterns {
for (idx, pattern) in patterns.iter().enumerate() {
for d in pattern.difficulty_range.0..=pattern.difficulty_range.1 {
self.pattern_index.insert((ct.clone(), d), idx);
}
}
}
}
// ═══════════════════════════════════════════════════════════════════
// Quarantine & Counterexamples (evidence binding)
// ═══════════════════════════════════════════════════════════════════
/// Send a trajectory to quarantine instead of learning from it.
/// Quarantined trajectories do NOT update patterns or strategies.
pub fn quarantine_trajectory(&mut self, trajectory: Trajectory, reason: &str) {
self.quarantine.push(QuarantinedEntry {
trajectory,
reason: reason.to_string(),
counterexample_ids: Vec::new(),
});
}
/// Record a counterexample: a trajectory that contradicts an existing pattern.
/// If counterexamples for a constraint type exceed the threshold, the pattern
/// is demoted (success_rate reduced, observations reset).
pub fn record_counterexample(&mut self, constraint_type: &str, trajectory: Trajectory) {
let examples = self
.counterexamples
.entry(constraint_type.to_string())
.or_default();
examples.push(trajectory);
// Check if counterexamples exceed threshold → demote pattern
let threshold = self.evidence_threshold;
if examples.len() >= threshold {
if let Some(patterns) = self.patterns.get_mut(constraint_type) {
for pattern in patterns.iter_mut() {
if pattern.observations < threshold {
// Weak pattern contradicted by strong evidence → demote
pattern.success_rate *= 0.5;
pattern.observations = 0;
}
}
}
}
}
/// Check if a pattern has enough evidence for promotion to Trusted.
/// Requires: >= evidence_threshold observations, at least 1 counterexample linked,
/// more observations than counterexamples, success_rate > 0.7.
pub fn is_pattern_promoted(&self, constraint_type: &str, difficulty: u8) -> bool {
let counter_count = self
.counterexamples
.get(constraint_type)
.map(|v| v.len())
.unwrap_or(0);
if let Some(patterns) = self.patterns.get(constraint_type) {
for pattern in patterns {
if difficulty >= pattern.difficulty_range.0
&& difficulty <= pattern.difficulty_range.1
&& pattern.observations >= self.evidence_threshold
&& pattern.observations > counter_count
&& pattern.success_rate > 0.7
{
return true;
}
}
}
false
}
/// Promote eligible Volatile patterns to Trusted.
/// A pattern can only be promoted if it has at least one counterexample entry.
/// This is the "counterexample-first promotion" rule.
pub fn promote_patterns(&mut self) -> usize {
let mut promoted = 0;
let threshold = self.evidence_threshold;
for (ct, patterns) in self.patterns.iter_mut() {
let counter_count = self.counterexamples.get(ct).map(|v| v.len()).unwrap_or(0);
// Counterexample-first: must have at least 1 counterexample
let has_counterexample = counter_count > 0;
for pattern in patterns.iter_mut() {
if pattern.memory_class == MemoryClass::Volatile
&& pattern.observations >= threshold
&& pattern.success_rate > 0.7
&& has_counterexample
&& pattern.observations > counter_count
{
pattern.memory_class = MemoryClass::Trusted;
promoted += 1;
}
}
}
promoted
}
/// Demote a Trusted pattern to Quarantined.
/// Called when a Trusted pattern causes a holdout failure.
/// Automatically refreshes best_strategies to exclude the quarantined pattern.
pub fn demote_to_quarantined(&mut self, constraint_type: &str, difficulty: u8) -> bool {
let mut found = false;
if let Some(patterns) = self.patterns.get_mut(constraint_type) {
for pattern in patterns.iter_mut() {
if pattern.memory_class == MemoryClass::Trusted
&& difficulty >= pattern.difficulty_range.0
&& difficulty <= pattern.difficulty_range.1
{
pattern.memory_class = MemoryClass::Quarantined;
pattern.holdout_failures += 1;
found = true;
}
}
}
if found {
self.update_best_strategies();
}
found
}
/// Record a structured counterexample with full evidence chain.
pub fn record_structured_counterexample(&mut self, ce: Counterexample) -> String {
let id = format!("ce_{}", self.counterexample_counter);
self.counterexample_counter += 1;
let mut ce = ce;
ce.id = id.clone();
// Link to pattern if pattern_id provided
if let Some(ref pid) = ce.pattern_id {
// Find and link
for patterns in self.patterns.values_mut() {
for pattern in patterns.iter_mut() {
if pattern.constraint_type == *pid
|| format!("{}_{}", pattern.constraint_type, pattern.best_strategy) == *pid
{
pattern.counterexample_ids.push(id.clone());
}
}
}
}
self.structured_counterexamples.push(ce);
id
}
/// Record a rollback with a full witness.
pub fn rollback_with_witness(
&mut self,
checkpoint_id: usize,
reason: &str,
pre_accuracy: f64,
post_accuracy: f64,
) -> bool {
let pre_traj_len = self.trajectories.len();
let pre_pattern_count: usize = self.patterns.values().map(|v| v.len()).sum();
let ok = self.rollback_to(checkpoint_id);
if ok {
let post_traj_len = self.trajectories.len();
let post_pattern_count: usize = self.patterns.values().map(|v| v.len()).sum();
self.rollback_witnesses.push(RollbackWitness {
checkpoint_id,
trajectories_discarded: pre_traj_len.saturating_sub(post_traj_len),
patterns_restored: pre_pattern_count.saturating_sub(post_pattern_count),
reason: reason.to_string(),
pre_accuracy,
post_accuracy,
});
}
ok
}
/// Count of Volatile patterns.
pub fn volatile_count(&self) -> usize {
self.patterns
.values()
.flat_map(|ps| ps.iter())
.filter(|p| p.memory_class == MemoryClass::Volatile)
.count()
}
/// Count of Trusted patterns.
pub fn trusted_count(&self) -> usize {
self.patterns
.values()
.flat_map(|ps| ps.iter())
.filter(|p| p.memory_class == MemoryClass::Trusted)
.count()
}
/// Count of Quarantined patterns.
pub fn quarantined_pattern_count(&self) -> usize {
self.patterns
.values()
.flat_map(|ps| ps.iter())
.filter(|p| p.memory_class == MemoryClass::Quarantined)
.count()
}
/// Record a trajectory with quarantine gating: if the trajectory is
/// solved-but-wrong (contradiction), quarantine it instead of learning.
/// Otherwise, record normally.
pub fn record_trajectory_gated(&mut self, trajectory: Trajectory) {
let is_contradiction = trajectory
.verdict
.as_ref()
.map(|v| !v.is_success())
.unwrap_or(true)
&& trajectory
.attempts
.iter()
.any(|a| !a.solution.is_empty() && a.solution != "none");
if is_contradiction {
// Quarantine: record as counterexample but don't learn from it
for ct in &trajectory.constraint_types {
self.record_counterexample(ct, trajectory.clone());
}
self.quarantine_trajectory(trajectory, "contradiction: solved but incorrect");
} else {
self.record_trajectory(trajectory);
}
}
/// Quarantine pool size.
pub fn quarantine_len(&self) -> usize {
self.quarantine.len()
}
/// Total counterexamples across all constraint types.
pub fn counterexample_count(&self) -> usize {
self.counterexamples.values().map(|v| v.len()).sum()
}
}
/// Learning progress summary
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct LearningProgress {
pub total_trajectories: usize,
pub success_rate: f64,
pub improvement_rate: f64,
pub patterns_learned: usize,
pub strategies_tried: usize,
pub is_improving: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reasoning_bank_learning() {
let mut bank = ReasoningBank::new();
// Record some successful trajectories
for i in 0..10 {
let mut traj = Trajectory::new(&format!("puzzle_{}", i), 5);
traj.constraint_types.push("RelativeDate".to_string());
traj.record_attempt("2024-01-15".to_string(), 0.8, 20, 5, "adaptive");
traj.set_verdict(Verdict::Success, Some("2024-01-15".to_string()));
traj.latency_ms = 100;
bank.record_trajectory(traj);
}
// Should have learned patterns
assert!(bank.patterns.contains_key("RelativeDate"));
assert!(bank.strategy_stats.contains_key("adaptive"));
let stats = &bank.strategy_stats["adaptive"];
assert_eq!(stats.successes, 10);
assert!(stats.success_rate() > 0.9);
}
#[test]
fn test_strategy_selection() {
let mut bank = ReasoningBank::new();
// Train on aggressive strategy for easy puzzles
for i in 0..20 {
let mut traj = Trajectory::new(&format!("easy_{}", i), 3);
traj.constraint_types.push("Before".to_string());
traj.record_attempt("2024-01-10".to_string(), 0.9, 10, 2, "aggressive");
traj.set_verdict(Verdict::Success, Some("2024-01-10".to_string()));
bank.record_trajectory(traj);
}
// Should recommend aggressive for easy puzzles
let strategy = bank.get_strategy(3, &["Before".to_string()]);
assert_eq!(strategy.name, "aggressive");
}
#[test]
fn test_checkpoint_and_rollback() {
let mut bank = ReasoningBank::new();
// Learn 5 trajectories
for i in 0..5 {
let mut traj = Trajectory::new(&format!("good_{}", i), 5);
traj.constraint_types.push("Before".to_string());
traj.record_attempt("2024-01-10".into(), 0.9, 10, 1, "default");
traj.set_verdict(Verdict::Success, None);
bank.record_trajectory(traj);
}
assert_eq!(bank.trajectories.len(), 5);
// Checkpoint
let cp_id = bank.checkpoint();
assert_eq!(bank.checkpoint_count(), 1);
// Learn 5 more (potentially bad)
for i in 0..5 {
let mut traj = Trajectory::new(&format!("bad_{}", i), 5);
traj.constraint_types.push("Before".to_string());
traj.record_attempt("wrong".into(), 0.1, 50, 1, "default");
traj.set_verdict(Verdict::Failed, None);
bank.record_trajectory(traj);
}
assert_eq!(bank.trajectories.len(), 10);
// Rollback to checkpoint
assert!(bank.rollback_to(cp_id));
assert_eq!(bank.trajectories.len(), 5);
// Bad learning should be gone
assert!(bank
.trajectories
.iter()
.all(|t| t.puzzle_id.starts_with("good_")));
}
#[test]
fn test_quarantine_gating() {
let mut bank = ReasoningBank::new();
// Record a contradiction: has a solution but verdict is Failed
let mut traj = Trajectory::new("contra_1", 5);
traj.constraint_types.push("InMonth".to_string());
traj.record_attempt("2024-06-15".into(), 0.5, 20, 1, "default");
traj.set_verdict(Verdict::Failed, None);
bank.record_trajectory_gated(traj);
// Should be quarantined, not in main trajectories
assert_eq!(bank.quarantine_len(), 1);
assert_eq!(bank.trajectories.len(), 0);
assert_eq!(bank.counterexample_count(), 1);
}
#[test]
fn test_evidence_binding() {
let mut bank = ReasoningBank::new();
bank.evidence_threshold = 5;
// Record 3 successes (below threshold)
for i in 0..3 {
let mut traj = Trajectory::new(&format!("ev_{}", i), 5);
traj.constraint_types.push("DayOfWeek".to_string());
traj.record_attempt("2024-01-15".into(), 0.9, 10, 1, "default");
traj.set_verdict(Verdict::Success, None);
bank.record_trajectory(traj);
}
// Not yet promoted (only 3 observations, need 5)
assert!(!bank.is_pattern_promoted("DayOfWeek", 5));
// Record 3 more (now 6 total, above threshold)
for i in 3..6 {
let mut traj = Trajectory::new(&format!("ev_{}", i), 5);
traj.constraint_types.push("DayOfWeek".to_string());
traj.record_attempt("2024-01-15".into(), 0.9, 10, 1, "default");
traj.set_verdict(Verdict::Success, None);
bank.record_trajectory(traj);
}
// Now promoted
assert!(bank.is_pattern_promoted("DayOfWeek", 5));
}
#[test]
fn test_counterexample_demotion() {
let mut bank = ReasoningBank::new();
bank.evidence_threshold = 3;
// Build a weak pattern (2 observations, below threshold)
for i in 0..2 {
let mut traj = Trajectory::new(&format!("weak_{}", i), 5);
traj.constraint_types.push("Between".to_string());
traj.record_attempt("2024-01-10".into(), 0.8, 15, 1, "default");
traj.set_verdict(Verdict::Success, None);
bank.record_trajectory(traj);
}
let orig_rate = bank
.patterns
.get("Between")
.and_then(|ps| ps.first())
.map(|p| p.success_rate)
.unwrap_or(0.0);
assert!(orig_rate > 0.9);
// Record 3 counterexamples (meets threshold)
for i in 0..3 {
let mut traj = Trajectory::new(&format!("counter_{}", i), 5);
traj.constraint_types.push("Between".to_string());
traj.record_attempt("wrong".into(), 0.2, 30, 1, "default");
traj.set_verdict(Verdict::Failed, None);
bank.record_counterexample("Between", traj);
}
// Pattern should be demoted
let new_rate = bank
.patterns
.get("Between")
.and_then(|ps| ps.first())
.map(|p| p.success_rate)
.unwrap_or(1.0);
assert!(new_rate < orig_rate);
}
#[test]
fn test_three_class_promotion() {
let mut bank = ReasoningBank::new();
bank.evidence_threshold = 3;
// Build a pattern with 5 observations
for i in 0..5 {
let mut traj = Trajectory::new(&format!("promo_{}", i), 5);
traj.constraint_types.push("Month".to_string());
traj.record_attempt("2024-06-15".into(), 0.9, 10, 1, "adaptive");
traj.set_verdict(Verdict::Success, None);
bank.record_trajectory(traj);
}
// Pattern starts as Volatile
assert_eq!(bank.volatile_count(), 1);
assert_eq!(bank.trusted_count(), 0);
// Cannot promote without counterexample
let promoted = bank.promote_patterns();
assert_eq!(promoted, 0);
assert_eq!(bank.trusted_count(), 0);
// Record a counterexample
let ce_traj = Trajectory::new("fail_1", 5);
bank.record_counterexample("Month", ce_traj);
// Now promotion succeeds (has counterexample + enough observations)
let promoted = bank.promote_patterns();
assert_eq!(promoted, 1);
assert_eq!(bank.trusted_count(), 1);
assert_eq!(bank.volatile_count(), 0);
// Demote to Quarantined (holdout failure)
let demoted = bank.demote_to_quarantined("Month", 5);
assert!(demoted);
assert_eq!(bank.quarantined_pattern_count(), 1);
assert_eq!(bank.trusted_count(), 0);
// Strategy selection should skip Quarantined patterns
let strategy = bank.get_strategy(5, &["Month".to_string()]);
assert_ne!(strategy.name, "adaptive"); // Falls back to default
}
#[test]
fn test_rollback_witness() {
let mut bank = ReasoningBank::new();
// Record trajectories
for i in 0..3 {
let mut traj = Trajectory::new(&format!("w_{}", i), 5);
traj.constraint_types.push("Year".to_string());
traj.record_attempt("2024-01-01".into(), 0.9, 10, 1, "default");
traj.set_verdict(Verdict::Success, None);
bank.record_trajectory(traj);
}
let cp = bank.checkpoint();
// Record bad trajectories
for i in 3..6 {
let mut traj = Trajectory::new(&format!("w_{}", i), 5);
traj.constraint_types.push("Year".to_string());
traj.record_attempt("wrong".into(), 0.1, 50, 1, "default");
traj.set_verdict(Verdict::Failed, None);
bank.record_trajectory(traj);
}
assert_eq!(bank.trajectories.len(), 6);
// Rollback with witness
let ok = bank.rollback_with_witness(cp, "accuracy regression", 0.90, 0.50);
assert!(ok);
assert_eq!(bank.trajectories.len(), 3);
assert_eq!(bank.rollback_witnesses.len(), 1);
let witness = &bank.rollback_witnesses[0];
assert_eq!(witness.trajectories_discarded, 3);
assert_eq!(witness.reason, "accuracy regression");
}
#[test]
fn test_structured_counterexample() {
let mut bank = ReasoningBank::new();
let ce = Counterexample {
id: String::new(), // Will be assigned
input_signature: "d5_c3".to_string(),
failure_type: "contradiction".to_string(),
reproduction_seed: Some(42),
expected: "2024-06-15".to_string(),
observed: "2024-07-15".to_string(),
pattern_id: None,
verdict: Some(Verdict::Failed),
};
let id = bank.record_structured_counterexample(ce);
assert!(id.starts_with("ce_"));
assert_eq!(bank.structured_counterexamples.len(), 1);
assert_eq!(bank.structured_counterexamples[0].id, id);
}
}