141 lines
4.5 KiB
Rust
141 lines
4.5 KiB
Rust
//! REFRAG Pipeline Criterion Benchmarks
|
|
|
|
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
|
use rand::Rng;
|
|
|
|
use refrag_pipeline_example::{
|
|
compress::{CompressionStrategy, TensorCompressor},
|
|
expand::Projector,
|
|
sense::{LinearPolicy, MLPPolicy, PolicyModel, ThresholdPolicy},
|
|
store::RefragStoreBuilder,
|
|
types::RefragEntry,
|
|
};
|
|
|
|
fn bench_compression(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("compression");
|
|
|
|
for dim in [384, 768, 1024, 2048] {
|
|
let mut rng = rand::thread_rng();
|
|
let vector: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
|
|
|
|
for (name, strategy) in [
|
|
("none", CompressionStrategy::None),
|
|
("float16", CompressionStrategy::Float16),
|
|
("int8", CompressionStrategy::Int8),
|
|
("binary", CompressionStrategy::Binary),
|
|
] {
|
|
let compressor = TensorCompressor::new(dim).with_strategy(strategy);
|
|
|
|
group.throughput(Throughput::Elements(1));
|
|
group.bench_with_input(BenchmarkId::new(name, dim), &vector, |b, v| {
|
|
b.iter(|| compressor.compress(black_box(v)))
|
|
});
|
|
}
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_policy(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("policy");
|
|
|
|
for dim in [384, 768] {
|
|
let mut rng = rand::thread_rng();
|
|
let chunk: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
|
|
let query: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
|
|
|
|
// Threshold policy
|
|
let threshold = ThresholdPolicy::new(0.5);
|
|
group.bench_with_input(
|
|
BenchmarkId::new("threshold", dim),
|
|
&(&chunk, &query),
|
|
|b, (c, q)| b.iter(|| threshold.decide(black_box(c), black_box(q))),
|
|
);
|
|
|
|
// Linear policy
|
|
let linear = LinearPolicy::new(dim, 0.5);
|
|
group.bench_with_input(
|
|
BenchmarkId::new("linear", dim),
|
|
&(&chunk, &query),
|
|
|b, (c, q)| b.iter(|| linear.decide(black_box(c), black_box(q))),
|
|
);
|
|
|
|
// MLP policy
|
|
let mlp = MLPPolicy::new(dim, 32, 0.5);
|
|
group.bench_with_input(
|
|
BenchmarkId::new("mlp_32", dim),
|
|
&(&chunk, &query),
|
|
|b, (c, q)| b.iter(|| mlp.decide(black_box(c), black_box(q))),
|
|
);
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_projection(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("projection");
|
|
|
|
for (source, target) in [(768, 4096), (768, 8192), (1536, 8192)] {
|
|
let mut rng = rand::thread_rng();
|
|
let input: Vec<f32> = (0..source).map(|_| rng.gen_range(-1.0..1.0)).collect();
|
|
let projector = Projector::new(source, target, "test");
|
|
|
|
group.throughput(Throughput::Elements(1));
|
|
group.bench_with_input(
|
|
BenchmarkId::new(format!("{}->{}", source, target), source),
|
|
&input,
|
|
|b, v| b.iter(|| projector.project(black_box(v))),
|
|
);
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_search(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("search");
|
|
|
|
let search_dim = 384;
|
|
let tensor_dim = 768;
|
|
|
|
for num_docs in [100, 1000, 10000] {
|
|
let store = RefragStoreBuilder::new()
|
|
.search_dimensions(search_dim)
|
|
.tensor_dimensions(tensor_dim)
|
|
.compress_threshold(0.5)
|
|
.auto_project(false)
|
|
.build()
|
|
.unwrap();
|
|
|
|
let mut rng = rand::thread_rng();
|
|
|
|
// Insert documents
|
|
for i in 0..num_docs {
|
|
let search_vec: Vec<f32> = (0..search_dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
|
|
let tensor_vec: Vec<f32> = (0..tensor_dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
|
|
let tensor_bytes: Vec<u8> = tensor_vec.iter().flat_map(|f| f.to_le_bytes()).collect();
|
|
|
|
let entry = RefragEntry::new(format!("doc_{}", i), search_vec, format!("Text {}", i))
|
|
.with_tensor(tensor_bytes, "llama3-8b");
|
|
store.insert(entry).unwrap();
|
|
}
|
|
|
|
let query: Vec<f32> = (0..search_dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
|
|
|
|
group.throughput(Throughput::Elements(1));
|
|
group.bench_with_input(BenchmarkId::new("hybrid_k10", num_docs), &query, |b, q| {
|
|
b.iter(|| store.search_hybrid(black_box(q), 10, None))
|
|
});
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
criterion_group!(
|
|
benches,
|
|
bench_compression,
|
|
bench_policy,
|
|
bench_projection,
|
|
bench_search,
|
|
);
|
|
criterion_main!(benches);
|