Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,947 @@
//! Cold-tier GNN training via hyperbatch I/O for graphs exceeding RAM.
//!
//! Implements AGNES-style block-aligned I/O with hotset caching
//! for training on large-scale graphs that don't fit in memory.
#![cfg(all(feature = "cold-tier", not(target_arch = "wasm32")))]
use crate::error::{GnnError, Result};
use std::collections::{HashMap, VecDeque};
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
/// Size of an f32 in bytes.
const F32_SIZE: usize = std::mem::size_of::<f32>();
/// Header size in bytes: dim (u64) + num_nodes (u64) + block_size (u64).
const HEADER_SIZE: u64 = 24;
/// Return the system page size, falling back to 4096.
fn system_page_size() -> usize {
page_size::get()
}
/// Align `value` up to the nearest multiple of `alignment`.
fn align_up(value: usize, alignment: usize) -> usize {
(value + alignment - 1) / alignment * alignment
}
// ---------------------------------------------------------------------------
// FeatureStorage
// ---------------------------------------------------------------------------
/// Block-aligned feature file for storing node feature vectors on disk.
pub struct FeatureStorage {
path: PathBuf,
dim: usize,
num_nodes: usize,
block_size: usize,
file: Option<File>,
}
impl FeatureStorage {
/// Create a new feature file at `path` for `num_nodes` with dimension `dim`.
pub fn create(path: &Path, dim: usize, num_nodes: usize) -> Result<Self> {
if dim == 0 {
return Err(GnnError::invalid_input("dim must be > 0"));
}
let block_size = align_up(dim * F32_SIZE, system_page_size());
let data_size = num_nodes as u64 * block_size as u64;
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(path)
.map_err(|e| GnnError::Io(e))?;
// Write header
file.write_all(&(dim as u64).to_le_bytes())?;
file.write_all(&(num_nodes as u64).to_le_bytes())?;
file.write_all(&(block_size as u64).to_le_bytes())?;
// Extend file to full size
file.set_len(HEADER_SIZE + data_size)?;
Ok(Self {
path: path.to_path_buf(),
dim,
num_nodes,
block_size,
file: Some(file),
})
}
/// Open an existing feature file.
pub fn open(path: &Path) -> Result<Self> {
let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(path)
.map_err(|e| GnnError::Io(e))?;
let mut buf = [0u8; 8];
file.read_exact(&mut buf)?;
let dim = u64::from_le_bytes(buf) as usize;
file.read_exact(&mut buf)?;
let num_nodes = u64::from_le_bytes(buf) as usize;
file.read_exact(&mut buf)?;
let block_size = u64::from_le_bytes(buf) as usize;
Ok(Self {
path: path.to_path_buf(),
dim,
num_nodes,
block_size,
file: Some(file),
})
}
/// Write feature vector for a single node.
pub fn write_features(&mut self, node_id: usize, features: &[f32]) -> Result<()> {
if node_id >= self.num_nodes {
return Err(GnnError::invalid_input(format!(
"node_id {} out of bounds (num_nodes={})",
node_id, self.num_nodes
)));
}
if features.len() != self.dim {
return Err(GnnError::dimension_mismatch(
self.dim.to_string(),
features.len().to_string(),
));
}
let file = self
.file
.as_mut()
.ok_or_else(|| GnnError::other("file not open"))?;
let offset = HEADER_SIZE + (node_id as u64) * (self.block_size as u64);
file.seek(SeekFrom::Start(offset))?;
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(features.as_ptr() as *const u8, features.len() * F32_SIZE)
};
file.write_all(bytes)?;
Ok(())
}
/// Read feature vector for a single node.
pub fn read_features(&mut self, node_id: usize) -> Result<Vec<f32>> {
if node_id >= self.num_nodes {
return Err(GnnError::invalid_input(format!(
"node_id {} out of bounds (num_nodes={})",
node_id, self.num_nodes
)));
}
let file = self
.file
.as_mut()
.ok_or_else(|| GnnError::other("file not open"))?;
let offset = HEADER_SIZE + (node_id as u64) * (self.block_size as u64);
file.seek(SeekFrom::Start(offset))?;
let mut buf = vec![0u8; self.dim * F32_SIZE];
file.read_exact(&mut buf)?;
let features: Vec<f32> = buf
.chunks_exact(F32_SIZE)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
Ok(features)
}
/// Batch-read features for multiple nodes with block-aligned I/O.
pub fn read_batch(&mut self, node_ids: &[usize]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(node_ids.len());
// Sort node_ids to improve sequential I/O locality
let mut sorted: Vec<usize> = node_ids.to_vec();
sorted.sort_unstable();
// Read in sorted order, then reorder to match input
let mut map: HashMap<usize, Vec<f32>> = HashMap::with_capacity(sorted.len());
for &nid in &sorted {
if !map.contains_key(&nid) {
map.insert(nid, self.read_features(nid)?);
}
}
for &nid in node_ids {
results.push(map[&nid].clone());
}
Ok(results)
}
/// Flush pending writes to disk.
pub fn flush(&mut self) -> Result<()> {
if let Some(ref mut f) = self.file {
f.flush()?;
}
Ok(())
}
/// Dimension of each feature vector.
pub fn dim(&self) -> usize {
self.dim
}
/// Number of nodes in the storage.
pub fn num_nodes(&self) -> usize {
self.num_nodes
}
/// Path to the underlying file.
pub fn path(&self) -> &Path {
&self.path
}
}
// ---------------------------------------------------------------------------
// HyperbatchConfig / HyperbatchResult
// ---------------------------------------------------------------------------
/// Configuration for hyperbatch I/O.
#[derive(Debug, Clone)]
pub struct HyperbatchConfig {
/// Nodes per hyperbatch (default: 4096).
pub batch_size: usize,
/// Prefetch multiplier (default: 2).
pub prefetch_factor: usize,
/// I/O block alignment in bytes (default: 4096).
pub block_align: usize,
/// Double-buffering count (default: 2).
pub num_buffers: usize,
/// Fraction of nodes kept in the hotset (default: 0.05).
pub hotset_fraction: f64,
}
impl Default for HyperbatchConfig {
fn default() -> Self {
Self {
batch_size: 4096,
prefetch_factor: 2,
block_align: 4096,
num_buffers: 2,
hotset_fraction: 0.05,
}
}
}
/// Result from a single hyperbatch iteration.
#[derive(Debug, Clone)]
pub struct HyperbatchResult {
/// Node identifiers in this batch.
pub node_ids: Vec<usize>,
/// Feature vectors for each node.
pub features: Vec<Vec<f32>>,
/// Zero-based index of this batch within the epoch.
pub batch_index: usize,
}
// ---------------------------------------------------------------------------
// HyperbatchIterator
// ---------------------------------------------------------------------------
/// Yields batches from disk following BFS vertex ordering for I/O locality.
pub struct HyperbatchIterator {
storage: FeatureStorage,
config: HyperbatchConfig,
node_order: Vec<usize>,
current_offset: usize,
buffers: Vec<Vec<Vec<f32>>>,
active_buffer: usize,
batch_counter: usize,
}
impl HyperbatchIterator {
/// Create a new iterator with BFS-ordered node traversal.
pub fn new(
storage: FeatureStorage,
config: HyperbatchConfig,
adjacency: &[(usize, usize)],
) -> Self {
let num_nodes = storage.num_nodes();
let node_order = Self::reorder_bfs(adjacency, num_nodes);
let num_buffers = config.num_buffers.max(1);
let buffers = vec![Vec::new(); num_buffers];
Self {
storage,
config,
node_order,
current_offset: 0,
buffers,
active_buffer: 0,
batch_counter: 0,
}
}
/// Get the next batch, or `None` when the epoch is complete.
pub fn next_batch(&mut self) -> Option<HyperbatchResult> {
if self.current_offset >= self.node_order.len() {
return None;
}
let end = (self.current_offset + self.config.batch_size).min(self.node_order.len());
let node_ids: Vec<usize> = self.node_order[self.current_offset..end].to_vec();
let features = self.storage.read_batch(&node_ids).ok()?;
// Store in active buffer for potential re-use
let buf_idx = self.active_buffer % self.buffers.len();
self.buffers[buf_idx] = features.clone();
self.active_buffer += 1;
let batch_index = self.batch_counter;
self.batch_counter += 1;
self.current_offset = end;
Some(HyperbatchResult {
node_ids,
features,
batch_index,
})
}
/// Reset the iterator to the beginning of the epoch.
pub fn reset(&mut self) {
self.current_offset = 0;
self.batch_counter = 0;
self.active_buffer = 0;
}
/// Produce a BFS vertex ordering for better I/O locality.
pub fn reorder_bfs(adjacency: &[(usize, usize)], num_nodes: usize) -> Vec<usize> {
if num_nodes == 0 {
return Vec::new();
}
// Build adjacency list
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
for &(u, v) in adjacency {
if u < num_nodes && v < num_nodes {
adj[u].push(v);
adj[v].push(u);
}
}
let mut visited = vec![false; num_nodes];
let mut order = Vec::with_capacity(num_nodes);
let mut queue = VecDeque::new();
// BFS from node 0; handle disconnected components
for start in 0..num_nodes {
if visited[start] {
continue;
}
visited[start] = true;
queue.push_back(start);
while let Some(node) = queue.pop_front() {
order.push(node);
for &neighbor in &adj[node] {
if !visited[neighbor] {
visited[neighbor] = true;
queue.push_back(neighbor);
}
}
}
}
order
}
}
// ---------------------------------------------------------------------------
// AdaptiveHotset
// ---------------------------------------------------------------------------
/// In-memory cache of frequently accessed node features.
pub struct AdaptiveHotset {
features: HashMap<usize, Vec<f32>>,
access_counts: HashMap<usize, u64>,
capacity: usize,
decay_factor: f64,
total_lookups: u64,
hits: u64,
}
impl AdaptiveHotset {
/// Create a new hotset with the given capacity and decay factor.
pub fn new(capacity: usize, decay_factor: f64) -> Self {
Self {
features: HashMap::with_capacity(capacity),
access_counts: HashMap::with_capacity(capacity),
capacity,
decay_factor,
total_lookups: 0,
hits: 0,
}
}
/// O(1) lookup of cached features.
pub fn get(&mut self, node_id: usize) -> Option<&[f32]> {
self.total_lookups += 1;
if self.features.contains_key(&node_id) {
self.hits += 1;
*self.access_counts.entry(node_id).or_insert(0) += 1;
// Safety: we just confirmed the key exists
Some(self.features.get(&node_id).unwrap().as_slice())
} else {
None
}
}
/// Insert features, evicting the coldest entry if at capacity.
pub fn insert(&mut self, node_id: usize, features: Vec<f32>) {
if self.features.len() >= self.capacity && !self.features.contains_key(&node_id) {
self.evict_cold();
}
self.access_counts.entry(node_id).or_insert(0);
self.features.insert(node_id, features);
}
/// Record an access without returning features (for tracking frequency).
pub fn record_access(&mut self, node_id: usize) {
*self.access_counts.entry(node_id).or_insert(0) += 1;
}
/// Evict the least-accessed node from the hotset.
pub fn evict_cold(&mut self) {
if self.access_counts.is_empty() {
return;
}
// Find the node with the lowest access count that is cached
let coldest = self
.features
.keys()
.min_by_key(|nid| self.access_counts.get(nid).copied().unwrap_or(0))
.copied();
if let Some(nid) = coldest {
self.features.remove(&nid);
self.access_counts.remove(&nid);
}
}
/// Cache hit rate since creation.
pub fn hit_rate(&self) -> f64 {
if self.total_lookups == 0 {
return 0.0;
}
self.hits as f64 / self.total_lookups as f64
}
/// Multiply all access counts by `decay_factor` to age out stale entries.
pub fn decay_counts(&mut self) {
for count in self.access_counts.values_mut() {
*count = (*count as f64 * self.decay_factor) as u64;
}
}
/// Number of nodes currently cached.
pub fn len(&self) -> usize {
self.features.len()
}
/// Whether the hotset is empty.
pub fn is_empty(&self) -> bool {
self.features.is_empty()
}
}
// ---------------------------------------------------------------------------
// ColdTierEpochResult
// ---------------------------------------------------------------------------
/// Statistics from one cold-tier training epoch.
#[derive(Debug, Clone)]
pub struct ColdTierEpochResult {
/// Epoch number.
pub epoch: usize,
/// Average loss across all batches.
pub avg_loss: f64,
/// Number of batches processed.
pub batches: usize,
/// Hotset hit rate during this epoch.
pub hotset_hit_rate: f64,
/// Milliseconds spent on I/O.
pub io_time_ms: u64,
/// Milliseconds spent on compute.
pub compute_time_ms: u64,
}
// ---------------------------------------------------------------------------
// ColdTierTrainer
// ---------------------------------------------------------------------------
/// Orchestrates cold-tier training with hyperbatch I/O and hotset caching.
pub struct ColdTierTrainer {
storage: FeatureStorage,
hotset: AdaptiveHotset,
config: HyperbatchConfig,
epoch: usize,
total_loss: f64,
batches_processed: usize,
}
impl ColdTierTrainer {
/// Create a new trainer, initializing feature storage and hotset.
pub fn new(
storage_path: &Path,
dim: usize,
num_nodes: usize,
config: HyperbatchConfig,
) -> Result<Self> {
let storage = FeatureStorage::create(storage_path, dim, num_nodes)?;
let hotset_cap = ((num_nodes as f64) * config.hotset_fraction).max(1.0) as usize;
let hotset = AdaptiveHotset::new(hotset_cap, 0.95);
Ok(Self {
storage,
hotset,
config,
epoch: 0,
total_loss: 0.0,
batches_processed: 0,
})
}
/// Run one training epoch over all hyperbatches.
///
/// For each batch a simple gradient-descent step is simulated:
/// the loss is the L2 norm of the feature vector, and the gradient
/// nudges each element toward zero by `learning_rate`.
pub fn train_epoch(
&mut self,
adjacency: &[(usize, usize)],
learning_rate: f64,
) -> ColdTierEpochResult {
let io_start = std::time::Instant::now();
// Build a fresh iterator each epoch (re-shuffles BFS ordering)
let storage_for_iter = FeatureStorage::open(self.storage.path()).ok();
let mut epoch_loss = 0.0;
let mut batch_count: usize = 0;
let mut io_ms: u64 = 0;
let mut compute_ms: u64 = 0;
if let Some(iter_storage) = storage_for_iter {
let mut iter = HyperbatchIterator::new(iter_storage, self.config.clone(), adjacency);
while let Some(batch) = iter.next_batch() {
let io_elapsed = io_start.elapsed().as_millis() as u64;
let compute_start = std::time::Instant::now();
// Process each node in the batch
for (i, node_id) in batch.node_ids.iter().enumerate() {
let features = &batch.features[i];
// Simple L2 loss for demonstration
let loss: f64 = features
.iter()
.map(|&x| (x as f64) * (x as f64))
.sum::<f64>()
* 0.5;
epoch_loss += loss;
// Gradient: d(0.5 * x^2)/dx = x; step: x' = x - lr * x
let updated: Vec<f32> = features
.iter()
.map(|&x| x - (learning_rate as f32) * x)
.collect();
let _ = self.storage.write_features(*node_id, &updated);
self.hotset.insert(*node_id, updated);
}
compute_ms += compute_start.elapsed().as_millis() as u64;
io_ms = io_elapsed;
batch_count += 1;
}
}
let _ = self.storage.flush();
self.hotset.decay_counts();
self.epoch += 1;
self.total_loss = if batch_count > 0 {
epoch_loss / batch_count as f64
} else {
0.0
};
self.batches_processed = batch_count;
ColdTierEpochResult {
epoch: self.epoch,
avg_loss: self.total_loss,
batches: batch_count,
hotset_hit_rate: self.hotset.hit_rate(),
io_time_ms: io_ms,
compute_time_ms: compute_ms,
}
}
/// Retrieve features for a node, checking the hotset first.
pub fn get_features(&mut self, node_id: usize) -> Result<Vec<f32>> {
if let Some(cached) = self.hotset.get(node_id) {
return Ok(cached.to_vec());
}
let features = self.storage.read_features(node_id)?;
self.hotset.insert(node_id, features.clone());
Ok(features)
}
/// Save a checkpoint (header + storage path + hotset metadata).
pub fn save_checkpoint(&self, path: &Path) -> Result<()> {
let data = serde_json::json!({
"storage_path": self.storage.path().to_string_lossy(),
"dim": self.storage.dim(),
"num_nodes": self.storage.num_nodes(),
"epoch": self.epoch,
"total_loss": self.total_loss,
"batches_processed": self.batches_processed,
"config": {
"batch_size": self.config.batch_size,
"prefetch_factor": self.config.prefetch_factor,
"block_align": self.config.block_align,
"num_buffers": self.config.num_buffers,
"hotset_fraction": self.config.hotset_fraction,
}
});
let content = serde_json::to_string_pretty(&data)
.map_err(|e| GnnError::other(format!("serialize checkpoint: {}", e)))?;
std::fs::write(path, content)?;
Ok(())
}
/// Load a trainer from a checkpoint file.
pub fn load_checkpoint(path: &Path) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
let v: serde_json::Value = serde_json::from_str(&content)
.map_err(|e| GnnError::other(format!("deserialize checkpoint: {}", e)))?;
let storage_path = PathBuf::from(
v["storage_path"]
.as_str()
.ok_or_else(|| GnnError::other("missing storage_path"))?,
);
let _dim = v["dim"].as_u64().unwrap_or(0) as usize;
let num_nodes = v["num_nodes"].as_u64().unwrap_or(0) as usize;
let epoch = v["epoch"].as_u64().unwrap_or(0) as usize;
let total_loss = v["total_loss"].as_f64().unwrap_or(0.0);
let batches_processed = v["batches_processed"].as_u64().unwrap_or(0) as usize;
let cfg_val = &v["config"];
let config = HyperbatchConfig {
batch_size: cfg_val["batch_size"].as_u64().unwrap_or(4096) as usize,
prefetch_factor: cfg_val["prefetch_factor"].as_u64().unwrap_or(2) as usize,
block_align: cfg_val["block_align"].as_u64().unwrap_or(4096) as usize,
num_buffers: cfg_val["num_buffers"].as_u64().unwrap_or(2) as usize,
hotset_fraction: cfg_val["hotset_fraction"].as_f64().unwrap_or(0.05),
};
let storage = FeatureStorage::open(&storage_path).map_err(|_| {
// If the storage file no longer exists, recreate it
GnnError::other("storage file not found; re-create before loading")
})?;
let hotset_cap = ((num_nodes as f64) * config.hotset_fraction).max(1.0) as usize;
let hotset = AdaptiveHotset::new(hotset_cap, 0.95);
Ok(Self {
storage,
hotset,
config,
epoch,
total_loss,
batches_processed,
})
}
}
// ---------------------------------------------------------------------------
// ColdTierEwc
// ---------------------------------------------------------------------------
/// Disk-backed Elastic Weight Consolidation using FeatureStorage.
///
/// Stores Fisher information diagonal and anchor weights on disk
/// so that EWC can scale to models that do not fit in RAM.
pub struct ColdTierEwc {
fisher_storage: FeatureStorage,
anchor_storage: FeatureStorage,
lambda: f64,
active: bool,
dim: usize,
num_params: usize,
}
impl ColdTierEwc {
/// Create a new disk-backed EWC instance.
///
/// `dim` is the width of each parameter "row" (analogous to feature dim),
/// and `num_params` is the number of such rows.
pub fn new(path: &Path, dim: usize, num_params: usize, lambda: f64) -> Result<Self> {
let fisher_path = path.join("fisher.bin");
let anchor_path = path.join("anchor.bin");
std::fs::create_dir_all(path)?;
let fisher_storage = FeatureStorage::create(&fisher_path, dim, num_params)?;
let anchor_storage = FeatureStorage::create(&anchor_path, dim, num_params)?;
Ok(Self {
fisher_storage,
anchor_storage,
lambda,
active: false,
dim,
num_params,
})
}
/// Compute Fisher information diagonal from gradient samples.
///
/// Each entry in `gradients` is one sample's gradient for one parameter row.
pub fn compute_fisher(&mut self, gradients: &[Vec<f32>], sample_count: usize) -> Result<()> {
if gradients.is_empty() {
return Ok(());
}
let rows = gradients.len() / self.num_params;
if rows == 0 {
return Ok(());
}
let norm = 1.0 / (sample_count as f32).max(1.0);
for param_idx in 0..self.num_params {
let mut fisher_row = vec![0.0f32; self.dim];
for sample in 0..rows {
let idx = sample * self.num_params + param_idx;
if idx < gradients.len() {
let grad = &gradients[idx];
for (i, &g) in grad.iter().enumerate().take(self.dim) {
fisher_row[i] += g * g;
}
}
}
for v in &mut fisher_row {
*v *= norm;
}
self.fisher_storage.write_features(param_idx, &fisher_row)?;
}
self.fisher_storage.flush()?;
Ok(())
}
/// Consolidate current weights as anchors and activate EWC.
pub fn consolidate(&mut self, current_weights: &[Vec<f32>]) -> Result<()> {
if current_weights.len() != self.num_params {
return Err(GnnError::dimension_mismatch(
self.num_params.to_string(),
current_weights.len().to_string(),
));
}
for (i, w) in current_weights.iter().enumerate() {
self.anchor_storage.write_features(i, w)?;
}
self.anchor_storage.flush()?;
self.active = true;
Ok(())
}
/// Compute the EWC penalty: lambda/2 * sum(F_i * (w_i - w*_i)^2).
pub fn penalty(&mut self, current_weights: &[Vec<f32>]) -> Result<f64> {
if !self.active {
return Ok(0.0);
}
let mut total = 0.0f64;
for i in 0..self.num_params {
let fisher = self.fisher_storage.read_features(i)?;
let anchor = self.anchor_storage.read_features(i)?;
let w = &current_weights[i];
for j in 0..self.dim.min(w.len()) {
let diff = w[j] - anchor[j];
total += (fisher[j] as f64) * (diff as f64) * (diff as f64);
}
}
Ok(total * self.lambda * 0.5)
}
/// Compute the EWC gradient for a specific parameter row.
pub fn gradient(&mut self, current_weights: &[Vec<f32>], param_idx: usize) -> Result<Vec<f32>> {
if !self.active || param_idx >= self.num_params {
return Ok(vec![0.0; self.dim]);
}
let fisher = self.fisher_storage.read_features(param_idx)?;
let anchor = self.anchor_storage.read_features(param_idx)?;
let w = &current_weights[param_idx];
let grad: Vec<f32> = (0..self.dim)
.map(|j| (self.lambda as f32) * fisher[j] * (w[j] - anchor[j]))
.collect();
Ok(grad)
}
/// Whether EWC is active.
pub fn is_active(&self) -> bool {
self.active
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_feature_storage_roundtrip() {
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("features.bin");
let dim = 8;
let num_nodes = 10;
let mut storage = FeatureStorage::create(&path, dim, num_nodes).unwrap();
// Write features for several nodes
for nid in 0..num_nodes {
let features: Vec<f32> = (0..dim).map(|j| (nid * dim + j) as f32).collect();
storage.write_features(nid, &features).unwrap();
}
storage.flush().unwrap();
// Re-open and read back
let mut storage2 = FeatureStorage::open(&path).unwrap();
assert_eq!(storage2.dim(), dim);
assert_eq!(storage2.num_nodes(), num_nodes);
for nid in 0..num_nodes {
let features = storage2.read_features(nid).unwrap();
assert_eq!(features.len(), dim);
for j in 0..dim {
assert!((features[j] - (nid * dim + j) as f32).abs() < 1e-6);
}
}
}
#[test]
fn test_hyperbatch_ordering() {
// Build a simple chain: 0-1-2-3-4
let adjacency = vec![(0, 1), (1, 2), (2, 3), (3, 4)];
let order = HyperbatchIterator::reorder_bfs(&adjacency, 5);
// BFS from 0 should visit 0, 1, 2, 3, 4 in order
assert_eq!(order, vec![0, 1, 2, 3, 4]);
// Star graph: 0 connected to 1..4
let star = vec![(0, 1), (0, 2), (0, 3), (0, 4)];
let star_order = HyperbatchIterator::reorder_bfs(&star, 5);
// 0 first, then neighbors (order may vary but 0 must be first)
assert_eq!(star_order[0], 0);
assert_eq!(star_order.len(), 5);
}
#[test]
fn test_hotset_eviction() {
let mut hotset = AdaptiveHotset::new(3, 0.9);
hotset.insert(0, vec![1.0, 2.0]);
hotset.insert(1, vec![3.0, 4.0]);
hotset.insert(2, vec![5.0, 6.0]);
// Access node 0 and 1 more frequently
for _ in 0..10 {
hotset.record_access(0);
hotset.record_access(1);
}
// Node 2 has fewest accesses (only the initial 0)
// Insert a 4th node -> should evict node 2 (coldest)
hotset.insert(3, vec![7.0, 8.0]);
assert_eq!(hotset.len(), 3);
// Node 2 should be gone
assert!(hotset.get(2).is_none());
// Nodes 0, 1, 3 should still be present
assert!(hotset.get(0).is_some());
assert!(hotset.get(1).is_some());
assert!(hotset.get(3).is_some());
}
#[test]
fn test_cold_tier_epoch() {
let tmp = TempDir::new().unwrap();
let storage_path = tmp.path().join("train_features.bin");
let dim = 4;
let num_nodes = 16;
let config = HyperbatchConfig {
batch_size: 4,
hotset_fraction: 0.25,
..Default::default()
};
let mut trainer = ColdTierTrainer::new(&storage_path, dim, num_nodes, config).unwrap();
// Write initial features
for nid in 0..num_nodes {
let features = vec![1.0f32; dim];
trainer.storage.write_features(nid, &features).unwrap();
}
trainer.storage.flush().unwrap();
// Build a simple chain adjacency
let adjacency: Vec<(usize, usize)> = (0..num_nodes.saturating_sub(1))
.map(|i| (i, i + 1))
.collect();
let result = trainer.train_epoch(&adjacency, 0.1);
assert_eq!(result.epoch, 1);
assert!(result.batches > 0);
// All 16 nodes in batches of 4 = 4 batches
assert_eq!(result.batches, 4);
// Loss should be positive (features started at 1.0)
assert!(result.avg_loss > 0.0);
}
#[test]
fn test_cold_tier_ewc() {
let tmp = TempDir::new().unwrap();
let ewc_dir = tmp.path().join("ewc");
let dim = 4;
let num_params = 3;
let lambda = 100.0;
let mut ewc = ColdTierEwc::new(&ewc_dir, dim, num_params, lambda).unwrap();
// Compute Fisher from gradients (1 sample, 3 param rows)
let gradients = vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![0.5, 0.5, 0.5, 0.5],
vec![2.0, 1.0, 0.0, 1.0],
];
ewc.compute_fisher(&gradients, 1).unwrap();
// Verify Fisher was stored correctly
let fisher0 = ewc.fisher_storage.read_features(0).unwrap();
assert!((fisher0[0] - 1.0).abs() < 1e-6); // 1^2 / 1
assert!((fisher0[1] - 4.0).abs() < 1e-6); // 2^2 / 1
// Consolidate
let weights = vec![
vec![0.0, 0.0, 0.0, 0.0],
vec![0.0, 0.0, 0.0, 0.0],
vec![0.0, 0.0, 0.0, 0.0],
];
ewc.consolidate(&weights).unwrap();
assert!(ewc.is_active());
// Penalty should be 0 at anchor
let penalty = ewc.penalty(&weights).unwrap();
assert!(penalty.abs() < 1e-6);
// Deviation should produce a penalty
let deviated = vec![
vec![1.0, 1.0, 1.0, 1.0],
vec![1.0, 1.0, 1.0, 1.0],
vec![1.0, 1.0, 1.0, 1.0],
];
let penalty = ewc.penalty(&deviated).unwrap();
assert!(penalty > 0.0);
// Gradient for param 0 should be lambda * fisher * diff
let grad = ewc.gradient(&deviated, 0).unwrap();
assert!((grad[0] - 100.0 * 1.0 * 1.0).abs() < 1e-4);
assert!((grad[1] - 100.0 * 4.0 * 1.0).abs() < 1e-4);
}
}

