//! Mamba SSM (Selective State Space Model) Attention Mechanism //! //! Implements the Mamba architecture's selective scan mechanism for efficient //! sequence modeling with linear time complexity O(n). //! //! Key Features: //! - **Selective Scan**: Input-dependent state transitions //! - **Linear Complexity**: O(n) vs O(n^2) for standard attention //! - **Hardware Efficient**: Optimized for parallel scan operations //! - **Long Context**: Handles very long sequences efficiently //! //! ## Architecture //! //! Mamba uses a selective state space model: //! ```text //! h_t = A_t * h_{t-1} + B_t * x_t //! y_t = C_t * h_t //! ``` //! //! Where A_t, B_t, C_t are input-dependent (selective), computed from x_t. //! //! ## References //! //! - Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Gu & Dao, 2023) use serde::{Deserialize, Serialize}; use wasm_bindgen::prelude::*; // ============================================================================ // Configuration // ============================================================================ /// Configuration for Mamba SSM attention #[derive(Debug, Clone, Serialize, Deserialize)] #[wasm_bindgen] pub struct MambaConfig { /// Model dimension (d_model) pub dim: usize, /// State space dimension (n) pub state_dim: usize, /// Expansion factor for inner dimension pub expand_factor: usize, /// Convolution kernel size pub conv_kernel_size: usize, /// Delta (discretization step) range minimum pub dt_min: f32, /// Delta range maximum pub dt_max: f32, /// Whether to use learnable D skip connection pub use_d_skip: bool, } #[wasm_bindgen] impl MambaConfig { /// Create a new Mamba configuration #[wasm_bindgen(constructor)] pub fn new(dim: usize) -> MambaConfig { MambaConfig { dim, state_dim: 16, expand_factor: 2, conv_kernel_size: 4, dt_min: 0.001, dt_max: 0.1, use_d_skip: true, } } /// Set state space dimension #[wasm_bindgen(js_name = withStateDim)] pub fn with_state_dim(mut self, state_dim: usize) -> MambaConfig { self.state_dim = state_dim; self } /// Set expansion factor #[wasm_bindgen(js_name = withExpandFactor)] pub fn with_expand_factor(mut self, factor: usize) -> MambaConfig { self.expand_factor = factor; self } /// Set convolution kernel size #[wasm_bindgen(js_name = withConvKernelSize)] pub fn with_conv_kernel_size(mut self, size: usize) -> MambaConfig { self.conv_kernel_size = size; self } } impl Default for MambaConfig { fn default() -> Self { MambaConfig::new(256) } } // ============================================================================ // State Space Parameters // ============================================================================ /// Selective state space parameters (input-dependent) #[derive(Debug, Clone, Serialize, Deserialize)] struct SelectiveSSMParams { /// Discretized A matrix diagonal (batch, seq_len, state_dim) a_bar: Vec>>, /// Discretized B matrix (batch, seq_len, state_dim) b_bar: Vec>>, /// Output projection C (batch, seq_len, state_dim) c: Vec>>, /// Discretization step delta (batch, seq_len, inner_dim) delta: Vec>>, } // ============================================================================ // Mamba SSM Attention // ============================================================================ /// Mamba Selective State Space Model for sequence attention /// /// Provides O(n) attention-like mechanism using selective state spaces #[wasm_bindgen] pub struct MambaSSMAttention { config: MambaConfig, /// Inner dimension after expansion inner_dim: usize, /// A parameter (state_dim,) - diagonal of continuous A a_log: Vec, /// D skip connection (inner_dim,) d_skip: Vec, /// Projection weights (simplified for WASM) in_proj: Vec>, out_proj: Vec>, } #[wasm_bindgen] impl MambaSSMAttention { /// Create a new Mamba SSM attention layer #[wasm_bindgen(constructor)] pub fn new(config: MambaConfig) -> MambaSSMAttention { let inner_dim = config.dim * config.expand_factor; // Initialize A as negative values (for stability) - log of eigenvalues let a_log: Vec = (0..config.state_dim) .map(|i| -((i + 1) as f32).ln()) .collect(); // D skip connection let d_skip = vec![1.0; inner_dim]; // Simplified projection matrices (identity-like for stub) let in_proj: Vec> = (0..inner_dim) .map(|i| { let mut row = vec![0.0; config.dim]; if i < config.dim { row[i] = 1.0; } row }) .collect(); let out_proj: Vec> = (0..config.dim) .map(|i| { let mut row = vec![0.0; inner_dim]; if i < inner_dim { row[i] = 1.0; } row }) .collect(); MambaSSMAttention { config, inner_dim, a_log, d_skip, in_proj, out_proj, } } /// Create with default configuration #[wasm_bindgen(js_name = withDefaults)] pub fn with_defaults(dim: usize) -> MambaSSMAttention { MambaSSMAttention::new(MambaConfig::new(dim)) } /// Forward pass through Mamba SSM /// /// # Arguments /// * `input` - Input sequence (seq_len, dim) flattened to 1D /// * `seq_len` - Sequence length /// /// # Returns /// Output sequence (seq_len, dim) flattened to 1D #[wasm_bindgen] pub fn forward(&self, input: Vec, seq_len: usize) -> Result, JsError> { let dim = self.config.dim; if input.len() != seq_len * dim { return Err(JsError::new(&format!( "Input size mismatch: expected {} ({}x{}), got {}", seq_len * dim, seq_len, dim, input.len() ))); } // Reshape input to 2D let input_2d: Vec> = (0..seq_len) .map(|t| input[t * dim..(t + 1) * dim].to_vec()) .collect(); // Step 1: Input projection to inner_dim let projected = self.project_in(&input_2d); // Step 2: Compute selective SSM parameters from input let ssm_params = self.compute_selective_params(&projected); // Step 3: Run selective scan let ssm_output = self.selective_scan(&projected, &ssm_params); // Step 4: Apply D skip connection let with_skip: Vec> = ssm_output .iter() .zip(projected.iter()) .map(|(y, x)| { y.iter() .zip(x.iter()) .zip(self.d_skip.iter()) .map(|((yi, xi), di)| yi + di * xi) .collect() }) .collect(); // Step 5: Output projection let output = self.project_out(&with_skip); // Flatten output Ok(output.into_iter().flatten().collect()) } /// Get the configuration #[wasm_bindgen(getter)] pub fn config(&self) -> MambaConfig { self.config.clone() } /// Get the inner dimension #[wasm_bindgen(getter, js_name = innerDim)] pub fn inner_dim(&self) -> usize { self.inner_dim } /// Compute attention-like scores (for visualization/analysis) /// /// Returns pseudo-attention scores showing which positions influence output #[wasm_bindgen(js_name = getAttentionScores)] pub fn get_attention_scores( &self, input: Vec, seq_len: usize, ) -> Result, JsError> { let dim = self.config.dim; if input.len() != seq_len * dim { return Err(JsError::new(&format!( "Input size mismatch: expected {}, got {}", seq_len * dim, input.len() ))); } // Compute approximate attention scores based on state decay // This shows how much each position can "attend to" previous positions let mut scores = vec![0.0f32; seq_len * seq_len]; for t in 0..seq_len { for s in 0..=t { // Exponential decay based on distance and A parameters let distance = (t - s) as f32; let decay: f32 = self .a_log .iter() .map(|&a| (a * distance).exp()) .sum::() / self.config.state_dim as f32; scores[t * seq_len + s] = decay; } } Ok(scores) } } // Internal implementation methods impl MambaSSMAttention { /// Project input from dim to inner_dim fn project_in(&self, input: &[Vec]) -> Vec> { input .iter() .map(|x| { self.in_proj .iter() .map(|row| row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum()) .collect() }) .collect() } /// Project from inner_dim back to dim fn project_out(&self, input: &[Vec]) -> Vec> { input .iter() .map(|x| { self.out_proj .iter() .map(|row| row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum()) .collect() }) .collect() } /// Compute selective SSM parameters from input fn compute_selective_params(&self, input: &[Vec]) -> SelectiveSSMParams { let seq_len = input.len(); let state_dim = self.config.state_dim; // Compute input-dependent delta, B, C // Simplified: use sigmoid/tanh of input projections let mut a_bar = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len]; let mut b_bar = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len]; let mut c = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len]; let mut delta = vec![vec![vec![0.0; self.inner_dim]; 1]; seq_len]; for (t, x) in input.iter().enumerate() { // Compute delta from input (softplus of projection) let dt: Vec = x .iter() .map(|&xi| { let raw = xi * 0.1; // Simple scaling let dt_val = (1.0 + raw.exp()).ln(); // Softplus dt_val.clamp(self.config.dt_min, self.config.dt_max) }) .collect(); delta[t][0] = dt.clone(); for d in 0..self.inner_dim.min(x.len()) { let dt_d = dt[d.min(dt.len() - 1)]; for n in 0..state_dim { // Discretize A: A_bar = exp(delta * A) let a_continuous = self.a_log[n].exp(); // Negative a_bar[t][d][n] = (dt_d * a_continuous).exp(); // Discretize B: B_bar = delta * B (simplified) // B is input-dependent let b_input = if d < x.len() { x[d] } else { 0.0 }; b_bar[t][d][n] = dt_d * Self::sigmoid(b_input * 0.1); // C is input-dependent c[t][d][n] = Self::tanh(b_input * 0.1); } } } SelectiveSSMParams { a_bar, b_bar, c, delta, } } /// Run selective scan (parallel associative scan in practice) fn selective_scan(&self, input: &[Vec], params: &SelectiveSSMParams) -> Vec> { let seq_len = input.len(); let state_dim = self.config.state_dim; // Initialize hidden state let mut hidden = vec![vec![0.0f32; state_dim]; self.inner_dim]; let mut output = vec![vec![0.0f32; self.inner_dim]; seq_len]; for t in 0..seq_len { for d in 0..self.inner_dim { let x_d = if d < input[t].len() { input[t][d] } else { 0.0 }; // Update hidden state: h_t = A_bar * h_{t-1} + B_bar * x_t for n in 0..state_dim { hidden[d][n] = params.a_bar[t][d][n] * hidden[d][n] + params.b_bar[t][d][n] * x_d; } // Compute output: y_t = C * h_t output[t][d] = hidden[d] .iter() .zip(params.c[t][d].iter()) .map(|(h, c)| h * c) .sum(); } } output } #[inline] fn sigmoid(x: f32) -> f32 { 1.0 / (1.0 + (-x).exp()) } #[inline] fn tanh(x: f32) -> f32 { x.tanh() } } // ============================================================================ // Hybrid Mamba-Attention // ============================================================================ /// Hybrid layer combining Mamba SSM with standard attention /// /// Uses Mamba for long-range dependencies and attention for local patterns #[wasm_bindgen] pub struct HybridMambaAttention { mamba: MambaSSMAttention, local_window: usize, use_attention_for_local: bool, } #[wasm_bindgen] impl HybridMambaAttention { /// Create a new hybrid Mamba-Attention layer #[wasm_bindgen(constructor)] pub fn new(config: MambaConfig, local_window: usize) -> HybridMambaAttention { HybridMambaAttention { mamba: MambaSSMAttention::new(config), local_window, use_attention_for_local: true, } } /// Forward pass #[wasm_bindgen] pub fn forward(&self, input: Vec, seq_len: usize) -> Result, JsError> { let dim = self.mamba.config.dim; // Run Mamba for global context let mamba_output = self.mamba.forward(input.clone(), seq_len)?; // Apply local attention mixing (simplified) let mut output = mamba_output.clone(); if self.use_attention_for_local { for t in 0..seq_len { let start = t.saturating_sub(self.local_window / 2); let end = (t + self.local_window / 2 + 1).min(seq_len); // Simple local averaging for d in 0..dim { let mut local_sum = 0.0; let mut count = 0; for s in start..end { local_sum += input[s * dim + d]; count += 1; } // Mix global (Mamba) and local let local_avg = local_sum / count as f32; output[t * dim + d] = 0.7 * output[t * dim + d] + 0.3 * local_avg; } } } Ok(output) } /// Get local window size #[wasm_bindgen(getter, js_name = localWindow)] pub fn local_window(&self) -> usize { self.local_window } } // ============================================================================ // Tests // ============================================================================ #[cfg(test)] mod tests { use super::*; use wasm_bindgen_test::*; wasm_bindgen_test_configure!(run_in_browser); #[wasm_bindgen_test] fn test_mamba_config() { let config = MambaConfig::new(256); assert_eq!(config.dim, 256); assert_eq!(config.state_dim, 16); assert_eq!(config.expand_factor, 2); } #[wasm_bindgen_test] fn test_mamba_creation() { let config = MambaConfig::new(64); let mamba = MambaSSMAttention::new(config); assert_eq!(mamba.inner_dim(), 128); // 64 * 2 } #[wasm_bindgen_test] fn test_mamba_forward() { let config = MambaConfig::new(8); let mamba = MambaSSMAttention::new(config); // Input: 4 tokens of dimension 8 let input = vec![0.1f32; 32]; let output = mamba.forward(input, 4); assert!(output.is_ok()); let out = output.unwrap(); assert_eq!(out.len(), 32); // Same shape as input } #[wasm_bindgen_test] fn test_attention_scores() { let config = MambaConfig::new(8); let mamba = MambaSSMAttention::new(config); let input = vec![0.1f32; 24]; // 3 tokens let scores = mamba.get_attention_scores(input, 3); assert!(scores.is_ok()); let s = scores.unwrap(); assert_eq!(s.len(), 9); // 3x3 attention matrix // Causal: upper triangle should be 0 assert_eq!(s[0 * 3 + 1], 0.0); // t=0 cannot attend to t=1 assert_eq!(s[0 * 3 + 2], 0.0); // t=0 cannot attend to t=2 } #[wasm_bindgen_test] fn test_hybrid_mamba() { let config = MambaConfig::new(8); let hybrid = HybridMambaAttention::new(config, 4); let input = vec![0.5f32; 40]; // 5 tokens let output = hybrid.forward(input, 5); assert!(output.is_ok()); assert_eq!(output.unwrap().len(), 40); } }