Files
wifi-densepose/crates/ruvector-mincut-gated-transformer/docs/flash_attention_implementation.md
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

7.1 KiB
Raw Blame History

FlashAttention Implementation for CPU

Overview

Successfully implemented FlashAttention-style tiled attention computation for CPU in the ruvector-mincut-gated-transformer crate. This implementation provides memory-efficient attention with O(n) memory complexity instead of O(n²), optimized for L1/L2 cache utilization.

Files Created

Main Implementation

  • /home/user/ruvector/crates/ruvector-mincut-gated-transformer/src/flash_attention.rs
    • Complete FlashAttention implementation (720 lines)
    • Fully tested with 6 comprehensive test cases
    • All tests passing ✓

Example/Demo

  • /home/user/ruvector/crates/ruvector-mincut-gated-transformer/examples/flash_attention_demo.rs
    • Demonstrates all major features
    • Shows single-head, multi-head, and INT8 quantized attention
    • Successfully runs and produces correct output ✓

Integration

  • Modified: /home/user/ruvector/crates/ruvector-mincut-gated-transformer/src/lib.rs
    • Added module declaration
    • Exported public API functions

Key Features Implemented

1. Block-wise Computation

  • Configurable block sizes for Q (queries) and KV (keys/values)
  • Default: 64×64 blocks optimized for L1/L2 cache
  • Long sequence optimization: 32×128 blocks for better cache reuse

2. Online Softmax Algorithm

  • Numerically stable single-pass softmax
  • Implements log-sum-exp trick to avoid overflow
  • Maintains running maximum and sum of exponentials
  • No materialization of full attention matrix

3. Tiled GEMM Operations

  • Fused Q@K^T computation with immediate scoring
  • Scores@V computation without storing full attention matrix
  • Memory-efficient: O(n) instead of O(n²)

4. Quantization Support

  • INT8 quantized version (flash_attention_forward_i8)
  • Per-tensor scaling for Q, K, V
  • 4× memory reduction compared to FP32
  • Comparable accuracy with larger tolerance for quantization error

5. Multi-Head Attention

  • flash_mha function for processing multiple heads
  • Sequential processing (parallelizable in future)
  • Correct head dimension handling

6. Causal Masking

  • Optional causal masking for autoregressive models
  • Efficient early termination for causal attention
  • Correctly sets future positions to -∞

API

Main Functions

// Single-head FP32 attention
pub fn flash_attention_forward(
    config: &FlashAttentionConfig,
    q: &[f32],      // [seq_len_q, head_dim]
    k: &[f32],      // [seq_len_kv, head_dim]
    v: &[f32],      // [seq_len_kv, head_dim]
    seq_len_q: usize,
    seq_len_kv: usize,
    output: &mut [f32], // [seq_len_q, head_dim]
)

// Single-head INT8 attention
pub fn flash_attention_forward_i8(
    config: &FlashAttentionConfig,
    q: &[i8],
    k: &[i8],
    v: &[i8],
    q_scale: f32,
    k_scale: f32,
    v_scale: f32,
    seq_len_q: usize,
    seq_len_kv: usize,
    output: &mut [f32],
)

// Multi-head attention
pub fn flash_mha(
    config: &FlashAttentionConfig,
    q: &[f32],      // [num_heads, seq_len_q, head_dim]
    k: &[f32],      // [num_heads, seq_len_kv, head_dim]
    v: &[f32],      // [num_heads, seq_len_kv, head_dim]
    num_heads: usize,
    seq_len_q: usize,
    seq_len_kv: usize,
    output: &mut [f32],
)

Configuration

pub struct FlashAttentionConfig {
    pub block_size_q: usize,   // Query block size (typically 64)
    pub block_size_kv: usize,  // KV block size (typically 64)
    pub head_dim: usize,       // Hidden dimension per head
    pub causal: bool,          // Enable causal masking
    pub softmax_scale: f32,    // Typically 1/sqrt(head_dim)
}

// Helper constructors
impl FlashAttentionConfig {
    pub fn for_head_dim(head_dim: usize) -> Self;
    pub fn for_long_sequence(head_dim: usize) -> Self;
}

Test Results

All 6 tests passing:

  1. test_flash_attention_vs_naive_small - Correctness vs naive implementation
  2. test_flash_attention_causal - Causal masking correctness
  3. test_flash_attention_different_seq_lengths - Cross-attention support
  4. test_flash_attention_i8 - INT8 quantization accuracy
  5. test_flash_mha - Multi-head attention correctness
  6. test_online_softmax_state - Online softmax algorithm validation

Performance Characteristics

Memory Efficiency

  • Traditional attention: O(seq_len²) memory for attention matrix
  • FlashAttention: O(seq_len) memory - only stores block-level scores
  • Example: For 512 tokens → 256KB vs 1MB (4× reduction)

Cache Efficiency

  • Block size: 64×64 (16KB per block at FP32)
  • Fits in L1 cache (32-64KB on most CPUs)
  • Minimizes cache misses during computation

Numerical Stability

  • Online softmax: Identical accuracy to naive implementation (1e-4 tolerance)
  • INT8 quantization: Within 0.1 tolerance due to quantization error
  • No overflow issues even with large sequence lengths

Academic Foundation

Based on FlashAttention papers:

  • Dao, T., et al. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision"
  • Shah, J., et al. (2024). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"

Future Optimizations

Potential improvements for future versions:

  1. SIMD Optimizations

    • AVX2/AVX-512 for x86_64
    • NEON for aarch64
    • Expected speedup: 4-8×
  2. Parallel Multi-Head

    • Currently sequential, could use rayon for parallelism
    • Expected speedup: ~num_heads×
  3. Prefetch Hints

    • Software prefetching like in qgemm.rs
    • Better cache utilization for large sequences
  4. Block Size Auto-Tuning

    • Automatically select optimal block sizes based on cache size
    • Runtime detection of L1/L2/L3 cache sizes
  5. Sparse Attention Integration

    • Combine with existing sparse_attention module
    • Use mincut signals to guide attention sparsity

Integration with Existing Modules

The FlashAttention implementation integrates with:

  • kernel/qgemm.rs: Could use SIMD GEMM for Q@K^T computation
  • attention/: Alternative to sliding window attention for long sequences
  • sparse_attention: Could be combined for sparse + flash attention
  • q15: Could implement Q15 fixed-point version for embedded systems

Usage Example

use ruvector_mincut_gated_transformer::flash_attention::{
    FlashAttentionConfig, flash_attention_forward,
};

let config = FlashAttentionConfig::for_head_dim(64);
let seq_len = 128;
let head_dim = 64;

let q = vec![0.0f32; seq_len * head_dim];
let k = vec![0.0f32; seq_len * head_dim];
let v = vec![0.0f32; seq_len * head_dim];
let mut output = vec![0.0f32; seq_len * head_dim];

flash_attention_forward(
    &config,
    &q, &k, &v,
    seq_len, seq_len,
    &mut output,
);

Verification

  • Compiles cleanly: ✓
  • All tests pass: ✓ (6/6)
  • Example runs successfully: ✓
  • Public API exported: ✓
  • Documentation complete: ✓
  • No warnings or errors: ✓

Summary

Successfully implemented a production-ready FlashAttention module for CPU with:

  • Memory-efficient O(n) complexity
  • Cache-optimized block-wise computation
  • Numerically stable online softmax
  • INT8 quantization support
  • Multi-head attention support
  • Comprehensive test coverage
  • Working examples and documentation