Files
wifi-densepose/examples/ruvLLM/task_specific_adapters.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

229 lines
10 KiB
Rust

//! Task-Specific LoRA Adapters Example
//!
//! This example demonstrates:
//! 1. Using pre-defined adapters for different agent types
//! 2. Training adapters from synthetic datasets
//! 3. Merging multiple adapters
//! 4. Hot-swapping adapters at runtime
//!
//! Run with:
//! ```bash
//! cargo run --example task_specific_adapters --features ruvllm
//! ```
use ruvllm::lora::{
RuvLtraAdapters, AdapterTrainer, AdapterTrainingConfig, SyntheticDataGenerator,
AdapterMerger, MergeConfig, MergeStrategy, HotSwapManager, AdaptFeedback,
};
use std::collections::HashMap;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("🚀 Task-Specific LoRA Adapters Demo\n");
// 1. Explore available adapters
println!("📋 Available Adapters:");
println!("═══════════════════════\n");
let adapters = RuvLtraAdapters::new();
for name in adapters.list_names() {
if let Some(config) = adapters.get(&name) {
println!(" 🔧 {}", name);
println!(" Description: {}", config.description);
println!(" Rank: {}, Alpha: {}", config.rank, config.alpha);
println!(" Target modules: {} modules", config.target_modules.len());
println!(" Memory (768d): {:.2} KB", config.estimate_memory(768) as f32 / 1024.0);
println!(" Tags: {}", config.domain_tags.join(", "));
println!();
}
}
// 2. Create and train adapters
println!("\n🎓 Training Adapters");
println!("═══════════════════════\n");
let hidden_dim = 768;
let generator = SyntheticDataGenerator::new(hidden_dim, 42);
// Train coder adapter
println!(" Training 'coder' adapter...");
let coder_dataset = generator.generate("coder", 1000);
println!(" Dataset: {} train, {} val examples",
coder_dataset.examples.len(),
coder_dataset.validation.len());
let coder_lora = adapters.create_lora("coder", hidden_dim)?;
let mut coder_trainer = AdapterTrainer::new(AdapterTrainingConfig::quick());
let coder_result = coder_trainer.train(&coder_lora, &coder_dataset)?;
println!(" ✓ Completed {} epochs in {} steps",
coder_result.epochs_completed,
coder_result.total_steps);
println!(" Final loss: {:.4}", coder_result.final_loss);
// Train security adapter
println!("\n Training 'security' adapter...");
let security_dataset = generator.generate("security", 1000);
let security_lora = adapters.create_lora("security", hidden_dim)?;
let mut security_trainer = AdapterTrainer::new(AdapterTrainingConfig::quick());
let security_result = security_trainer.train(&security_lora, &security_dataset)?;
println!(" ✓ Completed {} epochs in {} steps",
security_result.epochs_completed,
security_result.total_steps);
// 3. Use adapters for inference
println!("\n\n🔮 Adapter Inference");
println!("═══════════════════════\n");
let test_input = vec![0.5; hidden_dim];
println!(" Coder adapter output:");
let coder_output = coder_lora.forward(&test_input, &ruvllm::lora::TargetModule::QProj);
println!(" Output dim: {}", coder_output.len());
println!(" Mean activation: {:.4}", coder_output.iter().sum::<f32>() / coder_output.len() as f32);
println!("\n Security adapter output:");
let security_output = security_lora.forward(&test_input, &ruvllm::lora::TargetModule::QProj);
println!(" Output dim: {}", security_output.len());
println!(" Mean activation: {:.4}", security_output.iter().sum::<f32>() / security_output.len() as f32);
// 4. Merge adapters
println!("\n\n🔀 Adapter Merging");
println!("═══════════════════════\n");
// Average merge
println!(" Average merge (coder + security):");
let merge_config = MergeConfig::average();
let merger = AdapterMerger::new(merge_config);
let adapters_to_merge = vec![
("coder".to_string(), coder_lora.clone()),
("security".to_string(), security_lora.clone()),
];
let merged = merger.merge(&adapters_to_merge, &adapters.coder, hidden_dim)?;
let merged_output = merged.forward(&test_input, &ruvllm::lora::TargetModule::QProj);
println!(" Mean activation: {:.4}", merged_output.iter().sum::<f32>() / merged_output.len() as f32);
// Weighted merge
println!("\n Weighted merge (70% coder, 30% security):");
let mut weights = HashMap::new();
weights.insert("coder".to_string(), 0.7);
weights.insert("security".to_string(), 0.3);
let weighted_config = MergeConfig::weighted(weights);
let weighted_merger = AdapterMerger::new(weighted_config);
let weighted_merged = weighted_merger.merge(&adapters_to_merge, &adapters.coder, hidden_dim)?;
let weighted_output = weighted_merged.forward(&test_input, &ruvllm::lora::TargetModule::QProj);
println!(" Mean activation: {:.4}", weighted_output.iter().sum::<f32>() / weighted_output.len() as f32);
// SLERP interpolation
println!("\n SLERP interpolation (t=0.5):");
let slerp_config = MergeConfig::slerp(0.5);
let slerp_merger = AdapterMerger::new(slerp_config);
let slerp_merged = slerp_merger.merge(&adapters_to_merge, &adapters.coder, hidden_dim)?;
let slerp_output = slerp_merged.forward(&test_input, &ruvllm::lora::TargetModule::QProj);
println!(" Mean activation: {:.4}", slerp_output.iter().sum::<f32>() / slerp_output.len() as f32);
// 5. Hot-swapping demonstration
println!("\n\n🔄 Hot-Swap Demo");
println!("═══════════════════════\n");
let mut swap_manager = HotSwapManager::new();
println!(" Setting coder as active adapter...");
swap_manager.set_active(coder_lora.clone());
if let Some(active) = swap_manager.active() {
let output = active.forward(&test_input, &ruvllm::lora::TargetModule::QProj);
println!(" Active adapter mean: {:.4}", output.iter().sum::<f32>() / output.len() as f32);
}
println!("\n Preparing security adapter in standby...");
swap_manager.prepare_standby(security_lora.clone());
println!(" Performing hot-swap...");
swap_manager.swap()?;
if let Some(active) = swap_manager.active() {
let output = active.forward(&test_input, &ruvllm::lora::TargetModule::QProj);
println!(" New active adapter mean: {:.4}", output.iter().sum::<f32>() / output.len() as f32);
}
// 6. Adapter composition (multi-task)
println!("\n\n🧩 Multi-Task Composition");
println!("═══════════════════════\n");
println!(" Creating researcher adapter...");
let researcher_dataset = generator.generate("researcher", 1000);
let researcher_lora = adapters.create_lora("researcher", hidden_dim)?;
let mut researcher_trainer = AdapterTrainer::new(AdapterTrainingConfig::quick());
researcher_trainer.train(&researcher_lora, &researcher_dataset)?;
println!("\n TIES merge (coder + security + researcher):");
let ties_adapters = vec![
("coder".to_string(), coder_lora.clone()),
("security".to_string(), security_lora.clone()),
("researcher".to_string(), researcher_lora.clone()),
];
let ties_config = MergeConfig::ties(0.6);
let ties_merger = AdapterMerger::new(ties_config);
let ties_merged = ties_merger.merge(&ties_adapters, &adapters.coder, hidden_dim)?;
let ties_output = ties_merged.forward(&test_input, &ruvllm::lora::TargetModule::QProj);
println!(" Mean activation: {:.4}", ties_output.iter().sum::<f32>() / ties_output.len() as f32);
// 7. Per-request adaptation
println!("\n\n⚡ Per-Request Adaptation");
println!("═══════════════════════\n");
println!(" Baseline output:");
let baseline = coder_lora.forward(&test_input, &ruvllm::lora::TargetModule::QProj);
println!(" Mean: {:.4}", baseline.iter().sum::<f32>() / baseline.len() as f32);
println!("\n Adapting with high-quality feedback...");
let feedback = AdaptFeedback::from_quality(0.95);
coder_lora.adapt(&test_input, feedback)?;
coder_lora.apply_updates(0.01);
let adapted = coder_lora.forward(&test_input, &ruvllm::lora::TargetModule::QProj);
println!(" Mean after adaptation: {:.4}", adapted.iter().sum::<f32>() / adapted.len() as f32);
println!(" Change: {:.4}",
(adapted.iter().sum::<f32>() - baseline.iter().sum::<f32>()) / baseline.len() as f32);
// 8. Save and load adapters
println!("\n\n💾 Persistence");
println!("═══════════════════════\n");
let save_path = "/tmp/coder_adapter.bin";
println!(" Saving coder adapter to {}...", save_path);
coder_lora.save(save_path)?;
println!(" ✓ Saved");
println!("\n Loading adapter...");
let loaded_lora = ruvllm::lora::MicroLoRA::load(save_path)?;
println!(" ✓ Loaded");
println!(" Params: {}", loaded_lora.param_count());
println!(" Memory: {:.2} KB", loaded_lora.memory_bytes() as f32 / 1024.0);
// 9. Performance summary
println!("\n\n📊 Performance Summary");
println!("═══════════════════════\n");
println!(" Coder Adapter:");
println!(" Rank: {}", adapters.coder.rank);
println!(" Parameters: {}", coder_lora.param_count());
println!(" Memory: {:.2} KB", coder_lora.memory_bytes() as f32 / 1024.0);
println!(" Forward passes: {}", coder_lora.forward_count());
println!(" Adaptations: {}", coder_lora.adaptation_count());
println!("\n Security Adapter:");
println!(" Rank: {}", adapters.security.rank);
println!(" Parameters: {}", security_lora.param_count());
println!(" Memory: {:.2} KB", security_lora.memory_bytes() as f32 / 1024.0);
println!("\n✨ Demo Complete!\n");
Ok(())
}