Files
wifi-densepose/crates/ruvllm/examples/generate_claude_dataset.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

238 lines
7.9 KiB
Rust
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#![allow(
clippy::all,
unused_imports,
unused_variables,
dead_code,
unused_mut,
unused_assignments,
non_camel_case_types,
clippy::approx_constant,
unexpected_cfgs,
unused_must_use,
unused_parens
)]
//! # Claude Task Dataset Generation Example
//!
//! This example demonstrates how to generate a comprehensive fine-tuning dataset
//! for RuvLTRA models trained on Claude Flow agent tasks.
//!
//! ## Usage
//!
//! ```bash
//! cargo run --example generate_claude_dataset --release
//! ```
//!
//! This will generate:
//! - `claude_training_full.jsonl` - Full dataset in JSONL format
//! - `claude_training_train.jsonl` - Training split (70%)
//! - `claude_training_val.jsonl` - Validation split (15%)
//! - `claude_training_test.jsonl` - Test split (15%)
//! - `claude_training_stats.json` - Dataset statistics
use ruvllm::training::{
AugmentationConfig, ClaudeTaskDataset, DatasetConfig, DatasetGenerator, TaskCategory,
};
use std::error::Error;
fn main() -> Result<(), Box<dyn Error>> {
println!("🚀 Claude Task Dataset Generator");
println!("═══════════════════════════════════════════════════\n");
// Configure dataset generation
let config = DatasetConfig {
examples_per_category: 100,
enable_augmentation: true,
augmentation: AugmentationConfig {
paraphrases_per_example: 2,
complexity_variations: 2,
enable_domain_transfer: true,
},
seed: 42,
};
println!("📋 Configuration:");
println!(
" • Examples per category: {}",
config.examples_per_category
);
println!(" • Augmentation enabled: {}", config.enable_augmentation);
println!(
" • Paraphrases per example: {}",
config.augmentation.paraphrases_per_example
);
println!(
" • Complexity variations: {}",
config.augmentation.complexity_variations
);
println!(
" • Domain transfer: {}\n",
config.augmentation.enable_domain_transfer
);
// Generate dataset
println!("⚙️ Generating dataset...");
let mut generator = DatasetGenerator::new(config);
let dataset = generator.generate();
println!("✅ Dataset generated!\n");
// Print statistics
print_statistics(&dataset);
// Export full dataset
println!("\n💾 Exporting datasets...");
dataset.export_jsonl("claude_training_full.jsonl")?;
println!(
" ✓ Full dataset: claude_training_full.jsonl ({} examples)",
dataset.examples.len()
);
dataset.export_json("claude_training_full.json")?;
println!(" ✓ Full dataset JSON: claude_training_full.json");
// Split and export
let (train, val, test) = dataset.split(0.7, 0.15, 0.15, 42);
let train_dataset = ClaudeTaskDataset::new(train);
train_dataset.export_jsonl("claude_training_train.jsonl")?;
println!(
" ✓ Training set: claude_training_train.jsonl ({} examples)",
train_dataset.examples.len()
);
let val_dataset = ClaudeTaskDataset::new(val);
val_dataset.export_jsonl("claude_training_val.jsonl")?;
println!(
" ✓ Validation set: claude_training_val.jsonl ({} examples)",
val_dataset.examples.len()
);
let test_dataset = ClaudeTaskDataset::new(test);
test_dataset.export_jsonl("claude_training_test.jsonl")?;
println!(
" ✓ Test set: claude_training_test.jsonl ({} examples)",
test_dataset.examples.len()
);
// Export statistics
dataset.export_stats("claude_training_stats.json")?;
println!(" ✓ Statistics: claude_training_stats.json\n");
// Print sample examples
print_sample_examples(&dataset);
// Print model routing analysis
print_model_routing_analysis(&dataset);
println!("\n✨ Dataset generation complete!");
println!(" Total examples: {}", dataset.examples.len());
println!(" Ready for fine-tuning RuvLTRA models\n");
Ok(())
}
fn print_statistics(dataset: &ClaudeTaskDataset) {
println!("📊 Dataset Statistics:");
println!(" ═══════════════════════════════════════════════════");
println!(" Total examples: {}", dataset.stats.total_examples);
println!(
" Average quality score: {:.2}",
dataset.stats.avg_quality_score
);
println!("\n 📂 Examples by Category:");
for category in TaskCategory::all() {
let count = dataset
.stats
.examples_per_category
.get(category.name())
.unwrap_or(&0);
let percentage = (*count as f32 / dataset.stats.total_examples as f32) * 100.0;
println!(
"{:12} {:4} ({:5.1}%)",
category.name(),
count,
percentage
);
}
println!("\n 📈 Examples by Complexity:");
for (complexity, count) in &dataset.stats.examples_per_complexity {
let percentage = (*count as f32 / dataset.stats.total_examples as f32) * 100.0;
println!("{:12} {:4} ({:5.1}%)", complexity, count, percentage);
}
println!("\n 🏷️ Examples by Domain:");
for (domain, count) in &dataset.stats.examples_per_domain {
let percentage = (*count as f32 / dataset.stats.total_examples as f32) * 100.0;
println!("{:12} {:4} ({:5.1}%)", domain, count, percentage);
}
}
fn print_sample_examples(dataset: &ClaudeTaskDataset) {
println!("📝 Sample Examples:");
println!(" ═══════════════════════════════════════════════════");
for category in TaskCategory::all() {
let sample = dataset
.examples
.iter()
.find(|e| e.metadata.category == category);
if let Some(example) = sample {
println!(
"\n 🔹 {} ({})",
category.name(),
example.metadata.expected_model
);
println!(
" Complexity: {:?}, Domain: {:?}",
example.metadata.complexity, example.metadata.domain
);
println!(" Input: {}", truncate(&example.input, 80));
println!(" Context: {}", truncate(&example.context, 80));
println!(" Quality: {:.2}", example.metadata.quality_score);
}
}
}
fn print_model_routing_analysis(dataset: &ClaudeTaskDataset) {
println!("\n🎯 Model Routing Analysis:");
println!(" ═══════════════════════════════════════════════════");
let mut model_counts = std::collections::HashMap::new();
for example in &dataset.examples {
*model_counts
.entry(&example.metadata.expected_model)
.or_insert(0) += 1;
}
for (model, count) in model_counts.iter() {
let percentage = (*count as f32 / dataset.stats.total_examples as f32) * 100.0;
let cost_indicator = match model.as_str() {
"haiku" => "💰 (cheapest)",
"sonnet" => "💰💰 (balanced)",
"opus" => "💰💰💰 (most capable)",
_ => "",
};
println!(
"{:8} {:4} ({:5.1}%) {}",
model, count, percentage, cost_indicator
);
}
println!("\n Model Selection Guide:");
println!(" • Haiku: Simple tasks, fast responses, low cost");
println!(" • Sonnet: Balanced complexity, moderate cost");
println!(" • Opus: Complex reasoning, highest quality");
}
fn truncate(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}...", &s[..max_len - 3])
}
}