Files
wifi-densepose/vendor/ruvector/examples/refrag-pipeline/src/main.rs

240 lines
8.6 KiB
Rust

//! REFRAG Pipeline Demo
//!
//! This example demonstrates the full REFRAG (Compress-Sense-Expand) pipeline
//! for ~30x latency reduction in RAG systems.
//!
//! Run with: cargo run --bin refrag-demo
use refrag_pipeline_example::{
compress::CompressionStrategy,
expand::ExpandLayer,
sense::PolicyNetwork,
store::RefragStoreBuilder,
types::{RefragEntry, RefragResponseType},
};
use rand::Rng;
use std::time::Instant;
fn main() -> anyhow::Result<()> {
// Initialize logging
tracing_subscriber::fmt()
.with_env_filter("refrag=debug,info")
.init();
println!("=================================================");
println!(" REFRAG Pipeline Demo - Compress-Sense-Expand ");
println!("=================================================\n");
// Configuration
let search_dim = 384; // Sentence embedding dimension
let tensor_dim = 768; // Representation tensor dimension (RoBERTa)
let num_documents = 1000;
let num_queries = 100;
let k = 10;
println!("Configuration:");
println!(" - Search dimensions: {}", search_dim);
println!(" - Tensor dimensions: {}", tensor_dim);
println!(" - Documents: {}", num_documents);
println!(" - Queries: {}", num_queries);
println!(" - Top-K: {}\n", k);
// Create REFRAG store with different policy thresholds
let thresholds = [0.3, 0.5, 0.7, 0.9];
for threshold in thresholds {
println!("--- Testing with threshold: {:.1} ---\n", threshold);
let store = RefragStoreBuilder::new()
.search_dimensions(search_dim)
.tensor_dimensions(tensor_dim)
.compress_threshold(threshold)
.auto_project(false) // Disable projection for speed
.build()?;
// Generate and insert documents
println!("Inserting {} documents...", num_documents);
let insert_start = Instant::now();
let mut rng = rand::thread_rng();
for i in 0..num_documents {
let search_vec: Vec<f32> = (0..search_dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
let tensor_vec: Vec<f32> = (0..tensor_dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
let tensor_bytes: Vec<u8> = tensor_vec.iter().flat_map(|f| f.to_le_bytes()).collect();
let entry = RefragEntry::new(
format!("doc_{}", i),
search_vec,
format!("This is the text content for document {}. It contains important information that might be relevant to various queries.", i),
)
.with_tensor(tensor_bytes, "llama3-8b")
.with_metadata("source", serde_json::json!("synthetic"))
.with_metadata("index", serde_json::json!(i));
store.insert(entry)?;
}
let insert_time = insert_start.elapsed();
println!(
" Inserted in {:.2}ms ({:.0} docs/sec)\n",
insert_time.as_secs_f64() * 1000.0,
num_documents as f64 / insert_time.as_secs_f64()
);
// Run queries
println!("Running {} hybrid searches...", num_queries);
let search_start = Instant::now();
let mut total_results = 0;
let mut compress_count = 0;
let mut expand_count = 0;
for _ in 0..num_queries {
let query: Vec<f32> = (0..search_dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
let results = store.search_hybrid(&query, k, None)?;
for result in &results {
total_results += 1;
match result.response_type {
RefragResponseType::Compress => compress_count += 1,
RefragResponseType::Expand => expand_count += 1,
}
}
}
let search_time = search_start.elapsed();
let avg_query_time_us = search_time.as_micros() as f64 / num_queries as f64;
println!(
" Total search time: {:.2}ms",
search_time.as_secs_f64() * 1000.0
);
println!(" Average query time: {:.1}us", avg_query_time_us);
println!(
" QPS: {:.0}",
num_queries as f64 / search_time.as_secs_f64()
);
// Results breakdown
let compress_ratio = compress_count as f64 / total_results as f64 * 100.0;
println!("\nResults breakdown:");
println!(
" - COMPRESS (tensor): {} ({:.1}%)",
compress_count, compress_ratio
);
println!(
" - EXPAND (text): {} ({:.1}%)",
expand_count,
100.0 - compress_ratio
);
// Statistics
let stats = store.stats();
println!("\nStore statistics:");
println!(" - Total searches: {}", stats.total_searches);
println!(" - Avg policy time: {:.1}us", stats.avg_policy_time_us);
println!(
" - Compression ratio: {:.1}%",
stats.compression_ratio() * 100.0
);
println!();
}
// Demo: Show actual search results
println!("=================================================");
println!(" Example Search Results ");
println!("=================================================\n");
let demo_store = RefragStoreBuilder::new()
.search_dimensions(search_dim)
.tensor_dimensions(tensor_dim)
.compress_threshold(0.5)
.build()?;
// Insert some demo documents
let demo_docs = [
("doc_ml", "Machine learning is a subset of artificial intelligence that enables systems to learn from data."),
("doc_dl", "Deep learning uses neural networks with multiple layers to model complex patterns."),
("doc_nlp", "Natural language processing allows computers to understand human language."),
("doc_cv", "Computer vision enables machines to interpret and understand visual information."),
("doc_rl", "Reinforcement learning trains agents through rewards and punishments."),
];
let mut rng = rand::thread_rng();
for (id, text) in demo_docs {
let search_vec: Vec<f32> = (0..search_dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
let tensor_vec: Vec<f32> = (0..tensor_dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
let tensor_bytes: Vec<u8> = tensor_vec.iter().flat_map(|f| f.to_le_bytes()).collect();
let entry = RefragEntry::new(id, search_vec, text).with_tensor(tensor_bytes, "llama3-8b");
demo_store.insert(entry)?;
}
let query: Vec<f32> = (0..search_dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
let results = demo_store.search_hybrid(&query, 3, None)?;
println!("Query: [synthetic vector]\n");
println!("Results:");
for (i, result) in results.iter().enumerate() {
println!(
" {}. ID: {} (score: {:.3})",
i + 1,
result.id,
result.score
);
println!(" Type: {:?}", result.response_type);
println!(" Confidence: {:.2}", result.policy_confidence);
match result.response_type {
RefragResponseType::Expand => {
if let Some(content) = &result.content {
println!(" Content: \"{}...\"", &content[..content.len().min(60)]);
}
}
RefragResponseType::Compress => {
if let Some(dims) = result.tensor_dims {
println!(" Tensor: {} dimensions", dims);
}
if let Some(model) = &result.alignment_model_id {
println!(" Aligned to: {}", model);
}
}
}
println!();
}
// Latency comparison
println!("=================================================");
println!(" Latency Comparison: Text vs Tensor ");
println!("=================================================\n");
let text_sizes = [100, 500, 1000, 2000, 5000];
let tensor_dims = [768, 1024, 2048, 4096];
println!("Text response sizes (bytes):");
for size in text_sizes {
println!(" - {} chars = {} bytes", size, size);
}
println!("\nTensor response sizes (bytes):");
for dim in tensor_dims {
let bytes = dim * 4; // f32
let b64_bytes = (bytes * 4 + 2) / 3; // Base64 overhead
println!(
" - {} dims = {} bytes (raw), ~{} bytes (base64)",
dim, bytes, b64_bytes
);
}
println!("\nEstimated latency savings:");
println!(" - Network transfer: ~10-50x reduction");
println!(" - LLM context window: Direct tensor injection vs tokenization");
println!(" - Policy overhead: <50us per decision");
println!("\nDone!");
Ok(())
}