Files
wifi-densepose/crates/ruvector-core/tests/concurrent_tests.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

441 lines
12 KiB
Rust

//! Concurrent access tests with multiple threads
//!
//! These tests verify thread-safety and correct behavior under concurrent access.
use ruvector_core::types::{DbOptions, HnswConfig, SearchQuery};
use ruvector_core::{VectorDB, VectorEntry};
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use std::thread;
use tempfile::tempdir;
#[test]
fn test_concurrent_reads() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("concurrent_reads.db")
.to_string_lossy()
.to_string();
options.dimensions = 32;
let db = Arc::new(VectorDB::new(options).unwrap());
// Insert initial data
for i in 0..100 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..32).map(|j| ((i + j) as f32) * 0.1).collect(),
metadata: None,
})
.unwrap();
}
// Spawn multiple reader threads
let num_threads = 10;
let mut handles = vec![];
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..50 {
let id = format!("vec_{}", (thread_id * 10 + i) % 100);
let result = db_clone.get(&id).unwrap();
assert!(result.is_some(), "Failed to get {}", id);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_concurrent_writes_no_collision() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("concurrent_writes.db")
.to_string_lossy()
.to_string();
options.dimensions = 32;
let db = Arc::new(VectorDB::new(options).unwrap());
// Spawn multiple writer threads with non-overlapping IDs
let num_threads = 10;
let vectors_per_thread = 20;
let mut handles = vec![];
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..vectors_per_thread {
let id = format!("thread_{}_{}", thread_id, i);
db_clone
.insert(VectorEntry {
id: Some(id.clone()),
vector: vec![thread_id as f32; 32],
metadata: None,
})
.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify all vectors were inserted
assert_eq!(db.len().unwrap(), num_threads * vectors_per_thread);
}
#[test]
fn test_concurrent_delete_and_insert() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("concurrent_delete_insert.db")
.to_string_lossy()
.to_string();
options.dimensions = 16;
let db = Arc::new(VectorDB::new(options).unwrap());
// Insert initial data
for i in 0..100 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: vec![i as f32; 16],
metadata: None,
})
.unwrap();
}
let num_threads = 5;
let mut handles = vec![];
// Deleter threads
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..10 {
let id = format!("vec_{}", thread_id * 10 + i);
db_clone.delete(&id).unwrap();
}
});
handles.push(handle);
}
// Inserter threads
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..10 {
let id = format!("new_{}_{}", thread_id, i);
db_clone
.insert(VectorEntry {
id: Some(id),
vector: vec![(thread_id * 100 + i) as f32; 16],
metadata: None,
})
.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify database is in consistent state
let final_len = db.len().unwrap();
assert_eq!(final_len, 100); // 100 original - 50 deleted + 50 inserted
}
#[test]
fn test_concurrent_search_and_insert() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("concurrent_search_insert.db")
.to_string_lossy()
.to_string();
options.dimensions = 64;
options.hnsw_config = Some(HnswConfig::default());
let db = Arc::new(VectorDB::new(options).unwrap());
// Insert initial data
for i in 0..100 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..64).map(|j| ((i + j) as f32) * 0.01).collect(),
metadata: None,
})
.unwrap();
}
let num_search_threads = 5;
let num_insert_threads = 2;
let mut handles = vec![];
// Search threads
for search_id in 0..num_search_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..20 {
let query: Vec<f32> = (0..64)
.map(|j| ((search_id * 10 + i + j) as f32) * 0.01)
.collect();
let results = db_clone
.search(SearchQuery {
vector: query,
k: 5,
filter: None,
ef_search: None,
})
.unwrap();
// Should always return some results (at least from initial data)
assert!(results.len() > 0);
}
});
handles.push(handle);
}
// Insert threads
for insert_id in 0..num_insert_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..50 {
db_clone
.insert(VectorEntry {
id: Some(format!("new_{}_{}", insert_id, i)),
vector: (0..64)
.map(|j| ((insert_id * 1000 + i + j) as f32) * 0.01)
.collect(),
metadata: None,
})
.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify final state
assert_eq!(db.len().unwrap(), 200); // 100 initial + 100 new
}
#[test]
fn test_atomicity_of_batch_insert() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("atomic_batch.db")
.to_string_lossy()
.to_string();
options.dimensions = 16;
let db = Arc::new(VectorDB::new(options).unwrap());
// Track successful insertions
let inserted_ids = Arc::new(Mutex::new(HashSet::new()));
let num_threads = 5;
let mut handles = vec![];
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let ids_clone = Arc::clone(&inserted_ids);
let handle = thread::spawn(move || {
for batch_idx in 0..10 {
let vectors: Vec<VectorEntry> = (0..10)
.map(|i| {
let id = format!("t{}_b{}_v{}", thread_id, batch_idx, i);
VectorEntry {
id: Some(id.clone()),
vector: vec![(thread_id * 100 + batch_idx * 10 + i) as f32; 16],
metadata: None,
}
})
.collect();
let ids = db_clone.insert_batch(vectors).unwrap();
let mut lock = ids_clone.lock().unwrap();
for id in ids {
lock.insert(id);
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify all insertions were recorded
let total_inserted = inserted_ids.lock().unwrap().len();
assert_eq!(total_inserted, num_threads * 10 * 10); // threads * batches * vectors_per_batch
assert_eq!(db.len().unwrap(), total_inserted);
}
#[test]
fn test_read_write_consistency() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("consistency.db")
.to_string_lossy()
.to_string();
options.dimensions = 32;
let db = Arc::new(VectorDB::new(options).unwrap());
// Insert initial vector
db.insert(VectorEntry {
id: Some("test".to_string()),
vector: vec![1.0; 32],
metadata: None,
})
.unwrap();
let num_threads = 10;
let mut handles = vec![];
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for _ in 0..100 {
// Read
let entry = db_clone.get("test").unwrap();
assert!(entry.is_some());
// Verify vector is consistent
let vector = entry.unwrap().vector;
assert_eq!(vector.len(), 32);
// All values should be the same (not corrupted)
let first_val = vector[0];
assert!(vector
.iter()
.all(|&v| v == first_val || (first_val == 1.0 || v == (thread_id as f32))));
// Write (update) - this creates a race condition intentionally
if thread_id % 2 == 0 {
let _ = db_clone.insert(VectorEntry {
id: Some("test".to_string()),
vector: vec![thread_id as f32; 32],
metadata: None,
});
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify database is still consistent
let final_entry = db.get("test").unwrap();
assert!(final_entry.is_some());
let vector = final_entry.unwrap().vector;
assert_eq!(vector.len(), 32);
// Check no corruption (all values should be the same)
let first_val = vector[0];
assert!(vector.iter().all(|&v| v == first_val));
}
#[test]
fn test_concurrent_metadata_updates() {
use std::collections::HashMap;
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("metadata.db").to_string_lossy().to_string();
options.dimensions = 16;
let db = Arc::new(VectorDB::new(options).unwrap());
// Insert initial vectors
for i in 0..50 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: vec![i as f32; 16],
metadata: None,
})
.unwrap();
}
let num_threads = 5;
let mut handles = vec![];
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..10 {
let mut metadata = HashMap::new();
metadata.insert(format!("thread_{}", thread_id), serde_json::json!(i));
// Update vector with metadata
let id = format!("vec_{}", i * 5 + thread_id);
db_clone
.insert(VectorEntry {
id: Some(id.clone()),
vector: vec![thread_id as f32; 16],
metadata: Some(metadata),
})
.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify some vectors have metadata
let entry = db.get("vec_0").unwrap();
assert!(entry.is_some());
}