//! Q-Learning Module for iOS WASM //! //! Lightweight reinforcement learning for adaptive recommendations. //! Uses tabular Q-learning with function approximation for state generalization. /// Maximum number of actions (content recommendations) const MAX_ACTIONS: usize = 100; /// State discretization buckets const STATE_BUCKETS: usize = 16; /// User interaction types #[derive(Clone, Copy, Debug, PartialEq)] #[repr(u8)] pub enum InteractionType { /// User viewed content View = 0, /// User liked/saved content Like = 1, /// User shared content Share = 2, /// User skipped content Skip = 3, /// User completed content (video/audio) Complete = 4, /// User dismissed/hid content Dismiss = 5, } impl InteractionType { /// Convert interaction to reward signal #[inline] pub fn to_reward(self) -> f32 { match self { InteractionType::View => 0.1, InteractionType::Like => 0.8, InteractionType::Share => 1.0, InteractionType::Skip => -0.1, InteractionType::Complete => 0.6, InteractionType::Dismiss => -0.5, } } } /// User interaction event #[derive(Clone, Debug)] pub struct UserInteraction { /// Content ID that was interacted with pub content_id: u64, /// Type of interaction pub interaction: InteractionType, /// Time spent in seconds pub time_spent: f32, /// Position in recommendation list (0-indexed) pub position: u8, } /// Q-Learning agent for personalized recommendations pub struct QLearner { /// Q-values: state_bucket x action -> value q_table: Vec, /// Learning rate (alpha) learning_rate: f32, /// Discount factor (gamma) discount: f32, /// Exploration rate (epsilon) exploration: f32, /// Number of state buckets state_dim: usize, /// Number of actions action_dim: usize, /// Visit counts for UCB exploration visit_counts: Vec, /// Total updates total_updates: u64, } impl QLearner { /// Create a new Q-learner pub fn new(action_dim: usize) -> Self { let action_dim = action_dim.min(MAX_ACTIONS); let state_dim = STATE_BUCKETS; let table_size = state_dim * action_dim; Self { q_table: vec![0.0; table_size], learning_rate: 0.1, discount: 0.95, exploration: 0.1, state_dim, action_dim, visit_counts: vec![0; table_size], total_updates: 0, } } /// Create with custom hyperparameters pub fn with_params( action_dim: usize, learning_rate: f32, discount: f32, exploration: f32, ) -> Self { let mut learner = Self::new(action_dim); learner.learning_rate = learning_rate.clamp(0.001, 1.0); learner.discount = discount.clamp(0.0, 1.0); learner.exploration = exploration.clamp(0.0, 1.0); learner } /// Discretize state embedding to bucket index #[inline] fn discretize_state(&self, state_embedding: &[f32]) -> usize { if state_embedding.is_empty() { return 0; } // Use first few dimensions to compute hash let mut hash: u32 = 0; for (i, &val) in state_embedding.iter().take(8).enumerate() { let quantized = ((val + 1.0) * 127.0) as u32; hash = hash.wrapping_add(quantized << (i * 4)); } (hash as usize) % self.state_dim } /// Get Q-value for state-action pair #[inline] fn get_q(&self, state: usize, action: usize) -> f32 { let idx = state * self.action_dim + action; if idx < self.q_table.len() { self.q_table[idx] } else { 0.0 } } /// Set Q-value for state-action pair #[inline] fn set_q(&mut self, state: usize, action: usize, value: f32) { let idx = state * self.action_dim + action; if idx < self.q_table.len() { self.q_table[idx] = value; self.visit_counts[idx] += 1; } } /// Select action using epsilon-greedy with UCB exploration bonus pub fn select_action(&self, state_embedding: &[f32], rng_seed: u32) -> usize { let state = self.discretize_state(state_embedding); // Epsilon-greedy exploration let explore_threshold = (rng_seed % 1000) as f32 / 1000.0; if explore_threshold < self.exploration { // Random action return (rng_seed as usize) % self.action_dim; } // Greedy action with UCB bonus let mut best_action = 0; let mut best_value = f32::NEG_INFINITY; let total_visits = self.total_updates.max(1) as f32; for action in 0..self.action_dim { let q_val = self.get_q(state, action); let visits = self.visit_counts[state * self.action_dim + action].max(1) as f32; // UCB exploration bonus let ucb_bonus = (2.0 * total_visits.ln() / visits).sqrt() * 0.5; let value = q_val + ucb_bonus; if value > best_value { best_value = value; best_action = action; } } best_action } /// Update Q-value based on interaction pub fn update( &mut self, state_embedding: &[f32], action: usize, interaction: &UserInteraction, next_state_embedding: &[f32], ) { let state = self.discretize_state(state_embedding); let next_state = self.discretize_state(next_state_embedding); // Compute reward let base_reward = interaction.interaction.to_reward(); let time_bonus = (interaction.time_spent / 60.0).min(1.0) * 0.2; let position_bonus = (1.0 - interaction.position as f32 / 10.0).max(0.0) * 0.1; let reward = base_reward + time_bonus + position_bonus; // Find max Q-value for next state let mut max_next_q = f32::NEG_INFINITY; for a in 0..self.action_dim { let q = self.get_q(next_state, a); if q > max_next_q { max_next_q = q; } } if max_next_q == f32::NEG_INFINITY { max_next_q = 0.0; } // Q-learning update let current_q = self.get_q(state, action); let td_target = reward + self.discount * max_next_q; let new_q = current_q + self.learning_rate * (td_target - current_q); self.set_q(state, action, new_q); self.total_updates += 1; // Decay exploration over time if self.total_updates % 100 == 0 { self.exploration = (self.exploration * 0.99).max(0.01); } } /// Get action rankings for a state (returns sorted action indices) pub fn rank_actions(&self, state_embedding: &[f32]) -> Vec { let state = self.discretize_state(state_embedding); let mut action_values: Vec<(usize, f32)> = (0..self.action_dim) .map(|a| (a, self.get_q(state, a))) .collect(); action_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal)); action_values.into_iter().map(|(a, _)| a).collect() } /// Serialize Q-table to bytes for persistence pub fn serialize(&self) -> Vec { let mut bytes = Vec::with_capacity(self.q_table.len() * 4 + 32); // Header bytes.extend_from_slice(&(self.state_dim as u32).to_le_bytes()); bytes.extend_from_slice(&(self.action_dim as u32).to_le_bytes()); bytes.extend_from_slice(&self.learning_rate.to_le_bytes()); bytes.extend_from_slice(&self.discount.to_le_bytes()); bytes.extend_from_slice(&self.exploration.to_le_bytes()); bytes.extend_from_slice(&self.total_updates.to_le_bytes()); // Q-table for &q in &self.q_table { bytes.extend_from_slice(&q.to_le_bytes()); } bytes } /// Deserialize Q-table from bytes pub fn deserialize(bytes: &[u8]) -> Option { // Header: 4+4+4+4+4+8 = 28 bytes const HEADER_SIZE: usize = 28; if bytes.len() < HEADER_SIZE { return None; } let state_dim = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; let action_dim = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]) as usize; let learning_rate = f32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]); let discount = f32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]); let exploration = f32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]); let total_updates = u64::from_le_bytes([ bytes[20], bytes[21], bytes[22], bytes[23], bytes[24], bytes[25], bytes[26], bytes[27], ]); let table_size = state_dim * action_dim; let expected_len = HEADER_SIZE + table_size * 4; if bytes.len() < expected_len { return None; } let mut q_table = Vec::with_capacity(table_size); for i in 0..table_size { let offset = HEADER_SIZE + i * 4; let q = f32::from_le_bytes([ bytes[offset], bytes[offset + 1], bytes[offset + 2], bytes[offset + 3], ]); q_table.push(q); } Some(Self { q_table, learning_rate, discount, exploration, state_dim, action_dim, visit_counts: vec![0; table_size], total_updates, }) } /// Get current exploration rate pub fn exploration_rate(&self) -> f32 { self.exploration } /// Get total number of updates pub fn update_count(&self) -> u64 { self.total_updates } } #[cfg(test)] mod tests { use super::*; #[test] fn test_qlearner_creation() { let learner = QLearner::new(50); assert_eq!(learner.action_dim, 50); } #[test] fn test_action_selection() { let learner = QLearner::new(10); let state = vec![0.5; 64]; let action = learner.select_action(&state, 42); assert!(action < 10); } #[test] fn test_serialization_roundtrip() { let mut learner = QLearner::with_params(10, 0.1, 0.9, 0.2); // Do some updates let state = vec![0.5; 64]; let interaction = UserInteraction { content_id: 1, interaction: InteractionType::Like, time_spent: 30.0, position: 0, }; learner.update(&state, 0, &interaction, &state); // Serialize and deserialize let bytes = learner.serialize(); let restored = QLearner::deserialize(&bytes).unwrap(); assert_eq!(restored.action_dim, learner.action_dim); assert_eq!(restored.total_updates, learner.total_updates); } }