View File

@@ -0,0 +1,678 @@
//! Tensor compression with adaptive level selection
//!
//! This module provides multi-level tensor compression based on access frequency:
//! - Hot data (f > 0.8): Full precision
//! - Warm data (f > 0.4): Half precision
//! - Cool data (f > 0.1): 8-bit product quantization
//! - Cold data (f > 0.01): 4-bit product quantization
//! - Archive (f <= 0.01): Binary quantization
use crate::error::{GnnError, Result};
use serde::{Deserialize, Serialize};
/// Compression level with associated parameters
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum CompressionLevel {
/// Full precision - no compression
None,
/// Half precision with scale factor
Half { scale: f32 },
/// Product quantization with 8-bit codes
PQ8 { subvectors: u8, centroids: u8 },
/// Product quantization with 4-bit codes and outlier handling
PQ4 {
subvectors: u8,
outlier_threshold: f32,
},
/// Binary quantization with threshold
Binary { threshold: f32 },
}
/// Compressed tensor data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CompressedTensor {
/// Uncompressed full precision data
Full { data: Vec<f32> },
/// Half precision data
Half {
data: Vec<u16>,
scale: f32,
dim: usize,
},
/// 8-bit product quantization
PQ8 {
codes: Vec<u8>,
codebooks: Vec<Vec<f32>>,
subvector_dim: usize,
dim: usize,
},
/// 4-bit product quantization with outliers
PQ4 {
codes: Vec<u8>, // Packed 4-bit codes
codebooks: Vec<Vec<f32>>,
outliers: Vec<(usize, f32)>, // (index, value) pairs
subvector_dim: usize,
dim: usize,
},
/// Binary quantization
Binary {
bits: Vec<u8>,
threshold: f32,
dim: usize,
},
}
/// Tensor compressor with adaptive level selection
#[derive(Debug, Clone)]
pub struct TensorCompress {
/// Default compression parameters
default_level: CompressionLevel,
}
impl Default for TensorCompress {
fn default() -> Self {
Self::new()
}
}
impl TensorCompress {
/// Create a new tensor compressor with default settings
pub fn new() -> Self {
Self {
default_level: CompressionLevel::None,
}
}
/// Compress an embedding based on access frequency
///
/// # Arguments
/// * `embedding` - The input embedding vector
/// * `access_freq` - Access frequency in range [0.0, 1.0]
///
/// # Returns
/// Compressed tensor using adaptive compression level
pub fn compress(&self, embedding: &[f32], access_freq: f32) -> Result<CompressedTensor> {
if embedding.is_empty() {
return Err(GnnError::InvalidInput("Empty embedding vector".to_string()));
}
let level = self.select_level(access_freq);
self.compress_with_level(embedding, &level)
}
/// Compress with explicit compression level
pub fn compress_with_level(
&self,
embedding: &[f32],
level: &CompressionLevel,
) -> Result<CompressedTensor> {
match level {
CompressionLevel::None => self.compress_none(embedding),
CompressionLevel::Half { scale } => self.compress_half(embedding, *scale),
CompressionLevel::PQ8 {
subvectors,
centroids,
} => self.compress_pq8(embedding, *subvectors, *centroids),
CompressionLevel::PQ4 {
subvectors,
outlier_threshold,
} => self.compress_pq4(embedding, *subvectors, *outlier_threshold),
CompressionLevel::Binary { threshold } => self.compress_binary(embedding, *threshold),
}
}
/// Decompress a compressed tensor
pub fn decompress(&self, compressed: &CompressedTensor) -> Result<Vec<f32>> {
match compressed {
CompressedTensor::Full { data } => Ok(data.clone()),
CompressedTensor::Half { data, scale, dim } => self.decompress_half(data, *scale, *dim),
CompressedTensor::PQ8 {
codes,
codebooks,
subvector_dim,
dim,
} => self.decompress_pq8(codes, codebooks, *subvector_dim, *dim),
CompressedTensor::PQ4 {
codes,
codebooks,
outliers,
subvector_dim,
dim,
} => self.decompress_pq4(codes, codebooks, outliers, *subvector_dim, *dim),
CompressedTensor::Binary {
bits,
threshold,
dim,
} => self.decompress_binary(bits, *threshold, *dim),
}
}
/// Select compression level based on access frequency
///
/// Thresholds:
/// - f > 0.8: None (hot data)
/// - f > 0.4: Half (warm data)
/// - f > 0.1: PQ8 (cool data)
/// - f > 0.01: PQ4 (cold data)
/// - f <= 0.01: Binary (archive)
fn select_level(&self, access_freq: f32) -> CompressionLevel {
if access_freq > 0.8 {
CompressionLevel::None
} else if access_freq > 0.4 {
CompressionLevel::Half { scale: 1.0 }
} else if access_freq > 0.1 {
CompressionLevel::PQ8 {
subvectors: 8,
centroids: 16,
}
} else if access_freq > 0.01 {
CompressionLevel::PQ4 {
subvectors: 8,
outlier_threshold: 3.0,
}
} else {
CompressionLevel::Binary { threshold: 0.0 }
}
}
// === Compression implementations ===
fn compress_none(&self, embedding: &[f32]) -> Result<CompressedTensor> {
Ok(CompressedTensor::Full {
data: embedding.to_vec(),
})
}
fn compress_half(&self, embedding: &[f32], scale: f32) -> Result<CompressedTensor> {
// Simple half precision: scale and convert to 16-bit
let data: Vec<u16> = embedding
.iter()
.map(|&x| {
let scaled = x * scale;
let clamped = scaled.clamp(-65504.0, 65504.0);
// Convert to half precision representation
f32_to_f16_bits(clamped)
})
.collect();
Ok(CompressedTensor::Half {
data,
scale,
dim: embedding.len(),
})
}
fn compress_pq8(
&self,
embedding: &[f32],
subvectors: u8,
centroids: u8,
) -> Result<CompressedTensor> {
let dim = embedding.len();
let subvectors = subvectors as usize;
if dim % subvectors != 0 {
return Err(GnnError::InvalidInput(format!(
"Dimension {} not divisible by subvectors {}",
dim, subvectors
)));
}
let subvector_dim = dim / subvectors;
let mut codes = Vec::with_capacity(subvectors);
let mut codebooks = Vec::with_capacity(subvectors);
// For each subvector, create a codebook and quantize
for i in 0..subvectors {
let start = i * subvector_dim;
let end = start + subvector_dim;
let subvector = &embedding[start..end];
// Simple k-means clustering (k=centroids)
let (codebook, code) = self.quantize_subvector(subvector, centroids as usize);
codes.push(code);
codebooks.push(codebook);
}
Ok(CompressedTensor::PQ8 {
codes,
codebooks,
subvector_dim,
dim,
})
}
fn compress_pq4(
&self,
embedding: &[f32],
subvectors: u8,
outlier_threshold: f32,
) -> Result<CompressedTensor> {
let dim = embedding.len();
let subvectors = subvectors as usize;
if dim % subvectors != 0 {
return Err(GnnError::InvalidInput(format!(
"Dimension {} not divisible by subvectors {}",
dim, subvectors
)));
}
let subvector_dim = dim / subvectors;
let mut codes = Vec::with_capacity(subvectors);
let mut codebooks = Vec::with_capacity(subvectors);
let mut outliers = Vec::new();
// Detect outliers based on magnitude
let mean = embedding.iter().sum::<f32>() / dim as f32;
let std_dev =
(embedding.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / dim as f32).sqrt();
// For each subvector
for i in 0..subvectors {
let start = i * subvector_dim;
let end = start + subvector_dim;
let subvector = &embedding[start..end];
// Extract outliers
let mut cleaned_subvector = subvector.to_vec();
for (j, &val) in subvector.iter().enumerate() {
if (val - mean).abs() > outlier_threshold * std_dev {
outliers.push((start + j, val));
cleaned_subvector[j] = mean; // Replace with mean
}
}
// Quantize to 4-bit (16 centroids)
let (codebook, code) = self.quantize_subvector(&cleaned_subvector, 16);
codes.push(code);
codebooks.push(codebook);
}
Ok(CompressedTensor::PQ4 {
codes,
codebooks,
outliers,
subvector_dim,
dim,
})
}
fn compress_binary(&self, embedding: &[f32], threshold: f32) -> Result<CompressedTensor> {
let dim = embedding.len();
let num_bytes = (dim + 7) / 8;
let mut bits = vec![0u8; num_bytes];
for (i, &val) in embedding.iter().enumerate() {
if val > threshold {
let byte_idx = i / 8;
let bit_idx = i % 8;
bits[byte_idx] |= 1 << bit_idx;
}
}
Ok(CompressedTensor::Binary {
bits,
threshold,
dim,
})
}
// === Decompression implementations ===
fn decompress_half(&self, data: &[u16], scale: f32, dim: usize) -> Result<Vec<f32>> {
if data.len() != dim {
return Err(GnnError::InvalidInput(format!(
"Dimension mismatch: expected {}, got {}",
dim,
data.len()
)));
}
Ok(data
.iter()
.map(|&bits| f16_bits_to_f32(bits) / scale)
.collect())
}
fn decompress_pq8(
&self,
codes: &[u8],
codebooks: &[Vec<f32>],
subvector_dim: usize,
dim: usize,
) -> Result<Vec<f32>> {
let subvectors = codes.len();
let expected_dim = subvectors * subvector_dim;
if expected_dim != dim {
return Err(GnnError::InvalidInput(format!(
"Dimension mismatch: expected {}, got {}",
dim, expected_dim
)));
}
let mut result = Vec::with_capacity(dim);
for (code, codebook) in codes.iter().zip(codebooks.iter()) {
let centroid_idx = *code as usize;
if centroid_idx >= codebook.len() / subvector_dim {
return Err(GnnError::InvalidInput(format!(
"Invalid centroid index: {}",
centroid_idx
)));
}
let start = centroid_idx * subvector_dim;
let end = start + subvector_dim;
result.extend_from_slice(&codebook[start..end]);
}
Ok(result)
}
fn decompress_pq4(
&self,
codes: &[u8],
codebooks: &[Vec<f32>],
outliers: &[(usize, f32)],
subvector_dim: usize,
dim: usize,
) -> Result<Vec<f32>> {
// First decompress using PQ8 logic
let mut result = self.decompress_pq8(codes, codebooks, subvector_dim, dim)?;
// Restore outliers
for &(idx, val) in outliers {
if idx < result.len() {
result[idx] = val;
}
}
Ok(result)
}
fn decompress_binary(&self, bits: &[u8], _threshold: f32, dim: usize) -> Result<Vec<f32>> {
let expected_bytes = (dim + 7) / 8;
if bits.len() != expected_bytes {
return Err(GnnError::InvalidInput(format!(
"Dimension mismatch: expected {} bytes, got {}",
expected_bytes,
bits.len()
)));
}
let mut result = Vec::with_capacity(dim);
for i in 0..dim {
let byte_idx = i / 8;
let bit_idx = i % 8;
let is_set = (bits[byte_idx] & (1 << bit_idx)) != 0;
result.push(if is_set { 1.0 } else { -1.0 });
}
Ok(result)
}
// === Helper methods ===
/// Simple quantization using k-means-like approach
fn quantize_subvector(&self, subvector: &[f32], k: usize) -> (Vec<f32>, u8) {
let dim = subvector.len();
// Initialize centroids using simple range-based approach
let min_val = subvector.iter().cloned().fold(f32::INFINITY, f32::min);
let max_val = subvector.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let range = max_val - min_val;
if range < 1e-6 {
// All values are essentially the same
let codebook = vec![min_val; dim * k];
return (codebook, 0);
}
// Create k centroids evenly spaced across the range
let centroids: Vec<Vec<f32>> = (0..k)
.map(|i| {
let offset = min_val + (i as f32 / k as f32) * range;
vec![offset; dim]
})
.collect();
// Find nearest centroid for this subvector
let code = self.nearest_centroid(subvector, &centroids);
// Flatten codebook
let codebook: Vec<f32> = centroids.into_iter().flatten().collect();
(codebook, code as u8)
}
fn nearest_centroid(&self, subvector: &[f32], centroids: &[Vec<f32>]) -> usize {
centroids
.iter()
.enumerate()
.map(|(i, centroid)| {
let dist: f32 = subvector
.iter()
.zip(centroid.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
(i, dist)
})
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
}
// === Half precision conversion helpers ===
/// Convert f32 to f16 bits (simplified implementation)
fn f32_to_f16_bits(value: f32) -> u16 {
// Simple conversion: scale to 16-bit range
// This is a simplified version, not IEEE 754 half precision
let scaled = (value * 1000.0).clamp(-32768.0, 32767.0);
((scaled as i32) + 32768) as u16
}
/// Convert f16 bits to f32 (simplified implementation)
fn f16_bits_to_f32(bits: u16) -> f32 {
// Reverse of f32_to_f16_bits
let value = bits as i32 - 32768;
value as f32 / 1000.0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compress_none() {
let compressor = TensorCompress::new();
let embedding = vec![1.0, 2.0, 3.0, 4.0];
let compressed = compressor.compress(&embedding, 1.0).unwrap();
let decompressed = compressor.decompress(&compressed).unwrap();
assert_eq!(embedding, decompressed);
}
#[test]
fn test_compress_half() {
let compressor = TensorCompress::new();
let embedding = vec![1.0, 2.0, 3.0, 4.0];
let compressed = compressor.compress(&embedding, 0.5).unwrap();
let decompressed = compressor.decompress(&compressed).unwrap();
// Half precision should be close but not exact
for (a, b) in embedding.iter().zip(decompressed.iter()) {
assert!((a - b).abs() < 0.01, "Expected {}, got {}", a, b);
}
}
#[test]
fn test_compress_binary() {
let compressor = TensorCompress::new();
let embedding = vec![1.0, -1.0, 0.5, -0.5];
let compressed = compressor.compress(&embedding, 0.005).unwrap();
let decompressed = compressor.decompress(&compressed).unwrap();
// Binary should be +1 or -1
assert_eq!(decompressed.len(), embedding.len());
for val in &decompressed {
assert!(*val == 1.0 || *val == -1.0);
}
}
#[test]
fn test_select_level() {
let compressor = TensorCompress::new();
// Hot data
assert!(matches!(
compressor.select_level(0.9),
CompressionLevel::None
));
// Warm data
assert!(matches!(
compressor.select_level(0.5),
CompressionLevel::Half { .. }
));
// Cool data
assert!(matches!(
compressor.select_level(0.2),
CompressionLevel::PQ8 { .. }
));
// Cold data
assert!(matches!(
compressor.select_level(0.05),
CompressionLevel::PQ4 { .. }
));
// Archive
assert!(matches!(
compressor.select_level(0.001),
CompressionLevel::Binary { .. }
));
}
#[test]
fn test_empty_embedding() {
let compressor = TensorCompress::new();
let result = compressor.compress(&[], 0.5);
assert!(result.is_err());
}
#[test]
fn test_pq8_compression() {
let compressor = TensorCompress::new();
let embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
let compressed = compressor.compress_pq8(&embedding, 8, 16).unwrap();
let decompressed = compressor.decompress(&compressed).unwrap();
assert_eq!(decompressed.len(), embedding.len());
}
#[test]
fn test_round_trip_all_levels() {
let compressor = TensorCompress::new();
let embedding: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.01).collect();
let access_frequencies = vec![0.9, 0.5, 0.2, 0.05, 0.001];
for freq in access_frequencies {
let compressed = compressor.compress(&embedding, freq).unwrap();
let decompressed = compressor.decompress(&compressed).unwrap();
assert_eq!(decompressed.len(), embedding.len());
}
}
#[test]
fn test_half_precision_roundtrip() {
let compressor = TensorCompress::new();
// Use values within the supported range (-32.768 to 32.767)
let values = vec![-30.0, -1.0, 0.0, 1.0, 30.0];
for val in values {
let embedding = vec![val; 4];
let compressed = compressor
.compress_with_level(&embedding, &CompressionLevel::Half { scale: 1.0 })
.unwrap();
let decompressed = compressor.decompress(&compressed).unwrap();
for (a, b) in embedding.iter().zip(decompressed.iter()) {
let diff = (a - b).abs();
assert!(
diff < 0.1,
"Value {} decompressed to {}, diff: {}",
a,
b,
diff
);
}
}
}
#[test]
fn test_binary_threshold() {
let compressor = TensorCompress::new();
let embedding = vec![0.5, -0.5, 1.5, -1.5];
let compressed = compressor
.compress_with_level(&embedding, &CompressionLevel::Binary { threshold: 0.0 })
.unwrap();
let decompressed = compressor.decompress(&compressed).unwrap();
// Values > 0 should be 1.0, values <= 0 should be -1.0
assert_eq!(decompressed, vec![1.0, -1.0, 1.0, -1.0]);
}
#[test]
fn test_pq4_with_outliers() {
let compressor = TensorCompress::new();
// Create embedding with some outliers
let mut embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
embedding[10] = 100.0; // Outlier
embedding[30] = -100.0; // Outlier
let compressed = compressor
.compress_with_level(
&embedding,
&CompressionLevel::PQ4 {
subvectors: 8,
outlier_threshold: 2.0,
},
)
.unwrap();
let decompressed = compressor.decompress(&compressed).unwrap();
assert_eq!(decompressed.len(), embedding.len());
// Outliers should be preserved
assert_eq!(decompressed[10], 100.0);
assert_eq!(decompressed[30], -100.0);
}
#[test]
fn test_dimension_validation() {
let compressor = TensorCompress::new();
let embedding = vec![1.0; 10]; // Not divisible by 8
let result = compressor.compress_pq8(&embedding, 8, 16);
assert!(result.is_err());
}
}

View File

@@ -0,0 +1,111 @@
//! Error types for the GNN module.
use thiserror::Error;
/// Result type alias for GNN operations.
pub type Result<T> = std::result::Result<T, GnnError>;
/// Errors that can occur during GNN operations.
#[derive(Error, Debug)]
pub enum GnnError {
/// Tensor dimension mismatch
#[error("Tensor dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch {
/// Expected dimension
expected: String,
/// Actual dimension
actual: String,
},
/// Invalid tensor shape
#[error("Invalid tensor shape: {0}")]
InvalidShape(String),
/// Layer configuration error
#[error("Layer configuration error: {0}")]
LayerConfig(String),
/// Training error
#[error("Training error: {0}")]
Training(String),
/// Compression error
#[error("Compression error: {0}")]
Compression(String),
/// Search error
#[error("Search error: {0}")]
Search(String),
/// Invalid input
#[error("Invalid input: {0}")]
InvalidInput(String),
/// Memory mapping error
#[cfg(not(target_arch = "wasm32"))]
#[error("Memory mapping error: {0}")]
Mmap(String),
/// I/O error
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
/// Core library error
#[error("Core error: {0}")]
Core(#[from] ruvector_core::error::RuvectorError),
/// Generic error
#[error("{0}")]
Other(String),
}
impl GnnError {
/// Create a dimension mismatch error
pub fn dimension_mismatch(expected: impl Into<String>, actual: impl Into<String>) -> Self {
Self::DimensionMismatch {
expected: expected.into(),
actual: actual.into(),
}
}
/// Create an invalid shape error
pub fn invalid_shape(msg: impl Into<String>) -> Self {
Self::InvalidShape(msg.into())
}
/// Create a layer config error
pub fn layer_config(msg: impl Into<String>) -> Self {
Self::LayerConfig(msg.into())
}
/// Create a training error
pub fn training(msg: impl Into<String>) -> Self {
Self::Training(msg.into())
}
/// Create a compression error
pub fn compression(msg: impl Into<String>) -> Self {
Self::Compression(msg.into())
}
/// Create a search error
pub fn search(msg: impl Into<String>) -> Self {
Self::Search(msg.into())
}
/// Create a memory mapping error
#[cfg(not(target_arch = "wasm32"))]
pub fn mmap(msg: impl Into<String>) -> Self {
Self::Mmap(msg.into())
}
/// Create an invalid input error
pub fn invalid_input(msg: impl Into<String>) -> Self {
Self::InvalidInput(msg.into())
}
/// Create a generic error
pub fn other(msg: impl Into<String>) -> Self {
Self::Other(msg.into())
}
}

View File

@@ -0,0 +1,582 @@
/// Elastic Weight Consolidation (EWC) for preventing catastrophic forgetting in GNNs
///
/// EWC adds a regularization term that penalizes changes to important weights,
/// where importance is measured by the Fisher information matrix diagonal.
///
/// The EWC loss term is: L_EWC = λ/2 * Σ F_i * (θ_i - θ*_i)²
/// where:
/// - λ is the regularization strength
/// - F_i is the Fisher information for weight i
/// - θ_i is the current weight
/// - θ*_i is the anchor weight from the previous task
use std::f32;
/// Elastic Weight Consolidation implementation
///
/// Prevents catastrophic forgetting by penalizing changes to important weights
/// learned from previous tasks.
#[derive(Debug, Clone)]
pub struct ElasticWeightConsolidation {
/// Fisher information diagonal (importance of each weight)
/// Higher values indicate more important weights
fisher_diag: Vec<f32>,
/// Anchor weights (optimal weights from previous task)
/// These are the weights we want to stay close to
anchor_weights: Vec<f32>,
/// Regularization strength (λ)
/// Controls how strongly we penalize deviations from anchor weights
lambda: f32,
/// Whether EWC is active
/// EWC is only active after consolidation has been called
active: bool,
}
impl ElasticWeightConsolidation {
/// Create a new EWC instance with specified regularization strength
///
/// # Arguments
/// * `lambda` - Regularization strength (typically 10-10000)
///
/// # Returns
/// A new inactive EWC instance
pub fn new(lambda: f32) -> Self {
assert!(lambda >= 0.0, "Lambda must be non-negative");
Self {
fisher_diag: Vec::new(),
anchor_weights: Vec::new(),
lambda,
active: false,
}
}
/// Compute Fisher information diagonal from gradients
///
/// The Fisher information measures the importance of each weight.
/// It's approximated as the mean squared gradient over samples:
/// F_i ≈ (1/N) * Σ (∂L/∂θ_i)²
///
/// # Arguments
/// * `gradients` - Slice of gradient vectors for each sample
/// * `sample_count` - Number of samples (for normalization)
pub fn compute_fisher(&mut self, gradients: &[&[f32]], sample_count: usize) {
if gradients.is_empty() {
return;
}
let num_weights = gradients[0].len();
// Always reset Fisher diagonal to zero before computing
// (Fisher information should be computed fresh from current gradients)
self.fisher_diag = vec![0.0; num_weights];
// Accumulate squared gradients
for grad in gradients {
assert_eq!(
grad.len(),
num_weights,
"All gradient vectors must have the same length"
);
for (i, &g) in grad.iter().enumerate() {
self.fisher_diag[i] += g * g;
}
}
// Normalize by sample count
let normalization = 1.0 / (sample_count as f32).max(1.0);
for f in &mut self.fisher_diag {
*f *= normalization;
}
}
/// Save current weights as anchor and activate EWC
///
/// This should be called after training on a task, before moving to the next task.
/// It marks the current weights as important and activates the EWC penalty.
///
/// # Arguments
/// * `weights` - Current model weights to save as anchor
pub fn consolidate(&mut self, weights: &[f32]) {
assert!(
!self.fisher_diag.is_empty(),
"Must compute Fisher information before consolidating"
);
assert_eq!(
weights.len(),
self.fisher_diag.len(),
"Weight count must match Fisher information size"
);
self.anchor_weights = weights.to_vec();
self.active = true;
}
/// Compute EWC penalty term
///
/// Returns: λ/2 * Σ F_i * (θ_i - θ*_i)²
///
/// This penalty is added to the loss function to discourage changes
/// to important weights.
///
/// # Arguments
/// * `weights` - Current model weights
///
/// # Returns
/// The EWC penalty value (0.0 if not active)
pub fn penalty(&self, weights: &[f32]) -> f32 {
if !self.active {
return 0.0;
}
assert_eq!(
weights.len(),
self.anchor_weights.len(),
"Weight count must match anchor weights"
);
let mut penalty = 0.0;
for i in 0..weights.len() {
let diff = weights[i] - self.anchor_weights[i];
penalty += self.fisher_diag[i] * diff * diff;
}
// Multiply by λ/2
penalty * self.lambda * 0.5
}
/// Compute EWC gradient
///
/// Returns: λ * F_i * (θ_i - θ*_i) for each weight i
///
/// This gradient is added to the model gradients during training
/// to push weights back toward their anchor values.
///
/// # Arguments
/// * `weights` - Current model weights
///
/// # Returns
/// Gradient vector (all zeros if not active)
pub fn gradient(&self, weights: &[f32]) -> Vec<f32> {
if !self.active {
return vec![0.0; weights.len()];
}
assert_eq!(
weights.len(),
self.anchor_weights.len(),
"Weight count must match anchor weights"
);
let mut grad = Vec::with_capacity(weights.len());
for i in 0..weights.len() {
let diff = weights[i] - self.anchor_weights[i];
grad.push(self.lambda * self.fisher_diag[i] * diff);
}
grad
}
/// Check if EWC is active
///
/// # Returns
/// true if consolidate() has been called, false otherwise
pub fn is_active(&self) -> bool {
self.active
}
/// Get the regularization strength
pub fn lambda(&self) -> f32 {
self.lambda
}
/// Update the regularization strength
pub fn set_lambda(&mut self, lambda: f32) {
assert!(lambda >= 0.0, "Lambda must be non-negative");
self.lambda = lambda;
}
/// Get the Fisher information diagonal
pub fn fisher_diag(&self) -> &[f32] {
&self.fisher_diag
}
/// Get the anchor weights
pub fn anchor_weights(&self) -> &[f32] {
&self.anchor_weights
}
/// Reset EWC to inactive state
pub fn reset(&mut self) {
self.fisher_diag.clear();
self.anchor_weights.clear();
self.active = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let ewc = ElasticWeightConsolidation::new(1000.0);
assert_eq!(ewc.lambda(), 1000.0);
assert!(!ewc.is_active());
assert!(ewc.fisher_diag().is_empty());
assert!(ewc.anchor_weights().is_empty());
}
#[test]
#[should_panic(expected = "Lambda must be non-negative")]
fn test_new_negative_lambda() {
ElasticWeightConsolidation::new(-1.0);
}
#[test]
fn test_compute_fisher_single_sample() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Single gradient: [1.0, 2.0, 3.0]
let grad1 = vec![1.0, 2.0, 3.0];
let gradients = vec![grad1.as_slice()];
ewc.compute_fisher(&gradients, 1);
// Fisher should be squared gradients
assert_eq!(ewc.fisher_diag(), &[1.0, 4.0, 9.0]);
}
#[test]
fn test_compute_fisher_multiple_samples() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Two gradients
let grad1 = vec![1.0, 2.0, 3.0];
let grad2 = vec![2.0, 1.0, 1.0];
let gradients = vec![grad1.as_slice(), grad2.as_slice()];
ewc.compute_fisher(&gradients, 2);
// Fisher should be mean of squared gradients
// Position 0: (1² + 2²) / 2 = 2.5
// Position 1: (2² + 1²) / 2 = 2.5
// Position 2: (3² + 1²) / 2 = 5.0
let expected = vec![2.5, 2.5, 5.0];
assert_eq!(ewc.fisher_diag().len(), expected.len());
for (actual, exp) in ewc.fisher_diag().iter().zip(expected.iter()) {
assert!((actual - exp).abs() < 1e-6);
}
}
#[test]
fn test_compute_fisher_accumulates() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// First computation
let grad1 = vec![1.0, 2.0];
ewc.compute_fisher(&[grad1.as_slice()], 1);
assert_eq!(ewc.fisher_diag(), &[1.0, 4.0]);
// Second computation accumulates on top of first
// When fisher_diag has same length, it's reset to zero first in compute_fisher
// then accumulates: 0 + 2^2 = 4, 0 + 1^2 = 1
// normalized by 1/1 = 4.0, 1.0
let grad2 = vec![2.0, 1.0];
ewc.compute_fisher(&[grad2.as_slice()], 1);
// Fisher is reset and recomputed with new gradients
assert_eq!(ewc.fisher_diag(), &[4.0, 1.0]);
}
#[test]
#[should_panic(expected = "All gradient vectors must have the same length")]
fn test_compute_fisher_mismatched_sizes() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
let grad1 = vec![1.0, 2.0];
let grad2 = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad1.as_slice(), grad2.as_slice()], 2);
}
#[test]
fn test_consolidate() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Setup Fisher information
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
// Consolidate weights
let weights = vec![0.5, 1.0, 1.5];
ewc.consolidate(&weights);
assert!(ewc.is_active());
assert_eq!(ewc.anchor_weights(), &weights);
}
#[test]
#[should_panic(expected = "Must compute Fisher information before consolidating")]
fn test_consolidate_without_fisher() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
let weights = vec![1.0, 2.0];
ewc.consolidate(&weights);
}
#[test]
#[should_panic(expected = "Weight count must match Fisher information size")]
fn test_consolidate_size_mismatch() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
let grad = vec![1.0, 2.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let weights = vec![1.0, 2.0, 3.0]; // Wrong size
ewc.consolidate(&weights);
}
#[test]
fn test_penalty_inactive() {
let ewc = ElasticWeightConsolidation::new(100.0);
let weights = vec![1.0, 2.0, 3.0];
assert_eq!(ewc.penalty(&weights), 0.0);
}
#[test]
fn test_penalty_no_deviation() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Setup
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let weights = vec![0.5, 1.0, 1.5];
ewc.consolidate(&weights);
// Penalty should be 0 when weights match anchor
assert_eq!(ewc.penalty(&weights), 0.0);
}
#[test]
fn test_penalty_with_deviation() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Fisher diagonal: [1.0, 4.0, 9.0]
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
// Anchor weights: [0.0, 0.0, 0.0]
let anchor = vec![0.0, 0.0, 0.0];
ewc.consolidate(&anchor);
// Current weights: [1.0, 1.0, 1.0]
let weights = vec![1.0, 1.0, 1.0];
// Penalty = λ/2 * Σ F_i * (w_i - w*_i)²
// = 100/2 * (1.0 * 1² + 4.0 * 1² + 9.0 * 1²)
// = 50 * 14 = 700
let penalty = ewc.penalty(&weights);
assert!((penalty - 700.0).abs() < 1e-4);
}
#[test]
fn test_penalty_increases_with_deviation() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
let grad = vec![1.0, 1.0, 1.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let anchor = vec![0.0, 0.0, 0.0];
ewc.consolidate(&anchor);
// Small deviation
let weights1 = vec![0.1, 0.1, 0.1];
let penalty1 = ewc.penalty(&weights1);
// Larger deviation
let weights2 = vec![0.5, 0.5, 0.5];
let penalty2 = ewc.penalty(&weights2);
// Penalty should increase
assert!(penalty2 > penalty1);
// Penalty should scale quadratically
// (0.5/0.1)² = 25
assert!((penalty2 / penalty1 - 25.0).abs() < 1e-4);
}
#[test]
fn test_gradient_inactive() {
let ewc = ElasticWeightConsolidation::new(100.0);
let weights = vec![1.0, 2.0, 3.0];
let grad = ewc.gradient(&weights);
assert_eq!(grad, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_gradient_no_deviation() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let weights = vec![0.5, 1.0, 1.5];
ewc.consolidate(&weights);
// Gradient should be 0 when weights match anchor
let grad = ewc.gradient(&weights);
assert_eq!(grad, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_gradient_points_toward_anchor() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Fisher diagonal: [1.0, 4.0, 9.0]
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
// Anchor at origin
let anchor = vec![0.0, 0.0, 0.0];
ewc.consolidate(&anchor);
// Weights moved positive
let weights = vec![1.0, 1.0, 1.0];
// Gradient = λ * F_i * (w_i - w*_i)
// = 100 * [1.0, 4.0, 9.0] * [1.0, 1.0, 1.0]
// = [100, 400, 900]
let grad = ewc.gradient(&weights);
assert_eq!(grad.len(), 3);
assert!((grad[0] - 100.0).abs() < 1e-4);
assert!((grad[1] - 400.0).abs() < 1e-4);
assert!((grad[2] - 900.0).abs() < 1e-4);
// Weights moved negative
let weights = vec![-1.0, -1.0, -1.0];
let grad = ewc.gradient(&weights);
// Gradient should point opposite direction (toward anchor)
assert!(grad[0] < 0.0);
assert!(grad[1] < 0.0);
assert!(grad[2] < 0.0);
assert!((grad[0] + 100.0).abs() < 1e-4);
assert!((grad[1] + 400.0).abs() < 1e-4);
assert!((grad[2] + 900.0).abs() < 1e-4);
}
#[test]
fn test_gradient_magnitude_scales_with_fisher() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Fisher with varying importance
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let anchor = vec![0.0, 0.0, 0.0];
ewc.consolidate(&anchor);
let weights = vec![1.0, 1.0, 1.0];
let grad = ewc.gradient(&weights);
// Gradient magnitude should increase with Fisher importance
assert!(grad[0].abs() < grad[1].abs());
assert!(grad[1].abs() < grad[2].abs());
}
#[test]
fn test_lambda_scaling() {
let mut ewc1 = ElasticWeightConsolidation::new(100.0);
let mut ewc2 = ElasticWeightConsolidation::new(200.0);
// Same setup for both
let grad = vec![1.0, 1.0, 1.0];
ewc1.compute_fisher(&[grad.as_slice()], 1);
ewc2.compute_fisher(&[grad.as_slice()], 1);
let anchor = vec![0.0, 0.0, 0.0];
ewc1.consolidate(&anchor);
ewc2.consolidate(&anchor);
let weights = vec![1.0, 1.0, 1.0];
// Penalty and gradient should scale with lambda
let penalty1 = ewc1.penalty(&weights);
let penalty2 = ewc2.penalty(&weights);
assert!((penalty2 / penalty1 - 2.0).abs() < 1e-4);
let grad1 = ewc1.gradient(&weights);
let grad2 = ewc2.gradient(&weights);
assert!((grad2[0] / grad1[0] - 2.0).abs() < 1e-4);
}
#[test]
fn test_set_lambda() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
assert_eq!(ewc.lambda(), 100.0);
ewc.set_lambda(500.0);
assert_eq!(ewc.lambda(), 500.0);
}
#[test]
#[should_panic(expected = "Lambda must be non-negative")]
fn test_set_lambda_negative() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
ewc.set_lambda(-10.0);
}
#[test]
fn test_reset() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Setup active EWC
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let weights = vec![0.5, 1.0, 1.5];
ewc.consolidate(&weights);
assert!(ewc.is_active());
// Reset
ewc.reset();
assert!(!ewc.is_active());
assert!(ewc.fisher_diag().is_empty());
assert!(ewc.anchor_weights().is_empty());
assert_eq!(ewc.lambda(), 100.0); // Lambda preserved
}
#[test]
fn test_sequential_task_learning() {
// Simulate learning two tasks sequentially
let mut ewc = ElasticWeightConsolidation::new(1000.0);
// Task 1: Learn weights [1.0, 2.0, 3.0]
let task1_grad = vec![2.0, 1.0, 3.0];
ewc.compute_fisher(&[task1_grad.as_slice()], 1);
let task1_weights = vec![1.0, 2.0, 3.0];
ewc.consolidate(&task1_weights);
// Task 2: Try to learn very different weights
let task2_weights = vec![5.0, 6.0, 7.0];
// EWC penalty should be significant
let penalty = ewc.penalty(&task2_weights);
assert!(penalty > 10000.0); // Large penalty for large deviation
// Gradient should point back toward task 1 weights
let grad = ewc.gradient(&task2_weights);
assert!(grad[0] > 0.0); // Push toward lower value
assert!(grad[1] > 0.0);
assert!(grad[2] > 0.0);
}
}

