Files
wifi-densepose/vendor/ruvector/crates/ruvllm/src/training/tests.rs

415 lines
12 KiB
Rust

//! Comprehensive tests for Claude task dataset generation
#[cfg(test)]
mod tests {
use super::super::*;
#[test]
fn test_basic_dataset_generation() {
let config = DatasetConfig {
examples_per_category: 5,
enable_augmentation: false,
..Default::default()
};
let mut generator = DatasetGenerator::new(config);
let dataset = generator.generate();
// 5 categories * 5 examples = 25 total
assert_eq!(dataset.examples.len(), 25);
assert_eq!(dataset.stats.total_examples, 25);
}
#[test]
fn test_category_distribution() {
let config = DatasetConfig {
examples_per_category: 10,
enable_augmentation: false,
..Default::default()
};
let mut generator = DatasetGenerator::new(config);
let dataset = generator.generate();
// Check each category has exactly 10 examples
for category in TaskCategory::all() {
let count = dataset
.stats
.examples_per_category
.get(category.name())
.unwrap_or(&0);
assert_eq!(
*count,
10,
"Category {} should have 10 examples",
category.name()
);
}
}
#[test]
fn test_augmentation_increases_dataset() {
let config_no_aug = DatasetConfig {
examples_per_category: 5,
enable_augmentation: false,
..Default::default()
};
let config_with_aug = DatasetConfig {
examples_per_category: 5,
enable_augmentation: true,
augmentation: AugmentationConfig {
paraphrases_per_example: 1,
complexity_variations: 1,
enable_domain_transfer: true,
},
..Default::default()
};
let mut gen_no_aug = DatasetGenerator::new(config_no_aug);
let dataset_no_aug = gen_no_aug.generate();
let mut gen_with_aug = DatasetGenerator::new(config_with_aug);
let dataset_with_aug = gen_with_aug.generate();
// Augmented dataset should be larger
assert!(
dataset_with_aug.examples.len() > dataset_no_aug.examples.len(),
"Augmented dataset should be larger: {} vs {}",
dataset_with_aug.examples.len(),
dataset_no_aug.examples.len()
);
}
#[test]
fn test_model_recommendation_logic() {
// Coder category
assert_eq!(
TaskCategory::Coder.recommended_model(ComplexityLevel::Simple),
"haiku"
);
assert_eq!(
TaskCategory::Coder.recommended_model(ComplexityLevel::Moderate),
"sonnet"
);
assert_eq!(
TaskCategory::Coder.recommended_model(ComplexityLevel::Complex),
"opus"
);
// Security category (always opus)
assert_eq!(
TaskCategory::Security.recommended_model(ComplexityLevel::Simple),
"opus"
);
assert_eq!(
TaskCategory::Security.recommended_model(ComplexityLevel::Moderate),
"opus"
);
assert_eq!(
TaskCategory::Security.recommended_model(ComplexityLevel::Complex),
"opus"
);
// Architecture category
assert_eq!(
TaskCategory::Architecture.recommended_model(ComplexityLevel::Simple),
"sonnet"
);
assert_eq!(
TaskCategory::Architecture.recommended_model(ComplexityLevel::Moderate),
"opus"
);
}
#[test]
fn test_quality_scores_in_range() {
let config = DatasetConfig {
examples_per_category: 20,
enable_augmentation: true,
..Default::default()
};
let mut generator = DatasetGenerator::new(config);
let dataset = generator.generate();
for example in &dataset.examples {
assert!(
example.metadata.quality_score >= 0.0 && example.metadata.quality_score <= 1.0,
"Quality score must be in [0, 1]: {}",
example.metadata.quality_score
);
}
// Average quality should be reasonable
assert!(
dataset.stats.avg_quality_score >= 0.7 && dataset.stats.avg_quality_score <= 1.0,
"Average quality should be good: {}",
dataset.stats.avg_quality_score
);
}
#[test]
fn test_dataset_split_ratios() {
let config = DatasetConfig {
examples_per_category: 20,
enable_augmentation: false,
..Default::default()
};
let mut generator = DatasetGenerator::new(config);
let dataset = generator.generate();
let (train, val, test) = dataset.split(0.7, 0.15, 0.15, 42);
let total = train.len() + val.len() + test.len();
assert_eq!(total, dataset.examples.len());
// Check approximate ratios (allow small rounding errors)
let train_ratio = train.len() as f32 / total as f32;
let val_ratio = val.len() as f32 / total as f32;
let test_ratio = test.len() as f32 / total as f32;
assert!(
(train_ratio - 0.7).abs() < 0.05,
"Train ratio should be ~0.7: {}",
train_ratio
);
assert!(
(val_ratio - 0.15).abs() < 0.05,
"Val ratio should be ~0.15: {}",
val_ratio
);
assert!(
(test_ratio - 0.15).abs() < 0.05,
"Test ratio should be ~0.15: {}",
test_ratio
);
}
#[test]
fn test_dataset_split_deterministic() {
let config = DatasetConfig {
examples_per_category: 10,
enable_augmentation: false,
seed: 42,
..Default::default()
};
let mut gen1 = DatasetGenerator::new(config.clone());
let dataset1 = gen1.generate();
let (train1, _, _) = dataset1.split(0.7, 0.15, 0.15, 42);
let mut gen2 = DatasetGenerator::new(config);
let dataset2 = gen2.generate();
let (train2, _, _) = dataset2.split(0.7, 0.15, 0.15, 42);
// Same seed should produce same split
assert_eq!(train1.len(), train2.len());
for (ex1, ex2) in train1.iter().zip(train2.iter()) {
assert_eq!(ex1.input, ex2.input);
}
}
#[test]
fn test_all_categories_present() {
let config = DatasetConfig {
examples_per_category: 10,
enable_augmentation: false,
..Default::default()
};
let mut generator = DatasetGenerator::new(config);
let dataset = generator.generate();
let mut categories_seen = std::collections::HashSet::new();
for example in &dataset.examples {
categories_seen.insert(example.metadata.category);
}
// Should see all 5 categories
assert_eq!(categories_seen.len(), 5);
assert!(categories_seen.contains(&TaskCategory::Coder));
assert!(categories_seen.contains(&TaskCategory::Researcher));
assert!(categories_seen.contains(&TaskCategory::Security));
assert!(categories_seen.contains(&TaskCategory::Architecture));
assert!(categories_seen.contains(&TaskCategory::Reviewer));
}
#[test]
fn test_complexity_levels_present() {
let config = DatasetConfig {
examples_per_category: 20,
enable_augmentation: true,
augmentation: AugmentationConfig {
paraphrases_per_example: 0,
complexity_variations: 2,
enable_domain_transfer: false,
},
..Default::default()
};
let mut generator = DatasetGenerator::new(config);
let dataset = generator.generate();
let mut complexities_seen = std::collections::HashSet::new();
for example in &dataset.examples {
complexities_seen.insert(example.metadata.complexity);
}
// Should see all 3 complexity levels due to variations
assert!(complexities_seen.contains(&ComplexityLevel::Simple));
assert!(complexities_seen.contains(&ComplexityLevel::Moderate));
assert!(complexities_seen.contains(&ComplexityLevel::Complex));
}
#[test]
fn test_domain_diversity() {
let config = DatasetConfig {
examples_per_category: 30,
enable_augmentation: false,
..Default::default()
};
let mut generator = DatasetGenerator::new(config);
let dataset = generator.generate();
let mut domains_seen = std::collections::HashSet::new();
for example in &dataset.examples {
domains_seen.insert(example.metadata.domain);
}
// Should see multiple domains
assert!(
domains_seen.len() >= 3,
"Should have at least 3 different domains"
);
}
#[test]
fn test_tags_not_empty() {
let config = DatasetConfig {
examples_per_category: 10,
enable_augmentation: false,
..Default::default()
};
let mut generator = DatasetGenerator::new(config);
let dataset = generator.generate();
for example in &dataset.examples {
assert!(
!example.metadata.tags.is_empty(),
"Examples should have tags"
);
}
}
#[test]
fn test_output_agent_matches_category() {
let config = DatasetConfig {
examples_per_category: 10,
enable_augmentation: false,
..Default::default()
};
let mut generator = DatasetGenerator::new(config);
let dataset = generator.generate();
for example in &dataset.examples {
assert_eq!(
example.output_agent,
example.metadata.category.name(),
"Output agent should match category"
);
}
}
#[test]
fn test_expected_model_is_valid() {
let config = DatasetConfig {
examples_per_category: 10,
enable_augmentation: false,
..Default::default()
};
let mut generator = DatasetGenerator::new(config);
let dataset = generator.generate();
for example in &dataset.examples {
let model = &example.metadata.expected_model;
assert!(
model == "haiku" || model == "sonnet" || model == "opus",
"Expected model should be haiku, sonnet, or opus: {}",
model
);
}
}
#[test]
fn test_reproducibility_with_seed() {
let config1 = DatasetConfig {
examples_per_category: 10,
enable_augmentation: false,
seed: 12345,
..Default::default()
};
let config2 = DatasetConfig {
examples_per_category: 10,
enable_augmentation: false,
seed: 12345,
..Default::default()
};
let mut gen1 = DatasetGenerator::new(config1);
let dataset1 = gen1.generate();
let mut gen2 = DatasetGenerator::new(config2);
let dataset2 = gen2.generate();
// Same seed should produce same examples
assert_eq!(dataset1.examples.len(), dataset2.examples.len());
for (ex1, ex2) in dataset1.examples.iter().zip(dataset2.examples.iter()) {
assert_eq!(ex1.input, ex2.input);
assert_eq!(ex1.output_agent, ex2.output_agent);
}
}
#[test]
fn test_different_seeds_produce_different_data() {
let config1 = DatasetConfig {
examples_per_category: 10,
enable_augmentation: false,
seed: 111,
..Default::default()
};
let config2 = DatasetConfig {
examples_per_category: 10,
enable_augmentation: false,
seed: 222,
..Default::default()
};
let mut gen1 = DatasetGenerator::new(config1);
let dataset1 = gen1.generate();
let mut gen2 = DatasetGenerator::new(config2);
let dataset2 = gen2.generate();
// Different seeds should produce different examples
let mut different_count = 0;
for (ex1, ex2) in dataset1.examples.iter().zip(dataset2.examples.iter()) {
if ex1.input != ex2.input {
different_count += 1;
}
}
assert!(
different_count > 0,
"Different seeds should produce at least some different examples"
);
}
}