Files
wifi-densepose/vendor/ruvector/examples/wasm/ios/src/qlearning.rs

355 lines
11 KiB
Rust

//! Q-Learning Module for iOS WASM
//!
//! Lightweight reinforcement learning for adaptive recommendations.
//! Uses tabular Q-learning with function approximation for state generalization.
/// Maximum number of actions (content recommendations)
const MAX_ACTIONS: usize = 100;
/// State discretization buckets
const STATE_BUCKETS: usize = 16;
/// User interaction types
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(u8)]
pub enum InteractionType {
/// User viewed content
View = 0,
/// User liked/saved content
Like = 1,
/// User shared content
Share = 2,
/// User skipped content
Skip = 3,
/// User completed content (video/audio)
Complete = 4,
/// User dismissed/hid content
Dismiss = 5,
}
impl InteractionType {
/// Convert interaction to reward signal
#[inline]
pub fn to_reward(self) -> f32 {
match self {
InteractionType::View => 0.1,
InteractionType::Like => 0.8,
InteractionType::Share => 1.0,
InteractionType::Skip => -0.1,
InteractionType::Complete => 0.6,
InteractionType::Dismiss => -0.5,
}
}
}
/// User interaction event
#[derive(Clone, Debug)]
pub struct UserInteraction {
/// Content ID that was interacted with
pub content_id: u64,
/// Type of interaction
pub interaction: InteractionType,
/// Time spent in seconds
pub time_spent: f32,
/// Position in recommendation list (0-indexed)
pub position: u8,
}
/// Q-Learning agent for personalized recommendations
pub struct QLearner {
/// Q-values: state_bucket x action -> value
q_table: Vec<f32>,
/// Learning rate (alpha)
learning_rate: f32,
/// Discount factor (gamma)
discount: f32,
/// Exploration rate (epsilon)
exploration: f32,
/// Number of state buckets
state_dim: usize,
/// Number of actions
action_dim: usize,
/// Visit counts for UCB exploration
visit_counts: Vec<u32>,
/// Total updates
total_updates: u64,
}
impl QLearner {
/// Create a new Q-learner
pub fn new(action_dim: usize) -> Self {
let action_dim = action_dim.min(MAX_ACTIONS);
let state_dim = STATE_BUCKETS;
let table_size = state_dim * action_dim;
Self {
q_table: vec![0.0; table_size],
learning_rate: 0.1,
discount: 0.95,
exploration: 0.1,
state_dim,
action_dim,
visit_counts: vec![0; table_size],
total_updates: 0,
}
}
/// Create with custom hyperparameters
pub fn with_params(
action_dim: usize,
learning_rate: f32,
discount: f32,
exploration: f32,
) -> Self {
let mut learner = Self::new(action_dim);
learner.learning_rate = learning_rate.clamp(0.001, 1.0);
learner.discount = discount.clamp(0.0, 1.0);
learner.exploration = exploration.clamp(0.0, 1.0);
learner
}
/// Discretize state embedding to bucket index
#[inline]
fn discretize_state(&self, state_embedding: &[f32]) -> usize {
if state_embedding.is_empty() {
return 0;
}
// Use first few dimensions to compute hash
let mut hash: u32 = 0;
for (i, &val) in state_embedding.iter().take(8).enumerate() {
let quantized = ((val + 1.0) * 127.0) as u32;
hash = hash.wrapping_add(quantized << (i * 4));
}
(hash as usize) % self.state_dim
}
/// Get Q-value for state-action pair
#[inline]
fn get_q(&self, state: usize, action: usize) -> f32 {
let idx = state * self.action_dim + action;
if idx < self.q_table.len() {
self.q_table[idx]
} else {
0.0
}
}
/// Set Q-value for state-action pair
#[inline]
fn set_q(&mut self, state: usize, action: usize, value: f32) {
let idx = state * self.action_dim + action;
if idx < self.q_table.len() {
self.q_table[idx] = value;
self.visit_counts[idx] += 1;
}
}
/// Select action using epsilon-greedy with UCB exploration bonus
pub fn select_action(&self, state_embedding: &[f32], rng_seed: u32) -> usize {
let state = self.discretize_state(state_embedding);
// Epsilon-greedy exploration
let explore_threshold = (rng_seed % 1000) as f32 / 1000.0;
if explore_threshold < self.exploration {
// Random action
return (rng_seed as usize) % self.action_dim;
}
// Greedy action with UCB bonus
let mut best_action = 0;
let mut best_value = f32::NEG_INFINITY;
let total_visits = self.total_updates.max(1) as f32;
for action in 0..self.action_dim {
let q_val = self.get_q(state, action);
let visits = self.visit_counts[state * self.action_dim + action].max(1) as f32;
// UCB exploration bonus
let ucb_bonus = (2.0 * total_visits.ln() / visits).sqrt() * 0.5;
let value = q_val + ucb_bonus;
if value > best_value {
best_value = value;
best_action = action;
}
}
best_action
}
/// Update Q-value based on interaction
pub fn update(
&mut self,
state_embedding: &[f32],
action: usize,
interaction: &UserInteraction,
next_state_embedding: &[f32],
) {
let state = self.discretize_state(state_embedding);
let next_state = self.discretize_state(next_state_embedding);
// Compute reward
let base_reward = interaction.interaction.to_reward();
let time_bonus = (interaction.time_spent / 60.0).min(1.0) * 0.2;
let position_bonus = (1.0 - interaction.position as f32 / 10.0).max(0.0) * 0.1;
let reward = base_reward + time_bonus + position_bonus;
// Find max Q-value for next state
let mut max_next_q = f32::NEG_INFINITY;
for a in 0..self.action_dim {
let q = self.get_q(next_state, a);
if q > max_next_q {
max_next_q = q;
}
}
if max_next_q == f32::NEG_INFINITY {
max_next_q = 0.0;
}
// Q-learning update
let current_q = self.get_q(state, action);
let td_target = reward + self.discount * max_next_q;
let new_q = current_q + self.learning_rate * (td_target - current_q);
self.set_q(state, action, new_q);
self.total_updates += 1;
// Decay exploration over time
if self.total_updates % 100 == 0 {
self.exploration = (self.exploration * 0.99).max(0.01);
}
}
/// Get action rankings for a state (returns sorted action indices)
pub fn rank_actions(&self, state_embedding: &[f32]) -> Vec<usize> {
let state = self.discretize_state(state_embedding);
let mut action_values: Vec<(usize, f32)> = (0..self.action_dim)
.map(|a| (a, self.get_q(state, a)))
.collect();
action_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
action_values.into_iter().map(|(a, _)| a).collect()
}
/// Serialize Q-table to bytes for persistence
pub fn serialize(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(self.q_table.len() * 4 + 32);
// Header
bytes.extend_from_slice(&(self.state_dim as u32).to_le_bytes());
bytes.extend_from_slice(&(self.action_dim as u32).to_le_bytes());
bytes.extend_from_slice(&self.learning_rate.to_le_bytes());
bytes.extend_from_slice(&self.discount.to_le_bytes());
bytes.extend_from_slice(&self.exploration.to_le_bytes());
bytes.extend_from_slice(&self.total_updates.to_le_bytes());
// Q-table
for &q in &self.q_table {
bytes.extend_from_slice(&q.to_le_bytes());
}
bytes
}
/// Deserialize Q-table from bytes
pub fn deserialize(bytes: &[u8]) -> Option<Self> {
// Header: 4+4+4+4+4+8 = 28 bytes
const HEADER_SIZE: usize = 28;
if bytes.len() < HEADER_SIZE {
return None;
}
let state_dim = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
let action_dim = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]) as usize;
let learning_rate = f32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]);
let discount = f32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]);
let exploration = f32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]);
let total_updates = u64::from_le_bytes([
bytes[20], bytes[21], bytes[22], bytes[23],
bytes[24], bytes[25], bytes[26], bytes[27],
]);
let table_size = state_dim * action_dim;
let expected_len = HEADER_SIZE + table_size * 4;
if bytes.len() < expected_len {
return None;
}
let mut q_table = Vec::with_capacity(table_size);
for i in 0..table_size {
let offset = HEADER_SIZE + i * 4;
let q = f32::from_le_bytes([
bytes[offset], bytes[offset + 1], bytes[offset + 2], bytes[offset + 3],
]);
q_table.push(q);
}
Some(Self {
q_table,
learning_rate,
discount,
exploration,
state_dim,
action_dim,
visit_counts: vec![0; table_size],
total_updates,
})
}
/// Get current exploration rate
pub fn exploration_rate(&self) -> f32 {
self.exploration
}
/// Get total number of updates
pub fn update_count(&self) -> u64 {
self.total_updates
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qlearner_creation() {
let learner = QLearner::new(50);
assert_eq!(learner.action_dim, 50);
}
#[test]
fn test_action_selection() {
let learner = QLearner::new(10);
let state = vec![0.5; 64];
let action = learner.select_action(&state, 42);
assert!(action < 10);
}
#[test]
fn test_serialization_roundtrip() {
let mut learner = QLearner::with_params(10, 0.1, 0.9, 0.2);
// Do some updates
let state = vec![0.5; 64];
let interaction = UserInteraction {
content_id: 1,
interaction: InteractionType::Like,
time_spent: 30.0,
position: 0,
};
learner.update(&state, 0, &interaction, &state);
// Serialize and deserialize
let bytes = learner.serialize();
let restored = QLearner::deserialize(&bytes).unwrap();
assert_eq!(restored.action_dim, learner.action_dim);
assert_eq!(restored.total_updates, learner.total_updates);
}
}