Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
489
crates/ruvector-mincut-gated-transformer/tests/early_exit.rs
Normal file
489
crates/ruvector-mincut-gated-transformer/tests/early_exit.rs
Normal file
@@ -0,0 +1,489 @@
|
||||
//! Early exit condition tests.
|
||||
//!
|
||||
//! Tests for tier-based early termination, speculation/verification,
|
||||
//! and fallback to full computation.
|
||||
|
||||
use ruvector_mincut_gated_transformer::{
|
||||
GateDecision, GatePacket, GatePolicy, GateReason, InferInput, InferOutput,
|
||||
MincutGatedTransformer, QuantizedWeights, TransformerConfig,
|
||||
};
|
||||
|
||||
fn create_transformer(config: TransformerConfig) -> MincutGatedTransformer {
|
||||
let policy = GatePolicy::default();
|
||||
let weights = QuantizedWeights::empty(&config);
|
||||
MincutGatedTransformer::new(config, policy, weights).unwrap()
|
||||
}
|
||||
|
||||
// ============ Early Exit Conditions ============
|
||||
|
||||
#[test]
|
||||
fn test_early_exit_on_low_lambda() {
|
||||
let config = TransformerConfig::baseline();
|
||||
let mut transformer = create_transformer(config.clone());
|
||||
|
||||
let tokens: Vec<u32> = (0..32).collect();
|
||||
|
||||
// Lambda below minimum triggers quarantine and early exit
|
||||
let gate = GatePacket {
|
||||
lambda: 20, // Below default min of 30
|
||||
lambda_prev: 100,
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
|
||||
// Should trigger early intervention
|
||||
assert_eq!(output.witness.decision, GateDecision::QuarantineUpdates);
|
||||
assert_eq!(output.witness.reason, GateReason::LambdaBelowMin);
|
||||
assert!(output.stats.layers_executed < config.layers);
|
||||
assert_eq!(output.stats.tier, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_early_exit_on_lambda_drop() {
|
||||
let config = TransformerConfig::baseline();
|
||||
let mut transformer = create_transformer(config.clone());
|
||||
|
||||
let tokens: Vec<u32> = (0..32).collect();
|
||||
|
||||
// Fast lambda drop triggers flush and reduced execution
|
||||
let gate = GatePacket {
|
||||
lambda: 35,
|
||||
lambda_prev: 100, // 65% drop
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
|
||||
assert_eq!(output.witness.decision, GateDecision::FlushKv);
|
||||
assert_eq!(output.witness.reason, GateReason::LambdaDroppedFast);
|
||||
assert!(output.stats.layers_executed < config.layers);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_early_exit_tier_selection() {
|
||||
let config = TransformerConfig::baseline();
|
||||
let mut transformer = create_transformer(config.clone());
|
||||
|
||||
let tokens: Vec<u32> = (0..32).collect();
|
||||
|
||||
// Different conditions should select different tiers
|
||||
let test_cases = vec![
|
||||
// (lambda, lambda_prev, boundary_edges, expected_tier_range)
|
||||
(100, 95, 5, 0..=0), // Normal - tier 0
|
||||
(100, 95, 30, 1..=1), // Boundary spike - tier 1
|
||||
(20, 100, 5, 2..=2), // Low lambda - tier 2
|
||||
(100, 95, 5, 0..=0), // Normal again - tier 0
|
||||
];
|
||||
|
||||
for (lambda, lambda_prev, boundary_edges, expected_tier_range) in test_cases {
|
||||
let gate = GatePacket {
|
||||
lambda,
|
||||
lambda_prev,
|
||||
boundary_edges,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
|
||||
assert!(
|
||||
expected_tier_range.contains(&output.stats.tier),
|
||||
"Tier {} not in expected range {:?} for lambda={}, lambda_prev={}, boundary_edges={}",
|
||||
output.stats.tier,
|
||||
expected_tier_range,
|
||||
lambda,
|
||||
lambda_prev,
|
||||
boundary_edges
|
||||
);
|
||||
|
||||
transformer.reset();
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Speculation and Verification ============
|
||||
|
||||
#[test]
|
||||
fn test_speculative_execution_with_stable_lambda() {
|
||||
let config = TransformerConfig::baseline();
|
||||
let mut transformer = create_transformer(config.clone());
|
||||
|
||||
let tokens: Vec<u32> = (0..32).collect();
|
||||
|
||||
// Stable lambda allows speculative full execution
|
||||
let gate = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 98,
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
|
||||
// Should use full layers (tier 0)
|
||||
assert_eq!(output.stats.tier, 0);
|
||||
assert_eq!(output.stats.layers_executed, config.layers);
|
||||
assert_eq!(output.witness.decision, GateDecision::Allow);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculation_fallback_on_instability() {
|
||||
let config = TransformerConfig::baseline();
|
||||
let mut transformer = create_transformer(config.clone());
|
||||
|
||||
let tokens: Vec<u32> = (0..32).collect();
|
||||
|
||||
// Start with stable state
|
||||
let gate_stable = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 98,
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate_stable);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
{
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
assert_eq!(output.stats.tier, 0);
|
||||
}
|
||||
|
||||
// Sudden instability should fallback to reduced execution
|
||||
let gate_unstable = GatePacket {
|
||||
lambda: 40,
|
||||
lambda_prev: 100, // 60% drop
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate_unstable);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
{
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
assert_eq!(output.witness.decision, GateDecision::FlushKv);
|
||||
assert!(output.stats.layers_executed < config.layers);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verification_prevents_invalid_cache() {
|
||||
let config = TransformerConfig::baseline();
|
||||
let mut transformer = create_transformer(config.clone());
|
||||
|
||||
let tokens: Vec<u32> = (0..32).collect();
|
||||
let signature = 12345u64;
|
||||
|
||||
// First run with stable conditions - cache result
|
||||
let gate_stable = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 95,
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate_stable).with_signature(signature);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
{
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
}
|
||||
|
||||
// Second run with unstable conditions but skip flag
|
||||
// Should not use cached result from stable conditions
|
||||
let gate_unstable_skip = GatePacket {
|
||||
lambda: 20, // Unstable
|
||||
lambda_prev: 100,
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: GatePacket::FLAG_SKIP,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate_unstable_skip).with_signature(signature);
|
||||
let mut logits2 = vec![0i32; config.logits as usize];
|
||||
{
|
||||
let mut output = InferOutput::new(&mut logits2);
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
assert_eq!(output.stats.skipped, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Fallback to Full Computation ============
|
||||
|
||||
#[test]
|
||||
fn test_fallback_after_failed_early_exit() {
|
||||
let config = TransformerConfig::baseline();
|
||||
let mut transformer = create_transformer(config.clone());
|
||||
|
||||
let tokens: Vec<u32> = (0..32).collect();
|
||||
|
||||
// Trigger early exit with boundary spike
|
||||
let gate_exit = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 95,
|
||||
boundary_edges: 35, // High boundary edges
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate_exit);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let early_exit_layers;
|
||||
{
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
early_exit_layers = output.stats.layers_executed;
|
||||
assert!(output.stats.layers_executed < config.layers);
|
||||
}
|
||||
|
||||
transformer.reset();
|
||||
|
||||
// Return to stable conditions - should use full computation
|
||||
let gate_stable = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 95,
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate_stable);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
{
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
assert!(output.stats.layers_executed > early_exit_layers);
|
||||
assert_eq!(output.stats.layers_executed, config.layers);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_force_safe_minimum_computation() {
|
||||
let config = TransformerConfig::baseline();
|
||||
let mut transformer = create_transformer(config.clone());
|
||||
|
||||
let tokens: Vec<u32> = (0..32).collect();
|
||||
|
||||
// Force safe mode
|
||||
let gate = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 95,
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: GatePacket::FLAG_FORCE_SAFE,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
|
||||
// Should use minimal computation
|
||||
assert_eq!(output.witness.decision, GateDecision::FreezeWrites);
|
||||
assert_eq!(output.stats.tier, 2);
|
||||
assert_eq!(output.stats.layers_executed, 1);
|
||||
assert_eq!(output.witness.kv_writes_enabled, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_progressive_degradation() {
|
||||
let config = TransformerConfig::baseline();
|
||||
let mut transformer = create_transformer(config.clone());
|
||||
|
||||
let tokens: Vec<u32> = (0..32).collect();
|
||||
|
||||
let conditions = vec![
|
||||
(100, 95, 5, "stable"),
|
||||
(95, 100, 10, "slight_boundary_increase"),
|
||||
(90, 95, 20, "boundary_spike"),
|
||||
(85, 90, 35, "severe_boundary_spike"),
|
||||
(40, 85, 40, "lambda_drop_and_boundary"),
|
||||
];
|
||||
|
||||
let mut prev_layers = config.layers;
|
||||
|
||||
for (lambda, lambda_prev, boundary_edges, _desc) in conditions {
|
||||
let gate = GatePacket {
|
||||
lambda,
|
||||
lambda_prev,
|
||||
boundary_edges,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
|
||||
// Should see progressive degradation or maintenance
|
||||
assert!(output.stats.layers_executed <= prev_layers);
|
||||
prev_layers = output.stats.layers_executed;
|
||||
|
||||
transformer.reset();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layer_execution_counts() {
|
||||
let config = TransformerConfig::baseline();
|
||||
let mut transformer = create_transformer(config.clone());
|
||||
|
||||
let tokens: Vec<u32> = (0..32).collect();
|
||||
|
||||
// Tier 0 - full execution
|
||||
let gate_t0 = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 95,
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate_t0);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let t0_layers;
|
||||
{
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
t0_layers = output.stats.layers_executed;
|
||||
}
|
||||
|
||||
transformer.reset();
|
||||
|
||||
// Tier 1 - reduced execution
|
||||
let gate_t1 = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 95,
|
||||
boundary_edges: 30,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate_t1);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let t1_layers;
|
||||
{
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
t1_layers = output.stats.layers_executed;
|
||||
}
|
||||
|
||||
transformer.reset();
|
||||
|
||||
// Tier 2 - minimal execution
|
||||
let gate_t2 = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 95,
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: GatePacket::FLAG_FORCE_SAFE,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate_t2);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let t2_layers;
|
||||
{
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
t2_layers = output.stats.layers_executed;
|
||||
}
|
||||
|
||||
// Verify tier ordering
|
||||
assert!(t0_layers > t1_layers);
|
||||
assert!(t1_layers > t2_layers);
|
||||
assert_eq!(t2_layers, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_early_exit_operation_counts() {
|
||||
let config = TransformerConfig::baseline();
|
||||
let mut transformer = create_transformer(config.clone());
|
||||
|
||||
let tokens: Vec<u32> = (0..32).collect();
|
||||
|
||||
// Full computation
|
||||
let gate_full = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 95,
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate_full);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let full_ops;
|
||||
{
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
full_ops = output.stats.attn_dot_ops + output.stats.ffn_ops;
|
||||
}
|
||||
|
||||
transformer.reset();
|
||||
|
||||
// Early exit
|
||||
let gate_exit = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 95,
|
||||
boundary_edges: 30,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let input = InferInput::from_tokens(&tokens, gate_exit);
|
||||
let mut logits = vec![0i32; config.logits as usize];
|
||||
let exit_ops;
|
||||
{
|
||||
let mut output = InferOutput::new(&mut logits);
|
||||
transformer.infer(&input, &mut output).unwrap();
|
||||
exit_ops = output.stats.attn_dot_ops + output.stats.ffn_ops;
|
||||
}
|
||||
|
||||
// Early exit should perform fewer operations
|
||||
assert!(exit_ops < full_ops);
|
||||
}
|
||||
Reference in New Issue
Block a user