View File

@@ -0,0 +1,546 @@
//! GNN Layer Implementation for HNSW Topology
//!
//! This module implements graph neural network layers that operate on HNSW graph structure,
//! including attention mechanisms, normalization, and gated recurrent updates.
use crate::error::GnnError;
use ndarray::{Array1, Array2, ArrayView1};
use rand::Rng;
use rand_distr::{Distribution, Normal};
use serde::{Deserialize, Serialize};
/// Linear transformation layer (weight matrix multiplication)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Linear {
weights: Array2<f32>,
bias: Array1<f32>,
}
impl Linear {
/// Create a new linear layer with Xavier/Glorot initialization
pub fn new(input_dim: usize, output_dim: usize) -> Self {
let mut rng = rand::thread_rng();
// Xavier initialization: scale = sqrt(2.0 / (input_dim + output_dim))
let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
let normal = Normal::new(0.0, scale as f64).unwrap();
let weights =
Array2::from_shape_fn((output_dim, input_dim), |_| normal.sample(&mut rng) as f32);
let bias = Array1::zeros(output_dim);
Self { weights, bias }
}
/// Forward pass: y = Wx + b
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
let x = ArrayView1::from(input);
let output = self.weights.dot(&x) + &self.bias;
output.to_vec()
}
/// Get output dimension
pub fn output_dim(&self) -> usize {
self.weights.shape()[0]
}
}
/// Layer normalization
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerNorm {
gamma: Array1<f32>,
beta: Array1<f32>,
eps: f32,
}
impl LayerNorm {
/// Create a new layer normalization layer
pub fn new(dim: usize, eps: f32) -> Self {
Self {
gamma: Array1::ones(dim),
beta: Array1::zeros(dim),
eps,
}
}
/// Forward pass: normalize and scale
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
let x = ArrayView1::from(input);
// Compute mean and variance
let mean = x.mean().unwrap_or(0.0);
let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
// Normalize
let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt());
// Scale and shift
let output = &self.gamma * &normalized + &self.beta;
output.to_vec()
}
}
/// Multi-head attention mechanism
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiHeadAttention {
num_heads: usize,
head_dim: usize,
q_linear: Linear,
k_linear: Linear,
v_linear: Linear,
out_linear: Linear,
}
impl MultiHeadAttention {
/// Create a new multi-head attention layer
///
/// # Errors
/// Returns `GnnError::LayerConfig` if `embed_dim` is not divisible by `num_heads`.
pub fn new(embed_dim: usize, num_heads: usize) -> Result<Self, GnnError> {
if embed_dim % num_heads != 0 {
return Err(GnnError::layer_config(format!(
"Embedding dimension ({}) must be divisible by number of heads ({})",
embed_dim, num_heads
)));
}
let head_dim = embed_dim / num_heads;
Ok(Self {
num_heads,
head_dim,
q_linear: Linear::new(embed_dim, embed_dim),
k_linear: Linear::new(embed_dim, embed_dim),
v_linear: Linear::new(embed_dim, embed_dim),
out_linear: Linear::new(embed_dim, embed_dim),
})
}
/// Forward pass: compute multi-head attention
///
/// # Arguments
/// * `query` - Query vector
/// * `keys` - Key vectors from neighbors
/// * `values` - Value vectors from neighbors
///
/// # Returns
/// Attention-weighted output vector
pub fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32> {
if keys.is_empty() || values.is_empty() {
return query.to_vec();
}
// Project query, keys, and values
let q = self.q_linear.forward(query);
let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
let v: Vec<Vec<f32>> = values.iter().map(|v| self.v_linear.forward(v)).collect();
// Reshape for multi-head attention
let q_heads = self.split_heads(&q);
let k_heads: Vec<Vec<Vec<f32>>> = k.iter().map(|k_vec| self.split_heads(k_vec)).collect();
let v_heads: Vec<Vec<Vec<f32>>> = v.iter().map(|v_vec| self.split_heads(v_vec)).collect();
// Compute attention for each head
let mut head_outputs = Vec::new();
for h in 0..self.num_heads {
let q_h = &q_heads[h];
let k_h: Vec<&Vec<f32>> = k_heads.iter().map(|heads| &heads[h]).collect();
let v_h: Vec<&Vec<f32>> = v_heads.iter().map(|heads| &heads[h]).collect();
let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h);
head_outputs.push(head_output);
}
// Concatenate heads
let concat: Vec<f32> = head_outputs.into_iter().flatten().collect();
// Final linear projection
self.out_linear.forward(&concat)
}
/// Split vector into multiple heads
fn split_heads(&self, x: &[f32]) -> Vec<Vec<f32>> {
let mut heads = Vec::new();
for h in 0..self.num_heads {
let start = h * self.head_dim;
let end = start + self.head_dim;
heads.push(x[start..end].to_vec());
}
heads
}
/// Scaled dot-product attention
fn scaled_dot_product_attention(
&self,
query: &[f32],
keys: &[&Vec<f32>],
values: &[&Vec<f32>],
) -> Vec<f32> {
if keys.is_empty() {
return query.to_vec();
}
let scale = (self.head_dim as f32).sqrt();
// Compute attention scores
let scores: Vec<f32> = keys
.iter()
.map(|k| {
let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum();
dot / scale
})
.collect();
// Softmax with epsilon guard against division by zero
let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
let attention_weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
// Weighted sum of values
let mut output = vec![0.0; self.head_dim];
for (weight, value) in attention_weights.iter().zip(values.iter()) {
for (out, &val) in output.iter_mut().zip(value.iter()) {
*out += weight * val;
}
}
output
}
}
/// Gated Recurrent Unit (GRU) cell for state updates
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GRUCell {
// Update gate
w_z: Linear,
u_z: Linear,
// Reset gate
w_r: Linear,
u_r: Linear,
// Candidate hidden state
w_h: Linear,
u_h: Linear,
}
impl GRUCell {
/// Create a new GRU cell
pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
Self {
// Update gate
w_z: Linear::new(input_dim, hidden_dim),
u_z: Linear::new(hidden_dim, hidden_dim),
// Reset gate
w_r: Linear::new(input_dim, hidden_dim),
u_r: Linear::new(hidden_dim, hidden_dim),
// Candidate hidden state
w_h: Linear::new(input_dim, hidden_dim),
u_h: Linear::new(hidden_dim, hidden_dim),
}
}
/// Forward pass: update hidden state
///
/// # Arguments
/// * `input` - Current input
/// * `hidden` - Previous hidden state
///
/// # Returns
/// Updated hidden state
pub fn forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
// Update gate: z_t = sigmoid(W_z * x_t + U_z * h_{t-1})
let z =
self.sigmoid_vec(&self.add_vecs(&self.w_z.forward(input), &self.u_z.forward(hidden)));
// Reset gate: r_t = sigmoid(W_r * x_t + U_r * h_{t-1})
let r =
self.sigmoid_vec(&self.add_vecs(&self.w_r.forward(input), &self.u_r.forward(hidden)));
// Candidate hidden state: h_tilde = tanh(W_h * x_t + U_h * (r_t ⊙ h_{t-1}))
let r_hidden = self.mul_vecs(&r, hidden);
let h_tilde =
self.tanh_vec(&self.add_vecs(&self.w_h.forward(input), &self.u_h.forward(&r_hidden)));
// Final hidden state: h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h_tilde
let one_minus_z: Vec<f32> = z.iter().map(|&zval| 1.0 - zval).collect();
let term1 = self.mul_vecs(&one_minus_z, hidden);
let term2 = self.mul_vecs(&z, &h_tilde);
self.add_vecs(&term1, &term2)
}
/// Sigmoid activation with numerical stability
fn sigmoid(&self, x: f32) -> f32 {
if x > 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let ex = x.exp();
ex / (1.0 + ex)
}
}
/// Sigmoid for vectors
fn sigmoid_vec(&self, v: &[f32]) -> Vec<f32> {
v.iter().map(|&x| self.sigmoid(x)).collect()
}
/// Tanh activation
fn tanh(&self, x: f32) -> f32 {
x.tanh()
}
/// Tanh for vectors
fn tanh_vec(&self, v: &[f32]) -> Vec<f32> {
v.iter().map(|&x| self.tanh(x)).collect()
}
/// Element-wise addition
fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
/// Element-wise multiplication
fn mul_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
}
}
/// Main GNN layer operating on HNSW topology
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuvectorLayer {
/// Message weight matrix
w_msg: Linear,
/// Aggregation weight matrix
w_agg: Linear,
/// GRU update cell
w_update: GRUCell,
/// Multi-head attention
attention: MultiHeadAttention,
/// Layer normalization
norm: LayerNorm,
/// Dropout rate
dropout: f32,
}
impl RuvectorLayer {
/// Create a new Ruvector GNN layer
///
/// # Arguments
/// * `input_dim` - Dimension of input node embeddings
/// * `hidden_dim` - Dimension of hidden representations
/// * `heads` - Number of attention heads
/// * `dropout` - Dropout rate (0.0 to 1.0)
///
/// # Errors
/// Returns `GnnError::LayerConfig` if `dropout` is outside `[0.0, 1.0]` or
/// if `hidden_dim` is not divisible by `heads`.
pub fn new(
input_dim: usize,
hidden_dim: usize,
heads: usize,
dropout: f32,
) -> Result<Self, GnnError> {
if !(0.0..=1.0).contains(&dropout) {
return Err(GnnError::layer_config(format!(
"Dropout must be between 0.0 and 1.0, got {}",
dropout
)));
}
Ok(Self {
w_msg: Linear::new(input_dim, hidden_dim),
w_agg: Linear::new(hidden_dim, hidden_dim),
w_update: GRUCell::new(hidden_dim, hidden_dim),
attention: MultiHeadAttention::new(hidden_dim, heads)?,
norm: LayerNorm::new(hidden_dim, 1e-5),
dropout,
})
}
/// Forward pass through the GNN layer
///
/// # Arguments
/// * `node_embedding` - Current node's embedding
/// * `neighbor_embeddings` - Embeddings of neighbor nodes
/// * `edge_weights` - Weights of edges to neighbors (e.g., distances)
///
/// # Returns
/// Updated node embedding
pub fn forward(
&self,
node_embedding: &[f32],
neighbor_embeddings: &[Vec<f32>],
edge_weights: &[f32],
) -> Vec<f32> {
if neighbor_embeddings.is_empty() {
// No neighbors: return normalized projection
let projected = self.w_msg.forward(node_embedding);
return self.norm.forward(&projected);
}
// Step 1: Message passing - transform node and neighbor embeddings
let node_msg = self.w_msg.forward(node_embedding);
let neighbor_msgs: Vec<Vec<f32>> = neighbor_embeddings
.iter()
.map(|n| self.w_msg.forward(n))
.collect();
// Step 2: Attention-based aggregation
let attention_output = self
.attention
.forward(&node_msg, &neighbor_msgs, &neighbor_msgs);
// Step 3: Weighted aggregation using edge weights
let weighted_msgs = self.aggregate_messages(&neighbor_msgs, edge_weights);
// Step 4: Combine attention and weighted aggregation
let combined = self.add_vecs(&attention_output, &weighted_msgs);
let aggregated = self.w_agg.forward(&combined);
// Step 5: GRU update
let updated = self.w_update.forward(&aggregated, &node_msg);
// Step 6: Apply dropout (simplified - always apply scaling)
let dropped = self.apply_dropout(&updated);
// Step 7: Layer normalization
self.norm.forward(&dropped)
}
/// Aggregate neighbor messages with edge weights
fn aggregate_messages(&self, messages: &[Vec<f32>], weights: &[f32]) -> Vec<f32> {
if messages.is_empty() || weights.is_empty() {
return vec![0.0; self.w_msg.output_dim()];
}
// Normalize weights to sum to 1
let weight_sum: f32 = weights.iter().sum();
let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
weights.iter().map(|&w| w / weight_sum).collect()
} else {
vec![1.0 / weights.len() as f32; weights.len()]
};
// Weighted sum
let dim = messages[0].len();
let mut aggregated = vec![0.0; dim];
for (msg, &weight) in messages.iter().zip(normalized_weights.iter()) {
for (agg, &m) in aggregated.iter_mut().zip(msg.iter()) {
*agg += weight * m;
}
}
aggregated
}
/// Apply dropout (simplified version - just scales by (1-dropout))
fn apply_dropout(&self, input: &[f32]) -> Vec<f32> {
let scale = 1.0 - self.dropout;
input.iter().map(|&x| x * scale).collect()
}
/// Element-wise vector addition
fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_layer() {
let linear = Linear::new(4, 2);
let input = vec![1.0, 2.0, 3.0, 4.0];
let output = linear.forward(&input);
assert_eq!(output.len(), 2);
}
#[test]
fn test_layer_norm() {
let norm = LayerNorm::new(4, 1e-5);
let input = vec![1.0, 2.0, 3.0, 4.0];
let output = norm.forward(&input);
// Check that output has zero mean (approximately)
let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
assert!((mean).abs() < 1e-5);
}
#[test]
fn test_multihead_attention() {
let attention = MultiHeadAttention::new(8, 2).unwrap();
let query = vec![0.5; 8];
let keys = vec![vec![0.3; 8], vec![0.7; 8]];
let values = vec![vec![0.2; 8], vec![0.8; 8]];
let output = attention.forward(&query, &keys, &values);
assert_eq!(output.len(), 8);
}
#[test]
fn test_multihead_attention_invalid_dims() {
let result = MultiHeadAttention::new(10, 3);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("divisible"));
}
#[test]
fn test_gru_cell() {
let gru = GRUCell::new(4, 8);
let input = vec![1.0; 4];
let hidden = vec![0.5; 8];
let new_hidden = gru.forward(&input, &hidden);
assert_eq!(new_hidden.len(), 8);
}
#[test]
fn test_ruvector_layer() {
let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
let node = vec![1.0, 2.0, 3.0, 4.0];
let neighbors = vec![vec![0.5, 1.0, 1.5, 2.0], vec![2.0, 3.0, 4.0, 5.0]];
let weights = vec![0.3, 0.7];
let output = layer.forward(&node, &neighbors, &weights);
assert_eq!(output.len(), 8);
}
#[test]
fn test_ruvector_layer_no_neighbors() {
let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
let node = vec![1.0, 2.0, 3.0, 4.0];
let neighbors: Vec<Vec<f32>> = vec![];
let weights: Vec<f32> = vec![];
let output = layer.forward(&node, &neighbors, &weights);
assert_eq!(output.len(), 8);
}
#[test]
fn test_ruvector_layer_invalid_dropout() {
let result = RuvectorLayer::new(4, 8, 2, 1.5);
assert!(result.is_err());
}
#[test]
fn test_ruvector_layer_invalid_heads() {
let result = RuvectorLayer::new(4, 7, 3, 0.1);
assert!(result.is_err());
}
}

