Files
wifi-densepose/crates/ruvector-mincut-gated-transformer/examples/flash_attention_demo.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

176 lines
5.6 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! FlashAttention demonstration
//!
//! Shows how to use FlashAttention-style tiled attention for CPU inference.
use ruvector_mincut_gated_transformer::flash_attention::{
flash_attention_forward, flash_attention_forward_i8, flash_mha, FlashAttentionConfig,
};
fn main() {
println!("=== FlashAttention CPU Demo ===\n");
// Configuration for 64-dim attention head
let config = FlashAttentionConfig::for_head_dim(64);
println!("Configuration:");
println!(" Block size (Q): {}", config.block_size_q);
println!(" Block size (KV): {}", config.block_size_kv);
println!(" Head dimension: {}", config.head_dim);
println!(" Causal masking: {}", config.causal);
println!(" Softmax scale: {:.4}\n", config.softmax_scale);
// Example 1: Single-head attention
{
println!("Example 1: Single-head attention (128 tokens, 64 dims)");
let seq_len = 128;
let head_dim = 64;
// Create random-like input (deterministic for demo)
let q: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i % 100) as f32) * 0.01)
.collect();
let k: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i % 100) as f32) * 0.01)
.collect();
let v: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i % 100) as f32) * 0.01)
.collect();
let mut output = vec![0.0f32; seq_len * head_dim];
flash_attention_forward(&config, &q, &k, &v, seq_len, seq_len, &mut output);
println!(" ✓ Computed attention output: {} elements", output.len());
println!(" ✓ First 5 output values: {:?}\n", &output[0..5]);
}
// Example 2: Multi-head attention
{
println!("Example 2: Multi-head attention (8 heads, 64 tokens, 64 dims)");
let num_heads = 8;
let seq_len = 64;
let head_dim = 64;
let total_size = num_heads * seq_len * head_dim;
let q: Vec<f32> = (0..total_size).map(|i| ((i % 100) as f32) * 0.01).collect();
let k: Vec<f32> = (0..total_size).map(|i| ((i % 100) as f32) * 0.01).collect();
let v: Vec<f32> = (0..total_size).map(|i| ((i % 100) as f32) * 0.01).collect();
let mut output = vec![0.0f32; total_size];
flash_mha(
&config,
&q,
&k,
&v,
num_heads,
seq_len,
seq_len,
&mut output,
);
println!(
" ✓ Computed multi-head attention: {} elements",
output.len()
);
println!(" ✓ Output per head: {} elements", seq_len * head_dim);
println!(" ✓ First 5 output values: {:?}\n", &output[0..5]);
}
// Example 3: INT8 quantized attention
{
println!("Example 3: INT8 quantized attention (64 tokens, 64 dims)");
let seq_len = 64;
let head_dim = 64;
// Create FP32 data and quantize to INT8
let q_f32: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i % 100) as f32) * 0.01)
.collect();
let k_f32: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i % 100) as f32) * 0.01)
.collect();
let v_f32: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i % 100) as f32) * 0.01)
.collect();
// Quantization scales
let q_scale = 0.01f32;
let k_scale = 0.01f32;
let v_scale = 0.01f32;
// Quantize to INT8
let q_i8: Vec<i8> = q_f32
.iter()
.map(|&x| (x / q_scale).round().clamp(-128.0, 127.0) as i8)
.collect();
let k_i8: Vec<i8> = k_f32
.iter()
.map(|&x| (x / k_scale).round().clamp(-128.0, 127.0) as i8)
.collect();
let v_i8: Vec<i8> = v_f32
.iter()
.map(|&x| (x / v_scale).round().clamp(-128.0, 127.0) as i8)
.collect();
let mut output = vec![0.0f32; seq_len * head_dim];
flash_attention_forward_i8(
&config,
&q_i8,
&k_i8,
&v_i8,
q_scale,
k_scale,
v_scale,
seq_len,
seq_len,
&mut output,
);
println!(" ✓ Computed INT8 quantized attention");
println!(" ✓ Memory savings: 4× (INT8 vs FP32)");
println!(" ✓ First 5 output values: {:?}\n", &output[0..5]);
}
// Example 4: Configuration for long sequences
{
println!("Example 4: Optimized config for long sequences (512 tokens)");
let long_config = FlashAttentionConfig::for_long_sequence(64);
println!(
" Block size (Q): {} (smaller for cache reuse)",
long_config.block_size_q
);
println!(
" Block size (KV): {} (larger for efficiency)",
long_config.block_size_kv
);
let seq_len = 512;
let head_dim = 64;
let q: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i % 100) as f32) * 0.01)
.collect();
let k: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i % 100) as f32) * 0.01)
.collect();
let v: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i % 100) as f32) * 0.01)
.collect();
let mut output = vec![0.0f32; seq_len * head_dim];
flash_attention_forward(&long_config, &q, &k, &v, seq_len, seq_len, &mut output);
println!(" ✓ Computed attention for {} tokens", seq_len);
println!(" ✓ Memory efficient: O(n) instead of O(n²)");
println!(" ✓ Cache efficient: Tiled for L1/L2 cache\n");
}
println!("=== All examples completed successfully! ===");
}