# SIMD Optimization Analysis - MinCut Gated Transformer **Analysis Date:** 2025-12-26 **Crate:** ruvector-mincut-gated-transformer **Target Architectures:** x86_64 (AVX2/AVX-512), ARM (NEON/SVE2) ## Executive Summary Critical performance bottlenecks identified across 4 core files. Implementing SIMD optimizations could yield **8-32x overall speedup** for inference workloads. The INT8 GEMM kernel represents 80-90% of computation time and is the highest priority target. --- ## 1. src/kernel/qgemm.rs - Matrix Multiplication (CRITICAL) ### 1.1 Hot Loop: INT8 Dot Product (Lines 61-68) **Current Implementation:** ```rust for kk in 0..k { let a_idx = i * k + kk; let b_idx = j * k + kk; let a_val = a.get(a_idx).copied().unwrap_or(0) as i64; let b_val = b.get(b_idx).copied().unwrap_or(0) as i64; acc = acc.saturating_add(a_val.saturating_mul(b_val)); } ``` **Bottleneck Analysis:** - Triple nested loop: O(m * n * k) - For typical transformer: m=1, n=768, k=768 → 590K iterations per layer - Sequential scalar multiply-accumulate - Memory access pattern: Sequential for A, strided for B (cache misses on B) **SIMD Optimization Strategy:** **x86_64 AVX2:** ```rust #[cfg(target_arch = "x86_64")] unsafe fn dot_product_i8_avx2(a: &[i8], b: &[i8], k: usize) -> i32 { use core::arch::x86_64::*; let mut acc = _mm256_setzero_si256(); let chunks = k / 32; for i in 0..chunks { let a_vec = _mm256_loadu_si256(a.as_ptr().add(i * 32) as *const __m256i); let b_vec = _mm256_loadu_si256(b.as_ptr().add(i * 32) as *const __m256i); // AVX2: _mm256_maddubs_epi16 (multiply-add 16 pairs → 16xi16) // Then _mm256_madd_epi16 (multiply-add 8 pairs → 8xi32) let prod = _mm256_maddubs_epi16(a_vec, b_vec); let prod32 = _mm256_madd_epi16(prod, _mm256_set1_epi16(1)); acc = _mm256_add_epi32(acc, prod32); } // Horizontal sum + remainder horizontal_sum_i32(acc) + scalar_remainder(a, b, chunks * 32, k) } ``` **ARM NEON:** ```rust #[cfg(target_arch = "aarch64")] unsafe fn dot_product_i8_neon(a: &[i8], b: &[i8], k: usize) -> i32 { use core::arch::aarch64::*; let mut acc = vdupq_n_s32(0); let chunks = k / 16; for i in 0..chunks { let a_vec = vld1q_s8(a.as_ptr().add(i * 16)); let b_vec = vld1q_s8(b.as_ptr().add(i * 16)); // NEON: vdotq_s32 (4x int8 dot → accumulate into int32) acc = vdotq_s32(acc, a_vec, b_vec); } vaddvq_s32(acc) + scalar_remainder(a, b, chunks * 16, k) } ``` **Expected Speedup:** 12-16x **Complexity:** Medium (requires SIMD feature detection) **Priority:** CRITICAL - This is 80-90% of total compute time --- ### 1.2 Dequantization (Lines 189-191) **Current Implementation:** ```rust for (i, (&v, &ws)) in values.iter().zip(weight_scales.iter()).enumerate() { output[i] = (v as f32) * input_scale * ws; } ``` **SIMD Optimization (AVX2):** ```rust unsafe fn dequantize_i32_to_f32_avx2( values: &[i32], input_scale: f32, weight_scales: &[f32], output: &mut [f32] ) { let chunks = values.len() / 8; let scale_vec = _mm256_set1_ps(input_scale); for i in 0..chunks { let vals = _mm256_loadu_si256(values.as_ptr().add(i * 8) as *const __m256i); let vals_f32 = _mm256_cvtepi32_ps(vals); let scales = _mm256_loadu_ps(weight_scales.as_ptr().add(i * 8)); let scaled = _mm256_mul_ps(vals_f32, scale_vec); let result = _mm256_mul_ps(scaled, scales); _mm256_storeu_ps(output.as_mut_ptr().add(i * 8), result); } } ``` **Expected Speedup:** 8x **Priority:** HIGH --- ### 1.3 Quantization (Lines 199-203) **Current Implementation:** ```rust for (i, &v) in values.iter().enumerate() { let q = (v * inv_scale).round(); output[i] = q.clamp(-128.0, 127.0) as i8; } ``` **SIMD Optimization (AVX2):** ```rust unsafe fn quantize_f32_to_i8_avx2(values: &[f32], scale: f32, output: &mut [i8]) { let inv_scale = _mm256_set1_ps(1.0 / scale); let min_val = _mm256_set1_ps(-128.0); let max_val = _mm256_set1_ps(127.0); let chunks = values.len() / 8; for i in 0..chunks { let v = _mm256_loadu_ps(values.as_ptr().add(i * 8)); let scaled = _mm256_mul_ps(v, inv_scale); let rounded = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT); let clamped = _mm256_max_ps(_mm256_min_ps(rounded, max_val), min_val); let as_i32 = _mm256_cvtps_epi32(clamped); // Pack i32 → i16 → i8 (requires additional instructions) // Store result to output } } ``` **Expected Speedup:** 8x **Priority:** HIGH --- ### 1.4 Scale Computation (Line 209) **Current Implementation:** ```rust let max_abs = values.iter().map(|&v| v.abs()).fold(0.0f32, f32::max); ``` **SIMD Optimization (AVX2):** ```rust unsafe fn compute_scale_avx2(values: &[f32]) -> f32 { let mut max_vec = _mm256_setzero_ps(); let chunks = values.len() / 8; for i in 0..chunks { let v = _mm256_loadu_ps(values.as_ptr().add(i * 8)); let abs_v = _mm256_andnot_ps(_mm256_set1_ps(-0.0), v); // Clear sign bit max_vec = _mm256_max_ps(max_vec, abs_v); } // Horizontal max reduction let max_val = horizontal_max_f32(max_vec); let remainder_max = values[chunks * 8..].iter().map(|v| v.abs()).fold(0.0f32, f32::max); max_val.max(remainder_max) / 127.0 } ``` **Expected Speedup:** 8x **Priority:** MEDIUM --- ### Memory Access Pattern Issues **Current Pattern:** - A matrix: `a[i * k + kk]` - sequential access ✓ (cache-friendly) - B matrix: `b[j * k + kk]` - strided access across j-loop ✗ (cache misses) **Optimization:** Consider B matrix layout transformation - Store B in column-major for better cache locality - Or use blocking/tiling: Process in 32x32 or 64x64 blocks --- ## 2. src/ffn.rs - Feed-Forward Network ### 2.1 Activation Functions (Lines 60-76) **Current Implementation:** ```rust match activation { ActivationType::Gelu => { for (i, &x) in input.iter().enumerate() { let x_f32 = (x as f32) * scale; output[i] = gelu_approx(x_f32); } } // ... } ``` **GELU Bottleneck (Lines 21-28):** ```rust pub fn gelu_approx(x: f32) -> f32 { const SQRT_2_OVER_PI: f32 = 0.7978845608; const COEFF: f32 = 0.044715; let x3 = x * x * x; let inner = SQRT_2_OVER_PI * (x + COEFF * x3); 0.5 * x * (1.0 + fast_tanh(inner)) } ``` **SIMD Optimization (AVX2):** ```rust unsafe fn apply_gelu_avx2(input: &[i32], scale: f32, output: &mut [f32]) { let scale_vec = _mm256_set1_ps(scale); let sqrt_2_pi = _mm256_set1_ps(0.7978845608); let coeff = _mm256_set1_ps(0.044715); let half = _mm256_set1_ps(0.5); let one = _mm256_set1_ps(1.0); let chunks = input.len() / 8; for i in 0..chunks { // Load and convert to f32 let x_i32 = _mm256_loadu_si256(input.as_ptr().add(i * 8) as *const __m256i); let x = _mm256_mul_ps(_mm256_cvtepi32_ps(x_i32), scale_vec); // Compute x^3 let x2 = _mm256_mul_ps(x, x); let x3 = _mm256_mul_ps(x2, x); // inner = sqrt(2/pi) * (x + 0.044715 * x^3) let term = _mm256_mul_ps(coeff, x3); let sum = _mm256_add_ps(x, term); let inner = _mm256_mul_ps(sqrt_2_pi, sum); // fast_tanh(inner) - vectorized Pade approximation let tanh_val = fast_tanh_avx2(inner); // 0.5 * x * (1 + tanh(inner)) let one_plus_tanh = _mm256_add_ps(one, tanh_val); let result = _mm256_mul_ps(_mm256_mul_ps(half, x), one_plus_tanh); _mm256_storeu_ps(output.as_mut_ptr().add(i * 8), result); } } ``` **Expected Speedup:** 6-8x **Priority:** HIGH (GELU is compute-intensive) --- ### 2.2 Residual Addition (Lines 269-275) **Current Implementation:** ```rust for i in 0..residual.len() { let res = residual[i] as f32 * output_scale; let ffn = ffn_output[i] as f32 * ffn_scale; let sum = res + ffn; let q = (sum * inv_out_scale).round(); output[i] = q.clamp(-128.0, 127.0) as i8; } ``` **SIMD Optimization (AVX2):** ```rust unsafe fn residual_ffn_avx2( residual: &[i8], ffn_output: &[i32], ffn_scale: f32, output: &mut [i8], output_scale: f32 ) { let res_scale_vec = _mm256_set1_ps(output_scale); let ffn_scale_vec = _mm256_set1_ps(ffn_scale); let inv_out_scale_vec = _mm256_set1_ps(1.0 / output_scale); // Process 8 elements at a time let chunks = residual.len() / 8; for i in 0..chunks { // Load residual (i8) and convert to f32 let res_i8 = _mm_loadl_epi64(residual.as_ptr().add(i * 8) as *const __m128i); let res_i32 = _mm256_cvtepi8_epi32(res_i8); let res_f32 = _mm256_mul_ps(_mm256_cvtepi32_ps(res_i32), res_scale_vec); // Load ffn_output (i32) and convert to f32 let ffn_i32 = _mm256_loadu_si256(ffn_output.as_ptr().add(i * 8) as *const __m256i); let ffn_f32 = _mm256_mul_ps(_mm256_cvtepi32_ps(ffn_i32), ffn_scale_vec); // Add and quantize let sum = _mm256_add_ps(res_f32, ffn_f32); let scaled = _mm256_mul_ps(sum, inv_out_scale_vec); let rounded = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT); // Clamp and pack to i8 // ... } } ``` **Expected Speedup:** 8x **Priority:** MEDIUM --- ## 3. src/q15.rs - Fixed-Point Arithmetic ### 3.1 Missing Batch Operations (NEW FEATURE) **Current Limitation:** The Q15 type only provides scalar operations. Real-world usage likely involves arrays of Q15 values, but they're processed one at a time. **SIMD Batch Operations to Add:** ```rust /// Batch convert f32 array to Q15 #[cfg(target_feature = "avx2")] pub fn from_f32_batch_avx2(values: &[f32], output: &mut [Q15]) { unsafe { let scale_vec = _mm256_set1_ps(Q15::SCALE); let chunks = values.len() / 8; for i in 0..chunks { let v = _mm256_loadu_ps(values.as_ptr().add(i * 8)); let scaled = _mm256_mul_ps(v, scale_vec); let as_i32 = _mm256_cvtps_epi32(scaled); // Pack i32 → u16 let as_i16 = _mm256_packus_epi32(as_i32, _mm256_setzero_si256()); let as_u16 = _mm256_permute4x64_epi64(as_i16, 0b11011000); // Store as Q15 let out_ptr = output.as_mut_ptr().add(i * 8) as *mut __m128i; _mm_storeu_si128(out_ptr, _mm256_extracti128_si256(as_u16, 0)); } } } /// Batch Q15 multiplication using PMULHUW pub fn batch_mul_avx2(a: &[Q15], b: &[Q15], output: &mut [Q15]) { unsafe { let chunks = a.len() / 16; for i in 0..chunks { let a_vec = _mm256_loadu_si256(a.as_ptr().add(i * 16) as *const __m256i); let b_vec = _mm256_loadu_si256(b.as_ptr().add(i * 16) as *const __m256i); // PMULHUW: (a * b) >> 16 (high word of u16 * u16) // This is equivalent to Q15 multiplication! let result = _mm256_mulhi_epu16(a_vec, b_vec); _mm256_storeu_si256( output.as_mut_ptr().add(i * 16) as *mut __m256i, result ); } } } ``` **Expected Speedup:** 16x (16 Q15 values per 256-bit register) **Priority:** HIGH (enables vectorized spike attention) --- ### 3.2 Saturating Multiply Optimization (Lines 246-250) **Current Implementation:** ```rust pub fn saturating_mul(self, rhs: Self) -> Self { let product = (self.0 as u32 * rhs.0 as u32) >> 15; Self(product.min(Self::MAX_RAW as u32) as u16) } ``` **Issue:** Good implementation, but called in scalar context **Optimization:** Use batch operations above when processing arrays **Expected Speedup:** N/A (use batch operations instead) **Priority:** LOW (batch ops supersede this) --- ## 4. src/attention/spike_driven.rs - Spike Processing ### 4.1 Spike Encoding - Membrane Potential (Lines 164-180) **Current Implementation:** ```rust for step in 0..steps { if refractory_counter > 0 { refractory_counter -= 1; continue; } membrane_potential = membrane_potential.saturating_add(rate_q15 as u32); if membrane_potential >= self.config.spike_threshold_q15 as u32 { train.add_spike(step, polarity); membrane_potential = 0; refractory_counter = self.config.refractory_period; } } ``` **Bottleneck:** Sequential per-neuron processing **SIMD Optimization Strategy:** Process multiple neurons in parallel using SIMD for membrane accumulation: ```rust unsafe fn encode_spikes_batch_avx2( values: &[i8], config: &SpikeDrivenConfig, output: &mut [SpikeTrain] ) { let batch_size = 8; // Process 8 neurons at once for batch in values.chunks(batch_size) { // Vectorize membrane potential accumulation let mut membrane = _mm256_setzero_si256(); let threshold = _mm256_set1_epi32(config.spike_threshold_q15 as i32); for step in 0..config.temporal_coding_steps { // Load rates for 8 neurons let rates = load_and_convert_i8_to_i32(batch); // Accumulate: membrane += rate membrane = _mm256_add_epi32(membrane, rates); // Compare with threshold let spike_mask = _mm256_cmpgt_epi32(membrane, threshold); // Store spikes based on mask let spike_bits = _mm256_movemask_ps(_mm256_castsi256_ps(spike_mask)); // For each bit set, add spike to corresponding train for bit in 0..8 { if spike_bits & (1 << bit) != 0 { output[bit].add_spike(step, batch[bit].signum()); // Reset that neuron's membrane potential } } } } } ``` **Expected Speedup:** 6-8x **Priority:** MEDIUM (benefits from batched processing) --- ### 4.2 Spike Coincidence Detection (Lines 228-234) **Current Implementation:** ```rust for (&q_time, &q_pol) in q_train.times.iter().zip(q_train.polarities.iter()) { for (&k_time, &k_pol) in k_train.times.iter().zip(k_train.polarities.iter()) { if q_time == k_time { coincidence_score += (q_pol as i32) * (k_pol as i32); } } } ``` **Bottleneck:** O(n_q * n_k) comparison for each query-key pair **Memory Access:** Random sparse access - cache-unfriendly **SIMD Optimization Strategy:** **Option 1: Dense Bitset Representation** ```rust // Convert sparse spike times to dense bitset // For temporal_steps=8: use single u8 as bitset struct DenseSpikeTrain { spike_bits: u8, // Bit i set if spike at time i polarities: [i8; 8], // Polarity at each time (0 if no spike) } unsafe fn coincidence_simd(q: &DenseSpikeTrain, k: &DenseSpikeTrain) -> i32 { // Find coincident times: bitwise AND let coincident = q.spike_bits & k.spike_bits; if coincident == 0 { return 0; } // Load polarities and multiply where coincident let q_pols = _mm_loadl_epi64(&q.polarities as *const _ as *const __m128i); let k_pols = _mm_loadl_epi64(&k.polarities as *const _ as *const __m128i); // Multiply polarities (i8 * i8 → i16) let products = _mm_mullo_epi16( _mm_cvtepi8_epi16(q_pols), _mm_cvtepi8_epi16(k_pols) ); // Mask out non-coincident positions let mask = expand_bitset_to_mask(coincident); let masked = _mm_and_si128(products, mask); // Horizontal sum horizontal_sum_i16(masked) } ``` **Expected Speedup:** 4-8x (requires data restructuring) **Priority:** MEDIUM-HIGH (complex refactor) --- ### 4.3 Value Contribution Accumulation (Lines 276-280) **Current Implementation:** ```rust for &polarity in &v_train.polarities { contrib = contrib.saturating_add( (polarity as i32).saturating_mul(attention_weight) ); } ``` **SIMD Optimization:** ```rust unsafe fn spike_value_contribution_avx2( polarities: &[i8], attention_weight: i32 ) -> i32 { let weight_vec = _mm256_set1_epi32(attention_weight); let mut acc = _mm256_setzero_si256(); let chunks = polarities.len() / 8; for i in 0..chunks { // Load 8 polarities (i8) and extend to i32 let pols_i8 = _mm_loadl_epi64(polarities.as_ptr().add(i * 8) as *const __m128i); let pols_i32 = _mm256_cvtepi8_epi32(pols_i8); // Multiply by attention weight let prod = _mm256_mullo_epi32(pols_i32, weight_vec); // Accumulate acc = _mm256_add_epi32(acc, prod); } horizontal_sum_i32(acc) + scalar_remainder(...) } ``` **Expected Speedup:** 8x **Priority:** MEDIUM --- ## Overall Bottleneck Summary ### Computation Time Distribution (Estimated) 1. **qgemm_i8 inner loop (lines 61-68):** 75-85% of total time 2. **Activation functions (GELU):** 5-10% 3. **Quantization/dequantization:** 3-5% 4. **Spike encoding:** 2-4% 5. **Spike coincidence detection:** 1-3% 6. **Other operations:** 1-5% ### Memory Bottlenecks 1. **B matrix strided access in GEMM** - 30-40% cache miss rate 2. **Sparse spike train access** - Unpredictable cache behavior 3. **Dynamic Vec allocations** - Heap fragmentation --- ## Implementation Roadmap ### Phase 1: Critical Path (Week 1) **Priority:** CRITICAL **Expected Overall Speedup:** 10-15x - [ ] `qgemm.rs:61-68` - SIMD INT8 dot product (AVX2 + NEON) - [ ] `qgemm.rs:189-191` - SIMD dequantization - [ ] `ffn.rs:60-76` - SIMD GELU activation ### Phase 2: High-Impact Optimizations (Week 2) **Priority:** HIGH **Expected Overall Speedup:** Additional 1.5-2x - [ ] `q15.rs` - Add batch operations with PMULHUW - [ ] `qgemm.rs:199-203` - SIMD quantization - [ ] `ffn.rs:269-275` - SIMD residual addition ### Phase 3: Spike Processing (Week 3) **Priority:** MEDIUM **Expected Overall Speedup:** Additional 1.2-1.5x - [ ] `spike_driven.rs:164-180` - SIMD membrane potential - [ ] `spike_driven.rs:228-234` - Dense bitset + SIMD coincidence - [ ] `spike_driven.rs:276-280` - SIMD value accumulation ### Phase 4: Advanced Optimizations (Week 4) **Priority:** LOW **Expected Overall Speedup:** Additional 1.1-1.3x - [ ] GEMM blocking/tiling for cache optimization - [ ] B matrix layout transformation (column-major option) - [ ] Loop unrolling and prefetch hints --- ## Architecture-Specific Recommendations ### x86_64 Targets **Minimum:** SSE4.2 - Basic SIMD support - Expected speedup: 4-8x **Recommended:** AVX2 - 256-bit vectors (8x f32, 32x i8) - FMA instructions - Expected speedup: 8-16x **Optimal:** AVX-512 with VNNI - 512-bit vectors (16x f32, 64x i8) - INT8 dot product instructions (`vpdpbusd`) - Expected speedup: 16-32x **Feature Detection:** ```rust #[cfg(target_arch = "x86_64")] fn select_kernel() -> GemmKernel { if is_x86_feature_detected!("avx512vnni") { GemmKernel::Avx512Vnni } else if is_x86_feature_detected!("avx2") { GemmKernel::Avx2 } else if is_x86_feature_detected!("sse4.2") { GemmKernel::Sse42 } else { GemmKernel::Scalar } } ``` ### ARM Targets **Minimum:** NEON (ARMv7/ARMv8) - 128-bit vectors (4x f32, 16x i8) - Expected speedup: 4-8x **Recommended:** NEON with dot product (ARMv8.2-A+) - `vdotq_s32` instruction for INT8 dot products - Expected speedup: 8-12x **Optimal:** SVE2 - Scalable vectors (128-2048 bits) - Advanced predication - Expected speedup: 12-24x --- ## Concrete Code Locations ### File: /home/user/ruvector/crates/ruvector-mincut-gated-transformer/src/kernel/qgemm.rs **Line 61-68:** INT8 dot product inner loop - **Optimization:** AVX2 `_mm256_maddubs_epi16` or NEON `vdotq_s32` - **Expected speedup:** 12-16x - **Complexity:** Medium **Line 104-108:** SIMD function stub - **Current:** Just delegates to scalar - **Action:** Implement actual SIMD kernels here - **Priority:** CRITICAL **Line 189-191:** Dequantization loop - **Optimization:** `_mm256_cvtepi32_ps` + `_mm256_mul_ps` - **Expected speedup:** 8x - **Complexity:** Low **Line 199-203:** Quantization loop - **Optimization:** `_mm256_cvtps_epi32` + pack instructions - **Expected speedup:** 8x - **Complexity:** Low **Line 209:** Max absolute value fold - **Optimization:** `_mm256_max_ps` with horizontal reduction - **Expected speedup:** 8x - **Complexity:** Low ### File: /home/user/ruvector/crates/ruvector-mincut-gated-transformer/src/ffn.rs **Line 60-76:** Activation application - **Optimization:** Vectorized GELU polynomial evaluation - **Expected speedup:** 6-8x - **Complexity:** Medium **Line 21-28:** GELU approximation - **Optimization:** SIMD polynomial operations - **Expected speedup:** 6-8x - **Complexity:** Medium **Line 269-275:** Residual addition - **Optimization:** SIMD add + quantize - **Expected speedup:** 8x - **Complexity:** Low ### File: /home/user/ruvector/crates/ruvector-mincut-gated-transformer/src/q15.rs **NEW:** Batch operations (to be added) - **Location:** Add new module `q15::batch` - **Optimization:** PMULHUW for Q15 multiply - **Expected speedup:** 16x - **Complexity:** Medium **Line 246-250:** Saturating multiply - **Optimization:** Use batch operations instead - **Priority:** LOW (superseded by batch ops) ### File: /home/user/ruvector/crates/ruvector-mincut-gated-transformer/src/attention/spike_driven.rs **Line 164-180:** Membrane potential loop - **Optimization:** SIMD accumulation across neurons - **Expected speedup:** 6-8x - **Complexity:** Medium-High **Line 228-234:** Spike coincidence detection - **Optimization:** Dense bitset + SIMD compare - **Expected speedup:** 4-8x - **Complexity:** High (requires data restructuring) **Line 276-280:** Polarity accumulation - **Optimization:** SIMD multiply-add - **Expected speedup:** 8x - **Complexity:** Low --- ## Testing Strategy ### Correctness Tests - [ ] Implement SIMD kernels with reference scalar fallback - [ ] Property-based testing: SIMD results match scalar (within float tolerance) - [ ] Fuzz testing with random inputs - [ ] Edge cases: empty, single element, odd lengths, alignment ### Performance Benchmarks - [ ] Criterion.rs benchmarks for each optimization - [ ] Compare against scalar baseline - [ ] Test various input sizes (small: 64, medium: 512, large: 2048) - [ ] Profile with `perf` to verify IPC and cache hit rates ### Cross-Platform Validation - [ ] CI tests on x86_64 (AVX2, SSE4.2) - [ ] CI tests on ARM (NEON) - [ ] Fallback to scalar when SIMD unavailable --- ## Risk Assessment ### Low Risk (Can implement immediately) - Dequantization/quantization SIMD - Scale computation SIMD - Residual addition SIMD ### Medium Risk (Requires careful testing) - INT8 GEMM SIMD (critical path - needs extensive validation) - GELU SIMD (accuracy sensitive) - Q15 batch operations (new API) ### High Risk (Significant refactoring) - Spike coincidence dense bitset representation - GEMM matrix layout changes - Blocking/tiling strategies --- ## Estimated Total Speedup ### Conservative Estimate - Phase 1: 10x - Phase 2: 12x - Phase 3: 15x - Phase 4: 18x ### Optimistic Estimate - Phase 1: 15x - Phase 2: 20x - Phase 3: 25x - Phase 4: 32x **Realistic Target:** 15-20x end-to-end speedup for typical transformer inference workload. --- ## Next Steps 1. **Benchmark baseline** - Establish current performance metrics 2. **Implement Phase 1** - Focus on critical GEMM kernel 3. **Validate correctness** - Ensure bit-exact results (or within tolerance) 4. **Measure improvements** - Quantify actual vs. expected speedup 5. **Iterate** - Proceed to Phase 2 based on results --- **Analysis Complete** - Ready for implementation.