Files
wifi-densepose/crates/ruvector-mincut-gated-transformer/tests/early_exit.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

490 lines
15 KiB
Rust

//! Early exit condition tests.
//!
//! Tests for tier-based early termination, speculation/verification,
//! and fallback to full computation.
use ruvector_mincut_gated_transformer::{
GateDecision, GatePacket, GatePolicy, GateReason, InferInput, InferOutput,
MincutGatedTransformer, QuantizedWeights, TransformerConfig,
};
fn create_transformer(config: TransformerConfig) -> MincutGatedTransformer {
let policy = GatePolicy::default();
let weights = QuantizedWeights::empty(&config);
MincutGatedTransformer::new(config, policy, weights).unwrap()
}
// ============ Early Exit Conditions ============
#[test]
fn test_early_exit_on_low_lambda() {
let config = TransformerConfig::baseline();
let mut transformer = create_transformer(config.clone());
let tokens: Vec<u32> = (0..32).collect();
// Lambda below minimum triggers quarantine and early exit
let gate = GatePacket {
lambda: 20, // Below default min of 30
lambda_prev: 100,
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate);
let mut logits = vec![0i32; config.logits as usize];
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
// Should trigger early intervention
assert_eq!(output.witness.decision, GateDecision::QuarantineUpdates);
assert_eq!(output.witness.reason, GateReason::LambdaBelowMin);
assert!(output.stats.layers_executed < config.layers);
assert_eq!(output.stats.tier, 2);
}
#[test]
fn test_early_exit_on_lambda_drop() {
let config = TransformerConfig::baseline();
let mut transformer = create_transformer(config.clone());
let tokens: Vec<u32> = (0..32).collect();
// Fast lambda drop triggers flush and reduced execution
let gate = GatePacket {
lambda: 35,
lambda_prev: 100, // 65% drop
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate);
let mut logits = vec![0i32; config.logits as usize];
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
assert_eq!(output.witness.decision, GateDecision::FlushKv);
assert_eq!(output.witness.reason, GateReason::LambdaDroppedFast);
assert!(output.stats.layers_executed < config.layers);
}
#[test]
fn test_early_exit_tier_selection() {
let config = TransformerConfig::baseline();
let mut transformer = create_transformer(config.clone());
let tokens: Vec<u32> = (0..32).collect();
// Different conditions should select different tiers
let test_cases = vec![
// (lambda, lambda_prev, boundary_edges, expected_tier_range)
(100, 95, 5, 0..=0), // Normal - tier 0
(100, 95, 30, 1..=1), // Boundary spike - tier 1
(20, 100, 5, 2..=2), // Low lambda - tier 2
(100, 95, 5, 0..=0), // Normal again - tier 0
];
for (lambda, lambda_prev, boundary_edges, expected_tier_range) in test_cases {
let gate = GatePacket {
lambda,
lambda_prev,
boundary_edges,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate);
let mut logits = vec![0i32; config.logits as usize];
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
assert!(
expected_tier_range.contains(&output.stats.tier),
"Tier {} not in expected range {:?} for lambda={}, lambda_prev={}, boundary_edges={}",
output.stats.tier,
expected_tier_range,
lambda,
lambda_prev,
boundary_edges
);
transformer.reset();
}
}
// ============ Speculation and Verification ============
#[test]
fn test_speculative_execution_with_stable_lambda() {
let config = TransformerConfig::baseline();
let mut transformer = create_transformer(config.clone());
let tokens: Vec<u32> = (0..32).collect();
// Stable lambda allows speculative full execution
let gate = GatePacket {
lambda: 100,
lambda_prev: 98,
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate);
let mut logits = vec![0i32; config.logits as usize];
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
// Should use full layers (tier 0)
assert_eq!(output.stats.tier, 0);
assert_eq!(output.stats.layers_executed, config.layers);
assert_eq!(output.witness.decision, GateDecision::Allow);
}
#[test]
fn test_speculation_fallback_on_instability() {
let config = TransformerConfig::baseline();
let mut transformer = create_transformer(config.clone());
let tokens: Vec<u32> = (0..32).collect();
// Start with stable state
let gate_stable = GatePacket {
lambda: 100,
lambda_prev: 98,
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate_stable);
let mut logits = vec![0i32; config.logits as usize];
{
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
assert_eq!(output.stats.tier, 0);
}
// Sudden instability should fallback to reduced execution
let gate_unstable = GatePacket {
lambda: 40,
lambda_prev: 100, // 60% drop
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate_unstable);
let mut logits = vec![0i32; config.logits as usize];
{
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
assert_eq!(output.witness.decision, GateDecision::FlushKv);
assert!(output.stats.layers_executed < config.layers);
}
}
#[test]
fn test_verification_prevents_invalid_cache() {
let config = TransformerConfig::baseline();
let mut transformer = create_transformer(config.clone());
let tokens: Vec<u32> = (0..32).collect();
let signature = 12345u64;
// First run with stable conditions - cache result
let gate_stable = GatePacket {
lambda: 100,
lambda_prev: 95,
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate_stable).with_signature(signature);
let mut logits = vec![0i32; config.logits as usize];
{
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
}
// Second run with unstable conditions but skip flag
// Should not use cached result from stable conditions
let gate_unstable_skip = GatePacket {
lambda: 20, // Unstable
lambda_prev: 100,
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: GatePacket::FLAG_SKIP,
};
let input = InferInput::from_tokens(&tokens, gate_unstable_skip).with_signature(signature);
let mut logits2 = vec![0i32; config.logits as usize];
{
let mut output = InferOutput::new(&mut logits2);
transformer.infer(&input, &mut output).unwrap();
assert_eq!(output.stats.skipped, 1);
}
}
// ============ Fallback to Full Computation ============
#[test]
fn test_fallback_after_failed_early_exit() {
let config = TransformerConfig::baseline();
let mut transformer = create_transformer(config.clone());
let tokens: Vec<u32> = (0..32).collect();
// Trigger early exit with boundary spike
let gate_exit = GatePacket {
lambda: 100,
lambda_prev: 95,
boundary_edges: 35, // High boundary edges
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate_exit);
let mut logits = vec![0i32; config.logits as usize];
let early_exit_layers;
{
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
early_exit_layers = output.stats.layers_executed;
assert!(output.stats.layers_executed < config.layers);
}
transformer.reset();
// Return to stable conditions - should use full computation
let gate_stable = GatePacket {
lambda: 100,
lambda_prev: 95,
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate_stable);
let mut logits = vec![0i32; config.logits as usize];
{
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
assert!(output.stats.layers_executed > early_exit_layers);
assert_eq!(output.stats.layers_executed, config.layers);
}
}
#[test]
fn test_force_safe_minimum_computation() {
let config = TransformerConfig::baseline();
let mut transformer = create_transformer(config.clone());
let tokens: Vec<u32> = (0..32).collect();
// Force safe mode
let gate = GatePacket {
lambda: 100,
lambda_prev: 95,
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: GatePacket::FLAG_FORCE_SAFE,
};
let input = InferInput::from_tokens(&tokens, gate);
let mut logits = vec![0i32; config.logits as usize];
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
// Should use minimal computation
assert_eq!(output.witness.decision, GateDecision::FreezeWrites);
assert_eq!(output.stats.tier, 2);
assert_eq!(output.stats.layers_executed, 1);
assert_eq!(output.witness.kv_writes_enabled, 0);
}
#[test]
fn test_progressive_degradation() {
let config = TransformerConfig::baseline();
let mut transformer = create_transformer(config.clone());
let tokens: Vec<u32> = (0..32).collect();
let conditions = vec![
(100, 95, 5, "stable"),
(95, 100, 10, "slight_boundary_increase"),
(90, 95, 20, "boundary_spike"),
(85, 90, 35, "severe_boundary_spike"),
(40, 85, 40, "lambda_drop_and_boundary"),
];
let mut prev_layers = config.layers;
for (lambda, lambda_prev, boundary_edges, _desc) in conditions {
let gate = GatePacket {
lambda,
lambda_prev,
boundary_edges,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate);
let mut logits = vec![0i32; config.logits as usize];
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
// Should see progressive degradation or maintenance
assert!(output.stats.layers_executed <= prev_layers);
prev_layers = output.stats.layers_executed;
transformer.reset();
}
}
#[test]
fn test_layer_execution_counts() {
let config = TransformerConfig::baseline();
let mut transformer = create_transformer(config.clone());
let tokens: Vec<u32> = (0..32).collect();
// Tier 0 - full execution
let gate_t0 = GatePacket {
lambda: 100,
lambda_prev: 95,
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate_t0);
let mut logits = vec![0i32; config.logits as usize];
let t0_layers;
{
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
t0_layers = output.stats.layers_executed;
}
transformer.reset();
// Tier 1 - reduced execution
let gate_t1 = GatePacket {
lambda: 100,
lambda_prev: 95,
boundary_edges: 30,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate_t1);
let mut logits = vec![0i32; config.logits as usize];
let t1_layers;
{
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
t1_layers = output.stats.layers_executed;
}
transformer.reset();
// Tier 2 - minimal execution
let gate_t2 = GatePacket {
lambda: 100,
lambda_prev: 95,
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: GatePacket::FLAG_FORCE_SAFE,
};
let input = InferInput::from_tokens(&tokens, gate_t2);
let mut logits = vec![0i32; config.logits as usize];
let t2_layers;
{
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
t2_layers = output.stats.layers_executed;
}
// Verify tier ordering
assert!(t0_layers > t1_layers);
assert!(t1_layers > t2_layers);
assert_eq!(t2_layers, 1);
}
#[test]
fn test_early_exit_operation_counts() {
let config = TransformerConfig::baseline();
let mut transformer = create_transformer(config.clone());
let tokens: Vec<u32> = (0..32).collect();
// Full computation
let gate_full = GatePacket {
lambda: 100,
lambda_prev: 95,
boundary_edges: 5,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate_full);
let mut logits = vec![0i32; config.logits as usize];
let full_ops;
{
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
full_ops = output.stats.attn_dot_ops + output.stats.ffn_ops;
}
transformer.reset();
// Early exit
let gate_exit = GatePacket {
lambda: 100,
lambda_prev: 95,
boundary_edges: 30,
boundary_concentration_q15: 8192,
partition_count: 3,
flags: 0,
};
let input = InferInput::from_tokens(&tokens, gate_exit);
let mut logits = vec![0i32; config.logits as usize];
let exit_ops;
{
let mut output = InferOutput::new(&mut logits);
transformer.infer(&input, &mut output).unwrap();
exit_ops = output.stats.attn_dot_ops + output.stats.ffn_ops;
}
// Early exit should perform fewer operations
assert!(exit_ops < full_ops);
}