View File

@@ -0,0 +1,92 @@
//! # RuVector GNN
//!
//! Graph Neural Network capabilities for RuVector, providing tensor operations,
//! GNN layers, compression, and differentiable search.
//!
//! ## Forgetting Mitigation (Issue #17)
//!
//! This crate includes comprehensive forgetting mitigation for continual learning:
//!
//! - **Adam Optimizer**: Full implementation with momentum and bias correction
//! - **Replay Buffer**: Experience replay with reservoir sampling for uniform coverage
//! - **EWC (Elastic Weight Consolidation)**: Prevents catastrophic forgetting
//! - **Learning Rate Scheduling**: Multiple strategies including warmup and plateau detection
//!
//! ### Usage Example
//!
//! ```rust,ignore
//! use ruvector_gnn::{
//! training::{Optimizer, OptimizerType},
//! replay::ReplayBuffer,
//! ewc::ElasticWeightConsolidation,
//! scheduler::{LearningRateScheduler, SchedulerType},
//! };
//!
//! // Create Adam optimizer
//! let mut optimizer = Optimizer::new(OptimizerType::Adam {
//! learning_rate: 0.001,
//! beta1: 0.9,
//! beta2: 0.999,
//! epsilon: 1e-8,
//! });
//!
//! // Create replay buffer for experience replay
//! let mut replay = ReplayBuffer::new(10000);
//!
//! // Create EWC for preventing forgetting
//! let mut ewc = ElasticWeightConsolidation::new(0.4);
//!
//! // Create learning rate scheduler
//! let mut scheduler = LearningRateScheduler::new(
//! SchedulerType::CosineAnnealing { t_max: 100, eta_min: 1e-6 },
//! 0.001
//! );
//! ```
#![warn(missing_docs)]
#![deny(unsafe_op_in_unsafe_fn)]
pub mod compress;
pub mod error;
pub mod ewc;
pub mod layer;
pub mod query;
pub mod replay;
pub mod scheduler;
pub mod search;
pub mod tensor;
pub mod training;
#[cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
pub mod mmap;
#[cfg(all(feature = "cold-tier", not(target_arch = "wasm32")))]
pub mod cold_tier;
// Re-export commonly used types
pub use compress::{CompressedTensor, CompressionLevel, TensorCompress};
pub use error::{GnnError, Result};
pub use ewc::ElasticWeightConsolidation;
pub use layer::RuvectorLayer;
pub use query::{QueryMode, QueryResult, RuvectorQuery, SubGraph};
pub use replay::{DistributionStats, ReplayBuffer, ReplayEntry};
pub use scheduler::{LearningRateScheduler, SchedulerType};
pub use search::{cosine_similarity, differentiable_search, hierarchical_forward};
pub use training::{
info_nce_loss, local_contrastive_loss, sgd_step, Loss, LossType, OnlineConfig, Optimizer,
OptimizerType, TrainConfig,
};
#[cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
pub use mmap::{AtomicBitmap, MmapGradientAccumulator, MmapManager};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic() {
// Basic smoke test to ensure the crate compiles
assert!(true);
}
}

View File

