Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
560
vendor/ruvector/crates/ruQu/src/attention.rs
vendored
Normal file
560
vendor/ruvector/crates/ruQu/src/attention.rs
vendored
Normal file
@@ -0,0 +1,560 @@
|
||||
//! Mincut-Gated Attention Integration
|
||||
//!
|
||||
//! This module bridges ruQu's coherence gate with the `ruvector-mincut-gated-transformer`
|
||||
//! crate's attention optimization mechanisms:
|
||||
//!
|
||||
//! 1. **GatePacket Bridge** - Convert ruQu's `TileReport` aggregates into `GatePacket`
|
||||
//! 2. **MincutDepthRouter** - λ-based Mixture-of-Depths routing for 50% FLOPs reduction
|
||||
//! 3. **CoherenceEarlyExit** - Layer skipping based on coherence stability
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use ruqu::attention::{CoherenceAttention, AttentionConfig};
|
||||
//! use ruqu::tile::{TileReport, GateThresholds};
|
||||
//!
|
||||
//! // Create attention optimizer
|
||||
//! let config = AttentionConfig::default();
|
||||
//! let mut attention = CoherenceAttention::new(config);
|
||||
//!
|
||||
//! // Process syndrome patterns with coherence-optimized attention
|
||||
//! let reports: Vec<TileReport> = collect_worker_reports();
|
||||
//! let (gate_packet, routing) = attention.optimize(&reports);
|
||||
//!
|
||||
//! // Use routing decisions for efficient syndrome analysis
|
||||
//! for (i, route) in routing.iter().enumerate() {
|
||||
//! if route.requires_compute() {
|
||||
//! // Full analysis for this syndrome entry
|
||||
//! } else {
|
||||
//! // Skip - coherence is stable, use cached result
|
||||
//! }
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
#[cfg(feature = "attention")]
|
||||
use ruvector_mincut_gated_transformer::{
|
||||
CoherenceEarlyExit, EarlyExitConfig, EarlyExitDecision, ExitReason, GatePacket,
|
||||
MincutDepthRouter, ModRoutingConfig, RoutingStats, TokenRoute,
|
||||
};
|
||||
|
||||
use crate::tile::{GateDecision, TileReport};
|
||||
|
||||
/// Configuration for coherence-optimized attention
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AttentionConfig {
|
||||
/// Target FLOPs reduction (0.0-0.9), default 0.5 for 50%
|
||||
pub flops_reduction: f32,
|
||||
|
||||
/// Minimum entries that must be processed per round
|
||||
pub min_entries_per_round: u16,
|
||||
|
||||
/// λ-delta threshold for skipping (Q15 scale)
|
||||
/// Lower = more aggressive skipping
|
||||
pub lambda_delta_skip_threshold: i32,
|
||||
|
||||
/// Enable adaptive capacity based on coherence stability
|
||||
pub adaptive_capacity: bool,
|
||||
|
||||
/// Enable early exit when coherence is very stable
|
||||
pub enable_early_exit: bool,
|
||||
|
||||
/// Early exit confidence threshold (0.0-1.0)
|
||||
pub early_exit_threshold: f32,
|
||||
}
|
||||
|
||||
impl Default for AttentionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
flops_reduction: 0.5,
|
||||
min_entries_per_round: 4,
|
||||
lambda_delta_skip_threshold: 3276, // ~10% of Q15 range
|
||||
adaptive_capacity: true,
|
||||
enable_early_exit: true,
|
||||
early_exit_threshold: 0.95,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AttentionConfig {
|
||||
/// Configuration optimized for real-time coherence gating
|
||||
pub fn realtime() -> Self {
|
||||
Self {
|
||||
flops_reduction: 0.6, // More aggressive skip
|
||||
min_entries_per_round: 2,
|
||||
lambda_delta_skip_threshold: 2000, // More aggressive
|
||||
adaptive_capacity: true,
|
||||
enable_early_exit: true,
|
||||
early_exit_threshold: 0.9,
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration optimized for accuracy (less skipping)
|
||||
pub fn accurate() -> Self {
|
||||
Self {
|
||||
flops_reduction: 0.3,
|
||||
min_entries_per_round: 8,
|
||||
lambda_delta_skip_threshold: 5000, // Less aggressive
|
||||
adaptive_capacity: false,
|
||||
enable_early_exit: false,
|
||||
early_exit_threshold: 0.99,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Bridge between ruQu's TileReport and GatePacket
|
||||
///
|
||||
/// Converts aggregated tile metrics into the format expected by
|
||||
/// the mincut-gated-transformer system.
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct GatePacketBridge {
|
||||
/// Previous lambda for trend detection
|
||||
prev_lambda: u32,
|
||||
/// Smoothed boundary edge count
|
||||
smoothed_boundary: u16,
|
||||
}
|
||||
|
||||
impl GatePacketBridge {
|
||||
/// Create a new bridge
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Convert tile reports into a GatePacket
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `reports` - Aggregated worker tile reports
|
||||
///
|
||||
/// # Returns
|
||||
/// A `GatePacket` suitable for mincut-gated-transformer
|
||||
#[cfg(feature = "attention")]
|
||||
pub fn to_gate_packet(&mut self, reports: &[TileReport]) -> GatePacket {
|
||||
if reports.is_empty() {
|
||||
return GatePacket::default();
|
||||
}
|
||||
|
||||
// Aggregate metrics from reports
|
||||
let mut min_cut = f64::MAX;
|
||||
let mut max_shift = 0.0f64;
|
||||
let mut total_boundary = 0u32;
|
||||
let mut max_boundary_concentration = 0u32;
|
||||
|
||||
for report in reports {
|
||||
if report.local_cut < min_cut && report.local_cut > 0.0 {
|
||||
min_cut = report.local_cut;
|
||||
}
|
||||
if report.shift_score > max_shift {
|
||||
max_shift = report.shift_score;
|
||||
}
|
||||
// Use boundary candidate count as proxy for boundary edges
|
||||
total_boundary += report
|
||||
.boundary_candidates
|
||||
.iter()
|
||||
.filter(|&&c| c != 0)
|
||||
.count() as u32;
|
||||
|
||||
// Higher shift = more concentrated boundaries
|
||||
let concentration = (report.shift_score * 32767.0) as u32;
|
||||
if concentration > max_boundary_concentration {
|
||||
max_boundary_concentration = concentration;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert min_cut to lambda (Q15-ish scale)
|
||||
// Higher min_cut = more coherent = higher lambda
|
||||
let lambda = (min_cut.clamp(0.0, 1000.0) * 32.767) as u32;
|
||||
|
||||
// Smooth boundary edges
|
||||
let boundary_edges = ((total_boundary as u32 + self.smoothed_boundary as u32) / 2) as u16;
|
||||
self.smoothed_boundary = boundary_edges;
|
||||
|
||||
// Build packet
|
||||
let packet = GatePacket {
|
||||
lambda,
|
||||
lambda_prev: self.prev_lambda,
|
||||
boundary_edges,
|
||||
boundary_concentration_q15: max_boundary_concentration.min(32767) as u16,
|
||||
partition_count: reports.len() as u16,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
// Update history
|
||||
self.prev_lambda = lambda;
|
||||
|
||||
packet
|
||||
}
|
||||
|
||||
/// Convert a GatePacket back to approximate metrics
|
||||
#[cfg(feature = "attention")]
|
||||
pub fn from_gate_packet(packet: &GatePacket) -> (f64, f64, usize) {
|
||||
let min_cut = packet.lambda as f64 / 32.767;
|
||||
let shift_score = packet.boundary_concentration_q15 as f64 / 32767.0;
|
||||
let partition_count = packet.partition_count as usize;
|
||||
(min_cut, shift_score, partition_count)
|
||||
}
|
||||
}
|
||||
|
||||
/// Coherence-optimized attention processor
|
||||
///
|
||||
/// Uses mincut signals to dynamically route syndrome entries through
|
||||
/// the analysis pipeline, achieving up to 50% FLOPs reduction while
|
||||
/// maintaining accuracy on critical boundary patterns.
|
||||
#[cfg(feature = "attention")]
|
||||
pub struct CoherenceAttention {
|
||||
config: AttentionConfig,
|
||||
router: MincutDepthRouter,
|
||||
bridge: GatePacketBridge,
|
||||
stats: AttentionStats,
|
||||
}
|
||||
|
||||
#[cfg(feature = "attention")]
|
||||
impl CoherenceAttention {
|
||||
/// Create a new coherence attention processor
|
||||
pub fn new(config: AttentionConfig) -> Self {
|
||||
let mod_config = ModRoutingConfig {
|
||||
lambda_delta_skip_threshold: config.lambda_delta_skip_threshold,
|
||||
boundary_token_force_compute: true,
|
||||
layer_capacity_ratio: 1.0 - config.flops_reduction,
|
||||
min_tokens_per_layer: config.min_entries_per_round,
|
||||
adaptive_capacity: config.adaptive_capacity,
|
||||
};
|
||||
|
||||
Self {
|
||||
config,
|
||||
router: MincutDepthRouter::new(mod_config).unwrap_or_default(),
|
||||
bridge: GatePacketBridge::new(),
|
||||
stats: AttentionStats::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Optimize syndrome entry processing based on coherence
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `reports` - Worker tile reports with syndrome data
|
||||
///
|
||||
/// # Returns
|
||||
/// Tuple of (GatePacket, routing decisions for each entry)
|
||||
pub fn optimize(&mut self, reports: &[TileReport]) -> (GatePacket, Vec<TokenRoute>) {
|
||||
let gate = self.bridge.to_gate_packet(reports);
|
||||
|
||||
// Generate position indices for routing
|
||||
let positions: Vec<u16> = (0..reports.len() as u16).collect();
|
||||
|
||||
// Route entries based on coherence
|
||||
let routes = self.router.route_tokens(&gate, &positions);
|
||||
|
||||
// Update stats
|
||||
let routing_stats = self.router.routing_stats(&routes);
|
||||
self.stats.total_entries += routing_stats.total_tokens;
|
||||
self.stats.computed_entries += routing_stats.compute_tokens;
|
||||
self.stats.skipped_entries += routing_stats.skip_tokens;
|
||||
self.stats.boundary_entries += routing_stats.boundary_tokens;
|
||||
self.stats.decisions += 1;
|
||||
|
||||
(gate, routes)
|
||||
}
|
||||
|
||||
/// Check if early exit is warranted based on coherence stability
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `gate` - Current gate packet
|
||||
/// * `current_layer` - Current processing layer
|
||||
/// * `max_layers` - Maximum number of layers
|
||||
///
|
||||
/// # Returns
|
||||
/// Early exit decision
|
||||
pub fn check_early_exit(
|
||||
&self,
|
||||
gate: &GatePacket,
|
||||
current_layer: usize,
|
||||
max_layers: usize,
|
||||
) -> EarlyExitDecision {
|
||||
if !self.config.enable_early_exit {
|
||||
return EarlyExitDecision {
|
||||
should_exit: false,
|
||||
confidence: 0.0,
|
||||
reason: ExitReason::None,
|
||||
};
|
||||
}
|
||||
|
||||
// Calculate coherence stability
|
||||
let lambda_delta_abs = gate.lambda_delta().abs() as f32;
|
||||
let stability = 1.0 - (lambda_delta_abs / 32768.0).min(1.0);
|
||||
|
||||
// Calculate progress through layers
|
||||
let progress = current_layer as f32 / max_layers as f32;
|
||||
|
||||
// Exit if very stable AND past midpoint
|
||||
let should_exit = stability > self.config.early_exit_threshold && progress > 0.5;
|
||||
|
||||
EarlyExitDecision {
|
||||
should_exit,
|
||||
confidence: stability,
|
||||
reason: if should_exit {
|
||||
ExitReason::HighConfidence
|
||||
} else {
|
||||
ExitReason::None
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Get accumulated statistics
|
||||
pub fn stats(&self) -> &AttentionStats {
|
||||
&self.stats
|
||||
}
|
||||
|
||||
/// Reset statistics
|
||||
pub fn reset_stats(&mut self) {
|
||||
self.stats = AttentionStats::default();
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics for coherence attention
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct AttentionStats {
|
||||
/// Total entries processed
|
||||
pub total_entries: usize,
|
||||
/// Entries that required full computation
|
||||
pub computed_entries: usize,
|
||||
/// Entries that were skipped
|
||||
pub skipped_entries: usize,
|
||||
/// Boundary entries (always computed)
|
||||
pub boundary_entries: usize,
|
||||
/// Number of routing decisions made
|
||||
pub decisions: usize,
|
||||
}
|
||||
|
||||
impl AttentionStats {
|
||||
/// Calculate FLOPs reduction ratio
|
||||
pub fn flops_reduction(&self) -> f32 {
|
||||
if self.total_entries == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.skipped_entries as f32 / self.total_entries as f32
|
||||
}
|
||||
|
||||
/// Calculate compute ratio
|
||||
pub fn compute_ratio(&self) -> f32 {
|
||||
if self.total_entries == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.computed_entries as f32 / self.total_entries as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Fallback types when attention feature is disabled
|
||||
#[cfg(not(feature = "attention"))]
|
||||
pub mod fallback {
|
||||
use super::*;
|
||||
|
||||
/// Stub TokenRoute for when attention feature is disabled
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum TokenRoute {
|
||||
/// Process through full computation
|
||||
Compute,
|
||||
/// Skip - use cached result
|
||||
Skip,
|
||||
/// Boundary token - always compute
|
||||
Boundary,
|
||||
}
|
||||
|
||||
impl TokenRoute {
|
||||
/// Check if this route requires computation
|
||||
pub fn requires_compute(&self) -> bool {
|
||||
!matches!(self, TokenRoute::Skip)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stub GatePacket for when attention feature is disabled
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct GatePacket {
|
||||
/// Current lambda (coherence metric)
|
||||
pub lambda: u32,
|
||||
/// Previous lambda for trend detection
|
||||
pub lambda_prev: u32,
|
||||
/// Number of boundary edges
|
||||
pub boundary_edges: u16,
|
||||
/// Boundary concentration (Q15 scale)
|
||||
pub boundary_concentration_q15: u16,
|
||||
/// Number of partitions
|
||||
pub partition_count: u16,
|
||||
/// Policy flags
|
||||
pub flags: u16,
|
||||
}
|
||||
|
||||
impl GatePacket {
|
||||
/// Calculate lambda delta
|
||||
pub fn lambda_delta(&self) -> i32 {
|
||||
(self.lambda as i32) - (self.lambda_prev as i32)
|
||||
}
|
||||
}
|
||||
|
||||
/// Simplified attention processor without transformer dependency
|
||||
pub struct CoherenceAttention {
|
||||
#[allow(dead_code)]
|
||||
config: AttentionConfig,
|
||||
bridge: GatePacketBridge,
|
||||
stats: AttentionStats,
|
||||
}
|
||||
|
||||
impl CoherenceAttention {
|
||||
/// Create a new coherence attention processor
|
||||
pub fn new(config: AttentionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
bridge: GatePacketBridge::new(),
|
||||
stats: AttentionStats::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Optimize syndrome entry processing based on coherence
|
||||
pub fn optimize(&mut self, reports: &[TileReport]) -> (GatePacket, Vec<TokenRoute>) {
|
||||
let gate = self.bridge.to_gate_packet_fallback(reports);
|
||||
|
||||
// Simple heuristic routing without transformer
|
||||
let routes: Vec<TokenRoute> = reports
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, report)| {
|
||||
// Boundary tokens always compute
|
||||
if report.boundary_candidates.iter().any(|&c| c != 0) {
|
||||
return TokenRoute::Boundary;
|
||||
}
|
||||
|
||||
// Skip if shift score is low (stable)
|
||||
if report.shift_score < 0.1 && i % 2 == 0 {
|
||||
return TokenRoute::Skip;
|
||||
}
|
||||
|
||||
TokenRoute::Compute
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Update stats
|
||||
self.stats.total_entries += routes.len();
|
||||
self.stats.computed_entries += routes.iter().filter(|r| r.requires_compute()).count();
|
||||
self.stats.skipped_entries += routes
|
||||
.iter()
|
||||
.filter(|r| matches!(r, TokenRoute::Skip))
|
||||
.count();
|
||||
self.stats.boundary_entries += routes
|
||||
.iter()
|
||||
.filter(|r| matches!(r, TokenRoute::Boundary))
|
||||
.count();
|
||||
self.stats.decisions += 1;
|
||||
|
||||
(gate, routes)
|
||||
}
|
||||
|
||||
/// Get accumulated statistics
|
||||
pub fn stats(&self) -> &AttentionStats {
|
||||
&self.stats
|
||||
}
|
||||
|
||||
/// Reset statistics
|
||||
pub fn reset_stats(&mut self) {
|
||||
self.stats = AttentionStats::default();
|
||||
}
|
||||
}
|
||||
|
||||
impl GatePacketBridge {
|
||||
/// Convert tile reports to gate packet (fallback implementation)
|
||||
pub fn to_gate_packet_fallback(&mut self, reports: &[TileReport]) -> GatePacket {
|
||||
if reports.is_empty() {
|
||||
return GatePacket::default();
|
||||
}
|
||||
|
||||
let mut min_cut = f64::MAX;
|
||||
let mut max_shift = 0.0f64;
|
||||
|
||||
for report in reports {
|
||||
if report.local_cut < min_cut && report.local_cut > 0.0 {
|
||||
min_cut = report.local_cut;
|
||||
}
|
||||
if report.shift_score > max_shift {
|
||||
max_shift = report.shift_score;
|
||||
}
|
||||
}
|
||||
|
||||
let lambda = (min_cut.clamp(0.0, 1000.0) * 32.767) as u32;
|
||||
|
||||
let packet = GatePacket {
|
||||
lambda,
|
||||
lambda_prev: self.prev_lambda,
|
||||
boundary_edges: 0,
|
||||
boundary_concentration_q15: (max_shift * 32767.0) as u16,
|
||||
partition_count: reports.len() as u16,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
self.prev_lambda = lambda;
|
||||
packet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "attention"))]
|
||||
pub use fallback::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_attention_config_default() {
|
||||
let config = AttentionConfig::default();
|
||||
assert_eq!(config.flops_reduction, 0.5);
|
||||
assert!(config.enable_early_exit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_config_realtime() {
|
||||
let config = AttentionConfig::realtime();
|
||||
assert!(config.flops_reduction > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gate_packet_bridge() {
|
||||
let mut bridge = GatePacketBridge::new();
|
||||
|
||||
// First call establishes baseline
|
||||
let reports = vec![
|
||||
{
|
||||
let mut r = TileReport::new(1);
|
||||
r.local_cut = 10.0;
|
||||
r.shift_score = 0.2;
|
||||
r
|
||||
},
|
||||
{
|
||||
let mut r = TileReport::new(2);
|
||||
r.local_cut = 15.0;
|
||||
r.shift_score = 0.1;
|
||||
r
|
||||
},
|
||||
];
|
||||
|
||||
#[cfg(feature = "attention")]
|
||||
{
|
||||
let packet = bridge.to_gate_packet(&reports);
|
||||
assert!(packet.lambda > 0);
|
||||
assert_eq!(packet.partition_count, 2);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "attention"))]
|
||||
{
|
||||
let packet = bridge.to_gate_packet_fallback(&reports);
|
||||
assert!(packet.lambda > 0);
|
||||
assert_eq!(packet.partition_count, 2);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_stats() {
|
||||
let mut stats = AttentionStats::default();
|
||||
stats.total_entries = 100;
|
||||
stats.computed_entries = 60;
|
||||
stats.skipped_entries = 40;
|
||||
|
||||
assert_eq!(stats.flops_reduction(), 0.4);
|
||||
assert_eq!(stats.compute_ratio(), 0.6);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user