Files
wifi-densepose/crates/ruvllm/src/claude_flow/reasoning_bank.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

1520 lines
48 KiB
Rust

//! ReasoningBank Integration for RuvLTRA
//!
//! Implements intelligent pattern learning for Claude Flow agent routing:
//!
//! - **Trajectory Storage**: Records task executions with verdict judgments
//! - **Memory Distillation**: Extracts key patterns from multiple trajectories
//! - **EWC++ Consolidation**: Prevents catastrophic forgetting of learned patterns
//! - **Pattern-based Routing**: Recommends agents based on learned successes
//!
//! ## Architecture
//!
//! ```text
//! +-------------------+ +-------------------+
//! | Task Execution |---->| record_trajectory |
//! | (verdict, steps) | | - Store in buffer |
//! +-------------------+ | - Update quality |
//! +--------+----------+
//! |
//! v (threshold reached)
//! +--------+----------+
//! | distill_patterns |
//! | - Cluster similar |
//! | - Extract patterns|
//! | - Compute routing |
//! +--------+----------+
//! |
//! v (periodic)
//! +--------+----------+
//! | consolidate() |
//! | - EWC++ update |
//! | - Prune stale |
//! | - Merge similar |
//! +-------------------+
//! ```
//!
//! ## Example
//!
//! ```rust,ignore
//! use ruvllm::claude_flow::reasoning_bank::{
//! ReasoningBankIntegration, ReasoningBankConfig, Verdict, TrajectoryStep
//! };
//!
//! let config = ReasoningBankConfig::default();
//! let mut bank = ReasoningBankIntegration::new(config);
//!
//! // Record a successful coder task
//! let steps = vec![
//! TrajectoryStep::new("analyze_requirements", 0.8),
//! TrajectoryStep::new("implement_code", 0.9),
//! TrajectoryStep::new("run_tests", 0.95),
//! ];
//! bank.record_trajectory(
//! "task-123",
//! &embedding,
//! steps,
//! Verdict::Success { reason: "All tests passed".into() },
//! ).unwrap();
//!
//! // Get routing recommendation for new task
//! let rec = bank.get_recommendation(&new_embedding);
//! println!("Suggested agent: {:?} (confidence: {:.2})", rec.agent, rec.confidence);
//! ```
use super::AgentType;
use crate::error::{Result, RuvLLMError};
use crate::sona::{SonaConfig, SonaIntegration, Trajectory as SonaTrajectory};
use parking_lot::RwLock;
use ruvector_sona::{
EwcConfig, EwcPlusPlus, LearnedPattern, PatternConfig, PatternType, ReasoningBank,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
/// Verdict judgment for a trajectory
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Verdict {
/// Task completed successfully
Success {
/// Reason for success
reason: String,
},
/// Task failed
Failure {
/// Reason for failure
reason: String,
/// Optional error code
error_code: Option<String>,
},
/// Task partially completed
Partial {
/// Completion percentage (0.0 - 1.0)
completion: f32,
/// Reason for partial completion
reason: String,
},
/// Task recovered via self-reflection
///
/// This variant is used when a task initially failed but was successfully
/// recovered through the reflection system. It tracks the original error
/// and the recovery strategy that worked.
RecoveredViaReflection {
/// Original error that was encountered
original_error: String,
/// Recovery strategy that worked
recovery_strategy: String,
/// Number of attempts before successful recovery
attempts: u32,
},
}
impl Verdict {
/// Get quality score for this verdict
#[inline]
pub fn quality_score(&self) -> f32 {
match self {
Verdict::Success { .. } => 1.0,
Verdict::Failure { .. } => 0.0,
Verdict::Partial { completion, .. } => *completion,
// Recovered tasks get a slightly lower score than pure success
// to reflect that they required extra effort
Verdict::RecoveredViaReflection { attempts, .. } => {
// More attempts = lower score, but still successful
// 1 attempt = 0.95, 2 = 0.90, 3 = 0.85, etc.
(1.0 - (*attempts as f32 - 1.0) * 0.05).clamp(0.7, 0.95)
}
}
}
/// Check if verdict is successful (>= 0.5 quality)
#[inline]
pub fn is_successful(&self) -> bool {
self.quality_score() >= 0.5
}
/// Get verdict reason
#[inline]
pub fn reason(&self) -> &str {
match self {
Verdict::Success { reason } => reason,
Verdict::Failure { reason, .. } => reason,
Verdict::Partial { reason, .. } => reason,
Verdict::RecoveredViaReflection {
recovery_strategy, ..
} => recovery_strategy,
}
}
/// Check if this verdict involved recovery via reflection
#[inline]
pub fn is_recovered(&self) -> bool {
matches!(self, Verdict::RecoveredViaReflection { .. })
}
/// Get the original error if this was a recovered verdict
#[inline]
pub fn original_error(&self) -> Option<&str> {
match self {
Verdict::RecoveredViaReflection { original_error, .. } => Some(original_error),
_ => None,
}
}
/// Get the number of recovery attempts if applicable
#[inline]
pub fn recovery_attempts(&self) -> Option<u32> {
match self {
Verdict::RecoveredViaReflection { attempts, .. } => Some(*attempts),
_ => None,
}
}
}
impl Default for Verdict {
fn default() -> Self {
Verdict::Partial {
completion: 0.5,
reason: "Unknown".to_string(),
}
}
}
/// A single step in a trajectory
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrajectoryStep {
/// Step name/identifier
pub name: String,
/// Quality score for this step (0.0 - 1.0)
pub quality: f32,
/// Optional agent that performed this step
pub agent: Option<AgentType>,
/// Step duration in milliseconds
pub duration_ms: Option<u64>,
/// Additional metadata
pub metadata: HashMap<String, String>,
}
impl TrajectoryStep {
/// Create a new trajectory step
pub fn new(name: impl Into<String>, quality: f32) -> Self {
Self {
name: name.into(),
quality: quality.clamp(0.0, 1.0),
agent: None,
duration_ms: None,
metadata: HashMap::new(),
}
}
/// Set the agent for this step
pub fn with_agent(mut self, agent: AgentType) -> Self {
self.agent = Some(agent);
self
}
/// Set the duration for this step
pub fn with_duration(mut self, duration_ms: u64) -> Self {
self.duration_ms = Some(duration_ms);
self
}
/// Add metadata to this step
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
/// A complete trajectory record
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trajectory {
/// Unique task identifier
pub task_id: String,
/// Task embedding vector
pub embedding: Vec<f32>,
/// Execution steps
pub steps: Vec<TrajectoryStep>,
/// Final verdict
pub verdict: Verdict,
/// Computed quality score (from steps and verdict)
pub quality_score: f32,
/// Primary agent used
pub primary_agent: Option<AgentType>,
/// Task type classification
pub task_type: Option<String>,
/// Timestamp (Unix seconds)
pub timestamp: u64,
/// Total duration in milliseconds
pub total_duration_ms: Option<u64>,
}
impl Trajectory {
/// Create a new trajectory
pub fn new(
task_id: impl Into<String>,
embedding: Vec<f32>,
steps: Vec<TrajectoryStep>,
verdict: Verdict,
) -> Self {
let quality_score = Self::compute_quality(&steps, &verdict);
let primary_agent = steps.iter().filter_map(|s| s.agent).next();
let total_duration_ms = steps.iter().filter_map(|s| s.duration_ms).sum::<u64>();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
Self {
task_id: task_id.into(),
embedding,
steps,
verdict,
quality_score,
primary_agent,
task_type: None,
timestamp: now,
total_duration_ms: if total_duration_ms > 0 {
Some(total_duration_ms)
} else {
None
},
}
}
/// Compute quality score from steps and verdict
fn compute_quality(steps: &[TrajectoryStep], verdict: &Verdict) -> f32 {
if steps.is_empty() {
return verdict.quality_score();
}
// Weighted average: 70% steps, 30% verdict
let step_avg = steps.iter().map(|s| s.quality).sum::<f32>() / steps.len() as f32;
step_avg * 0.7 + verdict.quality_score() * 0.3
}
/// Set task type
pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
self.task_type = Some(task_type.into());
self
}
/// Check if trajectory is high quality
pub fn is_high_quality(&self, threshold: f32) -> bool {
self.quality_score >= threshold
}
}
/// Configuration for ReasoningBank integration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningBankConfig {
/// Maximum trajectory capacity
pub capacity: usize,
/// Quality threshold for distillation
pub distillation_threshold: f32,
/// EWC++ lambda (regularization strength)
pub ewc_lambda: f32,
/// Minimum trajectories before distillation
pub min_trajectories_for_distillation: usize,
/// Pattern similarity threshold for consolidation
pub consolidation_similarity: f32,
/// Embedding dimension
pub embedding_dim: usize,
/// Number of pattern clusters
pub num_clusters: usize,
/// Minimum pattern quality
pub min_pattern_quality: f32,
/// Pattern decay factor (for aging)
pub pattern_decay: f32,
/// Maximum pattern age in seconds
pub max_pattern_age_secs: u64,
/// Enable automatic distillation
pub auto_distill: bool,
/// Distillation interval (trajectory count)
pub distill_interval: usize,
}
impl Default for ReasoningBankConfig {
fn default() -> Self {
Self {
capacity: 10000,
distillation_threshold: 0.6,
ewc_lambda: 2000.0,
min_trajectories_for_distillation: 50,
consolidation_similarity: 0.85,
embedding_dim: 384,
num_clusters: 100,
min_pattern_quality: 0.3,
pattern_decay: 0.99,
max_pattern_age_secs: 604800, // 1 week
auto_distill: true,
distill_interval: 100,
}
}
}
impl ReasoningBankConfig {
/// Create configuration optimized for RuvLTRA-Small
pub fn for_ruvltra_small() -> Self {
Self {
capacity: 5000,
distillation_threshold: 0.6,
ewc_lambda: 500.0,
min_trajectories_for_distillation: 30,
consolidation_similarity: 0.9,
embedding_dim: 384,
num_clusters: 50,
min_pattern_quality: 0.4,
pattern_decay: 0.995,
max_pattern_age_secs: 259200, // 3 days
auto_distill: true,
distill_interval: 50,
}
}
/// Create configuration for edge deployment
pub fn for_edge() -> Self {
Self {
capacity: 1000,
distillation_threshold: 0.7,
ewc_lambda: 1000.0,
min_trajectories_for_distillation: 20,
consolidation_similarity: 0.95,
embedding_dim: 256,
num_clusters: 20,
min_pattern_quality: 0.5,
pattern_decay: 0.99,
max_pattern_age_secs: 86400, // 1 day
auto_distill: true,
distill_interval: 30,
}
}
}
/// Routing recommendation based on learned patterns
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingRecommendation {
/// Recommended agent type
pub agent: AgentType,
/// Confidence score (0.0 - 1.0)
pub confidence: f32,
/// Number of patterns used for recommendation
pub patterns_used: usize,
/// Average quality of matching patterns
pub avg_pattern_quality: f32,
/// Alternative agent suggestions
pub alternatives: Vec<(AgentType, f32)>,
/// Reasoning for recommendation
pub reasoning: String,
}
impl Default for RoutingRecommendation {
fn default() -> Self {
Self {
agent: AgentType::Coder,
confidence: 0.3,
patterns_used: 0,
avg_pattern_quality: 0.0,
alternatives: Vec::new(),
reasoning: "No patterns available, using default agent".to_string(),
}
}
}
/// Statistics for ReasoningBank
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ReasoningBankStats {
/// Total trajectories recorded
pub total_trajectories: u64,
/// Successful trajectories
pub successful_trajectories: u64,
/// Failed trajectories
pub failed_trajectories: u64,
/// Partial trajectories
pub partial_trajectories: u64,
/// Current trajectory buffer size
pub buffer_size: usize,
/// Number of learned patterns
pub patterns_learned: usize,
/// Distillation runs
pub distillation_runs: u64,
/// Consolidation runs
pub consolidation_runs: u64,
/// Average quality score
pub avg_quality: f32,
/// EWC task count
pub ewc_tasks: usize,
}
/// Distilled pattern from multiple trajectories
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistilledPattern {
/// Pattern identifier
pub id: u64,
/// Centroid embedding
pub centroid: Vec<f32>,
/// Primary agent association
pub primary_agent: AgentType,
/// Agent score distribution
pub agent_scores: HashMap<AgentType, f32>,
/// Average quality
pub avg_quality: f32,
/// Number of trajectories distilled
pub trajectory_count: usize,
/// Task type association
pub task_type: Option<String>,
/// Created timestamp
pub created_at: u64,
/// Last accessed timestamp
pub last_accessed: u64,
/// Access count
pub access_count: u32,
}
impl DistilledPattern {
/// Compute similarity with embedding using optimized dot product
#[inline]
pub fn similarity(&self, embedding: &[f32]) -> f32 {
let len = self.centroid.len();
if len != embedding.len() {
return 0.0;
}
// Compute all in single pass for cache efficiency
let mut dot: f32 = 0.0;
let mut norm_a_sq: f32 = 0.0;
let mut norm_b_sq: f32 = 0.0;
for i in 0..len {
let a = self.centroid[i];
let b = embedding[i];
dot += a * b;
norm_a_sq += a * a;
norm_b_sq += b * b;
}
let norm_a = norm_a_sq.sqrt();
let norm_b = norm_b_sq.sqrt();
if norm_a > 1e-8 && norm_b > 1e-8 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
/// Get best agent from this pattern
#[inline]
pub fn best_agent(&self) -> AgentType {
self.agent_scores
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(agent, _)| *agent)
.unwrap_or(self.primary_agent)
}
/// Check if pattern should be pruned
#[inline]
pub fn should_prune(&self, min_quality: f32, max_age_secs: u64) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let age = now.saturating_sub(self.last_accessed);
self.avg_quality < min_quality && age > max_age_secs && self.access_count < 5
}
}
/// ReasoningBank integration for Claude Flow
pub struct ReasoningBankIntegration {
/// Configuration
config: ReasoningBankConfig,
/// Trajectory buffer
trajectory_buffer: Arc<RwLock<Vec<Trajectory>>>,
/// Distilled patterns
patterns: Arc<RwLock<HashMap<u64, DistilledPattern>>>,
/// EWC++ for consolidation
ewc: Arc<RwLock<EwcPlusPlus>>,
/// Core reasoning bank (from ruvector_sona)
core_bank: Arc<RwLock<ReasoningBank>>,
/// SONA integration for trajectory recording
sona: Option<Arc<RwLock<SonaIntegration>>>,
/// Next pattern ID
next_pattern_id: AtomicU64,
/// Statistics
stats: RwLock<ReasoningBankStats>,
/// Trajectories since last distillation
trajectories_since_distill: AtomicU64,
}
impl ReasoningBankIntegration {
/// Create a new ReasoningBank integration
pub fn new(config: ReasoningBankConfig) -> Self {
let ewc_config = EwcConfig {
param_count: config.embedding_dim,
initial_lambda: config.ewc_lambda,
max_lambda: config.ewc_lambda * 5.0,
..Default::default()
};
let pattern_config = PatternConfig {
k_clusters: config.num_clusters,
embedding_dim: config.embedding_dim.min(256),
max_trajectories: config.capacity,
quality_threshold: config.min_pattern_quality,
..Default::default()
};
Self {
config,
trajectory_buffer: Arc::new(RwLock::new(Vec::new())),
patterns: Arc::new(RwLock::new(HashMap::new())),
ewc: Arc::new(RwLock::new(EwcPlusPlus::new(ewc_config))),
core_bank: Arc::new(RwLock::new(ReasoningBank::new(pattern_config))),
sona: None,
next_pattern_id: AtomicU64::new(0),
stats: RwLock::new(ReasoningBankStats::default()),
trajectories_since_distill: AtomicU64::new(0),
}
}
/// Create with SONA integration
pub fn with_sona(config: ReasoningBankConfig, sona_config: SonaConfig) -> Self {
let mut bank = Self::new(config);
bank.sona = Some(Arc::new(RwLock::new(SonaIntegration::new(sona_config))));
bank
}
/// Record a trajectory
pub fn record_trajectory(
&self,
task_id: impl Into<String>,
embedding: &[f32],
steps: Vec<TrajectoryStep>,
verdict: Verdict,
) -> Result<()> {
let trajectory = Trajectory::new(task_id, embedding.to_vec(), steps, verdict.clone());
// Update statistics
{
let mut stats = self.stats.write();
stats.total_trajectories += 1;
match &verdict {
Verdict::Success { .. } => stats.successful_trajectories += 1,
Verdict::Failure { .. } => stats.failed_trajectories += 1,
Verdict::Partial { .. } => stats.partial_trajectories += 1,
Verdict::RecoveredViaReflection { .. } => {
// Count recovered as successful since task completed
stats.successful_trajectories += 1;
}
}
// Update running average quality
let n = stats.total_trajectories as f32;
stats.avg_quality = stats.avg_quality * (n - 1.0) / n + trajectory.quality_score / n;
}
// Add to buffer
{
let mut buffer = self.trajectory_buffer.write();
// Enforce capacity
if buffer.len() >= self.config.capacity {
buffer.remove(0);
}
buffer.push(trajectory.clone());
}
// Record to SONA if available
if let Some(ref sona) = self.sona {
let sona_trajectory = SonaTrajectory {
request_id: trajectory.task_id.clone(),
session_id: "reasoning-bank".to_string(),
query_embedding: embedding.to_vec(),
response_embedding: embedding.to_vec(),
quality_score: trajectory.quality_score,
routing_features: vec![
trajectory.quality_score,
verdict.quality_score(),
trajectory.steps.len() as f32 / 10.0,
],
model_index: trajectory.primary_agent.map(|a| a as usize).unwrap_or(0),
timestamp: chrono::Utc::now(),
};
let sona_guard = sona.read();
let _ = sona_guard.record_trajectory(sona_trajectory);
}
// Record to core bank
{
let mut core = self.core_bank.write();
let query_traj =
ruvector_sona::QueryTrajectory::new(trajectory.timestamp, embedding.to_vec());
core.add_trajectory(&query_traj);
}
// Check for auto-distillation
let count = self
.trajectories_since_distill
.fetch_add(1, Ordering::SeqCst)
+ 1;
if self.config.auto_distill && count >= self.config.distill_interval as u64 {
self.distill_patterns()?;
self.trajectories_since_distill.store(0, Ordering::SeqCst);
}
Ok(())
}
/// Distill patterns from trajectories
pub fn distill_patterns(&self) -> Result<Vec<DistilledPattern>> {
let trajectories: Vec<Trajectory> = {
let buffer = self.trajectory_buffer.read();
buffer
.iter()
.filter(|t| t.quality_score >= self.config.distillation_threshold)
.cloned()
.collect()
};
if trajectories.len() < self.config.min_trajectories_for_distillation {
return Ok(Vec::new());
}
// Extract patterns from core bank
{
let mut core = self.core_bank.write();
core.extract_patterns();
}
// Group trajectories by similarity
let clusters = self.cluster_trajectories(&trajectories);
let mut new_patterns = Vec::new();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
for cluster in clusters {
if cluster.is_empty() {
continue;
}
// Compute centroid
let dim = cluster[0].embedding.len();
let mut centroid = vec![0.0f32; dim];
for traj in &cluster {
for (i, &e) in traj.embedding.iter().enumerate() {
if i < dim {
centroid[i] += e;
}
}
}
for c in &mut centroid {
*c /= cluster.len() as f32;
}
// Normalize centroid
let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for c in &mut centroid {
*c /= norm;
}
}
// Compute agent scores
let mut agent_scores: HashMap<AgentType, f32> = HashMap::new();
let mut total_quality = 0.0f32;
let mut task_type: Option<String> = None;
for traj in &cluster {
if let Some(agent) = traj.primary_agent {
*agent_scores.entry(agent).or_insert(0.0) += traj.quality_score;
}
total_quality += traj.quality_score;
if task_type.is_none() {
task_type = traj.task_type.clone();
}
}
// Normalize agent scores
let total_agent_score: f32 = agent_scores.values().sum();
if total_agent_score > 0.0 {
for score in agent_scores.values_mut() {
*score /= total_agent_score;
}
}
// Determine primary agent
let primary_agent = agent_scores
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(agent, _)| *agent)
.unwrap_or(AgentType::Coder);
let pattern_id = self.next_pattern_id.fetch_add(1, Ordering::SeqCst);
let pattern = DistilledPattern {
id: pattern_id,
centroid,
primary_agent,
agent_scores,
avg_quality: total_quality / cluster.len() as f32,
trajectory_count: cluster.len(),
task_type,
created_at: now,
last_accessed: now,
access_count: 0,
};
// Store pattern
{
let mut patterns = self.patterns.write();
patterns.insert(pattern_id, pattern.clone());
}
new_patterns.push(pattern);
}
// Update EWC with new patterns
self.update_ewc_from_patterns(&new_patterns);
// Update statistics
{
let mut stats = self.stats.write();
stats.distillation_runs += 1;
stats.patterns_learned = self.patterns.read().len();
}
Ok(new_patterns)
}
/// Cluster trajectories by embedding similarity
fn cluster_trajectories(&self, trajectories: &[Trajectory]) -> Vec<Vec<Trajectory>> {
if trajectories.is_empty() {
return Vec::new();
}
// Simple K-means style clustering
let k = self.config.num_clusters.min(trajectories.len() / 3).max(1);
let dim = trajectories[0].embedding.len();
// Initialize centroids with first k trajectories
let mut centroids: Vec<Vec<f32>> = trajectories
.iter()
.take(k)
.map(|t| t.embedding.clone())
.collect();
// Run clustering iterations
let mut assignments = vec![0usize; trajectories.len()];
for _ in 0..10 {
// Assign each trajectory to nearest centroid
let mut changed = false;
for (i, traj) in trajectories.iter().enumerate() {
let nearest = centroids
.iter()
.enumerate()
.map(|(j, c)| (j, self.cosine_similarity(&traj.embedding, c)))
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(j, _)| j)
.unwrap_or(0);
if assignments[i] != nearest {
assignments[i] = nearest;
changed = true;
}
}
if !changed {
break;
}
// Recompute centroids
let mut new_centroids = vec![vec![0.0f32; dim]; k];
let mut counts = vec![0usize; k];
for (i, traj) in trajectories.iter().enumerate() {
let cluster = assignments[i];
counts[cluster] += 1;
for (j, &e) in traj.embedding.iter().enumerate() {
if j < dim {
new_centroids[cluster][j] += e;
}
}
}
for (i, centroid) in new_centroids.iter_mut().enumerate() {
if counts[i] > 0 {
for c in centroid.iter_mut() {
*c /= counts[i] as f32;
}
}
}
centroids = new_centroids;
}
// Group trajectories by assignment
let mut clusters: Vec<Vec<Trajectory>> = vec![Vec::new(); k];
for (i, traj) in trajectories.iter().enumerate() {
clusters[assignments[i]].push(traj.clone());
}
// Filter out small clusters
clusters.into_iter().filter(|c| c.len() >= 2).collect()
}
/// Cosine similarity between two vectors
/// Optimized to compute all norms in a single pass
#[inline]
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
if len != b.len() {
return 0.0;
}
// Single-pass computation for cache efficiency
let mut dot: f32 = 0.0;
let mut norm_a_sq: f32 = 0.0;
let mut norm_b_sq: f32 = 0.0;
for i in 0..len {
let x = a[i];
let y = b[i];
dot += x * y;
norm_a_sq += x * x;
norm_b_sq += y * y;
}
let norm_a = norm_a_sq.sqrt();
let norm_b = norm_b_sq.sqrt();
if norm_a > 1e-8 && norm_b > 1e-8 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
/// Update EWC from new patterns
fn update_ewc_from_patterns(&self, patterns: &[DistilledPattern]) {
let mut ewc = self.ewc.write();
for pattern in patterns {
// Use centroid as pseudo-gradients
let gradients: Vec<f32> = pattern
.centroid
.iter()
.take(self.config.embedding_dim)
.copied()
.chain(std::iter::repeat(0.0))
.take(self.config.embedding_dim)
.collect();
ewc.update_fisher(&gradients);
}
// Start new task periodically
if patterns.len() >= 10 {
ewc.start_new_task();
}
}
/// Get routing recommendation for an embedding
pub fn get_recommendation(&self, embedding: &[f32]) -> RoutingRecommendation {
let patterns = self.patterns.read();
if patterns.is_empty() {
return RoutingRecommendation::default();
}
// Find most similar patterns with pre-allocated capacity
let mut scored: Vec<(&DistilledPattern, f32)> = Vec::with_capacity(patterns.len());
for pattern in patterns.values() {
scored.push((pattern, pattern.similarity(embedding)));
}
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_patterns: Vec<_> = scored.into_iter().take(5).collect();
if top_patterns.is_empty() {
return RoutingRecommendation::default();
}
// Update access counts for top patterns
{
let mut patterns_mut = self.patterns.write();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
for (pattern, _) in &top_patterns {
if let Some(p) = patterns_mut.get_mut(&pattern.id) {
p.access_count += 1;
p.last_accessed = now;
}
}
}
// Weighted vote for best agent - pre-allocate for typical agent count
let mut agent_votes: HashMap<AgentType, f32> = HashMap::with_capacity(16);
let mut total_weight = 0.0f32;
let mut total_quality = 0.0f32;
for (pattern, similarity) in &top_patterns {
let weight = similarity * pattern.avg_quality;
total_weight += weight;
total_quality += pattern.avg_quality;
for (agent, score) in &pattern.agent_scores {
*agent_votes.entry(*agent).or_insert(0.0) += weight * score;
}
// Also vote for primary agent
*agent_votes.entry(pattern.primary_agent).or_insert(0.0) += weight * 0.5;
}
// Normalize votes
if total_weight > 0.0 {
for vote in agent_votes.values_mut() {
*vote /= total_weight;
}
}
// Find best agent
let (best_agent, best_score) = agent_votes
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(agent, score)| (*agent, *score))
.unwrap_or((AgentType::Coder, 0.0));
// Get alternatives
let mut alternatives: Vec<(AgentType, f32)> = agent_votes
.into_iter()
.filter(|(agent, _)| *agent != best_agent)
.collect();
alternatives.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
alternatives.truncate(3);
// Compute confidence
let confidence = if top_patterns.is_empty() {
0.0
} else {
let max_similarity = top_patterns[0].1;
(best_score * max_similarity).min(1.0)
};
let avg_pattern_quality = if top_patterns.is_empty() {
0.0
} else {
total_quality / top_patterns.len() as f32
};
let reasoning = format!(
"Based on {} similar patterns with avg quality {:.2}; best match similarity: {:.2}",
top_patterns.len(),
avg_pattern_quality,
top_patterns.first().map(|(_, s)| *s).unwrap_or(0.0)
);
RoutingRecommendation {
agent: best_agent,
confidence,
patterns_used: top_patterns.len(),
avg_pattern_quality,
alternatives,
reasoning,
}
}
/// Consolidate patterns with EWC++
pub fn consolidate(&self) -> Result<()> {
// Prune old/low-quality patterns
{
let mut patterns = self.patterns.write();
let to_remove: Vec<u64> = patterns
.iter()
.filter(|(_, p)| {
p.should_prune(
self.config.min_pattern_quality,
self.config.max_pattern_age_secs,
)
})
.map(|(id, _)| *id)
.collect();
for id in to_remove {
patterns.remove(&id);
}
}
// Merge similar patterns
{
let mut patterns = self.patterns.write();
let pattern_ids: Vec<u64> = patterns.keys().copied().collect();
let mut merged_ids = Vec::new();
for i in 0..pattern_ids.len() {
for j in i + 1..pattern_ids.len() {
let id1 = pattern_ids[i];
let id2 = pattern_ids[j];
if merged_ids.contains(&id1) || merged_ids.contains(&id2) {
continue;
}
if let (Some(p1), Some(p2)) = (patterns.get(&id1), patterns.get(&id2)) {
let similarity = p1.similarity(&p2.centroid);
if similarity > self.config.consolidation_similarity {
// Merge p2 into p1
let merged = self.merge_patterns(p1, p2);
patterns.insert(id1, merged);
merged_ids.push(id2);
}
}
}
}
for id in merged_ids {
patterns.remove(&id);
}
}
// Consolidate EWC tasks
{
let mut ewc = self.ewc.write();
ewc.consolidate_all_tasks();
}
// Update statistics
{
let mut stats = self.stats.write();
stats.consolidation_runs += 1;
stats.patterns_learned = self.patterns.read().len();
stats.ewc_tasks = self.ewc.read().task_count();
}
Ok(())
}
/// Merge two patterns
fn merge_patterns(&self, p1: &DistilledPattern, p2: &DistilledPattern) -> DistilledPattern {
let total_count = p1.trajectory_count + p2.trajectory_count;
let w1 = p1.trajectory_count as f32 / total_count as f32;
let w2 = p2.trajectory_count as f32 / total_count as f32;
// Merge centroids
let centroid: Vec<f32> = p1
.centroid
.iter()
.zip(&p2.centroid)
.map(|(&a, &b)| a * w1 + b * w2)
.collect();
// Merge agent scores
let mut agent_scores: HashMap<AgentType, f32> = p1.agent_scores.clone();
for (agent, score) in &p2.agent_scores {
*agent_scores.entry(*agent).or_insert(0.0) += score * w2;
}
// Normalize
let total: f32 = agent_scores.values().sum();
if total > 0.0 {
for score in agent_scores.values_mut() {
*score /= total;
}
}
let primary_agent = agent_scores
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(agent, _)| *agent)
.unwrap_or(p1.primary_agent);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
DistilledPattern {
id: p1.id,
centroid,
primary_agent,
agent_scores,
avg_quality: p1.avg_quality * w1 + p2.avg_quality * w2,
trajectory_count: total_count,
task_type: p1.task_type.clone().or_else(|| p2.task_type.clone()),
created_at: p1.created_at.min(p2.created_at),
last_accessed: now,
access_count: p1.access_count + p2.access_count,
}
}
/// Get statistics
pub fn stats(&self) -> ReasoningBankStats {
let mut stats = self.stats.read().clone();
stats.buffer_size = self.trajectory_buffer.read().len();
stats.patterns_learned = self.patterns.read().len();
stats.ewc_tasks = self.ewc.read().task_count();
stats
}
/// Get all patterns
pub fn get_patterns(&self) -> Vec<DistilledPattern> {
self.patterns.read().values().cloned().collect()
}
/// Get trajectory count
pub fn trajectory_count(&self) -> usize {
self.trajectory_buffer.read().len()
}
/// Get pattern count
pub fn pattern_count(&self) -> usize {
self.patterns.read().len()
}
/// Clear all data
pub fn clear(&self) {
self.trajectory_buffer.write().clear();
self.patterns.write().clear();
*self.stats.write() = ReasoningBankStats::default();
self.trajectories_since_distill.store(0, Ordering::SeqCst);
}
/// Export patterns for persistence
pub fn export_patterns(&self) -> Vec<DistilledPattern> {
self.patterns.read().values().cloned().collect()
}
/// Import patterns
pub fn import_patterns(&self, patterns: Vec<DistilledPattern>) {
let mut pattern_map = self.patterns.write();
for pattern in patterns {
let id = pattern.id.max(self.next_pattern_id.load(Ordering::SeqCst));
self.next_pattern_id.fetch_max(id + 1, Ordering::SeqCst);
pattern_map.insert(pattern.id, pattern);
}
self.stats.write().patterns_learned = pattern_map.len();
}
}
impl std::fmt::Debug for ReasoningBankIntegration {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReasoningBankIntegration")
.field("config", &self.config)
.field("trajectory_count", &self.trajectory_count())
.field("pattern_count", &self.pattern_count())
.field("stats", &self.stats())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_verdict_quality_scores() {
assert_eq!(
Verdict::Success {
reason: "ok".into()
}
.quality_score(),
1.0
);
assert_eq!(
Verdict::Failure {
reason: "err".into(),
error_code: None
}
.quality_score(),
0.0
);
assert_eq!(
Verdict::Partial {
completion: 0.7,
reason: "partial".into()
}
.quality_score(),
0.7
);
}
#[test]
fn test_trajectory_step_creation() {
let step = TrajectoryStep::new("test_step", 0.8)
.with_agent(AgentType::Coder)
.with_duration(100)
.with_metadata("key", "value");
assert_eq!(step.name, "test_step");
assert_eq!(step.quality, 0.8);
assert_eq!(step.agent, Some(AgentType::Coder));
assert_eq!(step.duration_ms, Some(100));
assert_eq!(step.metadata.get("key"), Some(&"value".to_string()));
}
#[test]
fn test_trajectory_creation() {
let steps = vec![
TrajectoryStep::new("step1", 0.7).with_agent(AgentType::Researcher),
TrajectoryStep::new("step2", 0.9).with_agent(AgentType::Coder),
];
let traj = Trajectory::new(
"task-1",
vec![0.1, 0.2, 0.3],
steps,
Verdict::Success {
reason: "done".into(),
},
);
assert_eq!(traj.task_id, "task-1");
assert_eq!(traj.steps.len(), 2);
// Quality = 0.7 * ((0.7 + 0.9) / 2) + 0.3 * 1.0 = 0.56 + 0.3 = 0.86
assert!((traj.quality_score - 0.86).abs() < 0.01);
}
#[test]
fn test_reasoning_bank_creation() {
let config = ReasoningBankConfig::default();
let bank = ReasoningBankIntegration::new(config);
assert_eq!(bank.trajectory_count(), 0);
assert_eq!(bank.pattern_count(), 0);
}
#[test]
fn test_record_trajectory() {
let config = ReasoningBankConfig {
auto_distill: false,
..Default::default()
};
let bank = ReasoningBankIntegration::new(config);
let steps = vec![TrajectoryStep::new("step1", 0.8).with_agent(AgentType::Coder)];
bank.record_trajectory(
"task-1",
&vec![0.1; 384],
steps,
Verdict::Success {
reason: "done".into(),
},
)
.unwrap();
assert_eq!(bank.trajectory_count(), 1);
let stats = bank.stats();
assert_eq!(stats.total_trajectories, 1);
assert_eq!(stats.successful_trajectories, 1);
}
#[test]
fn test_distill_patterns() {
let config = ReasoningBankConfig {
min_trajectories_for_distillation: 5,
distillation_threshold: 0.0, // Accept all
num_clusters: 2,
auto_distill: false,
..Default::default()
};
let bank = ReasoningBankIntegration::new(config);
// Add trajectories
for i in 0..10 {
let embedding: Vec<f32> = if i < 5 {
vec![1.0, 0.0, 0.0]
.into_iter()
.chain(std::iter::repeat(0.0))
.take(384)
.collect()
} else {
vec![0.0, 1.0, 0.0]
.into_iter()
.chain(std::iter::repeat(0.0))
.take(384)
.collect()
};
let steps = vec![TrajectoryStep::new("step", 0.8).with_agent(AgentType::Coder)];
bank.record_trajectory(
format!("task-{}", i),
&embedding,
steps,
Verdict::Success {
reason: "done".into(),
},
)
.unwrap();
}
let patterns = bank.distill_patterns().unwrap();
assert!(!patterns.is_empty());
}
#[test]
fn test_get_recommendation() {
let config = ReasoningBankConfig {
min_trajectories_for_distillation: 2,
distillation_threshold: 0.0,
num_clusters: 1,
auto_distill: false,
..Default::default()
};
let bank = ReasoningBankIntegration::new(config);
// Add similar trajectories
for i in 0..5 {
let embedding: Vec<f32> = vec![1.0, 0.0, 0.0]
.into_iter()
.chain(std::iter::repeat(0.0))
.take(384)
.collect();
let steps = vec![TrajectoryStep::new("step", 0.9).with_agent(AgentType::Tester)];
bank.record_trajectory(
format!("task-{}", i),
&embedding,
steps,
Verdict::Success {
reason: "done".into(),
},
)
.unwrap();
}
bank.distill_patterns().unwrap();
// Get recommendation for similar embedding
let query: Vec<f32> = vec![0.9, 0.1, 0.0]
.into_iter()
.chain(std::iter::repeat(0.0))
.take(384)
.collect();
let rec = bank.get_recommendation(&query);
assert!(rec.patterns_used > 0);
assert!(rec.confidence > 0.0);
}
#[test]
fn test_consolidate() {
let config = ReasoningBankConfig {
min_trajectories_for_distillation: 2,
distillation_threshold: 0.0,
num_clusters: 2,
consolidation_similarity: 0.99, // High threshold for testing
auto_distill: false,
..Default::default()
};
let bank = ReasoningBankIntegration::new(config);
// Add trajectories
for i in 0..6 {
let embedding: Vec<f32> = vec![1.0 + (i as f32 * 0.001), 0.0, 0.0]
.into_iter()
.chain(std::iter::repeat(0.0))
.take(384)
.collect();
let steps = vec![TrajectoryStep::new("step", 0.8).with_agent(AgentType::Coder)];
bank.record_trajectory(
format!("task-{}", i),
&embedding,
steps,
Verdict::Success {
reason: "done".into(),
},
)
.unwrap();
}
bank.distill_patterns().unwrap();
let before = bank.pattern_count();
bank.consolidate().unwrap();
let after = bank.pattern_count();
assert!(after <= before);
}
#[test]
fn test_distilled_pattern_similarity() {
let pattern = DistilledPattern {
id: 1,
centroid: vec![1.0, 0.0, 0.0, 0.0],
primary_agent: AgentType::Coder,
agent_scores: HashMap::new(),
avg_quality: 0.9,
trajectory_count: 10,
task_type: None,
created_at: 0,
last_accessed: 0,
access_count: 0,
};
let same = vec![1.0, 0.0, 0.0, 0.0];
let orthogonal = vec![0.0, 1.0, 0.0, 0.0];
assert!((pattern.similarity(&same) - 1.0).abs() < 0.01);
assert!(pattern.similarity(&orthogonal).abs() < 0.01);
}
#[test]
fn test_export_import_patterns() {
let config = ReasoningBankConfig::default();
let bank = ReasoningBankIntegration::new(config.clone());
// Create some patterns manually
let pattern = DistilledPattern {
id: 42,
centroid: vec![0.5; 384],
primary_agent: AgentType::Researcher,
agent_scores: HashMap::from([(AgentType::Researcher, 0.8), (AgentType::Coder, 0.2)]),
avg_quality: 0.85,
trajectory_count: 50,
task_type: Some("research".to_string()),
created_at: 1000,
last_accessed: 2000,
access_count: 10,
};
bank.import_patterns(vec![pattern.clone()]);
assert_eq!(bank.pattern_count(), 1);
let exported = bank.export_patterns();
assert_eq!(exported.len(), 1);
assert_eq!(exported[0].id, 42);
assert_eq!(exported[0].primary_agent, AgentType::Researcher);
}
#[test]
fn test_config_presets() {
let default = ReasoningBankConfig::default();
let small = ReasoningBankConfig::for_ruvltra_small();
let edge = ReasoningBankConfig::for_edge();
assert!(default.capacity > small.capacity);
assert!(small.capacity > edge.capacity);
assert!(edge.num_clusters < small.num_clusters);
}
}