//! Self-learning service for continuous improvement use crate::config::LearningConfig; use crate::error::{Error, Result}; use crate::memory::MemoryService; use crate::router::FastGRNNRouter; use crate::types::{Feedback, InteractionOutcome, RouterSample}; use parking_lot::RwLock; use std::sync::Arc; use tokio::sync::mpsc; use tokio::task::JoinHandle; /// Learning service managing continuous improvement pub struct LearningService { /// Configuration config: LearningConfig, /// Router reference router: Arc>, /// Memory reference memory: Arc, /// Embedding dimension for creating new vectors embedding_dim: usize, /// Replay buffer replay_buffer: RwLock, /// EWC state ewc: RwLock, /// Shutdown signal shutdown_tx: Option>, /// Background task handle task_handle: RwLock>>, } /// Replay buffer with reservoir sampling #[derive(Debug, Default)] struct ReplayBuffer { entries: Vec, capacity: usize, total_seen: u64, } /// Elastic Weight Consolidation state #[derive(Debug, Default)] struct EWCState { /// Fisher information diagonal fisher_info: Vec, /// Optimal weights from previous task optimal_weights: Vec, /// Lambda regularization strength lambda: f32, } impl LearningService { /// Create a new learning service pub fn new( config: &LearningConfig, router: Arc>, memory: Arc, embedding_dim: usize, ) -> Result { Ok(Self { config: config.clone(), router, memory, embedding_dim, replay_buffer: RwLock::new(ReplayBuffer { entries: Vec::new(), capacity: config.replay_capacity, total_seen: 0, }), ewc: RwLock::new(EWCState { fisher_info: Vec::new(), optimal_weights: Vec::new(), lambda: config.ewc_lambda, }), shutdown_tx: None, task_handle: RwLock::new(None), }) } /// Start background training loop pub async fn start_background_training(&self) { let (tx, mut rx) = mpsc::channel::<()>(1); let config = self.config.clone(); let router = self.router.clone(); let replay_buffer = Arc::new(RwLock::new(ReplayBuffer { entries: Vec::new(), capacity: config.replay_capacity, total_seen: 0, })); let handle = tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_millis( config.training_interval_ms, )); loop { tokio::select! { _ = interval.tick() => { // Check if enough samples let buffer = replay_buffer.read(); if buffer.entries.len() < config.min_samples { continue; } drop(buffer); // Training step would go here tracing::debug!("Background training tick"); } _ = rx.recv() => { tracing::info!("Learning service shutting down"); break; } } } }); *self.task_handle.write() = Some(handle); } /// Called on each interaction pub async fn on_interaction( &self, query: &str, response: &str, context: &[String], ) -> Result { // Skip if learning is disabled if !self.config.enabled { return Ok(InteractionOutcome { quality_score: 0.0, used_nodes: vec![], task_success: true, user_rating: None, }); } // Evaluate quality (mock - in production use LLM judge) let quality_score = self.evaluate_quality(query, response, context); // Create outcome let outcome = InteractionOutcome { quality_score, used_nodes: vec![], task_success: quality_score > 0.5, user_rating: None, }; // Maybe write to memory if quality_score >= self.config.quality_threshold { self.writeback(query, response, quality_score).await?; } Ok(outcome) } /// Record explicit feedback pub async fn record_feedback(&self, feedback: Feedback) -> Result<()> { tracing::info!( request_id = %feedback.request_id, rating = ?feedback.rating, "Recording feedback" ); // Update memory edges based on feedback if let Some(rating) = feedback.rating { let delta = (rating as f32 - 3.0) / 10.0; // -0.2 to +0.2 // In production, look up the request and update edge weights tracing::debug!(delta = delta, "Would update edge weights"); } Ok(()) } /// Stop the learning service pub async fn stop(&self) { if let Some(tx) = &self.shutdown_tx { let _ = tx.send(()).await; } if let Some(handle) = self.task_handle.write().take() { let _ = handle.await; } } fn evaluate_quality(&self, query: &str, response: &str, _context: &[String]) -> f32 { // Simple heuristic quality evaluation (in production, use LLM judge) let mut score = 0.5; // Longer responses are typically better (up to a point) let word_count = response.split_whitespace().count(); if word_count > 10 { score += 0.1; } if word_count > 50 { score += 0.1; } // Response should relate to query let query_lower = query.to_lowercase(); let query_words: std::collections::HashSet<_> = query_lower .split_whitespace() .filter(|w| w.len() > 3) .collect(); let response_lower = response.to_lowercase(); let response_words: std::collections::HashSet<_> = response_lower .split_whitespace() .filter(|w| w.len() > 3) .collect(); let overlap = query_words.intersection(&response_words).count(); if overlap > 0 { score += 0.1 * (overlap as f32).min(3.0); } score.min(1.0) } async fn writeback(&self, query: &str, response: &str, quality: f32) -> Result<()> { use crate::types::{MemoryNode, NodeType}; use std::collections::HashMap; use uuid::Uuid; // Create combined Q&A node let text = format!("Q: {}\nA: {}", query, response); // Mock embedding using configured dimension let vector = vec![0.0f32; self.embedding_dim]; let node = MemoryNode { id: Uuid::new_v4().to_string(), vector, text, node_type: NodeType::QAPair, source: "self_learning".into(), metadata: { let mut m = HashMap::new(); m.insert("quality".into(), serde_json::json!(quality)); m.insert( "timestamp".into(), serde_json::json!(chrono::Utc::now().timestamp()), ); m }, }; self.memory.insert_node(node)?; tracing::debug!(quality = quality, "Wrote interaction to memory"); Ok(()) } } impl ReplayBuffer { fn add(&mut self, sample: RouterSample) { self.total_seen += 1; if self.entries.len() < self.capacity { self.entries.push(sample); } else { // Reservoir sampling use rand::Rng; let idx = rand::thread_rng().gen_range(0..self.total_seen) as usize; if idx < self.capacity { self.entries[idx] = sample; } } } fn sample(&self, batch_size: usize) -> Vec<&RouterSample> { use rand::seq::SliceRandom; let mut rng = rand::thread_rng(); self.entries.choose_multiple(&mut rng, batch_size).collect() } } impl EWCState { fn regularization_loss(&self, current_weights: &[f32]) -> f32 { if self.fisher_info.is_empty() || self.optimal_weights.is_empty() { return 0.0; } self.fisher_info .iter() .zip(current_weights.iter()) .zip(self.optimal_weights.iter()) .map(|((f, w), w_star)| f * (w - w_star).powi(2)) .sum::() * self.lambda / 2.0 } } #[cfg(test)] mod tests { use super::*; #[test] fn test_replay_buffer() { let mut buffer = ReplayBuffer { entries: Vec::new(), capacity: 10, total_seen: 0, }; for i in 0..20 { buffer.add(RouterSample { features: vec![i as f32], label_model: 0, label_context: 0, label_temperature: 0.7, label_top_p: 0.9, quality: 0.8, latency_ms: 100.0, }); } // Buffer should be at capacity assert_eq!(buffer.entries.len(), 10); assert_eq!(buffer.total_seen, 20); } #[test] fn test_ewc_regularization() { let ewc = EWCState { fisher_info: vec![1.0, 1.0, 1.0], optimal_weights: vec![0.0, 0.0, 0.0], lambda: 1.0, }; let current = vec![1.0, 1.0, 1.0]; let loss = ewc.regularization_loss(¤t); // Should penalize deviation from optimal assert!(loss > 0.0); } }