Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
@@ -0,0 +1,116 @@
|
||||
//! Example demonstrating Mamba State Space Model usage.
|
||||
//!
|
||||
//! This example shows:
|
||||
//! 1. Creating and configuring a Mamba layer
|
||||
//! 2. Single-step (recurrent) inference
|
||||
//! 3. Sequence processing
|
||||
//! 4. State persistence across timesteps
|
||||
|
||||
use ruvector_mincut_gated_transformer::mamba::{MambaConfig, MambaLayer, MambaState, MambaWeights};
|
||||
|
||||
fn main() {
|
||||
println!("=== Mamba State Space Model Example ===\n");
|
||||
|
||||
// Create configuration
|
||||
let config = MambaConfig {
|
||||
d_model: 128,
|
||||
d_state: 16,
|
||||
d_conv: 4,
|
||||
expand: 2,
|
||||
dt_rank: 16,
|
||||
dt_min: 0.001,
|
||||
dt_max: 0.1,
|
||||
};
|
||||
|
||||
println!("Configuration:");
|
||||
println!(" Model dimension: {}", config.d_model);
|
||||
println!(" State dimension: {}", config.d_state);
|
||||
println!(" Inner dimension: {}", config.d_inner());
|
||||
println!(" Convolution width: {}", config.d_conv);
|
||||
println!();
|
||||
|
||||
// Create layer and initialize weights
|
||||
let layer = MambaLayer::new(config.clone());
|
||||
let weights = MambaWeights::empty(&config);
|
||||
|
||||
println!("Layer created with {} parameters", {
|
||||
let d_inner = config.d_inner();
|
||||
config.d_model * d_inner * 2 // in_proj
|
||||
+ d_inner * config.d_conv // conv1d
|
||||
+ d_inner * (config.dt_rank + config.d_state * 2) // x_proj
|
||||
+ config.dt_rank * d_inner // dt_proj
|
||||
+ d_inner * config.d_state // a_log
|
||||
+ d_inner // d
|
||||
+ d_inner * config.d_model // out_proj
|
||||
});
|
||||
println!();
|
||||
|
||||
// Example 1: Single-step inference
|
||||
println!("Example 1: Single-step inference");
|
||||
let mut state = MambaState::new(&config);
|
||||
let input = vec![0.1; config.d_model];
|
||||
|
||||
println!("Processing single token...");
|
||||
let output = layer.forward_step(&weights, &input, &mut state);
|
||||
println!(" Input shape: [{}]", input.len());
|
||||
println!(" Output shape: [{}]", output.len());
|
||||
println!(" State updated: {}", state.h.iter().any(|&x| x != 0.0));
|
||||
println!();
|
||||
|
||||
// Example 2: Sequential processing with state
|
||||
println!("Example 2: Sequential processing");
|
||||
let mut state = MambaState::new(&config);
|
||||
let sequence_length = 5;
|
||||
|
||||
for t in 0..sequence_length {
|
||||
let input = vec![0.1 * (t as f32 + 1.0); config.d_model];
|
||||
let output = layer.forward_step(&weights, &input, &mut state);
|
||||
println!(" Step {}: output[0] = {:.6}", t, output[0]);
|
||||
}
|
||||
println!();
|
||||
|
||||
// Example 3: Sequence mode
|
||||
println!("Example 3: Sequence mode (parallel)");
|
||||
let seq_len = 4;
|
||||
let input_seq = vec![0.2; seq_len * config.d_model];
|
||||
|
||||
println!("Processing sequence of length {}...", seq_len);
|
||||
let output_seq = layer.forward_sequence(&weights, &input_seq, seq_len);
|
||||
println!(" Input shape: [{}, {}]", seq_len, config.d_model);
|
||||
println!(" Output shape: [{}, {}]", seq_len, config.d_model);
|
||||
println!(" First output: {:.6}", output_seq[0]);
|
||||
println!();
|
||||
|
||||
// Example 4: State reset
|
||||
println!("Example 4: State persistence and reset");
|
||||
let mut state = MambaState::new(&config);
|
||||
let input1 = vec![0.5; config.d_model];
|
||||
let input2 = vec![0.3; config.d_model];
|
||||
|
||||
let out1 = layer.forward_step(&weights, &input1, &mut state);
|
||||
println!(" First forward: output[0] = {:.6}", out1[0]);
|
||||
|
||||
let out2 = layer.forward_step(&weights, &input2, &mut state);
|
||||
println!(" Second forward: output[0] = {:.6}", out2[0]);
|
||||
|
||||
state.reset();
|
||||
let out1_reset = layer.forward_step(&weights, &input1, &mut state);
|
||||
println!(" After reset: output[0] = {:.6}", out1_reset[0]);
|
||||
println!(
|
||||
" Matches first: {}",
|
||||
(out1[0] - out1_reset[0]).abs() < 1e-5
|
||||
);
|
||||
println!();
|
||||
|
||||
// Performance characteristics
|
||||
println!("Performance Characteristics:");
|
||||
println!(" Complexity per step: O(N) vs O(N²) for attention");
|
||||
println!(" Memory per step: O(1) vs O(N) for attention");
|
||||
println!(
|
||||
" State size: {} floats",
|
||||
state.h.len() + state.conv_state.len()
|
||||
);
|
||||
println!();
|
||||
|
||||
println!("=== Example Complete ===");
|
||||
}
|
||||
Reference in New Issue
Block a user