@@ -0,0 +1,939 @@
//! Memory-mapped embedding management for large-scale GNN training.
//!
//! This module provides efficient memory-mapped access to embeddings and gradients
//! that don't fit in RAM. It includes:
//! - `MmapManager`: Memory-mapped embedding storage with dirty tracking
//! - `MmapGradientAccumulator`: Lock-free gradient accumulation
//! - `AtomicBitmap`: Thread-safe bitmap for access/dirty tracking
//!
//! Only available on non-WASM targets.
#![cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
use crate::error::{GnnError, Result};
use memmap2::{MmapMut, MmapOptions};
use parking_lot::RwLock;
use std::fs::{File, OpenOptions};
use std::io::{self, Write};
use std::path::Path;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
/// Thread-safe bitmap using atomic operations.
///
/// Used for tracking which embeddings have been accessed or modified.
/// Each bit represents one embedding node.
#[derive(Debug)]
pub struct AtomicBitmap {
/// Array of 64-bit atomic integers, each storing 64 bits
bits: Vec<AtomicU64>,
/// Total number of bits (nodes)
size: usize,
}
impl AtomicBitmap {
/// Create a new atomic bitmap with the specified capacity.
///
/// # Arguments
/// * `size` - Number of bits to allocate
pub fn new(size: usize) -> Self {
let num_words = (size + 63) / 64;
let bits = (0..num_words).map(|_| AtomicU64::new(0)).collect();
Self { bits, size }
}
/// Set a bit to 1 (mark as accessed/dirty).
///
/// # Arguments
/// * `index` - Bit index to set
pub fn set(&self, index: usize) {
if index >= self.size {
return;
}
let word_idx = index / 64;
let bit_idx = index % 64;
self.bits[word_idx].fetch_or(1u64 << bit_idx, Ordering::Release);
}
/// Clear a bit to 0 (mark as clean/not accessed).
///
/// # Arguments
/// * `index` - Bit index to clear
pub fn clear(&self, index: usize) {
if index >= self.size {
return;
}
let word_idx = index / 64;
let bit_idx = index % 64;
self.bits[word_idx].fetch_and(!(1u64 << bit_idx), Ordering::Release);
}
/// Check if a bit is set.
///
/// # Arguments
/// * `index` - Bit index to check
///
/// # Returns
/// `true` if the bit is set, `false` otherwise
pub fn get(&self, index: usize) -> bool {
if index >= self.size {
return false;
}
let word_idx = index / 64;
let bit_idx = index % 64;
let word = self.bits[word_idx].load(Ordering::Acquire);
(word & (1u64 << bit_idx)) != 0
}
/// Clear all bits in the bitmap.
pub fn clear_all(&self) {
for word in &self.bits {
word.store(0, Ordering::Release);
}
}
/// Get all set bit indices (for finding dirty pages).
///
/// # Returns
/// Vector of indices where bits are set
pub fn get_set_indices(&self) -> Vec<usize> {
let mut indices = Vec::new();
for (word_idx, word) in self.bits.iter().enumerate() {
let mut w = word.load(Ordering::Acquire);
while w != 0 {
let bit_idx = w.trailing_zeros() as usize;
indices.push(word_idx * 64 + bit_idx);
w &= w - 1; // Clear lowest set bit
}
}
indices
}
}
/// Memory-mapped embedding manager with dirty tracking and prefetching.
///
/// Manages large embedding matrices that may not fit in RAM using memory-mapped files.
/// Tracks which embeddings have been accessed and modified for efficient I/O.
#[derive(Debug)]
pub struct MmapManager {
/// The memory-mapped file
file: File,
/// Mutable memory mapping
mmap: MmapMut,
/// Operating system page size
page_size: usize,
/// Embedding dimension
d_embed: usize,
/// Bitmap tracking which embeddings have been accessed
access_bitmap: AtomicBitmap,
/// Bitmap tracking which embeddings have been modified
dirty_bitmap: AtomicBitmap,
/// Pin count for each page (prevents eviction)
pin_count: Vec<AtomicU32>,
/// Maximum number of nodes
max_nodes: usize,
}
impl MmapManager {
/// Create a new memory-mapped embedding manager.
///
/// # Arguments
/// * `path` - Path to the memory-mapped file
/// * `d_embed` - Embedding dimension
/// * `max_nodes` - Maximum number of nodes to support
///
/// # Returns
/// A new `MmapManager` instance
pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
// Calculate required file size
let embedding_size = d_embed * std::mem::size_of::<f32>();
let file_size = max_nodes * embedding_size;
// Create or open the file
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(path)
.map_err(|e| GnnError::mmap(format!("Failed to open mmap file: {}", e)))?;
// Set file size
file.set_len(file_size as u64)
.map_err(|e| GnnError::mmap(format!("Failed to set file size: {}", e)))?;
// Create memory mapping
let mmap = unsafe {
MmapOptions::new()
.len(file_size)
.map_mut(&file)
.map_err(|e| GnnError::mmap(format!("Failed to create mmap: {}", e)))?
};
// Get system page size
let page_size = page_size::get();
let num_pages = (file_size + page_size - 1) / page_size;
Ok(Self {
file,
mmap,
page_size,
d_embed,
access_bitmap: AtomicBitmap::new(max_nodes),
dirty_bitmap: AtomicBitmap::new(max_nodes),
pin_count: (0..num_pages).map(|_| AtomicU32::new(0)).collect(),
max_nodes,
})
}
/// Calculate the byte offset for a given node's embedding.
///
/// # Arguments
/// * `node_id` - Node identifier
///
/// # Returns
/// Byte offset in the memory-mapped file, or None if overflow would occur
///
/// # Security
/// Uses checked arithmetic to prevent integer overflow attacks.
#[inline]
pub fn embedding_offset(&self, node_id: u64) -> Option<usize> {
let node_idx = usize::try_from(node_id).ok()?;
let elem_size = std::mem::size_of::<f32>();
node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
}
/// Validate that a node_id is within bounds.
#[inline]
fn validate_node_id(&self, node_id: u64) -> bool {
(node_id as usize) < self.max_nodes
}
/// Get a read-only reference to a node's embedding.
///
/// # Arguments
/// * `node_id` - Node identifier
///
/// # Returns
/// Slice containing the embedding vector
///
/// # Panics
/// Panics if node_id is out of bounds or would cause overflow
pub fn get_embedding(&self, node_id: u64) -> &[f32] {
// Security: Validate bounds before any pointer arithmetic
assert!(
self.validate_node_id(node_id),
"node_id {} out of bounds (max: {})",
node_id,
self.max_nodes
);
let offset = self
.embedding_offset(node_id)
.expect("embedding offset calculation overflow");
let end = offset
.checked_add(
self.d_embed
.checked_mul(std::mem::size_of::<f32>())
.unwrap(),
)
.expect("end offset overflow");
assert!(
end <= self.mmap.len(),
"embedding extends beyond mmap bounds"
);
// Mark as accessed
self.access_bitmap.set(node_id as usize);
// Safety: We control the offset and know the data is properly aligned
unsafe {
let ptr = self.mmap.as_ptr().add(offset) as *const f32;
std::slice::from_raw_parts(ptr, self.d_embed)
}
}
/// Set a node's embedding data.
///
/// # Arguments
/// * `node_id` - Node identifier
/// * `data` - Embedding vector to write
///
/// # Panics
/// Panics if node_id is out of bounds, data length doesn't match d_embed,
/// or offset calculation would overflow.
pub fn set_embedding(&mut self, node_id: u64, data: &[f32]) {
// Security: Validate bounds first
assert!(
self.validate_node_id(node_id),
"node_id {} out of bounds (max: {})",
node_id,
self.max_nodes
);
assert_eq!(
data.len(),
self.d_embed,
"Embedding data length must match d_embed"
);
let offset = self
.embedding_offset(node_id)
.expect("embedding offset calculation overflow");
let end = offset
.checked_add(data.len().checked_mul(std::mem::size_of::<f32>()).unwrap())
.expect("end offset overflow");
assert!(
end <= self.mmap.len(),
"embedding extends beyond mmap bounds"
);
// Mark as accessed and dirty
self.access_bitmap.set(node_id as usize);
self.dirty_bitmap.set(node_id as usize);
// Safety: We control the offset and know the data is properly aligned
unsafe {
let ptr = self.mmap.as_mut_ptr().add(offset) as *mut f32;
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, self.d_embed);
}
}
/// Flush all dirty pages to disk.
///
/// # Returns
/// `Ok(())` on success, error otherwise
pub fn flush_dirty(&self) -> io::Result<()> {
let dirty_nodes = self.dirty_bitmap.get_set_indices();
if dirty_nodes.is_empty() {
return Ok(());
}
// Flush the entire mmap for simplicity
// In a production system, you might want to flush only dirty pages
self.mmap.flush()?;
// Clear dirty bitmap after successful flush
for &node_id in &dirty_nodes {
self.dirty_bitmap.clear(node_id);
}
Ok(())
}
/// Prefetch embeddings into memory for better cache locality.
///
/// # Arguments
/// * `node_ids` - List of node IDs to prefetch
pub fn prefetch(&self, node_ids: &[u64]) {
#[cfg(target_os = "linux")]
{
#[allow(unused_imports)]
use std::os::unix::io::AsRawFd;
for &node_id in node_ids {
// Skip invalid node IDs
if !self.validate_node_id(node_id) {
continue;
}
let offset = match self.embedding_offset(node_id) {
Some(o) => o,
None => continue,
};
let page_offset = (offset / self.page_size) * self.page_size;
let length = self.d_embed * std::mem::size_of::<f32>();
unsafe {
// Use madvise to hint the kernel to prefetch
libc::madvise(
self.mmap.as_ptr().add(page_offset) as *mut libc::c_void,
length,
libc::MADV_WILLNEED,
);
}
}
}
// On non-Linux platforms, just access the data to bring it into cache
#[cfg(not(target_os = "linux"))]
{
for &node_id in node_ids {
if self.validate_node_id(node_id) {
let _ = self.get_embedding(node_id);
}
}
}
}
/// Get the embedding dimension.
pub fn d_embed(&self) -> usize {
self.d_embed
}
/// Get the maximum number of nodes.
pub fn max_nodes(&self) -> usize {
self.max_nodes
}
}
/// Memory-mapped gradient accumulator with fine-grained locking.
///
/// Allows multiple threads to accumulate gradients concurrently with minimal contention.
/// Uses reader-writer locks at a configurable granularity.
pub struct MmapGradientAccumulator {
/// Memory-mapped gradient storage (using UnsafeCell for interior mutability)
grad_mmap: std::cell::UnsafeCell<MmapMut>,
/// Number of nodes per lock (lock granularity)
lock_granularity: usize,
/// Reader-writer locks for gradient regions
locks: Vec<RwLock<()>>,
/// Number of nodes
n_nodes: usize,
/// Embedding dimension
d_embed: usize,
/// Gradient file
_file: File,
}
impl MmapGradientAccumulator {
/// Create a new memory-mapped gradient accumulator.
///
/// # Arguments
/// * `path` - Path to the gradient file
/// * `d_embed` - Embedding dimension
/// * `max_nodes` - Maximum number of nodes
///
/// # Returns
/// A new `MmapGradientAccumulator` instance
pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
// Calculate required file size
let grad_size = d_embed * std::mem::size_of::<f32>();
let file_size = max_nodes * grad_size;
// Create or open the file
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(path)
.map_err(|e| GnnError::mmap(format!("Failed to open gradient file: {}", e)))?;
// Set file size
file.set_len(file_size as u64)
.map_err(|e| GnnError::mmap(format!("Failed to set gradient file size: {}", e)))?;
// Create memory mapping
let grad_mmap = unsafe {
MmapOptions::new()
.len(file_size)
.map_mut(&file)
.map_err(|e| GnnError::mmap(format!("Failed to create gradient mmap: {}", e)))?
};
// Zero out the gradients
for byte in grad_mmap.iter() {
// This forces the pages to be allocated and zeroed
let _ = byte;
}
// Use a lock granularity of 64 nodes per lock for good parallelism
let lock_granularity = 64;
let num_locks = (max_nodes + lock_granularity - 1) / lock_granularity;
let locks = (0..num_locks).map(|_| RwLock::new(())).collect();
Ok(Self {
grad_mmap: std::cell::UnsafeCell::new(grad_mmap),
lock_granularity,
locks,
n_nodes: max_nodes,
d_embed,
_file: file,
})
}
/// Calculate the byte offset for a node's gradient.
///
/// # Arguments
/// * `node_id` - Node identifier
///
/// # Returns
/// Byte offset in the gradient file, or None on overflow or out-of-bounds
///
/// # Security
/// Uses checked arithmetic to prevent integer overflow (SEC-001).
#[inline]
pub fn grad_offset(&self, node_id: u64) -> Option<usize> {
let node_idx = usize::try_from(node_id).ok()?;
if node_idx >= self.n_nodes {
return None;
}
let elem_size = std::mem::size_of::<f32>();
node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
}
/// Accumulate gradients for a specific node.
///
/// # Arguments
/// * `node_id` - Node identifier
/// * `grad` - Gradient vector to accumulate
///
/// # Panics
/// Panics if grad length doesn't match d_embed
pub fn accumulate(&self, node_id: u64, grad: &[f32]) {
assert_eq!(
grad.len(),
self.d_embed,
"Gradient length must match d_embed"
);
let offset = self
.grad_offset(node_id)
.expect("node_id out of bounds or offset overflow");
let lock_idx = (node_id as usize) / self.lock_granularity;
assert!(lock_idx < self.locks.len(), "lock index out of bounds");
let _lock = self.locks[lock_idx].write();
// Safety: We validated node_id bounds and offset above, and hold the write lock
unsafe {
let mmap = &mut *self.grad_mmap.get();
assert!(
offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
"gradient write would exceed mmap bounds"
);
let ptr = mmap.as_mut_ptr().add(offset) as *mut f32;
let grad_slice = std::slice::from_raw_parts_mut(ptr, self.d_embed);
// Accumulate gradients
for (g, &new_g) in grad_slice.iter_mut().zip(grad.iter()) {
*g += new_g;
}
}
}
/// Apply accumulated gradients to embeddings and zero out gradients.
///
/// # Arguments
/// * `learning_rate` - Learning rate for gradient descent
/// * `embeddings` - Embedding manager to update
pub fn apply(&mut self, learning_rate: f32, embeddings: &mut MmapManager) {
assert_eq!(
self.d_embed, embeddings.d_embed,
"Gradient and embedding dimensions must match"
);
// Process all nodes
for node_id in 0..self.n_nodes.min(embeddings.max_nodes) {
let grad = self.get_grad(node_id as u64);
let embedding = embeddings.get_embedding(node_id as u64);
// Apply gradient descent: embedding -= learning_rate * grad
let mut updated = vec![0.0f32; self.d_embed];
for i in 0..self.d_embed {
updated[i] = embedding[i] - learning_rate * grad[i];
}
embeddings.set_embedding(node_id as u64, &updated);
}
// Zero out gradients after applying
self.zero_grad();
}
/// Zero out all accumulated gradients.
pub fn zero_grad(&mut self) {
// Zero the entire gradient buffer
unsafe {
let mmap = &mut *self.grad_mmap.get();
for byte in mmap.iter_mut() {
*byte = 0;
}
}
}
/// Get a read-only reference to a node's accumulated gradient.
///
/// # Arguments
/// * `node_id` - Node identifier
///
/// # Returns
/// Slice containing the gradient vector
pub fn get_grad(&self, node_id: u64) -> &[f32] {
let offset = self
.grad_offset(node_id)
.expect("node_id out of bounds or offset overflow");
let lock_idx = (node_id as usize) / self.lock_granularity;
assert!(lock_idx < self.locks.len(), "lock index out of bounds");
let _lock = self.locks[lock_idx].read();
// Safety: We validated node_id bounds and offset above, and hold the read lock
unsafe {
let mmap = &*self.grad_mmap.get();
assert!(
offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
"gradient read would exceed mmap bounds"
);
let ptr = mmap.as_ptr().add(offset) as *const f32;
std::slice::from_raw_parts(ptr, self.d_embed)
}
}
/// Get the embedding dimension.
pub fn d_embed(&self) -> usize {
self.d_embed
}
/// Get the number of nodes.
pub fn n_nodes(&self) -> usize {
self.n_nodes
}
}
// Implement Drop to ensure proper cleanup
impl Drop for MmapManager {
fn drop(&mut self) {
// Try to flush dirty pages before dropping
let _ = self.flush_dirty();
}
}
impl Drop for MmapGradientAccumulator {
fn drop(&mut self) {
// Flush gradient data
unsafe {
let mmap = &mut *self.grad_mmap.get();
let _ = mmap.flush();
}
}
}
// Safety: MmapGradientAccumulator is safe to send between threads
// because access is protected by RwLocks
unsafe impl Send for MmapGradientAccumulator {}
unsafe impl Sync for MmapGradientAccumulator {}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_atomic_bitmap_basic() {
let bitmap = AtomicBitmap::new(128);
assert!(!bitmap.get(0));
assert!(!bitmap.get(127));
bitmap.set(0);
bitmap.set(127);
bitmap.set(64);
assert!(bitmap.get(0));
assert!(bitmap.get(127));
assert!(bitmap.get(64));
assert!(!bitmap.get(1));
bitmap.clear(0);
assert!(!bitmap.get(0));
assert!(bitmap.get(127));
}
#[test]
fn test_atomic_bitmap_get_set_indices() {
let bitmap = AtomicBitmap::new(256);
bitmap.set(0);
bitmap.set(63);
bitmap.set(64);
bitmap.set(128);
bitmap.set(255);
let mut indices = bitmap.get_set_indices();
indices.sort();
assert_eq!(indices, vec![0, 63, 64, 128, 255]);
}
#[test]
fn test_atomic_bitmap_clear_all() {
let bitmap = AtomicBitmap::new(128);
bitmap.set(0);
bitmap.set(64);
bitmap.set(127);
assert!(bitmap.get(0));
bitmap.clear_all();
assert!(!bitmap.get(0));
assert!(!bitmap.get(64));
assert!(!bitmap.get(127));
}
#[test]
fn test_mmap_manager_creation() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("embeddings.bin");
let manager = MmapManager::new(&path, 128, 1000).unwrap();
assert_eq!(manager.d_embed(), 128);
assert_eq!(manager.max_nodes(), 1000);
assert!(path.exists());
}
#[test]
fn test_mmap_manager_set_get_embedding() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("embeddings.bin");
let mut manager = MmapManager::new(&path, 64, 100).unwrap();
let embedding = vec![1.0f32; 64];
manager.set_embedding(0, &embedding);
let retrieved = manager.get_embedding(0);
assert_eq!(retrieved.len(), 64);
assert_eq!(retrieved[0], 1.0);
assert_eq!(retrieved[63], 1.0);
}
#[test]
fn test_mmap_manager_multiple_embeddings() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("embeddings.bin");
let mut manager = MmapManager::new(&path, 32, 100).unwrap();
for i in 0..10 {
let embedding: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
manager.set_embedding(i, &embedding);
}
// Verify each embedding
for i in 0..10 {
let retrieved = manager.get_embedding(i);
assert_eq!(retrieved.len(), 32);
assert_eq!(retrieved[0], (i * 32) as f32);
assert_eq!(retrieved[31], (i * 32 + 31) as f32);
}
}
#[test]
fn test_mmap_manager_dirty_tracking() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("embeddings.bin");
let mut manager = MmapManager::new(&path, 64, 100).unwrap();
let embedding = vec![2.0f32; 64];
manager.set_embedding(5, &embedding);
// Should be marked as dirty
assert!(manager.dirty_bitmap.get(5));
// Flush and check it's clean
manager.flush_dirty().unwrap();
assert!(!manager.dirty_bitmap.get(5));
}
#[test]
fn test_mmap_manager_persistence() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("embeddings.bin");
{
let mut manager = MmapManager::new(&path, 64, 100).unwrap();
let embedding = vec![3.14f32; 64];
manager.set_embedding(10, &embedding);
manager.flush_dirty().unwrap();
}
// Reopen and verify data persisted
{
let manager = MmapManager::new(&path, 64, 100).unwrap();
let retrieved = manager.get_embedding(10);
assert_eq!(retrieved[0], 3.14);
assert_eq!(retrieved[63], 3.14);
}
}
#[test]
fn test_gradient_accumulator_creation() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("gradients.bin");
let accumulator = MmapGradientAccumulator::new(&path, 128, 1000).unwrap();
assert_eq!(accumulator.d_embed(), 128);
assert_eq!(accumulator.n_nodes(), 1000);
assert!(path.exists());
}
#[test]
fn test_gradient_accumulator_accumulate() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("gradients.bin");
let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
let grad1 = vec![1.0f32; 64];
let grad2 = vec![2.0f32; 64];
accumulator.accumulate(0, &grad1);
accumulator.accumulate(0, &grad2);
let accumulated = accumulator.get_grad(0);
assert_eq!(accumulated[0], 3.0);
assert_eq!(accumulated[63], 3.0);
}
#[test]
fn test_gradient_accumulator_zero_grad() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("gradients.bin");
let mut accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
let grad = vec![1.5f32; 64];
accumulator.accumulate(0, &grad);
let accumulated = accumulator.get_grad(0);
assert_eq!(accumulated[0], 1.5);
accumulator.zero_grad();
let zeroed = accumulator.get_grad(0);
assert_eq!(zeroed[0], 0.0);
assert_eq!(zeroed[63], 0.0);
}
#[test]
fn test_gradient_accumulator_apply() {
let temp_dir = TempDir::new().unwrap();
let embed_path = temp_dir.path().join("embeddings.bin");
let grad_path = temp_dir.path().join("gradients.bin");
let mut embeddings = MmapManager::new(&embed_path, 32, 100).unwrap();
let mut accumulator = MmapGradientAccumulator::new(&grad_path, 32, 100).unwrap();
// Set initial embedding
let initial = vec![10.0f32; 32];
embeddings.set_embedding(0, &initial);
// Accumulate gradient
let grad = vec![1.0f32; 32];
accumulator.accumulate(0, &grad);
// Apply with learning rate 0.1
accumulator.apply(0.1, &mut embeddings);
// Check updated embedding: 10.0 - 0.1 * 1.0 = 9.9
let updated = embeddings.get_embedding(0);
assert!((updated[0] - 9.9).abs() < 1e-6);
// Check gradients were zeroed
let zeroed_grad = accumulator.get_grad(0);
assert_eq!(zeroed_grad[0], 0.0);
}
#[test]
fn test_gradient_accumulator_concurrent_accumulation() {
use std::thread;
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("gradients.bin");
let accumulator =
std::sync::Arc::new(MmapGradientAccumulator::new(&path, 64, 100).unwrap());
let mut handles = vec![];
// Spawn 10 threads, each accumulating 1.0 to node 0
for _ in 0..10 {
let acc = accumulator.clone();
let handle = thread::spawn(move || {
let grad = vec![1.0f32; 64];
acc.accumulate(0, &grad);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Should have accumulated 10.0
let result = accumulator.get_grad(0);
assert_eq!(result[0], 10.0);
}
#[test]
fn test_embedding_offset_calculation() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("embeddings.bin");
let manager = MmapManager::new(&path, 64, 100).unwrap();
assert_eq!(manager.embedding_offset(0), Some(0));
assert_eq!(manager.embedding_offset(1), Some(64 * 4)); // 64 floats * 4 bytes
assert_eq!(manager.embedding_offset(10), Some(64 * 4 * 10));
}
#[test]
fn test_grad_offset_calculation() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("gradients.bin");
let accumulator = MmapGradientAccumulator::new(&path, 128, 100).unwrap();
assert_eq!(accumulator.grad_offset(0), Some(0));
assert_eq!(accumulator.grad_offset(1), Some(128 * 4)); // 128 floats * 4 bytes
assert_eq!(accumulator.grad_offset(5), Some(128 * 4 * 5));
}
#[test]
#[should_panic(expected = "Embedding data length must match d_embed")]
fn test_set_embedding_wrong_size() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("embeddings.bin");
let mut manager = MmapManager::new(&path, 64, 100).unwrap();
let wrong_size = vec![1.0f32; 32]; // Should be 64
manager.set_embedding(0, &wrong_size);
}
#[test]
#[should_panic(expected = "Gradient length must match d_embed")]
fn test_accumulate_wrong_size() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("gradients.bin");
let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
let wrong_size = vec![1.0f32; 32]; // Should be 64
accumulator.accumulate(0, &wrong_size);
}
#[test]
fn test_prefetch() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("embeddings.bin");
let mut manager = MmapManager::new(&path, 64, 100).unwrap();
// Set some embeddings
for i in 0..10 {
let embedding = vec![i as f32; 64];
manager.set_embedding(i, &embedding);
}
// Prefetch should not crash
manager.prefetch(&[0, 1, 2, 3, 4]);
// Access should still work
let retrieved = manager.get_embedding(2);
assert_eq!(retrieved[0], 2.0);
}
}

View File

@@ -0,0 +1,82 @@
//! Memory-mapped embedding management for large-scale GNN training.
//!
//! This module provides efficient memory-mapped access to embeddings and gradients
//! that don't fit in RAM. It includes:
//! - `MmapManager`: Memory-mapped embedding storage with dirty tracking
//! - `MmapGradientAccumulator`: Lock-free gradient accumulation
//! - `AtomicBitmap`: Thread-safe bitmap for access/dirty tracking
//!
//! Only available on non-WASM targets.
#![cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
use crate::error::{GnnError, Result};
use std::cell::UnsafeCell;
use std::fs::{File, OpenOptions};
use std::io;
use std::path::Path;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use parking_lot::RwLock;
use memmap2::{MmapMut, MmapOptions};
/// Thread-safe bitmap using atomic operations.
#[derive(Debug)]
pub struct AtomicBitmap {
bits: Vec<AtomicU64>,
size: usize,
}
impl AtomicBitmap {
pub fn new(size: usize) -> Self {
let num_words = (size + 63) / 64;
let bits = (0..num_words).map(|_| AtomicU64::new(0)).collect();
Self { bits, size }
}
pub fn set(&self, index: usize) {
if index >= self.size {
return;
}
let word_idx = index / 64;
let bit_idx = index % 64;
self.bits[word_idx].fetch_or(1u64 << bit_idx, Ordering::Release);
}
pub fn clear(&self, index: usize) {
if index >= self.size {
return;
}
let word_idx = index / 64;
let bit_idx = index % 64;
self.bits[word_idx].fetch_and(!(1u64 << bit_idx), Ordering::Release);
}
pub fn get(&self, index: usize) -> bool {
if index >= self.size {
return false;
}
let word_idx = index / 64;
let bit_idx = index % 64;
let word = self.bits[word_idx].load(Ordering::Acquire);
(word & (1u64 << bit_idx)) != 0
}
pub fn clear_all(&self) {
for word in &self.bits {
word.store(0, Ordering::Release);
}
}
pub fn get_set_indices(&self) -> Vec<usize> {
let mut indices = Vec::new();
for (word_idx, word) in self.bits.iter().enumerate() {
let mut w = word.load(Ordering::Acquire);
while w != 0 {
let bit_idx = w.trailing_zeros() as usize;
indices.push(word_idx * 64 + bit_idx);
w &= w - 1;
}
}
indices
}
}

View File

