465 lines
12 KiB
Rust
465 lines
12 KiB
Rust
#![allow(
|
|
clippy::all,
|
|
unused_imports,
|
|
unused_variables,
|
|
dead_code,
|
|
unused_mut,
|
|
unused_assignments,
|
|
non_camel_case_types,
|
|
clippy::approx_constant,
|
|
unexpected_cfgs,
|
|
unused_must_use,
|
|
unused_parens
|
|
)]
|
|
//! Integration tests for speculative decoding
|
|
//!
|
|
//! These tests verify the speculative decoding implementation works correctly
|
|
//! with mock backends.
|
|
|
|
use ruvllm::speculative::{
|
|
log_softmax, softmax, top_k_filter, top_p_filter, AtomicSpeculativeStats, SpeculationTree,
|
|
SpeculativeConfig, SpeculativeStats, TreeNode, VerificationResult,
|
|
};
|
|
use std::time::Duration;
|
|
|
|
#[test]
|
|
fn test_speculative_config_defaults() {
|
|
let config = SpeculativeConfig::default();
|
|
|
|
assert_eq!(config.lookahead, 4);
|
|
assert!((config.acceptance_threshold - 0.5).abs() < 0.01);
|
|
assert!((config.draft_temperature - 0.0).abs() < 0.01);
|
|
assert!(!config.tree_speculation);
|
|
assert_eq!(config.max_tree_depth, 3);
|
|
assert_eq!(config.tree_branching_factor, 2);
|
|
assert!(config.adaptive_lookahead);
|
|
assert_eq!(config.min_lookahead, 2);
|
|
assert_eq!(config.max_lookahead, 8);
|
|
}
|
|
|
|
#[test]
|
|
fn test_speculative_config_custom() {
|
|
let config = SpeculativeConfig {
|
|
lookahead: 6,
|
|
acceptance_threshold: 0.7,
|
|
draft_temperature: 0.1,
|
|
tree_speculation: true,
|
|
max_tree_depth: 4,
|
|
tree_branching_factor: 3,
|
|
draft_top_p: 0.9,
|
|
min_acceptance_ratio: 0.2,
|
|
adaptive_lookahead: false,
|
|
min_lookahead: 3,
|
|
max_lookahead: 10,
|
|
};
|
|
|
|
assert_eq!(config.lookahead, 6);
|
|
assert!((config.acceptance_threshold - 0.7).abs() < 0.01);
|
|
assert!(config.tree_speculation);
|
|
}
|
|
|
|
#[test]
|
|
fn test_speculative_stats_empty() {
|
|
let stats = SpeculativeStats::new();
|
|
|
|
assert_eq!(stats.draft_tokens, 0);
|
|
assert_eq!(stats.accepted_tokens, 0);
|
|
assert!((stats.acceptance_rate - 0.0).abs() < 0.01);
|
|
assert!((stats.speedup - 0.0).abs() < 0.01);
|
|
assert_eq!(stats.main_forward_passes, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_speculative_stats_record_round() {
|
|
let mut stats = SpeculativeStats::new();
|
|
|
|
// Simulate a round: 4 draft tokens, 3 accepted
|
|
stats.record_round(4, 3, 10.0);
|
|
|
|
assert_eq!(stats.draft_tokens, 4);
|
|
assert_eq!(stats.accepted_tokens, 3);
|
|
assert!((stats.acceptance_rate - 0.75).abs() < 0.01);
|
|
assert_eq!(stats.main_forward_passes, 1);
|
|
assert_eq!(stats.draft_forward_passes, 4);
|
|
assert_eq!(stats.total_tokens_generated, 4); // 3 accepted + 1 correction
|
|
|
|
// Simulate another round: 4 draft, 2 accepted
|
|
stats.record_round(4, 2, 12.0);
|
|
|
|
assert_eq!(stats.draft_tokens, 8);
|
|
assert_eq!(stats.accepted_tokens, 5);
|
|
assert!((stats.acceptance_rate - 0.625).abs() < 0.01);
|
|
assert_eq!(stats.main_forward_passes, 2);
|
|
assert_eq!(stats.total_tokens_generated, 7);
|
|
}
|
|
|
|
#[test]
|
|
fn test_speculative_stats_speedup_calculation() {
|
|
let mut stats = SpeculativeStats::new();
|
|
|
|
// Perfect speculation: all accepted
|
|
stats.record_round(4, 4, 10.0);
|
|
|
|
// 5 tokens per pass (4 accepted + 1 continuation)
|
|
assert!((stats.avg_tokens_per_main_pass - 5.0).abs() < 0.01);
|
|
assert!((stats.speedup - 5.0).abs() < 0.01);
|
|
}
|
|
|
|
#[test]
|
|
fn test_atomic_speculative_stats() {
|
|
let stats = AtomicSpeculativeStats::new();
|
|
|
|
// Record multiple rounds (simulating concurrent access)
|
|
stats.record_round(4, 3, Duration::from_millis(10));
|
|
stats.record_round(4, 4, Duration::from_millis(8));
|
|
stats.record_round(4, 2, Duration::from_millis(12));
|
|
|
|
let snapshot = stats.snapshot();
|
|
|
|
assert_eq!(snapshot.draft_tokens, 12);
|
|
assert_eq!(snapshot.accepted_tokens, 9);
|
|
assert_eq!(snapshot.main_forward_passes, 3);
|
|
assert!((snapshot.acceptance_rate - 0.75).abs() < 0.01);
|
|
}
|
|
|
|
#[test]
|
|
fn test_atomic_stats_reset() {
|
|
let stats = AtomicSpeculativeStats::new();
|
|
stats.record_round(4, 3, Duration::from_millis(10));
|
|
|
|
assert_eq!(stats.snapshot().draft_tokens, 4);
|
|
|
|
stats.reset();
|
|
|
|
assert_eq!(stats.snapshot().draft_tokens, 0);
|
|
assert_eq!(stats.snapshot().accepted_tokens, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tree_node_creation() {
|
|
let node = TreeNode::new(42, 0.8, 0);
|
|
|
|
assert_eq!(node.token, 42);
|
|
assert!((node.prob - 0.8).abs() < 0.01);
|
|
assert_eq!(node.depth, 0);
|
|
assert!(node.children.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_tree_node_add_child() {
|
|
let mut root = TreeNode::new(0, 1.0, 0);
|
|
|
|
root.add_child(1, 0.6);
|
|
root.add_child(2, 0.3);
|
|
root.add_child(3, 0.1);
|
|
|
|
assert_eq!(root.children.len(), 3);
|
|
assert_eq!(root.children[0].token, 1);
|
|
assert_eq!(root.children[1].token, 2);
|
|
assert_eq!(root.children[2].token, 3);
|
|
assert_eq!(root.children[0].depth, 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tree_node_best_path() {
|
|
let mut root = TreeNode::new(0, 1.0, 0);
|
|
|
|
// Build tree:
|
|
// 0
|
|
// / \
|
|
// 1 2
|
|
// / / \
|
|
// 3 4 5
|
|
|
|
let child1 = root.add_child(1, 0.6);
|
|
child1.add_child(3, 0.5);
|
|
|
|
let child2 = root.add_child(2, 0.3);
|
|
child2.add_child(4, 0.2);
|
|
child2.add_child(5, 0.1);
|
|
|
|
// Best path should follow highest probabilities
|
|
let path = root.best_path();
|
|
assert_eq!(path[0], 0);
|
|
assert_eq!(path[1], 1); // 0.6 > 0.3
|
|
assert_eq!(path[2], 3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tree_node_get_paths() {
|
|
let mut root = TreeNode::new(0, 1.0, 0);
|
|
|
|
let child1 = root.add_child(1, 0.6);
|
|
child1.add_child(3, 0.5);
|
|
|
|
let child2 = root.add_child(2, 0.3);
|
|
child2.add_child(4, 0.2);
|
|
child2.add_child(5, 0.1);
|
|
|
|
let paths = root.get_paths();
|
|
|
|
// Should have 3 paths:
|
|
// [0, 1, 3]
|
|
// [0, 2, 4]
|
|
// [0, 2, 5]
|
|
assert_eq!(paths.len(), 3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_speculation_tree_creation() {
|
|
let tree = SpeculationTree::new(3, 2);
|
|
|
|
assert_eq!(tree.max_depth, 3);
|
|
assert_eq!(tree.branching_factor, 2);
|
|
assert_eq!(tree.node_count, 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_speculation_tree_best_path() {
|
|
let mut tree = SpeculationTree::new(3, 2);
|
|
|
|
// Build a linear path
|
|
let mut current = &mut tree.root;
|
|
current = current.add_child(10, 0.9);
|
|
tree.node_count += 1;
|
|
current = current.add_child(20, 0.8);
|
|
tree.node_count += 1;
|
|
current.add_child(30, 0.7);
|
|
tree.node_count += 1;
|
|
|
|
let best = tree.best_path();
|
|
|
|
// Should skip the root placeholder and return [10, 20, 30]
|
|
assert_eq!(best, vec![10, 20, 30]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_speculation_tree_clear() {
|
|
let mut tree = SpeculationTree::new(3, 2);
|
|
|
|
tree.root.add_child(1, 0.5);
|
|
tree.node_count += 1;
|
|
|
|
assert_eq!(tree.node_count, 2);
|
|
|
|
tree.clear();
|
|
|
|
assert_eq!(tree.node_count, 1);
|
|
assert!(tree.root.children.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_verification_result() {
|
|
let result = VerificationResult {
|
|
accepted_count: 3,
|
|
next_token: 100,
|
|
accepted_logprobs: vec![-0.1, -0.2, -0.3],
|
|
next_logprob: -0.5,
|
|
all_accepted: false,
|
|
};
|
|
|
|
assert_eq!(result.accepted_count, 3);
|
|
assert_eq!(result.next_token, 100);
|
|
assert!(!result.all_accepted);
|
|
assert_eq!(result.accepted_logprobs.len(), 3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_softmax() {
|
|
let logits = vec![1.0, 2.0, 3.0];
|
|
let probs = softmax(&logits);
|
|
|
|
// Probabilities should sum to 1
|
|
let sum: f32 = probs.iter().sum();
|
|
assert!((sum - 1.0).abs() < 0.001);
|
|
|
|
// Probabilities should be ordered
|
|
assert!(probs[2] > probs[1]);
|
|
assert!(probs[1] > probs[0]);
|
|
|
|
// All probabilities should be positive
|
|
assert!(probs.iter().all(|&p| p > 0.0));
|
|
}
|
|
|
|
#[test]
|
|
fn test_softmax_with_large_values() {
|
|
// Test numerical stability with large values
|
|
let logits = vec![100.0, 200.0, 300.0];
|
|
let probs = softmax(&logits);
|
|
|
|
let sum: f32 = probs.iter().sum();
|
|
assert!((sum - 1.0).abs() < 0.001);
|
|
}
|
|
|
|
#[test]
|
|
fn test_softmax_with_negative_values() {
|
|
let logits = vec![-1.0, -2.0, -3.0];
|
|
let probs = softmax(&logits);
|
|
|
|
let sum: f32 = probs.iter().sum();
|
|
assert!((sum - 1.0).abs() < 0.001);
|
|
}
|
|
|
|
#[test]
|
|
fn test_log_softmax() {
|
|
let logits = vec![1.0, 2.0, 3.0];
|
|
let log_probs = log_softmax(&logits);
|
|
|
|
// All log probabilities should be negative (probabilities < 1)
|
|
assert!(log_probs.iter().all(|&lp| lp <= 0.0));
|
|
|
|
// exp(log_softmax) should equal softmax
|
|
let probs: Vec<f32> = log_probs.iter().map(|&lp: &f32| lp.exp()).collect();
|
|
let sum: f32 = probs.iter().sum();
|
|
assert!((sum - 1.0).abs() < 0.001);
|
|
}
|
|
|
|
#[test]
|
|
fn test_top_k_filter() {
|
|
let mut logits: Vec<f32> = vec![1.0, 5.0, 3.0, 4.0, 2.0];
|
|
top_k_filter(&mut logits, 2);
|
|
|
|
// Only top 2 (5.0 and 4.0) should remain finite
|
|
let finite_count = logits.iter().filter(|&&x| x.is_finite()).count();
|
|
assert_eq!(finite_count, 2);
|
|
|
|
// The top values should be unchanged
|
|
assert!((logits[1] - 5.0).abs() < 0.01);
|
|
assert!((logits[3] - 4.0).abs() < 0.01);
|
|
}
|
|
|
|
#[test]
|
|
fn test_top_k_filter_k_equals_len() {
|
|
let mut logits: Vec<f32> = vec![1.0, 2.0, 3.0];
|
|
top_k_filter(&mut logits, 3);
|
|
|
|
// All values should remain
|
|
let finite_count = logits.iter().filter(|&&x| x.is_finite()).count();
|
|
assert_eq!(finite_count, 3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_top_k_filter_k_zero() {
|
|
let mut logits = vec![1.0, 2.0, 3.0];
|
|
let original = logits.clone();
|
|
top_k_filter(&mut logits, 0);
|
|
|
|
// No filtering when k=0
|
|
assert_eq!(logits, original);
|
|
}
|
|
|
|
#[test]
|
|
fn test_top_p_filter() {
|
|
let mut logits: Vec<f32> = vec![10.0, 5.0, 2.0, 1.0, 0.5];
|
|
top_p_filter(&mut logits, 0.9);
|
|
|
|
// Most probability mass should be preserved
|
|
let finite_count = logits.iter().filter(|&&x| x.is_finite()).count();
|
|
assert!(finite_count >= 1 && finite_count <= 4);
|
|
}
|
|
|
|
#[test]
|
|
fn test_top_p_filter_p_one() {
|
|
let mut logits = vec![1.0, 2.0, 3.0];
|
|
let original = logits.clone();
|
|
top_p_filter(&mut logits, 1.0);
|
|
|
|
// No filtering when p=1.0
|
|
assert_eq!(logits, original);
|
|
}
|
|
|
|
#[test]
|
|
fn test_top_p_filter_very_low_p() {
|
|
let mut logits: Vec<f32> = vec![10.0, 1.0, 0.5, 0.1];
|
|
top_p_filter(&mut logits, 0.01);
|
|
|
|
// Only the highest probability token should remain
|
|
let finite_count = logits.iter().filter(|&&x| x.is_finite()).count();
|
|
assert!(finite_count >= 1);
|
|
|
|
// The top value should be finite
|
|
assert!(logits[0].is_finite());
|
|
}
|
|
|
|
#[test]
|
|
fn test_config_serialization() {
|
|
let config = SpeculativeConfig {
|
|
lookahead: 6,
|
|
acceptance_threshold: 0.7,
|
|
draft_temperature: 0.1,
|
|
tree_speculation: true,
|
|
max_tree_depth: 4,
|
|
tree_branching_factor: 3,
|
|
draft_top_p: 0.9,
|
|
min_acceptance_ratio: 0.2,
|
|
adaptive_lookahead: true,
|
|
min_lookahead: 3,
|
|
max_lookahead: 10,
|
|
};
|
|
|
|
// Test JSON serialization
|
|
let json = serde_json::to_string(&config).unwrap();
|
|
let deserialized: SpeculativeConfig = serde_json::from_str(&json).unwrap();
|
|
|
|
assert_eq!(deserialized.lookahead, 6);
|
|
assert!(deserialized.tree_speculation);
|
|
}
|
|
|
|
#[test]
|
|
fn test_stats_serialization() {
|
|
let mut stats = SpeculativeStats::new();
|
|
stats.record_round(4, 3, 10.0);
|
|
|
|
let json = serde_json::to_string(&stats).unwrap();
|
|
let deserialized: SpeculativeStats = serde_json::from_str(&json).unwrap();
|
|
|
|
assert_eq!(deserialized.draft_tokens, 4);
|
|
assert_eq!(deserialized.accepted_tokens, 3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_realistic_speculation_scenario() {
|
|
let mut stats = SpeculativeStats::new();
|
|
|
|
// Simulate 100 generation rounds with varying acceptance
|
|
for i in 0..100 {
|
|
let draft_count = 4;
|
|
// Acceptance varies: high at start, lower later (simulating diverse output)
|
|
let accepted = if i < 30 {
|
|
4 // 100% acceptance
|
|
} else if i < 60 {
|
|
3 // 75% acceptance
|
|
} else {
|
|
2 // 50% acceptance
|
|
};
|
|
|
|
stats.record_round(draft_count, accepted, (i as f64) * 0.1);
|
|
}
|
|
|
|
// Verify stats are reasonable
|
|
assert_eq!(stats.draft_tokens, 400);
|
|
assert!(stats.acceptance_rate > 0.5 && stats.acceptance_rate < 1.0);
|
|
assert!(stats.speedup > 1.0); // Should show speedup
|
|
assert_eq!(stats.main_forward_passes, 100);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tree_with_deep_nesting() {
|
|
let mut tree = SpeculationTree::new(5, 2);
|
|
|
|
// Build a deep tree
|
|
fn build_recursive(node: &mut TreeNode, depth: usize, max_depth: usize) {
|
|
if depth >= max_depth {
|
|
return;
|
|
}
|
|
|
|
let child = node.add_child((depth * 10) as u32, 1.0 / (depth + 1) as f32);
|
|
build_recursive(child, depth + 1, max_depth);
|
|
}
|
|
|
|
build_recursive(&mut tree.root, 0, 5);
|
|
|
|
let best = tree.best_path();
|
|
assert_eq!(best.len(), 5);
|
|
}
|