@@ -0,0 +1,670 @@
//! Query API for RuVector GNN
//!
//! Provides high-level query interfaces for vector search, neural search,
//! and subgraph extraction.
use serde::{Deserialize, Serialize};
/// Query mode for different search strategies
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QueryMode {
/// Pure HNSW vector search
VectorSearch,
/// GNN-enhanced neural search
NeuralSearch,
/// Extract k-hop subgraph around results
SubgraphExtraction,
/// Differentiable search with soft attention
DifferentiableSearch,
}
/// Query configuration for RuVector searches
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuvectorQuery {
/// Query vector for similarity search
pub vector: Option<Vec<f32>>,
/// Text query (requires embedding model)
pub text: Option<String>,
/// Node ID for subgraph extraction
pub node_id: Option<u64>,
/// Search mode
pub mode: QueryMode,
/// Number of results to return
pub k: usize,
/// HNSW search parameter (exploration factor)
pub ef: usize,
/// GNN depth for neural search
pub gnn_depth: usize,
/// Temperature for differentiable search (higher = softer)
pub temperature: f32,
/// Whether to return attention weights
pub return_attention: bool,
}
impl Default for RuvectorQuery {
fn default() -> Self {
Self {
vector: None,
text: None,
node_id: None,
mode: QueryMode::VectorSearch,
k: 10,
ef: 50,
gnn_depth: 2,
temperature: 1.0,
return_attention: false,
}
}
}
impl RuvectorQuery {
/// Create a basic vector search query
///
/// # Arguments
/// * `vector` - Query vector
/// * `k` - Number of results to return
///
/// # Example
/// ```
/// use ruvector_gnn::query::RuvectorQuery;
///
/// let query = RuvectorQuery::vector_search(vec![0.1, 0.2, 0.3], 10);
/// assert_eq!(query.k, 10);
/// ```
pub fn vector_search(vector: Vec<f32>, k: usize) -> Self {
Self {
vector: Some(vector),
mode: QueryMode::VectorSearch,
k,
..Default::default()
}
}
/// Create a GNN-enhanced neural search query
///
/// # Arguments
/// * `vector` - Query vector
/// * `k` - Number of results to return
/// * `gnn_depth` - Number of GNN layers to apply
///
/// # Example
/// ```
/// use ruvector_gnn::query::RuvectorQuery;
///
/// let query = RuvectorQuery::neural_search(vec![0.1, 0.2, 0.3], 10, 3);
/// assert_eq!(query.gnn_depth, 3);
/// ```
pub fn neural_search(vector: Vec<f32>, k: usize, gnn_depth: usize) -> Self {
Self {
vector: Some(vector),
mode: QueryMode::NeuralSearch,
k,
gnn_depth,
..Default::default()
}
}
/// Create a subgraph extraction query
///
/// # Arguments
/// * `vector` - Query vector
/// * `k` - Number of nodes in subgraph
///
/// # Example
/// ```
/// use ruvector_gnn::query::RuvectorQuery;
///
/// let query = RuvectorQuery::subgraph_search(vec![0.1, 0.2, 0.3], 20);
/// assert_eq!(query.k, 20);
/// ```
pub fn subgraph_search(vector: Vec<f32>, k: usize) -> Self {
Self {
vector: Some(vector),
mode: QueryMode::SubgraphExtraction,
k,
..Default::default()
}
}
/// Create a differentiable search query with temperature
///
/// # Arguments
/// * `vector` - Query vector
/// * `k` - Number of results
/// * `temperature` - Softmax temperature (higher = softer distribution)
pub fn differentiable_search(vector: Vec<f32>, k: usize, temperature: f32) -> Self {
Self {
vector: Some(vector),
mode: QueryMode::DifferentiableSearch,
k,
temperature,
return_attention: true,
..Default::default()
}
}
/// Set text query (requires embedding model)
pub fn with_text(mut self, text: String) -> Self {
self.text = Some(text);
self
}
/// Set node ID for centered queries
pub fn with_node(mut self, node_id: u64) -> Self {
self.node_id = Some(node_id);
self
}
/// Set EF parameter for HNSW search
pub fn with_ef(mut self, ef: usize) -> Self {
self.ef = ef;
self
}
/// Enable attention weight return
pub fn with_attention(mut self) -> Self {
self.return_attention = true;
self
}
}
/// Subgraph representation
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct SubGraph {
/// Node IDs in the subgraph
pub nodes: Vec<u64>,
/// Edges as (from, to, weight) tuples
pub edges: Vec<(u64, u64, f32)>,
}
impl SubGraph {
/// Create a new empty subgraph
pub fn new() -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
}
}
/// Create subgraph with nodes and edges
pub fn with_edges(nodes: Vec<u64>, edges: Vec<(u64, u64, f32)>) -> Self {
Self { nodes, edges }
}
/// Get number of nodes
pub fn node_count(&self) -> usize {
self.nodes.len()
}
/// Get number of edges
pub fn edge_count(&self) -> usize {
self.edges.len()
}
/// Check if subgraph contains a node
pub fn contains_node(&self, node_id: u64) -> bool {
self.nodes.contains(&node_id)
}
/// Get average edge weight
pub fn average_edge_weight(&self) -> f32 {
if self.edges.is_empty() {
return 0.0;
}
let sum: f32 = self.edges.iter().map(|(_, _, w)| w).sum();
sum / self.edges.len() as f32
}
}
impl Default for SubGraph {
fn default() -> Self {
Self::new()
}
}
/// Query result with nodes, scores, and optional metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResult {
/// Matched node IDs
pub nodes: Vec<u64>,
/// Similarity scores (higher = more similar)
pub scores: Vec<f32>,
/// Optional node embeddings after GNN processing
pub embeddings: Option<Vec<Vec<f32>>>,
/// Optional attention weights from differentiable search
pub attention_weights: Option<Vec<Vec<f32>>>,
/// Optional subgraph extraction
pub subgraph: Option<SubGraph>,
/// Query latency in milliseconds
pub latency_ms: u64,
}
impl QueryResult {
/// Create a new empty query result
pub fn new() -> Self {
Self {
nodes: Vec::new(),
scores: Vec::new(),
embeddings: None,
attention_weights: None,
subgraph: None,
latency_ms: 0,
}
}
/// Create query result with nodes and scores
///
/// # Arguments
/// * `nodes` - Node IDs
/// * `scores` - Similarity scores
///
/// # Example
/// ```
/// use ruvector_gnn::query::QueryResult;
///
/// let result = QueryResult::with_nodes(vec![1, 2, 3], vec![0.9, 0.8, 0.7]);
/// assert_eq!(result.nodes.len(), 3);
/// ```
pub fn with_nodes(nodes: Vec<u64>, scores: Vec<f32>) -> Self {
Self {
nodes,
scores,
embeddings: None,
attention_weights: None,
subgraph: None,
latency_ms: 0,
}
}
/// Add embeddings to the result
pub fn with_embeddings(mut self, embeddings: Vec<Vec<f32>>) -> Self {
self.embeddings = Some(embeddings);
self
}
/// Add attention weights to the result
pub fn with_attention(mut self, attention: Vec<Vec<f32>>) -> Self {
self.attention_weights = Some(attention);
self
}
/// Add subgraph to the result
pub fn with_subgraph(mut self, subgraph: SubGraph) -> Self {
self.subgraph = Some(subgraph);
self
}
/// Set query latency
pub fn with_latency(mut self, latency_ms: u64) -> Self {
self.latency_ms = latency_ms;
self
}
/// Get number of results
pub fn len(&self) -> usize {
self.nodes.len()
}
/// Check if result is empty
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
/// Get top-k results
pub fn top_k(&self, k: usize) -> Self {
let k = k.min(self.nodes.len());
Self {
nodes: self.nodes[..k].to_vec(),
scores: self.scores[..k].to_vec(),
embeddings: self.embeddings.as_ref().map(|e| e[..k].to_vec()),
attention_weights: self.attention_weights.as_ref().map(|a| a[..k].to_vec()),
subgraph: self.subgraph.clone(),
latency_ms: self.latency_ms,
}
}
/// Get the best result (highest score)
pub fn best(&self) -> Option<(u64, f32)> {
if self.nodes.is_empty() {
None
} else {
Some((self.nodes[0], self.scores[0]))
}
}
/// Filter results by minimum score
pub fn filter_by_score(mut self, min_score: f32) -> Self {
let mut filtered_nodes = Vec::new();
let mut filtered_scores = Vec::new();
let mut filtered_embeddings = Vec::new();
let mut filtered_attention = Vec::new();
for i in 0..self.nodes.len() {
if self.scores[i] >= min_score {
filtered_nodes.push(self.nodes[i]);
filtered_scores.push(self.scores[i]);
if let Some(ref emb) = self.embeddings {
filtered_embeddings.push(emb[i].clone());
}
if let Some(ref att) = self.attention_weights {
filtered_attention.push(att[i].clone());
}
}
}
self.nodes = filtered_nodes;
self.scores = filtered_scores;
if !filtered_embeddings.is_empty() {
self.embeddings = Some(filtered_embeddings);
}
if !filtered_attention.is_empty() {
self.attention_weights = Some(filtered_attention);
}
self
}
}
impl Default for QueryResult {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_mode_serialization() {
let mode = QueryMode::NeuralSearch;
let json = serde_json::to_string(&mode).unwrap();
let deserialized: QueryMode = serde_json::from_str(&json).unwrap();
assert_eq!(mode, deserialized);
}
#[test]
fn test_ruvector_query_default() {
let query = RuvectorQuery::default();
assert_eq!(query.k, 10);
assert_eq!(query.ef, 50);
assert_eq!(query.gnn_depth, 2);
assert_eq!(query.temperature, 1.0);
assert_eq!(query.mode, QueryMode::VectorSearch);
assert!(!query.return_attention);
}
#[test]
fn test_vector_search_query() {
let vector = vec![0.1, 0.2, 0.3, 0.4];
let query = RuvectorQuery::vector_search(vector.clone(), 5);
assert_eq!(query.vector, Some(vector));
assert_eq!(query.k, 5);
assert_eq!(query.mode, QueryMode::VectorSearch);
}
#[test]
fn test_neural_search_query() {
let vector = vec![0.1, 0.2, 0.3];
let query = RuvectorQuery::neural_search(vector.clone(), 10, 3);
assert_eq!(query.vector, Some(vector));
assert_eq!(query.k, 10);
assert_eq!(query.gnn_depth, 3);
assert_eq!(query.mode, QueryMode::NeuralSearch);
}
#[test]
fn test_subgraph_search_query() {
let vector = vec![0.5, 0.5];
let query = RuvectorQuery::subgraph_search(vector.clone(), 20);
assert_eq!(query.vector, Some(vector));
assert_eq!(query.k, 20);
assert_eq!(query.mode, QueryMode::SubgraphExtraction);
}
#[test]
fn test_differentiable_search_query() {
let vector = vec![0.3, 0.4, 0.5];
let query = RuvectorQuery::differentiable_search(vector.clone(), 15, 0.5);
assert_eq!(query.vector, Some(vector));
assert_eq!(query.k, 15);
assert_eq!(query.temperature, 0.5);
assert_eq!(query.mode, QueryMode::DifferentiableSearch);
assert!(query.return_attention);
}
#[test]
fn test_query_builder_pattern() {
let query = RuvectorQuery::vector_search(vec![0.1, 0.2], 5)
.with_text("hello world".to_string())
.with_node(42)
.with_ef(100)
.with_attention();
assert_eq!(query.text, Some("hello world".to_string()));
assert_eq!(query.node_id, Some(42));
assert_eq!(query.ef, 100);
assert!(query.return_attention);
}
#[test]
fn test_subgraph_new() {
let subgraph = SubGraph::new();
assert_eq!(subgraph.node_count(), 0);
assert_eq!(subgraph.edge_count(), 0);
}
#[test]
fn test_subgraph_with_edges() {
let nodes = vec![1, 2, 3];
let edges = vec![(1, 2, 0.8), (2, 3, 0.6), (1, 3, 0.5)];
let subgraph = SubGraph::with_edges(nodes.clone(), edges.clone());
assert_eq!(subgraph.nodes, nodes);
assert_eq!(subgraph.edges, edges);
assert_eq!(subgraph.node_count(), 3);
assert_eq!(subgraph.edge_count(), 3);
}
#[test]
fn test_subgraph_contains_node() {
let nodes = vec![1, 2, 3];
let subgraph = SubGraph::with_edges(nodes, vec![]);
assert!(subgraph.contains_node(1));
assert!(subgraph.contains_node(2));
assert!(subgraph.contains_node(3));
assert!(!subgraph.contains_node(4));
}
#[test]
fn test_subgraph_average_edge_weight() {
let edges = vec![(1, 2, 0.8), (2, 3, 0.6), (1, 3, 0.4)];
let subgraph = SubGraph::with_edges(vec![1, 2, 3], edges);
let avg = subgraph.average_edge_weight();
assert!((avg - 0.6).abs() < 0.001);
}
#[test]
fn test_subgraph_empty_average() {
let subgraph = SubGraph::new();
assert_eq!(subgraph.average_edge_weight(), 0.0);
}
#[test]
fn test_query_result_new() {
let result = QueryResult::new();
assert!(result.is_empty());
assert_eq!(result.len(), 0);
assert_eq!(result.latency_ms, 0);
}
#[test]
fn test_query_result_with_nodes() {
let nodes = vec![1, 2, 3];
let scores = vec![0.9, 0.8, 0.7];
let result = QueryResult::with_nodes(nodes.clone(), scores.clone());
assert_eq!(result.nodes, nodes);
assert_eq!(result.scores, scores);
assert_eq!(result.len(), 3);
assert!(!result.is_empty());
}
#[test]
fn test_query_result_builder_pattern() {
let embeddings = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
let attention = vec![vec![0.5, 0.5], vec![0.6, 0.4]];
let subgraph = SubGraph::with_edges(vec![1, 2], vec![(1, 2, 0.8)]);
let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8])
.with_embeddings(embeddings.clone())
.with_attention(attention.clone())
.with_subgraph(subgraph.clone())
.with_latency(100);
assert_eq!(result.embeddings, Some(embeddings));
assert_eq!(result.attention_weights, Some(attention));
assert_eq!(result.subgraph, Some(subgraph));
assert_eq!(result.latency_ms, 100);
}
#[test]
fn test_query_result_top_k() {
let nodes = vec![1, 2, 3, 4, 5];
let scores = vec![0.9, 0.8, 0.7, 0.6, 0.5];
let result = QueryResult::with_nodes(nodes, scores);
let top_3 = result.top_k(3);
assert_eq!(top_3.len(), 3);
assert_eq!(top_3.nodes, vec![1, 2, 3]);
assert_eq!(top_3.scores, vec![0.9, 0.8, 0.7]);
}
#[test]
fn test_query_result_top_k_overflow() {
let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8]);
let top_10 = result.top_k(10);
assert_eq!(top_10.len(), 2); // Should only return available results
}
#[test]
fn test_query_result_best() {
let result = QueryResult::with_nodes(vec![1, 2, 3], vec![0.9, 0.8, 0.7]);
let best = result.best();
assert_eq!(best, Some((1, 0.9)));
}
#[test]
fn test_query_result_best_empty() {
let result = QueryResult::new();
assert_eq!(result.best(), None);
}
#[test]
fn test_query_result_filter_by_score() {
let nodes = vec![1, 2, 3, 4, 5];
let scores = vec![0.9, 0.8, 0.7, 0.6, 0.5];
let result = QueryResult::with_nodes(nodes, scores);
let filtered = result.filter_by_score(0.7);
assert_eq!(filtered.len(), 3);
assert_eq!(filtered.nodes, vec![1, 2, 3]);
assert_eq!(filtered.scores, vec![0.9, 0.8, 0.7]);
}
#[test]
fn test_query_result_filter_with_embeddings() {
let nodes = vec![1, 2, 3];
let scores = vec![0.9, 0.6, 0.8];
let embeddings = vec![vec![0.1], vec![0.2], vec![0.3]];
let result = QueryResult::with_nodes(nodes, scores).with_embeddings(embeddings);
let filtered = result.filter_by_score(0.7);
assert_eq!(filtered.len(), 2);
assert_eq!(filtered.nodes, vec![1, 3]);
assert_eq!(filtered.embeddings, Some(vec![vec![0.1], vec![0.3]]));
}
#[test]
fn test_query_result_filter_with_attention() {
let nodes = vec![1, 2, 3];
let scores = vec![0.9, 0.5, 0.8];
let attention = vec![vec![0.5, 0.5], vec![0.6, 0.4], vec![0.7, 0.3]];
let result = QueryResult::with_nodes(nodes, scores).with_attention(attention);
let filtered = result.filter_by_score(0.75);
assert_eq!(filtered.len(), 2);
assert_eq!(filtered.nodes, vec![1, 3]);
assert_eq!(
filtered.attention_weights,
Some(vec![vec![0.5, 0.5], vec![0.7, 0.3]])
);
}
#[test]
fn test_query_serialization() {
let query = RuvectorQuery::neural_search(vec![0.1, 0.2], 5, 2);
let json = serde_json::to_string(&query).unwrap();
let deserialized: RuvectorQuery = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.k, query.k);
assert_eq!(deserialized.gnn_depth, query.gnn_depth);
assert_eq!(deserialized.mode, query.mode);
}
#[test]
fn test_result_serialization() {
let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8]).with_latency(50);
let json = serde_json::to_string(&result).unwrap();
let deserialized: QueryResult = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.nodes, result.nodes);
assert_eq!(deserialized.scores, result.scores);
assert_eq!(deserialized.latency_ms, result.latency_ms);
}
#[test]
fn test_subgraph_serialization() {
let subgraph = SubGraph::with_edges(vec![1, 2, 3], vec![(1, 2, 0.8), (2, 3, 0.6)]);
let json = serde_json::to_string(&subgraph).unwrap();
let deserialized: SubGraph = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.nodes, subgraph.nodes);
assert_eq!(deserialized.edges, subgraph.edges);
}
#[test]
fn test_edge_case_empty_filter() {
let result = QueryResult::with_nodes(vec![1, 2], vec![0.5, 0.4]);
let filtered = result.filter_by_score(0.9);
assert!(filtered.is_empty());
assert_eq!(filtered.len(), 0);
}
#[test]
fn test_query_mode_variants() {
// Test all query mode variants
assert_eq!(QueryMode::VectorSearch, QueryMode::VectorSearch);
assert_ne!(QueryMode::VectorSearch, QueryMode::NeuralSearch);
assert_ne!(QueryMode::NeuralSearch, QueryMode::SubgraphExtraction);
assert_ne!(
QueryMode::SubgraphExtraction,
QueryMode::DifferentiableSearch
);
}
}

View File

@@ -0,0 +1,502 @@
//! Experience Replay Buffer for GNN Training
//!
//! This module implements an experience replay buffer to mitigate catastrophic forgetting
//! during continual learning. The buffer stores past training samples and supports:
//! - Reservoir sampling for uniform distribution over time
//! - Batch sampling for training
//! - Distribution shift detection
use rand::Rng;
use std::collections::VecDeque;
use std::time::{SystemTime, UNIX_EPOCH};
/// A single entry in the replay buffer
#[derive(Debug, Clone)]
pub struct ReplayEntry {
/// Query vector used for training
pub query: Vec<f32>,
/// IDs of positive nodes for this query
pub positive_ids: Vec<usize>,
/// Timestamp when this entry was added (milliseconds since epoch)
pub timestamp: u64,
}
impl ReplayEntry {
/// Create a new replay entry with current timestamp
pub fn new(query: Vec<f32>, positive_ids: Vec<usize>) -> Self {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
Self {
query,
positive_ids,
timestamp,
}
}
}
/// Statistics for tracking distribution characteristics
#[derive(Debug, Clone)]
pub struct DistributionStats {
/// Running mean of query vectors
pub mean: Vec<f32>,
/// Running variance of query vectors
pub variance: Vec<f32>,
/// Number of samples used to compute statistics
pub count: usize,
}
impl DistributionStats {
/// Create new distribution statistics
pub fn new(dimension: usize) -> Self {
Self {
mean: vec![0.0; dimension],
variance: vec![0.0; dimension],
count: 0,
}
}
/// Update statistics with a new sample using Welford's online algorithm
pub fn update(&mut self, sample: &[f32]) {
if self.mean.is_empty() && !sample.is_empty() {
self.mean = vec![0.0; sample.len()];
self.variance = vec![0.0; sample.len()];
}
if self.mean.len() != sample.len() {
return; // Dimension mismatch, skip update
}
self.count += 1;
let count = self.count as f32;
for i in 0..sample.len() {
let delta = sample[i] - self.mean[i];
self.mean[i] += delta / count;
let delta2 = sample[i] - self.mean[i];
self.variance[i] += delta * delta2;
}
}
/// Compute standard deviation from variance
pub fn std_dev(&self) -> Vec<f32> {
if self.count <= 1 {
return vec![0.0; self.variance.len()];
}
self.variance
.iter()
.map(|&v| (v / (self.count - 1) as f32).sqrt())
.collect()
}
/// Reset statistics
pub fn reset(&mut self) {
let dim = self.mean.len();
self.mean = vec![0.0; dim];
self.variance = vec![0.0; dim];
self.count = 0;
}
}
/// Experience Replay Buffer for storing and sampling past training examples
pub struct ReplayBuffer {
/// Circular buffer of replay entries
queries: VecDeque<ReplayEntry>,
/// Maximum capacity of the buffer
capacity: usize,
/// Total number of samples seen (including evicted ones)
total_seen: usize,
/// Statistics of the overall distribution
distribution_stats: DistributionStats,
}
impl ReplayBuffer {
/// Create a new replay buffer with specified capacity
///
/// # Arguments
/// * `capacity` - Maximum number of entries to store
pub fn new(capacity: usize) -> Self {
Self {
queries: VecDeque::with_capacity(capacity),
capacity,
total_seen: 0,
distribution_stats: DistributionStats::new(0),
}
}
/// Add a new entry to the buffer using reservoir sampling
///
/// Reservoir sampling ensures uniform distribution over all samples seen,
/// even as old samples are evicted due to capacity constraints.
///
/// # Arguments
/// * `query` - Query vector
/// * `positive_ids` - IDs of positive nodes for this query
pub fn add(&mut self, query: &[f32], positive_ids: &[usize]) {
let entry = ReplayEntry::new(query.to_vec(), positive_ids.to_vec());
self.total_seen += 1;
// Update distribution statistics
self.distribution_stats.update(query);
// If buffer is not full, just add the entry
if self.queries.len() < self.capacity {
self.queries.push_back(entry);
return;
}
// Reservoir sampling: replace a random entry with probability capacity/total_seen
let mut rng = rand::thread_rng();
let random_index = rng.gen_range(0..self.total_seen);
if random_index < self.capacity {
self.queries[random_index] = entry;
}
}
/// Sample a batch of entries uniformly at random
///
/// # Arguments
/// * `batch_size` - Number of entries to sample
///
/// # Returns
/// Vector of references to sampled entries (may be smaller than batch_size if buffer is small)
pub fn sample(&self, batch_size: usize) -> Vec<&ReplayEntry> {
if self.queries.is_empty() {
return Vec::new();
}
let actual_batch_size = batch_size.min(self.queries.len());
let mut rng = rand::thread_rng();
let mut indices: Vec<usize> = (0..self.queries.len()).collect();
// Fisher-Yates shuffle for first batch_size elements
for i in 0..actual_batch_size {
let j = rng.gen_range(i..indices.len());
indices.swap(i, j);
}
indices[..actual_batch_size]
.iter()
.map(|&idx| &self.queries[idx])
.collect()
}
/// Detect distribution shift between recent samples and overall distribution
///
/// Uses Kullback-Leibler divergence approximation based on mean and variance changes.
///
/// # Arguments
/// * `recent_window` - Number of most recent samples to compare
///
/// # Returns
/// Shift score (higher values indicate more significant distribution shift)
/// Returns 0.0 if insufficient data
pub fn detect_distribution_shift(&self, recent_window: usize) -> f32 {
if self.queries.len() < recent_window || recent_window == 0 {
return 0.0;
}
// Compute statistics for recent window
let mut recent_stats = DistributionStats::new(self.distribution_stats.mean.len());
let start_idx = self.queries.len().saturating_sub(recent_window);
for entry in self.queries.iter().skip(start_idx) {
recent_stats.update(&entry.query);
}
// Compute shift using normalized mean difference
let overall_mean = &self.distribution_stats.mean;
let recent_mean = &recent_stats.mean;
if overall_mean.is_empty() || recent_mean.is_empty() {
return 0.0;
}
let overall_std = self.distribution_stats.std_dev();
let mut shift_sum = 0.0;
let mut count = 0;
for i in 0..overall_mean.len() {
if overall_std[i] > 1e-8 {
let diff = (recent_mean[i] - overall_mean[i]).abs();
shift_sum += diff / overall_std[i];
count += 1;
}
}
if count > 0 {
shift_sum / count as f32
} else {
0.0
}
}
/// Get the number of entries currently in the buffer
pub fn len(&self) -> usize {
self.queries.len()
}
/// Check if the buffer is empty
pub fn is_empty(&self) -> bool {
self.queries.is_empty()
}
/// Get the total capacity of the buffer
pub fn capacity(&self) -> usize {
self.capacity
}
/// Get the total number of samples seen (including evicted ones)
pub fn total_seen(&self) -> usize {
self.total_seen
}
/// Get a reference to the distribution statistics
pub fn distribution_stats(&self) -> &DistributionStats {
&self.distribution_stats
}
/// Clear all entries from the buffer
pub fn clear(&mut self) {
self.queries.clear();
self.total_seen = 0;
self.distribution_stats.reset();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replay_buffer_basic() {
let mut buffer = ReplayBuffer::new(10);
assert_eq!(buffer.len(), 0);
assert!(buffer.is_empty());
assert_eq!(buffer.capacity(), 10);
buffer.add(&[1.0, 2.0, 3.0], &[0, 1]);
assert_eq!(buffer.len(), 1);
assert!(!buffer.is_empty());
buffer.add(&[4.0, 5.0, 6.0], &[2, 3]);
assert_eq!(buffer.len(), 2);
assert_eq!(buffer.total_seen(), 2);
}
#[test]
fn test_replay_buffer_capacity() {
let mut buffer = ReplayBuffer::new(3);
// Add entries up to capacity
for i in 0..3 {
buffer.add(&[i as f32], &[i]);
}
assert_eq!(buffer.len(), 3);
// Adding more should maintain capacity through reservoir sampling
for i in 3..10 {
buffer.add(&[i as f32], &[i]);
}
assert_eq!(buffer.len(), 3);
assert_eq!(buffer.total_seen(), 10);
}
#[test]
fn test_sample_empty_buffer() {
let buffer = ReplayBuffer::new(10);
let samples = buffer.sample(5);
assert!(samples.is_empty());
}
#[test]
fn test_sample_basic() {
let mut buffer = ReplayBuffer::new(10);
for i in 0..5 {
buffer.add(&[i as f32], &[i]);
}
let samples = buffer.sample(3);
assert_eq!(samples.len(), 3);
// Check that samples are from the buffer
for sample in samples {
assert!(sample.query[0] >= 0.0 && sample.query[0] < 5.0);
}
}
#[test]
fn test_sample_larger_than_buffer() {
let mut buffer = ReplayBuffer::new(10);
buffer.add(&[1.0], &[0]);
buffer.add(&[2.0], &[1]);
let samples = buffer.sample(5);
assert_eq!(samples.len(), 2); // Can only return what's available
}
#[test]
fn test_distribution_stats_update() {
let mut stats = DistributionStats::new(2);
stats.update(&[1.0, 2.0]);
assert_eq!(stats.count, 1);
assert_eq!(stats.mean, vec![1.0, 2.0]);
stats.update(&[3.0, 4.0]);
assert_eq!(stats.count, 2);
assert_eq!(stats.mean, vec![2.0, 3.0]);
stats.update(&[2.0, 3.0]);
assert_eq!(stats.count, 3);
assert_eq!(stats.mean, vec![2.0, 3.0]);
}
#[test]
fn test_distribution_stats_std_dev() {
let mut stats = DistributionStats::new(2);
stats.update(&[1.0, 1.0]);
stats.update(&[3.0, 3.0]);
stats.update(&[5.0, 5.0]);
let std_dev = stats.std_dev();
// Expected std dev for [1, 3, 5] is 2.0
assert!((std_dev[0] - 2.0).abs() < 0.01);
assert!((std_dev[1] - 2.0).abs() < 0.01);
}
#[test]
fn test_detect_distribution_shift_no_shift() {
let mut buffer = ReplayBuffer::new(100);
// Add samples from the same distribution
for _ in 0..50 {
buffer.add(&[1.0, 2.0, 3.0], &[0]);
}
let shift = buffer.detect_distribution_shift(10);
assert!(shift < 0.1); // Should be very low
}
#[test]
fn test_detect_distribution_shift_with_shift() {
let mut buffer = ReplayBuffer::new(100);
// Add samples from one distribution
for _ in 0..40 {
buffer.add(&[1.0, 2.0, 3.0], &[0]);
}
// Add samples from a different distribution
for _ in 0..10 {
buffer.add(&[5.0, 6.0, 7.0], &[1]);
}
let shift = buffer.detect_distribution_shift(10);
assert!(shift > 0.5); // Should detect significant shift
}
#[test]
fn test_detect_distribution_shift_insufficient_data() {
let mut buffer = ReplayBuffer::new(100);
buffer.add(&[1.0, 2.0], &[0]);
let shift = buffer.detect_distribution_shift(10);
assert_eq!(shift, 0.0); // Not enough data
}
#[test]
fn test_clear() {
let mut buffer = ReplayBuffer::new(10);
for i in 0..5 {
buffer.add(&[i as f32], &[i]);
}
assert_eq!(buffer.len(), 5);
assert_eq!(buffer.total_seen(), 5);
buffer.clear();
assert_eq!(buffer.len(), 0);
assert_eq!(buffer.total_seen(), 0);
assert!(buffer.is_empty());
assert_eq!(buffer.distribution_stats().count, 0);
}
#[test]
fn test_replay_entry_creation() {
let entry = ReplayEntry::new(vec![1.0, 2.0, 3.0], vec![0, 1, 2]);
assert_eq!(entry.query, vec![1.0, 2.0, 3.0]);
assert_eq!(entry.positive_ids, vec![0, 1, 2]);
assert!(entry.timestamp > 0);
}
#[test]
fn test_reservoir_sampling_distribution() {
let mut buffer = ReplayBuffer::new(10);
// Add 100 entries (much more than capacity)
for i in 0..100 {
buffer.add(&[i as f32], &[i]);
}
assert_eq!(buffer.len(), 10);
assert_eq!(buffer.total_seen(), 100);
// Sample multiple times and verify we get different samples
let samples1 = buffer.sample(5);
let samples2 = buffer.sample(5);
assert_eq!(samples1.len(), 5);
assert_eq!(samples2.len(), 5);
// Check that samples come from the full range (not just recent entries)
let sample_batch = buffer.sample(10);
let values: Vec<f32> = sample_batch.iter().map(|e| e.query[0]).collect();
// With reservoir sampling, we should have some diversity in values
let unique_values: std::collections::HashSet<_> =
values.iter().map(|&v| v as i32).collect();
assert!(unique_values.len() > 1);
}
#[test]
fn test_dimension_mismatch_handling() {
let mut buffer = ReplayBuffer::new(10);
buffer.add(&[1.0, 2.0], &[0]);
// This should not panic, just be handled gracefully
// The implementation will initialize stats on first add
assert_eq!(buffer.len(), 1);
assert_eq!(buffer.distribution_stats().mean.len(), 2);
}
#[test]
fn test_sample_uniqueness() {
let mut buffer = ReplayBuffer::new(5);
for i in 0..5 {
buffer.add(&[i as f32], &[i]);
}
// Sample all entries
let samples = buffer.sample(5);
let values: Vec<f32> = samples.iter().map(|e| e.query[0]).collect();
// All samples should be unique (no duplicates in a single batch)
let unique_values: std::collections::HashSet<_> =
values.iter().map(|&v| v as i32).collect();
assert_eq!(unique_values.len(), 5);
}
}

View File

@@ -0,0 +1,531 @@
//! Learning rate scheduling for Graph Neural Networks
//!
//! Provides various learning rate scheduling strategies to prevent catastrophic
//! forgetting and optimize training dynamics in continual learning scenarios.
use std::f32::consts::PI;
/// Learning rate scheduling strategies
#[derive(Debug, Clone)]
pub enum SchedulerType {
/// Constant learning rate throughout training
Constant,
/// Step decay: multiply learning rate by gamma every step_size epochs
/// Formula: lr = base_lr * gamma^(epoch / step_size)
StepDecay { step_size: usize, gamma: f32 },
/// Exponential decay: multiply learning rate by gamma each epoch
/// Formula: lr = base_lr * gamma^epoch
Exponential { gamma: f32 },
/// Cosine annealing with warm restarts
/// Formula: lr = eta_min + 0.5 * (base_lr - eta_min) * (1 + cos(pi * (epoch % t_max) / t_max))
CosineAnnealing { t_max: usize, eta_min: f32 },
/// Warmup phase followed by linear decay
/// Linearly increases lr from 0 to base_lr over warmup_steps,
/// then linearly decreases to 0 over remaining steps
WarmupLinear {
warmup_steps: usize,
total_steps: usize,
},
/// Reduce learning rate when a metric plateaus
/// Useful for online learning scenarios
ReduceOnPlateau {
factor: f32,
patience: usize,
min_lr: f32,
},
}
/// Learning rate scheduler for GNN training
///
/// Implements various scheduling strategies to control learning rate
/// during training, helping prevent catastrophic forgetting and
/// improve convergence.
#[derive(Debug, Clone)]
pub struct LearningRateScheduler {
scheduler_type: SchedulerType,
base_lr: f32,
current_lr: f32,
step_count: usize,
best_metric: f32,
patience_counter: usize,
}
impl LearningRateScheduler {
/// Creates a new learning rate scheduler
///
/// # Arguments
/// * `scheduler_type` - The scheduling strategy to use
/// * `base_lr` - The initial/base learning rate
///
/// # Example
/// ```
/// use ruvector_gnn::scheduler::{LearningRateScheduler, SchedulerType};
///
/// let scheduler = LearningRateScheduler::new(
/// SchedulerType::StepDecay { step_size: 10, gamma: 0.9 },
/// 0.001
/// );
/// ```
pub fn new(scheduler_type: SchedulerType, base_lr: f32) -> Self {
Self {
scheduler_type,
base_lr,
current_lr: base_lr,
step_count: 0,
best_metric: f32::INFINITY,
patience_counter: 0,
}
}
/// Advances the scheduler by one step and returns the new learning rate
///
/// For most schedulers, this should be called once per epoch.
/// For ReduceOnPlateau, use `step_with_metric` instead.
///
/// # Returns
/// The updated learning rate
pub fn step(&mut self) -> f32 {
self.step_count += 1;
self.current_lr = self.calculate_lr();
self.current_lr
}
/// Advances the scheduler with a metric value (for ReduceOnPlateau)
///
/// # Arguments
/// * `metric` - The metric value to monitor (e.g., validation loss)
///
/// # Returns
/// The updated learning rate
pub fn step_with_metric(&mut self, metric: f32) -> f32 {
self.step_count += 1;
match &self.scheduler_type {
SchedulerType::ReduceOnPlateau {
factor,
patience,
min_lr,
} => {
// Check if metric improved
if metric < self.best_metric - 1e-8 {
self.best_metric = metric;
self.patience_counter = 0;
} else {
self.patience_counter += 1;
// Reduce learning rate if patience exceeded
if self.patience_counter >= *patience {
self.current_lr = (self.current_lr * factor).max(*min_lr);
self.patience_counter = 0;
}
}
}
_ => {
// For non-plateau schedulers, just use step()
self.current_lr = self.calculate_lr();
}
}
self.current_lr
}
/// Gets the current learning rate without advancing the scheduler
pub fn get_lr(&self) -> f32 {
self.current_lr
}
/// Resets the scheduler to its initial state
pub fn reset(&mut self) {
self.current_lr = self.base_lr;
self.step_count = 0;
self.best_metric = f32::INFINITY;
self.patience_counter = 0;
}
/// Calculates the learning rate based on the current step and scheduler type
fn calculate_lr(&self) -> f32 {
match &self.scheduler_type {
SchedulerType::Constant => self.base_lr,
SchedulerType::StepDecay { step_size, gamma } => {
let decay_factor = (*gamma).powi((self.step_count / step_size) as i32);
self.base_lr * decay_factor
}
SchedulerType::Exponential { gamma } => {
let decay_factor = (*gamma).powi(self.step_count as i32);
self.base_lr * decay_factor
}
SchedulerType::CosineAnnealing { t_max, eta_min } => {
let cycle_step = self.step_count % t_max;
let cos_term = (PI * cycle_step as f32 / *t_max as f32).cos();
eta_min + 0.5 * (self.base_lr - eta_min) * (1.0 + cos_term)
}
SchedulerType::WarmupLinear {
warmup_steps,
total_steps,
} => {
if self.step_count < *warmup_steps {
// Warmup phase: linear increase
self.base_lr * (self.step_count as f32 / *warmup_steps as f32)
} else if self.step_count < *total_steps {
// Decay phase: linear decrease
let remaining_steps = *total_steps - self.step_count;
let total_decay_steps = *total_steps - *warmup_steps;
self.base_lr * (remaining_steps as f32 / total_decay_steps as f32)
} else {
// After total_steps, keep at 0
0.0
}
}
SchedulerType::ReduceOnPlateau { .. } => {
// For plateau scheduler, lr is updated in step_with_metric
self.current_lr
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-6;
fn assert_close(a: f32, b: f32, msg: &str) {
assert!((a - b).abs() < EPSILON, "{}: {} != {}", msg, a, b);
}
#[test]
fn test_constant_scheduler() {
let mut scheduler = LearningRateScheduler::new(SchedulerType::Constant, 0.01);
assert_close(scheduler.get_lr(), 0.01, "Initial LR");
for i in 1..=10 {
let lr = scheduler.step();
assert_close(lr, 0.01, &format!("Step {} LR", i));
}
}
#[test]
fn test_step_decay() {
let mut scheduler = LearningRateScheduler::new(
SchedulerType::StepDecay {
step_size: 5,
gamma: 0.5,
},
0.1,
);
assert_close(scheduler.get_lr(), 0.1, "Initial LR");
// Steps 1-4: no decay
for i in 1..=4 {
let lr = scheduler.step();
assert_close(lr, 0.1, &format!("Step {} LR", i));
}
// Step 5: first decay (0.1 * 0.5)
let lr = scheduler.step();
assert_close(lr, 0.05, "Step 5 LR (first decay)");
// Steps 6-9: maintain decayed rate
for i in 6..=9 {
let lr = scheduler.step();
assert_close(lr, 0.05, &format!("Step {} LR", i));
}
// Step 10: second decay (0.1 * 0.5^2)
let lr = scheduler.step();
assert_close(lr, 0.025, "Step 10 LR (second decay)");
}
#[test]
fn test_exponential_decay() {
let mut scheduler =
LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.9 }, 0.1);
assert_close(scheduler.get_lr(), 0.1, "Initial LR");
let expected_lrs = vec![
0.1 * 0.9, // Step 1
0.1 * 0.81, // Step 2 (0.9^2)
0.1 * 0.729, // Step 3 (0.9^3)
];
for (i, expected) in expected_lrs.iter().enumerate() {
let lr = scheduler.step();
assert_close(lr, *expected, &format!("Step {} LR", i + 1));
}
}
#[test]
fn test_cosine_annealing() {
let mut scheduler = LearningRateScheduler::new(
SchedulerType::CosineAnnealing {
t_max: 10,
eta_min: 0.0,
},
1.0,
);
assert_close(scheduler.get_lr(), 1.0, "Initial LR");
// Cosine annealing formula: lr = eta_min + 0.5 * (base_lr - eta_min) * (1 + cos(pi * cycle_step / t_max))
// cycle_step = step_count % t_max
// At step 5: cycle_step = 5, cos(pi * 5/10) = cos(pi/2) = 0, lr = 0 + 0.5 * 1 * (1 + 0) = 0.5
// At step 10: cycle_step = 0 (wrapped), cos(0) = 1, lr = 0 + 0.5 * 1 * (1 + 1) = 1.0 (restart)
for _ in 1..=5 {
scheduler.step();
}
assert_close(scheduler.get_lr(), 0.5, "Mid-cycle LR (step 5)");
// At step 9: cycle_step = 9, cos(pi * 9/10) ≈ -0.951, lr ≈ 0.025
for _ in 6..=9 {
scheduler.step();
}
let lr_step9 = scheduler.get_lr();
assert!(
lr_step9 < 0.1,
"Near end of cycle LR (step 9) should be small: {}",
lr_step9
);
// At step 10: warm restart (cycle_step = 0), LR goes back to base
scheduler.step();
assert_close(
scheduler.get_lr(),
1.0,
"Restart at step 10 (cycle_step = 0)",
);
// Continue new cycle
scheduler.step();
assert!(
scheduler.get_lr() < 1.0,
"Step 11 should be less than base LR"
);
}
#[test]
fn test_warmup_linear() {
let mut scheduler = LearningRateScheduler::new(
SchedulerType::WarmupLinear {
warmup_steps: 5,
total_steps: 10,
},
1.0,
);
assert_close(scheduler.get_lr(), 1.0, "Initial LR");
// Warmup phase: linear increase
scheduler.step();
assert_close(scheduler.get_lr(), 0.2, "Step 1 (warmup)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.4, "Step 2 (warmup)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.6, "Step 3 (warmup)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.8, "Step 4 (warmup)");
scheduler.step();
assert_close(scheduler.get_lr(), 1.0, "Step 5 (warmup end)");
// Decay phase: linear decrease
scheduler.step();
assert_close(scheduler.get_lr(), 0.8, "Step 6 (decay)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.6, "Step 7 (decay)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.4, "Step 8 (decay)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.2, "Step 9 (decay)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.0, "Step 10 (decay end)");
// After total_steps
scheduler.step();
assert_close(scheduler.get_lr(), 0.0, "Step 11 (after total)");
}
#[test]
fn test_reduce_on_plateau() {
let mut scheduler = LearningRateScheduler::new(
SchedulerType::ReduceOnPlateau {
factor: 0.5,
patience: 3,
min_lr: 0.0001,
},
0.01,
);
assert_close(scheduler.get_lr(), 0.01, "Initial LR");
// Improving metrics: no reduction (sets best_metric, resets patience)
scheduler.step_with_metric(1.0);
assert_close(
scheduler.get_lr(),
0.01,
"Step 1 (first metric, sets baseline)",
);
scheduler.step_with_metric(0.9);
assert_close(scheduler.get_lr(), 0.01, "Step 2 (improving)");
// Plateau: metric not improving (patience counter: 1, 2, 3)
scheduler.step_with_metric(0.91);
assert_close(scheduler.get_lr(), 0.01, "Step 3 (plateau 1)");
scheduler.step_with_metric(0.92);
assert_close(scheduler.get_lr(), 0.01, "Step 4 (plateau 2)");
// patience=3 means after 3 non-improvements, reduce LR
// Step 5 is the 3rd non-improvement, so LR gets reduced
scheduler.step_with_metric(0.93);
assert_close(
scheduler.get_lr(),
0.005,
"Step 5 (patience exceeded, reduced)",
);
// Counter is reset after reduction, so we need 3 more non-improvements
scheduler.step_with_metric(0.94); // plateau 1 after reset
assert_close(scheduler.get_lr(), 0.005, "Step 6 (plateau 1 after reset)");
scheduler.step_with_metric(0.95); // plateau 2
assert_close(scheduler.get_lr(), 0.005, "Step 7 (plateau 2)");
scheduler.step_with_metric(0.96); // plateau 3 - triggers reduction
assert_close(scheduler.get_lr(), 0.0025, "Step 8 (reduced again)");
// Test min_lr floor
for _ in 0..20 {
scheduler.step_with_metric(1.0);
}
assert!(
scheduler.get_lr() >= 0.0001,
"LR should not go below min_lr"
);
}
#[test]
fn test_scheduler_reset() {
let mut scheduler =
LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.9 }, 0.1);
// Run for several steps
for _ in 0..5 {
scheduler.step();
}
assert!(scheduler.get_lr() < 0.1, "LR should have decayed");
// Reset and verify
scheduler.reset();
assert_close(scheduler.get_lr(), 0.1, "Reset LR");
assert_eq!(scheduler.step_count, 0, "Reset step count");
}
#[test]
fn test_scheduler_cloning() {
let scheduler1 = LearningRateScheduler::new(
SchedulerType::StepDecay {
step_size: 10,
gamma: 0.5,
},
0.01,
);
let mut scheduler2 = scheduler1.clone();
// Advance clone
scheduler2.step();
// Original should be unchanged
assert_close(scheduler1.get_lr(), 0.01, "Original LR");
assert_close(scheduler2.get_lr(), 0.01, "Clone LR after step");
}
#[test]
fn test_multiple_scheduler_types() {
let schedulers = vec![
(SchedulerType::Constant, 0.01),
(
SchedulerType::StepDecay {
step_size: 5,
gamma: 0.9,
},
0.01,
),
(SchedulerType::Exponential { gamma: 0.95 }, 0.01),
(
SchedulerType::CosineAnnealing {
t_max: 10,
eta_min: 0.001,
},
0.01,
),
(
SchedulerType::WarmupLinear {
warmup_steps: 5,
total_steps: 20,
},
0.01,
),
(
SchedulerType::ReduceOnPlateau {
factor: 0.5,
patience: 5,
min_lr: 0.0001,
},
0.01,
),
];
for (sched_type, base_lr) in schedulers {
let mut scheduler = LearningRateScheduler::new(sched_type, base_lr);
// All schedulers should start at base_lr
assert_close(scheduler.get_lr(), base_lr, "Initial LR for scheduler type");
// All schedulers should be able to step
let _ = scheduler.step();
assert!(scheduler.get_lr() >= 0.0, "LR should be non-negative");
}
}
#[test]
fn test_edge_cases() {
// Zero learning rate
let mut scheduler = LearningRateScheduler::new(SchedulerType::Constant, 0.0);
assert_close(scheduler.get_lr(), 0.0, "Zero LR");
scheduler.step();
assert_close(scheduler.get_lr(), 0.0, "Zero LR after step");
// Very small gamma
let mut scheduler =
LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.1 }, 1.0);
for _ in 0..10 {
scheduler.step();
}
assert!(scheduler.get_lr() > 0.0, "LR should remain positive");
assert!(scheduler.get_lr() < 1e-8, "LR should be very small");
}
}

View File

@@ -0,0 +1,247 @@
use crate::layer::RuvectorLayer;
/// Compute cosine similarity between two vectors with improved precision
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
// Use f64 accumulator for better precision in norm computation
let norm_a: f32 = (a
.iter()
.map(|&x| (x as f64) * (x as f64))
.sum::<f64>()
.sqrt()) as f32;
let norm_b: f32 = (b
.iter()
.map(|&x| (x as f64) * (x as f64))
.sum::<f64>()
.sqrt()) as f32;
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot_product / (norm_a * norm_b)
}
}
/// Apply softmax with temperature scaling
fn softmax(values: &[f32], temperature: f32) -> Vec<f32> {
if values.is_empty() {
return Vec::new();
}
// Scale by temperature and subtract max for numerical stability
let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_values: Vec<f32> = values
.iter()
.map(|&x| ((x - max_val) / temperature).exp())
.collect();
let sum: f32 = exp_values.iter().sum::<f32>().max(1e-10);
exp_values.iter().map(|&x| x / sum).collect()
}
/// Differentiable search using soft attention mechanism
///
/// # Arguments
/// * `query` - The query vector
/// * `candidate_embeddings` - List of candidate embedding vectors
/// * `k` - Number of top results to return
/// * `temperature` - Temperature for softmax (lower = sharper, higher = smoother)
///
/// # Returns
/// * Tuple of (indices, soft_weights) for top-k candidates
pub fn differentiable_search(
query: &[f32],
candidate_embeddings: &[Vec<f32>],
k: usize,
temperature: f32,
) -> (Vec<usize>, Vec<f32>) {
if candidate_embeddings.is_empty() {
return (Vec::new(), Vec::new());
}
let k = k.min(candidate_embeddings.len());
// 1. Compute similarities using cosine similarity
let similarities: Vec<f32> = candidate_embeddings
.iter()
.map(|embedding| cosine_similarity(query, embedding))
.collect();
// 2. Apply softmax with temperature to get soft weights
let soft_weights = softmax(&similarities, temperature);
// 3. Get top-k indices by sorting similarities
let mut indexed_weights: Vec<(usize, f32)> = soft_weights
.iter()
.enumerate()
.map(|(i, &w)| (i, w))
.collect();
// Sort by weight descending
indexed_weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
// Take top-k
let top_k: Vec<(usize, f32)> = indexed_weights.into_iter().take(k).collect();
let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
let weights: Vec<f32> = top_k.iter().map(|&(_, w)| w).collect();
(indices, weights)
}
/// Hierarchical forward pass through GNN layers
///
/// # Arguments
/// * `query` - The query vector
/// * `layer_embeddings` - Embeddings organized by layer (outer vec = layers, inner vec = nodes per layer)
/// * `gnn_layers` - The GNN layers to process through
///
/// # Returns
/// * Final embedding after hierarchical processing
pub fn hierarchical_forward(
query: &[f32],
layer_embeddings: &[Vec<Vec<f32>>],
gnn_layers: &[RuvectorLayer],
) -> Vec<f32> {
if layer_embeddings.is_empty() || gnn_layers.is_empty() {
return query.to_vec();
}
let mut current_embedding = query.to_vec();
// Process through each layer from top to bottom
for (layer_idx, (embeddings, gnn_layer)) in
layer_embeddings.iter().zip(gnn_layers.iter()).enumerate()
{
if embeddings.is_empty() {
continue;
}
// Find most relevant nodes at this layer using differentiable search
let (top_indices, weights) = differentiable_search(
&current_embedding,
embeddings,
5.min(embeddings.len()), // Top-5 or all if less
1.0, // Default temperature
);
// Aggregate embeddings from top nodes using soft weights
let mut aggregated = vec![0.0; current_embedding.len()];
for (&idx, &weight) in top_indices.iter().zip(weights.iter()) {
for (i, &val) in embeddings[idx].iter().enumerate() {
if i < aggregated.len() {
aggregated[i] += weight * val;
}
}
}
// Combine with current embedding
let combined: Vec<f32> = current_embedding
.iter()
.zip(&aggregated)
.map(|(curr, agg)| (curr + agg) / 2.0)
.collect();
// Apply GNN layer transformation
// Extract neighbor embeddings and compute edge weights
let neighbor_embs: Vec<Vec<f32>> = top_indices
.iter()
.map(|&idx| embeddings[idx].clone())
.collect();
let edge_weights_vec: Vec<f32> = weights.clone();
current_embedding = gnn_layer.forward(&combined, &neighbor_embs, &edge_weights_vec);
}
current_embedding
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![1.0, 0.0, 0.0];
let d = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&c, &d) - 0.0).abs() < 1e-6);
}
#[test]
fn test_softmax() {
let values = vec![1.0, 2.0, 3.0];
let result = softmax(&values, 1.0);
// Sum should be 1.0
let sum: f32 = result.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
// Higher values should have higher probabilities
assert!(result[2] > result[1]);
assert!(result[1] > result[0]);
}
#[test]
fn test_softmax_with_temperature() {
let values = vec![1.0, 2.0, 3.0];
// Lower temperature = sharper distribution
let sharp = softmax(&values, 0.1);
let smooth = softmax(&values, 10.0);
// Sharp should have more weight on max
assert!(sharp[2] > smooth[2]);
}
#[test]
fn test_differentiable_search() {
let query = vec![1.0, 0.0, 0.0];
let candidates = vec![
vec![1.0, 0.0, 0.0], // Perfect match
vec![0.9, 0.1, 0.0], // Close match
vec![0.0, 1.0, 0.0], // Orthogonal
];
let (indices, weights) = differentiable_search(&query, &candidates, 2, 1.0);
assert_eq!(indices.len(), 2);
assert_eq!(weights.len(), 2);
// First result should be the perfect match
assert_eq!(indices[0], 0);
// Weights should sum to less than or equal to 1.0 (since we took top-k)
let sum: f32 = weights.iter().sum();
assert!(sum <= 1.0 + 1e-6);
}
#[test]
fn test_hierarchical_forward() {
// Use consistent dimensions throughout
let query = vec![1.0, 0.0];
// Layer embeddings should match the output dimensions of each layer
let layer_embeddings = vec![
// First layer: embeddings are 2-dimensional (match query)
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
];
// Single GNN layer that maintains dimension
let gnn_layers = vec![
RuvectorLayer::new(2, 2, 1, 0.0).unwrap(), // input_dim, hidden_dim, heads, dropout
];
let result = hierarchical_forward(&query, &layer_embeddings, &gnn_layers);
assert_eq!(result.len(), 2); // Should match hidden_dim of last layer
}
}

View File

@@ -0,0 +1,789 @@
//! Tensor operations for GNN computations.
//!
//! Provides efficient tensor operations including:
//! - Matrix multiplication
//! - Element-wise operations
//! - Activation functions
//! - Weight initialization
//! - Normalization
use crate::error::{GnnError, Result};
use rand::Rng;
use rand_distr::{Distribution, Normal, Uniform};
/// Basic tensor operations for GNN computations
#[derive(Debug, Clone, PartialEq)]
pub struct Tensor {
/// Flattened tensor data
pub data: Vec<f32>,
/// Shape of the tensor (dimensions)
pub shape: Vec<usize>,
}
impl Tensor {
/// Create a new tensor from data and shape
///
/// # Arguments
/// * `data` - Flattened tensor data
/// * `shape` - Dimensions of the tensor
///
/// # Returns
/// A new `Tensor` instance
///
/// # Errors
/// Returns `GnnError::InvalidShape` if data length doesn't match shape
pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Result<Self> {
let expected_len: usize = shape.iter().product();
if data.len() != expected_len {
return Err(GnnError::invalid_shape(format!(
"Data length {} doesn't match shape {:?} (expected {})",
data.len(),
shape,
expected_len
)));
}
Ok(Self { data, shape })
}
/// Create a zero-filled tensor with the given shape
///
/// # Arguments
/// * `shape` - Dimensions of the tensor
///
/// # Returns
/// A new zero-filled `Tensor`
///
/// # Errors
/// Returns `GnnError::InvalidShape` if shape is empty or contains zero
pub fn zeros(shape: &[usize]) -> Result<Self> {
if shape.is_empty() || shape.iter().any(|&d| d == 0) {
return Err(GnnError::invalid_shape(format!(
"Invalid shape: {:?}",
shape
)));
}
let size: usize = shape.iter().product();
Ok(Self {
data: vec![0.0; size],
shape: shape.to_vec(),
})
}
/// Create a 1D tensor from a vector
///
/// # Arguments
/// * `data` - Vector data
///
/// # Returns
/// A new 1D `Tensor`
pub fn from_vec(data: Vec<f32>) -> Self {
let len = data.len();
Self {
data,
shape: vec![len],
}
}
/// Compute dot product with another tensor (both must be 1D)
///
/// # Arguments
/// * `other` - Another tensor to compute dot product with
///
/// # Returns
/// The dot product as a scalar
///
/// # Errors
/// Returns `GnnError::DimensionMismatch` if tensors are not 1D or have different lengths
pub fn dot(&self, other: &Tensor) -> Result<f32> {
if self.shape.len() != 1 || other.shape.len() != 1 {
return Err(GnnError::dimension_mismatch(
"1D tensors",
format!("{}D and {}D", self.shape.len(), other.shape.len()),
));
}
if self.shape[0] != other.shape[0] {
return Err(GnnError::dimension_mismatch(
format!("length {}", self.shape[0]),
format!("length {}", other.shape[0]),
));
}
let result = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a * b)
.sum();
Ok(result)
}
/// Matrix multiplication
///
/// # Arguments
/// * `other` - Another tensor to multiply with
///
/// # Returns
/// The result of matrix multiplication
///
/// # Errors
/// Returns `GnnError::DimensionMismatch` if dimensions are incompatible
pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
// Support 1D x 1D (dot product), 2D x 1D, 2D x 2D
match (self.shape.len(), other.shape.len()) {
(1, 1) => {
let dot = self.dot(other)?;
Ok(Tensor::from_vec(vec![dot]))
}
(2, 1) => {
// Matrix-vector multiplication
let m = self.shape[0];
let n = self.shape[1];
if n != other.shape[0] {
return Err(GnnError::dimension_mismatch(
format!("{}x{}", m, n),
format!("vector of length {}", other.shape[0]),
));
}
let mut result = vec![0.0; m];
for i in 0..m {
for j in 0..n {
result[i] += self.data[i * n + j] * other.data[j];
}
}
Ok(Tensor::from_vec(result))
}
(2, 2) => {
// Matrix-matrix multiplication
let m = self.shape[0];
let n = self.shape[1];
let p = other.shape[1];
if n != other.shape[0] {
return Err(GnnError::dimension_mismatch(
format!("{}x{}", m, n),
format!("{}x{}", other.shape[0], p),
));
}
let mut result = vec![0.0; m * p];
for i in 0..m {
for j in 0..p {
for k in 0..n {
result[i * p + j] += self.data[i * n + k] * other.data[k * p + j];
}
}
}
Tensor::new(result, vec![m, p])
}
_ => Err(GnnError::dimension_mismatch(
"1D or 2D tensors",
format!("{}D and {}D", self.shape.len(), other.shape.len()),
)),
}
}
/// Element-wise addition
///
/// # Arguments
/// * `other` - Another tensor to add
///
/// # Returns
/// The sum of the two tensors
///
/// # Errors
/// Returns `GnnError::DimensionMismatch` if shapes don't match
pub fn add(&self, other: &Tensor) -> Result<Tensor> {
if self.shape != other.shape {
return Err(GnnError::dimension_mismatch(
format!("{:?}", self.shape),
format!("{:?}", other.shape),
));
}
let result: Vec<f32> = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a + b)
.collect();
Tensor::new(result, self.shape.clone())
}
/// Scalar multiplication
///
/// # Arguments
/// * `scalar` - Scalar value to multiply by
///
/// # Returns
/// A new tensor with all elements scaled
pub fn scale(&self, scalar: f32) -> Tensor {
let result: Vec<f32> = self.data.iter().map(|&x| x * scalar).collect();
Tensor {
data: result,
shape: self.shape.clone(),
}
}
/// ReLU activation function (max(0, x))
///
/// # Returns
/// A new tensor with ReLU applied element-wise
pub fn relu(&self) -> Tensor {
let result: Vec<f32> = self.data.iter().map(|&x| x.max(0.0)).collect();
Tensor {
data: result,
shape: self.shape.clone(),
}
}
/// Sigmoid activation function (1 / (1 + e^(-x))) with numerical stability
///
/// # Returns
/// A new tensor with sigmoid applied element-wise
pub fn sigmoid(&self) -> Tensor {
let result: Vec<f32> = self
.data
.iter()
.map(|&x| {
if x > 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let ex = x.exp();
ex / (1.0 + ex)
}
})
.collect();
Tensor {
data: result,
shape: self.shape.clone(),
}
}
/// Tanh activation function
///
/// # Returns
/// A new tensor with tanh applied element-wise
pub fn tanh(&self) -> Tensor {
let result: Vec<f32> = self.data.iter().map(|&x| x.tanh()).collect();
Tensor {
data: result,
shape: self.shape.clone(),
}
}
/// Compute L2 norm (Euclidean norm) with improved precision
///
/// # Returns
/// The L2 norm of the tensor
pub fn l2_norm(&self) -> f32 {
// Use f64 accumulator for better numerical precision
let sum_squares: f64 = self.data.iter().map(|&x| (x as f64) * (x as f64)).sum();
(sum_squares.sqrt()) as f32
}
/// Normalize the tensor to unit L2 norm
///
/// # Returns
/// A normalized tensor
///
/// # Errors
/// Returns `GnnError::InvalidInput` if norm is zero
pub fn normalize(&self) -> Result<Tensor> {
let norm = self.l2_norm();
if norm == 0.0 {
return Err(GnnError::invalid_input(
"Cannot normalize zero vector".to_string(),
));
}
Ok(self.scale(1.0 / norm))
}
/// Get a slice view of the tensor data
///
/// # Returns
/// A slice reference to the underlying data
pub fn as_slice(&self) -> &[f32] {
&self.data
}
/// Consume the tensor and return the underlying vector
///
/// # Returns
/// The vector containing the tensor data
pub fn into_vec(self) -> Vec<f32> {
self.data
}
/// Get the number of elements in the tensor
pub fn len(&self) -> usize {
self.data.len()
}
/// Check if the tensor is empty
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
/// Xavier/Glorot initialization for neural network weights
///
/// Samples from uniform distribution U(-a, a) where a = sqrt(6 / (fan_in + fan_out))
///
/// # Arguments
/// * `fan_in` - Number of input units
/// * `fan_out` - Number of output units
///
/// # Returns
/// A vector of initialized weights
///
/// # Panics
/// Panics if fan_in or fan_out is 0
pub fn xavier_init(fan_in: usize, fan_out: usize) -> Vec<f32> {
assert!(
fan_in > 0 && fan_out > 0,
"fan_in and fan_out must be positive"
);
let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
let uniform = Uniform::new(-limit, limit);
let mut rng = rand::thread_rng();
(0..fan_in * fan_out)
.map(|_| uniform.sample(&mut rng))
.collect()
}
/// He initialization for ReLU networks
///
/// Samples from normal distribution N(0, sqrt(2 / fan_in))
///
/// # Arguments
/// * `fan_in` - Number of input units
///
/// # Returns
/// A vector of initialized weights
///
/// # Panics
/// Panics if fan_in is 0
pub fn he_init(fan_in: usize) -> Vec<f32> {
assert!(fan_in > 0, "fan_in must be positive");
let std_dev = (2.0 / fan_in as f32).sqrt();
let normal = Normal::new(0.0, std_dev).expect("Invalid normal distribution parameters");
let mut rng = rand::thread_rng();
(0..fan_in).map(|_| normal.sample(&mut rng)).collect()
}
/// Element-wise (Hadamard) product
///
/// # Arguments
/// * `a` - First vector
/// * `b` - Second vector
///
/// # Returns
/// Element-wise product of the two vectors
///
/// # Panics
/// Panics if vectors have different lengths
pub fn hadamard_product(a: &[f32], b: &[f32]) -> Vec<f32> {
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
}
/// Element-wise vector addition
///
/// # Arguments
/// * `a` - First vector
/// * `b` - Second vector
///
/// # Returns
/// Element-wise sum of the two vectors
///
/// # Panics
/// Panics if vectors have different lengths
pub fn vector_add(a: &[f32], b: &[f32]) -> Vec<f32> {
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
/// Scalar multiplication of a vector
///
/// # Arguments
/// * `v` - Input vector
/// * `scalar` - Scalar multiplier
///
/// # Returns
/// Vector with all elements multiplied by scalar
pub fn vector_scale(v: &[f32], scalar: f32) -> Vec<f32> {
v.iter().map(|&x| x * scalar).collect()
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-6;
fn assert_vec_approx_eq(a: &[f32], b: &[f32], epsilon: f32) {
assert_eq!(a.len(), b.len(), "Vectors have different lengths");
for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
assert!(
(x - y).abs() < epsilon,
"Values at index {} differ: {} vs {} (diff: {})",
i,
x,
y,
(x - y).abs()
);
}
}
#[test]
fn test_tensor_new() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = Tensor::new(data.clone(), vec![2, 2]).unwrap();
assert_eq!(tensor.data, data);
assert_eq!(tensor.shape, vec![2, 2]);
}
#[test]
fn test_tensor_new_invalid_shape() {
let data = vec![1.0, 2.0, 3.0];
let result = Tensor::new(data, vec![2, 2]);
assert!(result.is_err());
}
#[test]
fn test_tensor_zeros() {
let tensor = Tensor::zeros(&[3, 2]).unwrap();
assert_eq!(tensor.data, vec![0.0; 6]);
assert_eq!(tensor.shape, vec![3, 2]);
}
#[test]
fn test_tensor_zeros_invalid_shape() {
let result = Tensor::zeros(&[0, 2]);
assert!(result.is_err());
let result = Tensor::zeros(&[]);
assert!(result.is_err());
}
#[test]
fn test_tensor_from_vec() {
let data = vec![1.0, 2.0, 3.0];
let tensor = Tensor::from_vec(data.clone());
assert_eq!(tensor.data, data);
assert_eq!(tensor.shape, vec![3]);
}
#[test]
fn test_dot_product() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
let result = a.dot(&b).unwrap();
assert!((result - 32.0).abs() < EPSILON); // 1*4 + 2*5 + 3*6 = 32
}
#[test]
fn test_dot_product_dimension_mismatch() {
let a = Tensor::from_vec(vec![1.0, 2.0]);
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
let result = a.dot(&b);
assert!(result.is_err());
}
#[test]
fn test_matmul_1d() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
let result = a.matmul(&b).unwrap();
assert_eq!(result.shape, vec![1]);
assert!((result.data[0] - 32.0).abs() < EPSILON);
}
#[test]
fn test_matmul_2d_1d() {
// Matrix-vector multiplication
let mat = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let vec = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
let result = mat.matmul(&vec).unwrap();
assert_eq!(result.shape, vec![2]);
// [1,2,3] * [1,2,3]' = 14
// [4,5,6] * [1,2,3]' = 32
assert_vec_approx_eq(&result.data, &[14.0, 32.0], EPSILON);
}
#[test]
fn test_matmul_2d_2d() {
// Matrix-matrix multiplication
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
let result = a.matmul(&b).unwrap();
assert_eq!(result.shape, vec![2, 2]);
// [[1,2], [3,4]] * [[5,6], [7,8]] = [[19,22], [43,50]]
assert_vec_approx_eq(&result.data, &[19.0, 22.0, 43.0, 50.0], EPSILON);
}
#[test]
fn test_matmul_dimension_mismatch() {
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
let result = a.matmul(&b);
assert!(result.is_err());
}
#[test]
fn test_add() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
let result = a.add(&b).unwrap();
assert_eq!(result.data, vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_add_dimension_mismatch() {
let a = Tensor::from_vec(vec![1.0, 2.0]);
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
let result = a.add(&b);
assert!(result.is_err());
}
#[test]
fn test_scale() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
let result = tensor.scale(2.0);
assert_eq!(result.data, vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_relu() {
let tensor = Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0]);
let result = tensor.relu();
assert_eq!(result.data, vec![0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_sigmoid() {
let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
let result = tensor.sigmoid();
assert!((result.data[0] - 0.5).abs() < EPSILON);
assert!((result.data[1] - 0.7310586).abs() < EPSILON);
assert!((result.data[2] - 0.26894143).abs() < EPSILON);
}
#[test]
fn test_tanh() {
let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
let result = tensor.tanh();
assert!((result.data[0] - 0.0).abs() < EPSILON);
assert!((result.data[1] - 0.7615942).abs() < EPSILON);
assert!((result.data[2] - (-0.7615942)).abs() < EPSILON);
}
#[test]
fn test_l2_norm() {
let tensor = Tensor::from_vec(vec![3.0, 4.0]);
let norm = tensor.l2_norm();
assert!((norm - 5.0).abs() < EPSILON);
}
#[test]
fn test_normalize() {
let tensor = Tensor::from_vec(vec![3.0, 4.0]);
let result = tensor.normalize().unwrap();
assert_vec_approx_eq(&result.data, &[0.6, 0.8], EPSILON);
assert!((result.l2_norm() - 1.0).abs() < EPSILON);
}
#[test]
fn test_normalize_zero_vector() {
let tensor = Tensor::from_vec(vec![0.0, 0.0]);
let result = tensor.normalize();
assert!(result.is_err());
}
#[test]
fn test_as_slice() {
let data = vec![1.0, 2.0, 3.0];
let tensor = Tensor::from_vec(data.clone());
assert_eq!(tensor.as_slice(), &data[..]);
}
#[test]
fn test_into_vec() {
let data = vec![1.0, 2.0, 3.0];
let tensor = Tensor::from_vec(data.clone());
assert_eq!(tensor.into_vec(), data);
}
#[test]
fn test_len() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
assert_eq!(tensor.len(), 3);
}
#[test]
fn test_is_empty() {
let tensor = Tensor::from_vec(vec![]);
assert!(tensor.is_empty());
let tensor = Tensor::from_vec(vec![1.0]);
assert!(!tensor.is_empty());
}
#[test]
fn test_xavier_init() {
let weights = xavier_init(100, 50);
assert_eq!(weights.len(), 5000);
// Check that values are in expected range
let limit = (6.0 / 150.0_f32).sqrt();
for &w in &weights {
assert!(w >= -limit && w <= limit);
}
// Check distribution properties
let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
assert!(mean.abs() < 0.1); // Mean should be close to 0
}
#[test]
#[should_panic(expected = "fan_in and fan_out must be positive")]
fn test_xavier_init_zero_fan() {
xavier_init(0, 10);
}
#[test]
fn test_he_init() {
let weights = he_init(100);
assert_eq!(weights.len(), 100);
// Check distribution properties
let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
assert!(mean.abs() < 0.2); // Mean should be close to 0
}
#[test]
#[should_panic(expected = "fan_in must be positive")]
fn test_he_init_zero_fan() {
he_init(0);
}
#[test]
fn test_hadamard_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let result = hadamard_product(&a, &b);
assert_eq!(result, vec![4.0, 10.0, 18.0]);
}
#[test]
#[should_panic(expected = "Vectors must have the same length")]
fn test_hadamard_product_length_mismatch() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
hadamard_product(&a, &b);
}
#[test]
fn test_vector_add() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let result = vector_add(&a, &b);
assert_eq!(result, vec![5.0, 7.0, 9.0]);
}
#[test]
#[should_panic(expected = "Vectors must have the same length")]
fn test_vector_add_length_mismatch() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
vector_add(&a, &b);
}
#[test]
fn test_vector_scale() {
let v = vec![1.0, 2.0, 3.0];
let result = vector_scale(&v, 2.5);
assert_vec_approx_eq(&result, &[2.5, 5.0, 7.5], EPSILON);
}
#[test]
fn test_complex_operations() {
// Test chaining operations
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
let b = Tensor::from_vec(vec![0.5, 1.0, 1.5]);
let sum = a.add(&b).unwrap();
let scaled = sum.scale(2.0);
let activated = scaled.relu();
let normalized = activated.normalize().unwrap();
assert!((normalized.l2_norm() - 1.0).abs() < EPSILON);
}
#[test]
fn test_edge_case_single_element() {
let tensor = Tensor::from_vec(vec![5.0]);
assert_eq!(tensor.len(), 1);
assert_eq!(tensor.l2_norm(), 5.0);
let normalized = tensor.normalize().unwrap();
assert_vec_approx_eq(&normalized.data, &[1.0], EPSILON);
}
#[test]
fn test_edge_case_negative_values() {
let tensor = Tensor::from_vec(vec![-3.0, -4.0]);
assert!((tensor.l2_norm() - 5.0).abs() < EPSILON);
let relu_result = tensor.relu();
assert_eq!(relu_result.data, vec![0.0, 0.0]);
}
#[test]
fn test_large_matrix_multiplication() {
// 10x10 matrix multiplication
let size = 10;
let a_data: Vec<f32> = (0..size * size).map(|i| i as f32).collect();
let b_data: Vec<f32> = (0..size * size).map(|i| (i % 2) as f32).collect();
let a = Tensor::new(a_data, vec![size, size]).unwrap();
let b = Tensor::new(b_data, vec![size, size]).unwrap();
let result = a.matmul(&b).unwrap();
assert_eq!(result.shape, vec![size, size]);
assert_eq!(result.len(), size * size);
}
#[test]
fn test_activation_functions_range() {
let tensor = Tensor::from_vec(vec![-10.0, -1.0, 0.0, 1.0, 10.0]);
// Sigmoid should be in (0, 1)
let sigmoid = tensor.sigmoid();
for &val in &sigmoid.data {
assert!(val > 0.0 && val < 1.0);
}
// Tanh should be in [-1, 1]
let tanh = tensor.tanh();
for &val in &tanh.data {
assert!(val >= -1.0 && val <= 1.0);
}
// ReLU should be non-negative
let relu = tensor.relu();
for &val in &relu.data {
assert!(val >= 0.0);
}
}
}

File diff suppressed because it is too large Load Diff