Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
947
vendor/ruvector/crates/ruvector-gnn/src/cold_tier.rs
vendored
Normal file
947
vendor/ruvector/crates/ruvector-gnn/src/cold_tier.rs
vendored
Normal file
@@ -0,0 +1,947 @@
|
||||
//! Cold-tier GNN training via hyperbatch I/O for graphs exceeding RAM.
|
||||
//!
|
||||
//! Implements AGNES-style block-aligned I/O with hotset caching
|
||||
//! for training on large-scale graphs that don't fit in memory.
|
||||
|
||||
#![cfg(all(feature = "cold-tier", not(target_arch = "wasm32")))]
|
||||
|
||||
use crate::error::{GnnError, Result};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::{Read, Seek, SeekFrom, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Size of an f32 in bytes.
|
||||
const F32_SIZE: usize = std::mem::size_of::<f32>();
|
||||
|
||||
/// Header size in bytes: dim (u64) + num_nodes (u64) + block_size (u64).
|
||||
const HEADER_SIZE: u64 = 24;
|
||||
|
||||
/// Return the system page size, falling back to 4096.
|
||||
fn system_page_size() -> usize {
|
||||
page_size::get()
|
||||
}
|
||||
|
||||
/// Align `value` up to the nearest multiple of `alignment`.
|
||||
fn align_up(value: usize, alignment: usize) -> usize {
|
||||
(value + alignment - 1) / alignment * alignment
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// FeatureStorage
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Block-aligned feature file for storing node feature vectors on disk.
|
||||
pub struct FeatureStorage {
|
||||
path: PathBuf,
|
||||
dim: usize,
|
||||
num_nodes: usize,
|
||||
block_size: usize,
|
||||
file: Option<File>,
|
||||
}
|
||||
|
||||
impl FeatureStorage {
|
||||
/// Create a new feature file at `path` for `num_nodes` with dimension `dim`.
|
||||
pub fn create(path: &Path, dim: usize, num_nodes: usize) -> Result<Self> {
|
||||
if dim == 0 {
|
||||
return Err(GnnError::invalid_input("dim must be > 0"));
|
||||
}
|
||||
let block_size = align_up(dim * F32_SIZE, system_page_size());
|
||||
let data_size = num_nodes as u64 * block_size as u64;
|
||||
|
||||
let mut file = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.open(path)
|
||||
.map_err(|e| GnnError::Io(e))?;
|
||||
|
||||
// Write header
|
||||
file.write_all(&(dim as u64).to_le_bytes())?;
|
||||
file.write_all(&(num_nodes as u64).to_le_bytes())?;
|
||||
file.write_all(&(block_size as u64).to_le_bytes())?;
|
||||
|
||||
// Extend file to full size
|
||||
file.set_len(HEADER_SIZE + data_size)?;
|
||||
|
||||
Ok(Self {
|
||||
path: path.to_path_buf(),
|
||||
dim,
|
||||
num_nodes,
|
||||
block_size,
|
||||
file: Some(file),
|
||||
})
|
||||
}
|
||||
|
||||
/// Open an existing feature file.
|
||||
pub fn open(path: &Path) -> Result<Self> {
|
||||
let mut file = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.open(path)
|
||||
.map_err(|e| GnnError::Io(e))?;
|
||||
|
||||
let mut buf = [0u8; 8];
|
||||
file.read_exact(&mut buf)?;
|
||||
let dim = u64::from_le_bytes(buf) as usize;
|
||||
file.read_exact(&mut buf)?;
|
||||
let num_nodes = u64::from_le_bytes(buf) as usize;
|
||||
file.read_exact(&mut buf)?;
|
||||
let block_size = u64::from_le_bytes(buf) as usize;
|
||||
|
||||
Ok(Self {
|
||||
path: path.to_path_buf(),
|
||||
dim,
|
||||
num_nodes,
|
||||
block_size,
|
||||
file: Some(file),
|
||||
})
|
||||
}
|
||||
|
||||
/// Write feature vector for a single node.
|
||||
pub fn write_features(&mut self, node_id: usize, features: &[f32]) -> Result<()> {
|
||||
if node_id >= self.num_nodes {
|
||||
return Err(GnnError::invalid_input(format!(
|
||||
"node_id {} out of bounds (num_nodes={})",
|
||||
node_id, self.num_nodes
|
||||
)));
|
||||
}
|
||||
if features.len() != self.dim {
|
||||
return Err(GnnError::dimension_mismatch(
|
||||
self.dim.to_string(),
|
||||
features.len().to_string(),
|
||||
));
|
||||
}
|
||||
let file = self
|
||||
.file
|
||||
.as_mut()
|
||||
.ok_or_else(|| GnnError::other("file not open"))?;
|
||||
let offset = HEADER_SIZE + (node_id as u64) * (self.block_size as u64);
|
||||
file.seek(SeekFrom::Start(offset))?;
|
||||
let bytes: &[u8] = unsafe {
|
||||
std::slice::from_raw_parts(features.as_ptr() as *const u8, features.len() * F32_SIZE)
|
||||
};
|
||||
file.write_all(bytes)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read feature vector for a single node.
|
||||
pub fn read_features(&mut self, node_id: usize) -> Result<Vec<f32>> {
|
||||
if node_id >= self.num_nodes {
|
||||
return Err(GnnError::invalid_input(format!(
|
||||
"node_id {} out of bounds (num_nodes={})",
|
||||
node_id, self.num_nodes
|
||||
)));
|
||||
}
|
||||
let file = self
|
||||
.file
|
||||
.as_mut()
|
||||
.ok_or_else(|| GnnError::other("file not open"))?;
|
||||
let offset = HEADER_SIZE + (node_id as u64) * (self.block_size as u64);
|
||||
file.seek(SeekFrom::Start(offset))?;
|
||||
let mut buf = vec![0u8; self.dim * F32_SIZE];
|
||||
file.read_exact(&mut buf)?;
|
||||
let features: Vec<f32> = buf
|
||||
.chunks_exact(F32_SIZE)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect();
|
||||
Ok(features)
|
||||
}
|
||||
|
||||
/// Batch-read features for multiple nodes with block-aligned I/O.
|
||||
pub fn read_batch(&mut self, node_ids: &[usize]) -> Result<Vec<Vec<f32>>> {
|
||||
let mut results = Vec::with_capacity(node_ids.len());
|
||||
// Sort node_ids to improve sequential I/O locality
|
||||
let mut sorted: Vec<usize> = node_ids.to_vec();
|
||||
sorted.sort_unstable();
|
||||
// Read in sorted order, then reorder to match input
|
||||
let mut map: HashMap<usize, Vec<f32>> = HashMap::with_capacity(sorted.len());
|
||||
for &nid in &sorted {
|
||||
if !map.contains_key(&nid) {
|
||||
map.insert(nid, self.read_features(nid)?);
|
||||
}
|
||||
}
|
||||
for &nid in node_ids {
|
||||
results.push(map[&nid].clone());
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Flush pending writes to disk.
|
||||
pub fn flush(&mut self) -> Result<()> {
|
||||
if let Some(ref mut f) = self.file {
|
||||
f.flush()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Dimension of each feature vector.
|
||||
pub fn dim(&self) -> usize {
|
||||
self.dim
|
||||
}
|
||||
|
||||
/// Number of nodes in the storage.
|
||||
pub fn num_nodes(&self) -> usize {
|
||||
self.num_nodes
|
||||
}
|
||||
|
||||
/// Path to the underlying file.
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HyperbatchConfig / HyperbatchResult
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Configuration for hyperbatch I/O.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HyperbatchConfig {
|
||||
/// Nodes per hyperbatch (default: 4096).
|
||||
pub batch_size: usize,
|
||||
/// Prefetch multiplier (default: 2).
|
||||
pub prefetch_factor: usize,
|
||||
/// I/O block alignment in bytes (default: 4096).
|
||||
pub block_align: usize,
|
||||
/// Double-buffering count (default: 2).
|
||||
pub num_buffers: usize,
|
||||
/// Fraction of nodes kept in the hotset (default: 0.05).
|
||||
pub hotset_fraction: f64,
|
||||
}
|
||||
|
||||
impl Default for HyperbatchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
batch_size: 4096,
|
||||
prefetch_factor: 2,
|
||||
block_align: 4096,
|
||||
num_buffers: 2,
|
||||
hotset_fraction: 0.05,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result from a single hyperbatch iteration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HyperbatchResult {
|
||||
/// Node identifiers in this batch.
|
||||
pub node_ids: Vec<usize>,
|
||||
/// Feature vectors for each node.
|
||||
pub features: Vec<Vec<f32>>,
|
||||
/// Zero-based index of this batch within the epoch.
|
||||
pub batch_index: usize,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HyperbatchIterator
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Yields batches from disk following BFS vertex ordering for I/O locality.
|
||||
pub struct HyperbatchIterator {
|
||||
storage: FeatureStorage,
|
||||
config: HyperbatchConfig,
|
||||
node_order: Vec<usize>,
|
||||
current_offset: usize,
|
||||
buffers: Vec<Vec<Vec<f32>>>,
|
||||
active_buffer: usize,
|
||||
batch_counter: usize,
|
||||
}
|
||||
|
||||
impl HyperbatchIterator {
|
||||
/// Create a new iterator with BFS-ordered node traversal.
|
||||
pub fn new(
|
||||
storage: FeatureStorage,
|
||||
config: HyperbatchConfig,
|
||||
adjacency: &[(usize, usize)],
|
||||
) -> Self {
|
||||
let num_nodes = storage.num_nodes();
|
||||
let node_order = Self::reorder_bfs(adjacency, num_nodes);
|
||||
let num_buffers = config.num_buffers.max(1);
|
||||
let buffers = vec![Vec::new(); num_buffers];
|
||||
Self {
|
||||
storage,
|
||||
config,
|
||||
node_order,
|
||||
current_offset: 0,
|
||||
buffers,
|
||||
active_buffer: 0,
|
||||
batch_counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the next batch, or `None` when the epoch is complete.
|
||||
pub fn next_batch(&mut self) -> Option<HyperbatchResult> {
|
||||
if self.current_offset >= self.node_order.len() {
|
||||
return None;
|
||||
}
|
||||
let end = (self.current_offset + self.config.batch_size).min(self.node_order.len());
|
||||
let node_ids: Vec<usize> = self.node_order[self.current_offset..end].to_vec();
|
||||
let features = self.storage.read_batch(&node_ids).ok()?;
|
||||
|
||||
// Store in active buffer for potential re-use
|
||||
let buf_idx = self.active_buffer % self.buffers.len();
|
||||
self.buffers[buf_idx] = features.clone();
|
||||
self.active_buffer += 1;
|
||||
|
||||
let batch_index = self.batch_counter;
|
||||
self.batch_counter += 1;
|
||||
self.current_offset = end;
|
||||
|
||||
Some(HyperbatchResult {
|
||||
node_ids,
|
||||
features,
|
||||
batch_index,
|
||||
})
|
||||
}
|
||||
|
||||
/// Reset the iterator to the beginning of the epoch.
|
||||
pub fn reset(&mut self) {
|
||||
self.current_offset = 0;
|
||||
self.batch_counter = 0;
|
||||
self.active_buffer = 0;
|
||||
}
|
||||
|
||||
/// Produce a BFS vertex ordering for better I/O locality.
|
||||
pub fn reorder_bfs(adjacency: &[(usize, usize)], num_nodes: usize) -> Vec<usize> {
|
||||
if num_nodes == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
// Build adjacency list
|
||||
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
|
||||
for &(u, v) in adjacency {
|
||||
if u < num_nodes && v < num_nodes {
|
||||
adj[u].push(v);
|
||||
adj[v].push(u);
|
||||
}
|
||||
}
|
||||
|
||||
let mut visited = vec![false; num_nodes];
|
||||
let mut order = Vec::with_capacity(num_nodes);
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
// BFS from node 0; handle disconnected components
|
||||
for start in 0..num_nodes {
|
||||
if visited[start] {
|
||||
continue;
|
||||
}
|
||||
visited[start] = true;
|
||||
queue.push_back(start);
|
||||
while let Some(node) = queue.pop_front() {
|
||||
order.push(node);
|
||||
for &neighbor in &adj[node] {
|
||||
if !visited[neighbor] {
|
||||
visited[neighbor] = true;
|
||||
queue.push_back(neighbor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
order
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AdaptiveHotset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// In-memory cache of frequently accessed node features.
|
||||
pub struct AdaptiveHotset {
|
||||
features: HashMap<usize, Vec<f32>>,
|
||||
access_counts: HashMap<usize, u64>,
|
||||
capacity: usize,
|
||||
decay_factor: f64,
|
||||
total_lookups: u64,
|
||||
hits: u64,
|
||||
}
|
||||
|
||||
impl AdaptiveHotset {
|
||||
/// Create a new hotset with the given capacity and decay factor.
|
||||
pub fn new(capacity: usize, decay_factor: f64) -> Self {
|
||||
Self {
|
||||
features: HashMap::with_capacity(capacity),
|
||||
access_counts: HashMap::with_capacity(capacity),
|
||||
capacity,
|
||||
decay_factor,
|
||||
total_lookups: 0,
|
||||
hits: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// O(1) lookup of cached features.
|
||||
pub fn get(&mut self, node_id: usize) -> Option<&[f32]> {
|
||||
self.total_lookups += 1;
|
||||
if self.features.contains_key(&node_id) {
|
||||
self.hits += 1;
|
||||
*self.access_counts.entry(node_id).or_insert(0) += 1;
|
||||
// Safety: we just confirmed the key exists
|
||||
Some(self.features.get(&node_id).unwrap().as_slice())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert features, evicting the coldest entry if at capacity.
|
||||
pub fn insert(&mut self, node_id: usize, features: Vec<f32>) {
|
||||
if self.features.len() >= self.capacity && !self.features.contains_key(&node_id) {
|
||||
self.evict_cold();
|
||||
}
|
||||
self.access_counts.entry(node_id).or_insert(0);
|
||||
self.features.insert(node_id, features);
|
||||
}
|
||||
|
||||
/// Record an access without returning features (for tracking frequency).
|
||||
pub fn record_access(&mut self, node_id: usize) {
|
||||
*self.access_counts.entry(node_id).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
/// Evict the least-accessed node from the hotset.
|
||||
pub fn evict_cold(&mut self) {
|
||||
if self.access_counts.is_empty() {
|
||||
return;
|
||||
}
|
||||
// Find the node with the lowest access count that is cached
|
||||
let coldest = self
|
||||
.features
|
||||
.keys()
|
||||
.min_by_key(|nid| self.access_counts.get(nid).copied().unwrap_or(0))
|
||||
.copied();
|
||||
if let Some(nid) = coldest {
|
||||
self.features.remove(&nid);
|
||||
self.access_counts.remove(&nid);
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache hit rate since creation.
|
||||
pub fn hit_rate(&self) -> f64 {
|
||||
if self.total_lookups == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.hits as f64 / self.total_lookups as f64
|
||||
}
|
||||
|
||||
/// Multiply all access counts by `decay_factor` to age out stale entries.
|
||||
pub fn decay_counts(&mut self) {
|
||||
for count in self.access_counts.values_mut() {
|
||||
*count = (*count as f64 * self.decay_factor) as u64;
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of nodes currently cached.
|
||||
pub fn len(&self) -> usize {
|
||||
self.features.len()
|
||||
}
|
||||
|
||||
/// Whether the hotset is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.features.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ColdTierEpochResult
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Statistics from one cold-tier training epoch.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ColdTierEpochResult {
|
||||
/// Epoch number.
|
||||
pub epoch: usize,
|
||||
/// Average loss across all batches.
|
||||
pub avg_loss: f64,
|
||||
/// Number of batches processed.
|
||||
pub batches: usize,
|
||||
/// Hotset hit rate during this epoch.
|
||||
pub hotset_hit_rate: f64,
|
||||
/// Milliseconds spent on I/O.
|
||||
pub io_time_ms: u64,
|
||||
/// Milliseconds spent on compute.
|
||||
pub compute_time_ms: u64,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ColdTierTrainer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Orchestrates cold-tier training with hyperbatch I/O and hotset caching.
|
||||
pub struct ColdTierTrainer {
|
||||
storage: FeatureStorage,
|
||||
hotset: AdaptiveHotset,
|
||||
config: HyperbatchConfig,
|
||||
epoch: usize,
|
||||
total_loss: f64,
|
||||
batches_processed: usize,
|
||||
}
|
||||
|
||||
impl ColdTierTrainer {
|
||||
/// Create a new trainer, initializing feature storage and hotset.
|
||||
pub fn new(
|
||||
storage_path: &Path,
|
||||
dim: usize,
|
||||
num_nodes: usize,
|
||||
config: HyperbatchConfig,
|
||||
) -> Result<Self> {
|
||||
let storage = FeatureStorage::create(storage_path, dim, num_nodes)?;
|
||||
let hotset_cap = ((num_nodes as f64) * config.hotset_fraction).max(1.0) as usize;
|
||||
let hotset = AdaptiveHotset::new(hotset_cap, 0.95);
|
||||
Ok(Self {
|
||||
storage,
|
||||
hotset,
|
||||
config,
|
||||
epoch: 0,
|
||||
total_loss: 0.0,
|
||||
batches_processed: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run one training epoch over all hyperbatches.
|
||||
///
|
||||
/// For each batch a simple gradient-descent step is simulated:
|
||||
/// the loss is the L2 norm of the feature vector, and the gradient
|
||||
/// nudges each element toward zero by `learning_rate`.
|
||||
pub fn train_epoch(
|
||||
&mut self,
|
||||
adjacency: &[(usize, usize)],
|
||||
learning_rate: f64,
|
||||
) -> ColdTierEpochResult {
|
||||
let io_start = std::time::Instant::now();
|
||||
|
||||
// Build a fresh iterator each epoch (re-shuffles BFS ordering)
|
||||
let storage_for_iter = FeatureStorage::open(self.storage.path()).ok();
|
||||
let mut epoch_loss = 0.0;
|
||||
let mut batch_count: usize = 0;
|
||||
let mut io_ms: u64 = 0;
|
||||
let mut compute_ms: u64 = 0;
|
||||
|
||||
if let Some(iter_storage) = storage_for_iter {
|
||||
let mut iter = HyperbatchIterator::new(iter_storage, self.config.clone(), adjacency);
|
||||
|
||||
while let Some(batch) = iter.next_batch() {
|
||||
let io_elapsed = io_start.elapsed().as_millis() as u64;
|
||||
|
||||
let compute_start = std::time::Instant::now();
|
||||
|
||||
// Process each node in the batch
|
||||
for (i, node_id) in batch.node_ids.iter().enumerate() {
|
||||
let features = &batch.features[i];
|
||||
|
||||
// Simple L2 loss for demonstration
|
||||
let loss: f64 = features
|
||||
.iter()
|
||||
.map(|&x| (x as f64) * (x as f64))
|
||||
.sum::<f64>()
|
||||
* 0.5;
|
||||
epoch_loss += loss;
|
||||
|
||||
// Gradient: d(0.5 * x^2)/dx = x; step: x' = x - lr * x
|
||||
let updated: Vec<f32> = features
|
||||
.iter()
|
||||
.map(|&x| x - (learning_rate as f32) * x)
|
||||
.collect();
|
||||
|
||||
let _ = self.storage.write_features(*node_id, &updated);
|
||||
self.hotset.insert(*node_id, updated);
|
||||
}
|
||||
|
||||
compute_ms += compute_start.elapsed().as_millis() as u64;
|
||||
io_ms = io_elapsed;
|
||||
batch_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let _ = self.storage.flush();
|
||||
self.hotset.decay_counts();
|
||||
self.epoch += 1;
|
||||
self.total_loss = if batch_count > 0 {
|
||||
epoch_loss / batch_count as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
self.batches_processed = batch_count;
|
||||
|
||||
ColdTierEpochResult {
|
||||
epoch: self.epoch,
|
||||
avg_loss: self.total_loss,
|
||||
batches: batch_count,
|
||||
hotset_hit_rate: self.hotset.hit_rate(),
|
||||
io_time_ms: io_ms,
|
||||
compute_time_ms: compute_ms,
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve features for a node, checking the hotset first.
|
||||
pub fn get_features(&mut self, node_id: usize) -> Result<Vec<f32>> {
|
||||
if let Some(cached) = self.hotset.get(node_id) {
|
||||
return Ok(cached.to_vec());
|
||||
}
|
||||
let features = self.storage.read_features(node_id)?;
|
||||
self.hotset.insert(node_id, features.clone());
|
||||
Ok(features)
|
||||
}
|
||||
|
||||
/// Save a checkpoint (header + storage path + hotset metadata).
|
||||
pub fn save_checkpoint(&self, path: &Path) -> Result<()> {
|
||||
let data = serde_json::json!({
|
||||
"storage_path": self.storage.path().to_string_lossy(),
|
||||
"dim": self.storage.dim(),
|
||||
"num_nodes": self.storage.num_nodes(),
|
||||
"epoch": self.epoch,
|
||||
"total_loss": self.total_loss,
|
||||
"batches_processed": self.batches_processed,
|
||||
"config": {
|
||||
"batch_size": self.config.batch_size,
|
||||
"prefetch_factor": self.config.prefetch_factor,
|
||||
"block_align": self.config.block_align,
|
||||
"num_buffers": self.config.num_buffers,
|
||||
"hotset_fraction": self.config.hotset_fraction,
|
||||
}
|
||||
});
|
||||
let content = serde_json::to_string_pretty(&data)
|
||||
.map_err(|e| GnnError::other(format!("serialize checkpoint: {}", e)))?;
|
||||
std::fs::write(path, content)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a trainer from a checkpoint file.
|
||||
pub fn load_checkpoint(path: &Path) -> Result<Self> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let v: serde_json::Value = serde_json::from_str(&content)
|
||||
.map_err(|e| GnnError::other(format!("deserialize checkpoint: {}", e)))?;
|
||||
|
||||
let storage_path = PathBuf::from(
|
||||
v["storage_path"]
|
||||
.as_str()
|
||||
.ok_or_else(|| GnnError::other("missing storage_path"))?,
|
||||
);
|
||||
let _dim = v["dim"].as_u64().unwrap_or(0) as usize;
|
||||
let num_nodes = v["num_nodes"].as_u64().unwrap_or(0) as usize;
|
||||
let epoch = v["epoch"].as_u64().unwrap_or(0) as usize;
|
||||
let total_loss = v["total_loss"].as_f64().unwrap_or(0.0);
|
||||
let batches_processed = v["batches_processed"].as_u64().unwrap_or(0) as usize;
|
||||
|
||||
let cfg_val = &v["config"];
|
||||
let config = HyperbatchConfig {
|
||||
batch_size: cfg_val["batch_size"].as_u64().unwrap_or(4096) as usize,
|
||||
prefetch_factor: cfg_val["prefetch_factor"].as_u64().unwrap_or(2) as usize,
|
||||
block_align: cfg_val["block_align"].as_u64().unwrap_or(4096) as usize,
|
||||
num_buffers: cfg_val["num_buffers"].as_u64().unwrap_or(2) as usize,
|
||||
hotset_fraction: cfg_val["hotset_fraction"].as_f64().unwrap_or(0.05),
|
||||
};
|
||||
|
||||
let storage = FeatureStorage::open(&storage_path).map_err(|_| {
|
||||
// If the storage file no longer exists, recreate it
|
||||
GnnError::other("storage file not found; re-create before loading")
|
||||
})?;
|
||||
|
||||
let hotset_cap = ((num_nodes as f64) * config.hotset_fraction).max(1.0) as usize;
|
||||
let hotset = AdaptiveHotset::new(hotset_cap, 0.95);
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
hotset,
|
||||
config,
|
||||
epoch,
|
||||
total_loss,
|
||||
batches_processed,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ColdTierEwc
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Disk-backed Elastic Weight Consolidation using FeatureStorage.
|
||||
///
|
||||
/// Stores Fisher information diagonal and anchor weights on disk
|
||||
/// so that EWC can scale to models that do not fit in RAM.
|
||||
pub struct ColdTierEwc {
|
||||
fisher_storage: FeatureStorage,
|
||||
anchor_storage: FeatureStorage,
|
||||
lambda: f64,
|
||||
active: bool,
|
||||
dim: usize,
|
||||
num_params: usize,
|
||||
}
|
||||
|
||||
impl ColdTierEwc {
|
||||
/// Create a new disk-backed EWC instance.
|
||||
///
|
||||
/// `dim` is the width of each parameter "row" (analogous to feature dim),
|
||||
/// and `num_params` is the number of such rows.
|
||||
pub fn new(path: &Path, dim: usize, num_params: usize, lambda: f64) -> Result<Self> {
|
||||
let fisher_path = path.join("fisher.bin");
|
||||
let anchor_path = path.join("anchor.bin");
|
||||
std::fs::create_dir_all(path)?;
|
||||
let fisher_storage = FeatureStorage::create(&fisher_path, dim, num_params)?;
|
||||
let anchor_storage = FeatureStorage::create(&anchor_path, dim, num_params)?;
|
||||
Ok(Self {
|
||||
fisher_storage,
|
||||
anchor_storage,
|
||||
lambda,
|
||||
active: false,
|
||||
dim,
|
||||
num_params,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute Fisher information diagonal from gradient samples.
|
||||
///
|
||||
/// Each entry in `gradients` is one sample's gradient for one parameter row.
|
||||
pub fn compute_fisher(&mut self, gradients: &[Vec<f32>], sample_count: usize) -> Result<()> {
|
||||
if gradients.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let rows = gradients.len() / self.num_params;
|
||||
if rows == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let norm = 1.0 / (sample_count as f32).max(1.0);
|
||||
|
||||
for param_idx in 0..self.num_params {
|
||||
let mut fisher_row = vec![0.0f32; self.dim];
|
||||
for sample in 0..rows {
|
||||
let idx = sample * self.num_params + param_idx;
|
||||
if idx < gradients.len() {
|
||||
let grad = &gradients[idx];
|
||||
for (i, &g) in grad.iter().enumerate().take(self.dim) {
|
||||
fisher_row[i] += g * g;
|
||||
}
|
||||
}
|
||||
}
|
||||
for v in &mut fisher_row {
|
||||
*v *= norm;
|
||||
}
|
||||
self.fisher_storage.write_features(param_idx, &fisher_row)?;
|
||||
}
|
||||
self.fisher_storage.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Consolidate current weights as anchors and activate EWC.
|
||||
pub fn consolidate(&mut self, current_weights: &[Vec<f32>]) -> Result<()> {
|
||||
if current_weights.len() != self.num_params {
|
||||
return Err(GnnError::dimension_mismatch(
|
||||
self.num_params.to_string(),
|
||||
current_weights.len().to_string(),
|
||||
));
|
||||
}
|
||||
for (i, w) in current_weights.iter().enumerate() {
|
||||
self.anchor_storage.write_features(i, w)?;
|
||||
}
|
||||
self.anchor_storage.flush()?;
|
||||
self.active = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute the EWC penalty: lambda/2 * sum(F_i * (w_i - w*_i)^2).
|
||||
pub fn penalty(&mut self, current_weights: &[Vec<f32>]) -> Result<f64> {
|
||||
if !self.active {
|
||||
return Ok(0.0);
|
||||
}
|
||||
let mut total = 0.0f64;
|
||||
for i in 0..self.num_params {
|
||||
let fisher = self.fisher_storage.read_features(i)?;
|
||||
let anchor = self.anchor_storage.read_features(i)?;
|
||||
let w = ¤t_weights[i];
|
||||
for j in 0..self.dim.min(w.len()) {
|
||||
let diff = w[j] - anchor[j];
|
||||
total += (fisher[j] as f64) * (diff as f64) * (diff as f64);
|
||||
}
|
||||
}
|
||||
Ok(total * self.lambda * 0.5)
|
||||
}
|
||||
|
||||
/// Compute the EWC gradient for a specific parameter row.
|
||||
pub fn gradient(&mut self, current_weights: &[Vec<f32>], param_idx: usize) -> Result<Vec<f32>> {
|
||||
if !self.active || param_idx >= self.num_params {
|
||||
return Ok(vec![0.0; self.dim]);
|
||||
}
|
||||
let fisher = self.fisher_storage.read_features(param_idx)?;
|
||||
let anchor = self.anchor_storage.read_features(param_idx)?;
|
||||
let w = ¤t_weights[param_idx];
|
||||
let grad: Vec<f32> = (0..self.dim)
|
||||
.map(|j| (self.lambda as f32) * fisher[j] * (w[j] - anchor[j]))
|
||||
.collect();
|
||||
Ok(grad)
|
||||
}
|
||||
|
||||
/// Whether EWC is active.
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.active
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_feature_storage_roundtrip() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let path = tmp.path().join("features.bin");
|
||||
|
||||
let dim = 8;
|
||||
let num_nodes = 10;
|
||||
let mut storage = FeatureStorage::create(&path, dim, num_nodes).unwrap();
|
||||
|
||||
// Write features for several nodes
|
||||
for nid in 0..num_nodes {
|
||||
let features: Vec<f32> = (0..dim).map(|j| (nid * dim + j) as f32).collect();
|
||||
storage.write_features(nid, &features).unwrap();
|
||||
}
|
||||
storage.flush().unwrap();
|
||||
|
||||
// Re-open and read back
|
||||
let mut storage2 = FeatureStorage::open(&path).unwrap();
|
||||
assert_eq!(storage2.dim(), dim);
|
||||
assert_eq!(storage2.num_nodes(), num_nodes);
|
||||
|
||||
for nid in 0..num_nodes {
|
||||
let features = storage2.read_features(nid).unwrap();
|
||||
assert_eq!(features.len(), dim);
|
||||
for j in 0..dim {
|
||||
assert!((features[j] - (nid * dim + j) as f32).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperbatch_ordering() {
|
||||
// Build a simple chain: 0-1-2-3-4
|
||||
let adjacency = vec![(0, 1), (1, 2), (2, 3), (3, 4)];
|
||||
let order = HyperbatchIterator::reorder_bfs(&adjacency, 5);
|
||||
|
||||
// BFS from 0 should visit 0, 1, 2, 3, 4 in order
|
||||
assert_eq!(order, vec![0, 1, 2, 3, 4]);
|
||||
|
||||
// Star graph: 0 connected to 1..4
|
||||
let star = vec![(0, 1), (0, 2), (0, 3), (0, 4)];
|
||||
let star_order = HyperbatchIterator::reorder_bfs(&star, 5);
|
||||
// 0 first, then neighbors (order may vary but 0 must be first)
|
||||
assert_eq!(star_order[0], 0);
|
||||
assert_eq!(star_order.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hotset_eviction() {
|
||||
let mut hotset = AdaptiveHotset::new(3, 0.9);
|
||||
|
||||
hotset.insert(0, vec![1.0, 2.0]);
|
||||
hotset.insert(1, vec![3.0, 4.0]);
|
||||
hotset.insert(2, vec![5.0, 6.0]);
|
||||
|
||||
// Access node 0 and 1 more frequently
|
||||
for _ in 0..10 {
|
||||
hotset.record_access(0);
|
||||
hotset.record_access(1);
|
||||
}
|
||||
// Node 2 has fewest accesses (only the initial 0)
|
||||
|
||||
// Insert a 4th node -> should evict node 2 (coldest)
|
||||
hotset.insert(3, vec![7.0, 8.0]);
|
||||
|
||||
assert_eq!(hotset.len(), 3);
|
||||
// Node 2 should be gone
|
||||
assert!(hotset.get(2).is_none());
|
||||
// Nodes 0, 1, 3 should still be present
|
||||
assert!(hotset.get(0).is_some());
|
||||
assert!(hotset.get(1).is_some());
|
||||
assert!(hotset.get(3).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cold_tier_epoch() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let storage_path = tmp.path().join("train_features.bin");
|
||||
|
||||
let dim = 4;
|
||||
let num_nodes = 16;
|
||||
let config = HyperbatchConfig {
|
||||
batch_size: 4,
|
||||
hotset_fraction: 0.25,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut trainer = ColdTierTrainer::new(&storage_path, dim, num_nodes, config).unwrap();
|
||||
|
||||
// Write initial features
|
||||
for nid in 0..num_nodes {
|
||||
let features = vec![1.0f32; dim];
|
||||
trainer.storage.write_features(nid, &features).unwrap();
|
||||
}
|
||||
trainer.storage.flush().unwrap();
|
||||
|
||||
// Build a simple chain adjacency
|
||||
let adjacency: Vec<(usize, usize)> = (0..num_nodes.saturating_sub(1))
|
||||
.map(|i| (i, i + 1))
|
||||
.collect();
|
||||
|
||||
let result = trainer.train_epoch(&adjacency, 0.1);
|
||||
|
||||
assert_eq!(result.epoch, 1);
|
||||
assert!(result.batches > 0);
|
||||
// All 16 nodes in batches of 4 = 4 batches
|
||||
assert_eq!(result.batches, 4);
|
||||
// Loss should be positive (features started at 1.0)
|
||||
assert!(result.avg_loss > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cold_tier_ewc() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let ewc_dir = tmp.path().join("ewc");
|
||||
|
||||
let dim = 4;
|
||||
let num_params = 3;
|
||||
let lambda = 100.0;
|
||||
|
||||
let mut ewc = ColdTierEwc::new(&ewc_dir, dim, num_params, lambda).unwrap();
|
||||
|
||||
// Compute Fisher from gradients (1 sample, 3 param rows)
|
||||
let gradients = vec![
|
||||
vec![1.0, 2.0, 3.0, 4.0],
|
||||
vec![0.5, 0.5, 0.5, 0.5],
|
||||
vec![2.0, 1.0, 0.0, 1.0],
|
||||
];
|
||||
ewc.compute_fisher(&gradients, 1).unwrap();
|
||||
|
||||
// Verify Fisher was stored correctly
|
||||
let fisher0 = ewc.fisher_storage.read_features(0).unwrap();
|
||||
assert!((fisher0[0] - 1.0).abs() < 1e-6); // 1^2 / 1
|
||||
assert!((fisher0[1] - 4.0).abs() < 1e-6); // 2^2 / 1
|
||||
|
||||
// Consolidate
|
||||
let weights = vec![
|
||||
vec![0.0, 0.0, 0.0, 0.0],
|
||||
vec![0.0, 0.0, 0.0, 0.0],
|
||||
vec![0.0, 0.0, 0.0, 0.0],
|
||||
];
|
||||
ewc.consolidate(&weights).unwrap();
|
||||
assert!(ewc.is_active());
|
||||
|
||||
// Penalty should be 0 at anchor
|
||||
let penalty = ewc.penalty(&weights).unwrap();
|
||||
assert!(penalty.abs() < 1e-6);
|
||||
|
||||
// Deviation should produce a penalty
|
||||
let deviated = vec![
|
||||
vec![1.0, 1.0, 1.0, 1.0],
|
||||
vec![1.0, 1.0, 1.0, 1.0],
|
||||
vec![1.0, 1.0, 1.0, 1.0],
|
||||
];
|
||||
let penalty = ewc.penalty(&deviated).unwrap();
|
||||
assert!(penalty > 0.0);
|
||||
|
||||
// Gradient for param 0 should be lambda * fisher * diff
|
||||
let grad = ewc.gradient(&deviated, 0).unwrap();
|
||||
assert!((grad[0] - 100.0 * 1.0 * 1.0).abs() < 1e-4);
|
||||
assert!((grad[1] - 100.0 * 4.0 * 1.0).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
678
vendor/ruvector/crates/ruvector-gnn/src/compress.rs
vendored
Normal file
678
vendor/ruvector/crates/ruvector-gnn/src/compress.rs
vendored
Normal file
@@ -0,0 +1,678 @@
|
||||
//! Tensor compression with adaptive level selection
|
||||
//!
|
||||
//! This module provides multi-level tensor compression based on access frequency:
|
||||
//! - Hot data (f > 0.8): Full precision
|
||||
//! - Warm data (f > 0.4): Half precision
|
||||
//! - Cool data (f > 0.1): 8-bit product quantization
|
||||
//! - Cold data (f > 0.01): 4-bit product quantization
|
||||
//! - Archive (f <= 0.01): Binary quantization
|
||||
|
||||
use crate::error::{GnnError, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Compression level with associated parameters
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum CompressionLevel {
|
||||
/// Full precision - no compression
|
||||
None,
|
||||
|
||||
/// Half precision with scale factor
|
||||
Half { scale: f32 },
|
||||
|
||||
/// Product quantization with 8-bit codes
|
||||
PQ8 { subvectors: u8, centroids: u8 },
|
||||
|
||||
/// Product quantization with 4-bit codes and outlier handling
|
||||
PQ4 {
|
||||
subvectors: u8,
|
||||
outlier_threshold: f32,
|
||||
},
|
||||
|
||||
/// Binary quantization with threshold
|
||||
Binary { threshold: f32 },
|
||||
}
|
||||
|
||||
/// Compressed tensor data
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum CompressedTensor {
|
||||
/// Uncompressed full precision data
|
||||
Full { data: Vec<f32> },
|
||||
|
||||
/// Half precision data
|
||||
Half {
|
||||
data: Vec<u16>,
|
||||
scale: f32,
|
||||
dim: usize,
|
||||
},
|
||||
|
||||
/// 8-bit product quantization
|
||||
PQ8 {
|
||||
codes: Vec<u8>,
|
||||
codebooks: Vec<Vec<f32>>,
|
||||
subvector_dim: usize,
|
||||
dim: usize,
|
||||
},
|
||||
|
||||
/// 4-bit product quantization with outliers
|
||||
PQ4 {
|
||||
codes: Vec<u8>, // Packed 4-bit codes
|
||||
codebooks: Vec<Vec<f32>>,
|
||||
outliers: Vec<(usize, f32)>, // (index, value) pairs
|
||||
subvector_dim: usize,
|
||||
dim: usize,
|
||||
},
|
||||
|
||||
/// Binary quantization
|
||||
Binary {
|
||||
bits: Vec<u8>,
|
||||
threshold: f32,
|
||||
dim: usize,
|
||||
},
|
||||
}
|
||||
|
||||
/// Tensor compressor with adaptive level selection
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TensorCompress {
|
||||
/// Default compression parameters
|
||||
default_level: CompressionLevel,
|
||||
}
|
||||
|
||||
impl Default for TensorCompress {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorCompress {
|
||||
/// Create a new tensor compressor with default settings
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
default_level: CompressionLevel::None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compress an embedding based on access frequency
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embedding` - The input embedding vector
|
||||
/// * `access_freq` - Access frequency in range [0.0, 1.0]
|
||||
///
|
||||
/// # Returns
|
||||
/// Compressed tensor using adaptive compression level
|
||||
pub fn compress(&self, embedding: &[f32], access_freq: f32) -> Result<CompressedTensor> {
|
||||
if embedding.is_empty() {
|
||||
return Err(GnnError::InvalidInput("Empty embedding vector".to_string()));
|
||||
}
|
||||
|
||||
let level = self.select_level(access_freq);
|
||||
self.compress_with_level(embedding, &level)
|
||||
}
|
||||
|
||||
/// Compress with explicit compression level
|
||||
pub fn compress_with_level(
|
||||
&self,
|
||||
embedding: &[f32],
|
||||
level: &CompressionLevel,
|
||||
) -> Result<CompressedTensor> {
|
||||
match level {
|
||||
CompressionLevel::None => self.compress_none(embedding),
|
||||
CompressionLevel::Half { scale } => self.compress_half(embedding, *scale),
|
||||
CompressionLevel::PQ8 {
|
||||
subvectors,
|
||||
centroids,
|
||||
} => self.compress_pq8(embedding, *subvectors, *centroids),
|
||||
CompressionLevel::PQ4 {
|
||||
subvectors,
|
||||
outlier_threshold,
|
||||
} => self.compress_pq4(embedding, *subvectors, *outlier_threshold),
|
||||
CompressionLevel::Binary { threshold } => self.compress_binary(embedding, *threshold),
|
||||
}
|
||||
}
|
||||
|
||||
/// Decompress a compressed tensor
|
||||
pub fn decompress(&self, compressed: &CompressedTensor) -> Result<Vec<f32>> {
|
||||
match compressed {
|
||||
CompressedTensor::Full { data } => Ok(data.clone()),
|
||||
CompressedTensor::Half { data, scale, dim } => self.decompress_half(data, *scale, *dim),
|
||||
CompressedTensor::PQ8 {
|
||||
codes,
|
||||
codebooks,
|
||||
subvector_dim,
|
||||
dim,
|
||||
} => self.decompress_pq8(codes, codebooks, *subvector_dim, *dim),
|
||||
CompressedTensor::PQ4 {
|
||||
codes,
|
||||
codebooks,
|
||||
outliers,
|
||||
subvector_dim,
|
||||
dim,
|
||||
} => self.decompress_pq4(codes, codebooks, outliers, *subvector_dim, *dim),
|
||||
CompressedTensor::Binary {
|
||||
bits,
|
||||
threshold,
|
||||
dim,
|
||||
} => self.decompress_binary(bits, *threshold, *dim),
|
||||
}
|
||||
}
|
||||
|
||||
/// Select compression level based on access frequency
|
||||
///
|
||||
/// Thresholds:
|
||||
/// - f > 0.8: None (hot data)
|
||||
/// - f > 0.4: Half (warm data)
|
||||
/// - f > 0.1: PQ8 (cool data)
|
||||
/// - f > 0.01: PQ4 (cold data)
|
||||
/// - f <= 0.01: Binary (archive)
|
||||
fn select_level(&self, access_freq: f32) -> CompressionLevel {
|
||||
if access_freq > 0.8 {
|
||||
CompressionLevel::None
|
||||
} else if access_freq > 0.4 {
|
||||
CompressionLevel::Half { scale: 1.0 }
|
||||
} else if access_freq > 0.1 {
|
||||
CompressionLevel::PQ8 {
|
||||
subvectors: 8,
|
||||
centroids: 16,
|
||||
}
|
||||
} else if access_freq > 0.01 {
|
||||
CompressionLevel::PQ4 {
|
||||
subvectors: 8,
|
||||
outlier_threshold: 3.0,
|
||||
}
|
||||
} else {
|
||||
CompressionLevel::Binary { threshold: 0.0 }
|
||||
}
|
||||
}
|
||||
|
||||
// === Compression implementations ===
|
||||
|
||||
fn compress_none(&self, embedding: &[f32]) -> Result<CompressedTensor> {
|
||||
Ok(CompressedTensor::Full {
|
||||
data: embedding.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
fn compress_half(&self, embedding: &[f32], scale: f32) -> Result<CompressedTensor> {
|
||||
// Simple half precision: scale and convert to 16-bit
|
||||
let data: Vec<u16> = embedding
|
||||
.iter()
|
||||
.map(|&x| {
|
||||
let scaled = x * scale;
|
||||
let clamped = scaled.clamp(-65504.0, 65504.0);
|
||||
// Convert to half precision representation
|
||||
f32_to_f16_bits(clamped)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(CompressedTensor::Half {
|
||||
data,
|
||||
scale,
|
||||
dim: embedding.len(),
|
||||
})
|
||||
}
|
||||
|
||||
fn compress_pq8(
|
||||
&self,
|
||||
embedding: &[f32],
|
||||
subvectors: u8,
|
||||
centroids: u8,
|
||||
) -> Result<CompressedTensor> {
|
||||
let dim = embedding.len();
|
||||
let subvectors = subvectors as usize;
|
||||
|
||||
if dim % subvectors != 0 {
|
||||
return Err(GnnError::InvalidInput(format!(
|
||||
"Dimension {} not divisible by subvectors {}",
|
||||
dim, subvectors
|
||||
)));
|
||||
}
|
||||
|
||||
let subvector_dim = dim / subvectors;
|
||||
let mut codes = Vec::with_capacity(subvectors);
|
||||
let mut codebooks = Vec::with_capacity(subvectors);
|
||||
|
||||
// For each subvector, create a codebook and quantize
|
||||
for i in 0..subvectors {
|
||||
let start = i * subvector_dim;
|
||||
let end = start + subvector_dim;
|
||||
let subvector = &embedding[start..end];
|
||||
|
||||
// Simple k-means clustering (k=centroids)
|
||||
let (codebook, code) = self.quantize_subvector(subvector, centroids as usize);
|
||||
codes.push(code);
|
||||
codebooks.push(codebook);
|
||||
}
|
||||
|
||||
Ok(CompressedTensor::PQ8 {
|
||||
codes,
|
||||
codebooks,
|
||||
subvector_dim,
|
||||
dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn compress_pq4(
|
||||
&self,
|
||||
embedding: &[f32],
|
||||
subvectors: u8,
|
||||
outlier_threshold: f32,
|
||||
) -> Result<CompressedTensor> {
|
||||
let dim = embedding.len();
|
||||
let subvectors = subvectors as usize;
|
||||
|
||||
if dim % subvectors != 0 {
|
||||
return Err(GnnError::InvalidInput(format!(
|
||||
"Dimension {} not divisible by subvectors {}",
|
||||
dim, subvectors
|
||||
)));
|
||||
}
|
||||
|
||||
let subvector_dim = dim / subvectors;
|
||||
let mut codes = Vec::with_capacity(subvectors);
|
||||
let mut codebooks = Vec::with_capacity(subvectors);
|
||||
let mut outliers = Vec::new();
|
||||
|
||||
// Detect outliers based on magnitude
|
||||
let mean = embedding.iter().sum::<f32>() / dim as f32;
|
||||
let std_dev =
|
||||
(embedding.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / dim as f32).sqrt();
|
||||
|
||||
// For each subvector
|
||||
for i in 0..subvectors {
|
||||
let start = i * subvector_dim;
|
||||
let end = start + subvector_dim;
|
||||
let subvector = &embedding[start..end];
|
||||
|
||||
// Extract outliers
|
||||
let mut cleaned_subvector = subvector.to_vec();
|
||||
for (j, &val) in subvector.iter().enumerate() {
|
||||
if (val - mean).abs() > outlier_threshold * std_dev {
|
||||
outliers.push((start + j, val));
|
||||
cleaned_subvector[j] = mean; // Replace with mean
|
||||
}
|
||||
}
|
||||
|
||||
// Quantize to 4-bit (16 centroids)
|
||||
let (codebook, code) = self.quantize_subvector(&cleaned_subvector, 16);
|
||||
codes.push(code);
|
||||
codebooks.push(codebook);
|
||||
}
|
||||
|
||||
Ok(CompressedTensor::PQ4 {
|
||||
codes,
|
||||
codebooks,
|
||||
outliers,
|
||||
subvector_dim,
|
||||
dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn compress_binary(&self, embedding: &[f32], threshold: f32) -> Result<CompressedTensor> {
|
||||
let dim = embedding.len();
|
||||
let num_bytes = (dim + 7) / 8;
|
||||
let mut bits = vec![0u8; num_bytes];
|
||||
|
||||
for (i, &val) in embedding.iter().enumerate() {
|
||||
if val > threshold {
|
||||
let byte_idx = i / 8;
|
||||
let bit_idx = i % 8;
|
||||
bits[byte_idx] |= 1 << bit_idx;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(CompressedTensor::Binary {
|
||||
bits,
|
||||
threshold,
|
||||
dim,
|
||||
})
|
||||
}
|
||||
|
||||
// === Decompression implementations ===
|
||||
|
||||
fn decompress_half(&self, data: &[u16], scale: f32, dim: usize) -> Result<Vec<f32>> {
|
||||
if data.len() != dim {
|
||||
return Err(GnnError::InvalidInput(format!(
|
||||
"Dimension mismatch: expected {}, got {}",
|
||||
dim,
|
||||
data.len()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(data
|
||||
.iter()
|
||||
.map(|&bits| f16_bits_to_f32(bits) / scale)
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn decompress_pq8(
|
||||
&self,
|
||||
codes: &[u8],
|
||||
codebooks: &[Vec<f32>],
|
||||
subvector_dim: usize,
|
||||
dim: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
let subvectors = codes.len();
|
||||
let expected_dim = subvectors * subvector_dim;
|
||||
|
||||
if expected_dim != dim {
|
||||
return Err(GnnError::InvalidInput(format!(
|
||||
"Dimension mismatch: expected {}, got {}",
|
||||
dim, expected_dim
|
||||
)));
|
||||
}
|
||||
|
||||
let mut result = Vec::with_capacity(dim);
|
||||
|
||||
for (code, codebook) in codes.iter().zip(codebooks.iter()) {
|
||||
let centroid_idx = *code as usize;
|
||||
if centroid_idx >= codebook.len() / subvector_dim {
|
||||
return Err(GnnError::InvalidInput(format!(
|
||||
"Invalid centroid index: {}",
|
||||
centroid_idx
|
||||
)));
|
||||
}
|
||||
|
||||
let start = centroid_idx * subvector_dim;
|
||||
let end = start + subvector_dim;
|
||||
result.extend_from_slice(&codebook[start..end]);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn decompress_pq4(
|
||||
&self,
|
||||
codes: &[u8],
|
||||
codebooks: &[Vec<f32>],
|
||||
outliers: &[(usize, f32)],
|
||||
subvector_dim: usize,
|
||||
dim: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
// First decompress using PQ8 logic
|
||||
let mut result = self.decompress_pq8(codes, codebooks, subvector_dim, dim)?;
|
||||
|
||||
// Restore outliers
|
||||
for &(idx, val) in outliers {
|
||||
if idx < result.len() {
|
||||
result[idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn decompress_binary(&self, bits: &[u8], _threshold: f32, dim: usize) -> Result<Vec<f32>> {
|
||||
let expected_bytes = (dim + 7) / 8;
|
||||
if bits.len() != expected_bytes {
|
||||
return Err(GnnError::InvalidInput(format!(
|
||||
"Dimension mismatch: expected {} bytes, got {}",
|
||||
expected_bytes,
|
||||
bits.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut result = Vec::with_capacity(dim);
|
||||
|
||||
for i in 0..dim {
|
||||
let byte_idx = i / 8;
|
||||
let bit_idx = i % 8;
|
||||
let is_set = (bits[byte_idx] & (1 << bit_idx)) != 0;
|
||||
result.push(if is_set { 1.0 } else { -1.0 });
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
// === Helper methods ===
|
||||
|
||||
/// Simple quantization using k-means-like approach
|
||||
fn quantize_subvector(&self, subvector: &[f32], k: usize) -> (Vec<f32>, u8) {
|
||||
let dim = subvector.len();
|
||||
|
||||
// Initialize centroids using simple range-based approach
|
||||
let min_val = subvector.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
let max_val = subvector.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let range = max_val - min_val;
|
||||
|
||||
if range < 1e-6 {
|
||||
// All values are essentially the same
|
||||
let codebook = vec![min_val; dim * k];
|
||||
return (codebook, 0);
|
||||
}
|
||||
|
||||
// Create k centroids evenly spaced across the range
|
||||
let centroids: Vec<Vec<f32>> = (0..k)
|
||||
.map(|i| {
|
||||
let offset = min_val + (i as f32 / k as f32) * range;
|
||||
vec![offset; dim]
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Find nearest centroid for this subvector
|
||||
let code = self.nearest_centroid(subvector, ¢roids);
|
||||
|
||||
// Flatten codebook
|
||||
let codebook: Vec<f32> = centroids.into_iter().flatten().collect();
|
||||
|
||||
(codebook, code as u8)
|
||||
}
|
||||
|
||||
fn nearest_centroid(&self, subvector: &[f32], centroids: &[Vec<f32>]) -> usize {
|
||||
centroids
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, centroid)| {
|
||||
let dist: f32 = subvector
|
||||
.iter()
|
||||
.zip(centroid.iter())
|
||||
.map(|(a, b)| (a - b).powi(2))
|
||||
.sum();
|
||||
(i, dist)
|
||||
})
|
||||
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
// === Half precision conversion helpers ===
|
||||
|
||||
/// Convert f32 to f16 bits (simplified implementation)
|
||||
fn f32_to_f16_bits(value: f32) -> u16 {
|
||||
// Simple conversion: scale to 16-bit range
|
||||
// This is a simplified version, not IEEE 754 half precision
|
||||
let scaled = (value * 1000.0).clamp(-32768.0, 32767.0);
|
||||
((scaled as i32) + 32768) as u16
|
||||
}
|
||||
|
||||
/// Convert f16 bits to f32 (simplified implementation)
|
||||
fn f16_bits_to_f32(bits: u16) -> f32 {
|
||||
// Reverse of f32_to_f16_bits
|
||||
let value = bits as i32 - 32768;
|
||||
value as f32 / 1000.0
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_compress_none() {
|
||||
let compressor = TensorCompress::new();
|
||||
let embedding = vec![1.0, 2.0, 3.0, 4.0];
|
||||
|
||||
let compressed = compressor.compress(&embedding, 1.0).unwrap();
|
||||
let decompressed = compressor.decompress(&compressed).unwrap();
|
||||
|
||||
assert_eq!(embedding, decompressed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compress_half() {
|
||||
let compressor = TensorCompress::new();
|
||||
let embedding = vec![1.0, 2.0, 3.0, 4.0];
|
||||
|
||||
let compressed = compressor.compress(&embedding, 0.5).unwrap();
|
||||
let decompressed = compressor.decompress(&compressed).unwrap();
|
||||
|
||||
// Half precision should be close but not exact
|
||||
for (a, b) in embedding.iter().zip(decompressed.iter()) {
|
||||
assert!((a - b).abs() < 0.01, "Expected {}, got {}", a, b);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compress_binary() {
|
||||
let compressor = TensorCompress::new();
|
||||
let embedding = vec![1.0, -1.0, 0.5, -0.5];
|
||||
|
||||
let compressed = compressor.compress(&embedding, 0.005).unwrap();
|
||||
let decompressed = compressor.decompress(&compressed).unwrap();
|
||||
|
||||
// Binary should be +1 or -1
|
||||
assert_eq!(decompressed.len(), embedding.len());
|
||||
for val in &decompressed {
|
||||
assert!(*val == 1.0 || *val == -1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_level() {
|
||||
let compressor = TensorCompress::new();
|
||||
|
||||
// Hot data
|
||||
assert!(matches!(
|
||||
compressor.select_level(0.9),
|
||||
CompressionLevel::None
|
||||
));
|
||||
|
||||
// Warm data
|
||||
assert!(matches!(
|
||||
compressor.select_level(0.5),
|
||||
CompressionLevel::Half { .. }
|
||||
));
|
||||
|
||||
// Cool data
|
||||
assert!(matches!(
|
||||
compressor.select_level(0.2),
|
||||
CompressionLevel::PQ8 { .. }
|
||||
));
|
||||
|
||||
// Cold data
|
||||
assert!(matches!(
|
||||
compressor.select_level(0.05),
|
||||
CompressionLevel::PQ4 { .. }
|
||||
));
|
||||
|
||||
// Archive
|
||||
assert!(matches!(
|
||||
compressor.select_level(0.001),
|
||||
CompressionLevel::Binary { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_embedding() {
|
||||
let compressor = TensorCompress::new();
|
||||
let result = compressor.compress(&[], 0.5);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pq8_compression() {
|
||||
let compressor = TensorCompress::new();
|
||||
let embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
|
||||
|
||||
let compressed = compressor.compress_pq8(&embedding, 8, 16).unwrap();
|
||||
let decompressed = compressor.decompress(&compressed).unwrap();
|
||||
|
||||
assert_eq!(decompressed.len(), embedding.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_all_levels() {
|
||||
let compressor = TensorCompress::new();
|
||||
let embedding: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.01).collect();
|
||||
|
||||
let access_frequencies = vec![0.9, 0.5, 0.2, 0.05, 0.001];
|
||||
|
||||
for freq in access_frequencies {
|
||||
let compressed = compressor.compress(&embedding, freq).unwrap();
|
||||
let decompressed = compressor.decompress(&compressed).unwrap();
|
||||
assert_eq!(decompressed.len(), embedding.len());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_half_precision_roundtrip() {
|
||||
let compressor = TensorCompress::new();
|
||||
// Use values within the supported range (-32.768 to 32.767)
|
||||
let values = vec![-30.0, -1.0, 0.0, 1.0, 30.0];
|
||||
|
||||
for val in values {
|
||||
let embedding = vec![val; 4];
|
||||
let compressed = compressor
|
||||
.compress_with_level(&embedding, &CompressionLevel::Half { scale: 1.0 })
|
||||
.unwrap();
|
||||
let decompressed = compressor.decompress(&compressed).unwrap();
|
||||
|
||||
for (a, b) in embedding.iter().zip(decompressed.iter()) {
|
||||
let diff = (a - b).abs();
|
||||
assert!(
|
||||
diff < 0.1,
|
||||
"Value {} decompressed to {}, diff: {}",
|
||||
a,
|
||||
b,
|
||||
diff
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_threshold() {
|
||||
let compressor = TensorCompress::new();
|
||||
let embedding = vec![0.5, -0.5, 1.5, -1.5];
|
||||
|
||||
let compressed = compressor
|
||||
.compress_with_level(&embedding, &CompressionLevel::Binary { threshold: 0.0 })
|
||||
.unwrap();
|
||||
let decompressed = compressor.decompress(&compressed).unwrap();
|
||||
|
||||
// Values > 0 should be 1.0, values <= 0 should be -1.0
|
||||
assert_eq!(decompressed, vec![1.0, -1.0, 1.0, -1.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pq4_with_outliers() {
|
||||
let compressor = TensorCompress::new();
|
||||
// Create embedding with some outliers
|
||||
let mut embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
|
||||
embedding[10] = 100.0; // Outlier
|
||||
embedding[30] = -100.0; // Outlier
|
||||
|
||||
let compressed = compressor
|
||||
.compress_with_level(
|
||||
&embedding,
|
||||
&CompressionLevel::PQ4 {
|
||||
subvectors: 8,
|
||||
outlier_threshold: 2.0,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
let decompressed = compressor.decompress(&compressed).unwrap();
|
||||
|
||||
assert_eq!(decompressed.len(), embedding.len());
|
||||
// Outliers should be preserved
|
||||
assert_eq!(decompressed[10], 100.0);
|
||||
assert_eq!(decompressed[30], -100.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_validation() {
|
||||
let compressor = TensorCompress::new();
|
||||
let embedding = vec![1.0; 10]; // Not divisible by 8
|
||||
|
||||
let result = compressor.compress_pq8(&embedding, 8, 16);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
111
vendor/ruvector/crates/ruvector-gnn/src/error.rs
vendored
Normal file
111
vendor/ruvector/crates/ruvector-gnn/src/error.rs
vendored
Normal file
@@ -0,0 +1,111 @@
|
||||
//! Error types for the GNN module.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type alias for GNN operations.
|
||||
pub type Result<T> = std::result::Result<T, GnnError>;
|
||||
|
||||
/// Errors that can occur during GNN operations.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum GnnError {
|
||||
/// Tensor dimension mismatch
|
||||
#[error("Tensor dimension mismatch: expected {expected}, got {actual}")]
|
||||
DimensionMismatch {
|
||||
/// Expected dimension
|
||||
expected: String,
|
||||
/// Actual dimension
|
||||
actual: String,
|
||||
},
|
||||
|
||||
/// Invalid tensor shape
|
||||
#[error("Invalid tensor shape: {0}")]
|
||||
InvalidShape(String),
|
||||
|
||||
/// Layer configuration error
|
||||
#[error("Layer configuration error: {0}")]
|
||||
LayerConfig(String),
|
||||
|
||||
/// Training error
|
||||
#[error("Training error: {0}")]
|
||||
Training(String),
|
||||
|
||||
/// Compression error
|
||||
#[error("Compression error: {0}")]
|
||||
Compression(String),
|
||||
|
||||
/// Search error
|
||||
#[error("Search error: {0}")]
|
||||
Search(String),
|
||||
|
||||
/// Invalid input
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
/// Memory mapping error
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[error("Memory mapping error: {0}")]
|
||||
Mmap(String),
|
||||
|
||||
/// I/O error
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// Core library error
|
||||
#[error("Core error: {0}")]
|
||||
Core(#[from] ruvector_core::error::RuvectorError),
|
||||
|
||||
/// Generic error
|
||||
#[error("{0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl GnnError {
|
||||
/// Create a dimension mismatch error
|
||||
pub fn dimension_mismatch(expected: impl Into<String>, actual: impl Into<String>) -> Self {
|
||||
Self::DimensionMismatch {
|
||||
expected: expected.into(),
|
||||
actual: actual.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an invalid shape error
|
||||
pub fn invalid_shape(msg: impl Into<String>) -> Self {
|
||||
Self::InvalidShape(msg.into())
|
||||
}
|
||||
|
||||
/// Create a layer config error
|
||||
pub fn layer_config(msg: impl Into<String>) -> Self {
|
||||
Self::LayerConfig(msg.into())
|
||||
}
|
||||
|
||||
/// Create a training error
|
||||
pub fn training(msg: impl Into<String>) -> Self {
|
||||
Self::Training(msg.into())
|
||||
}
|
||||
|
||||
/// Create a compression error
|
||||
pub fn compression(msg: impl Into<String>) -> Self {
|
||||
Self::Compression(msg.into())
|
||||
}
|
||||
|
||||
/// Create a search error
|
||||
pub fn search(msg: impl Into<String>) -> Self {
|
||||
Self::Search(msg.into())
|
||||
}
|
||||
|
||||
/// Create a memory mapping error
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn mmap(msg: impl Into<String>) -> Self {
|
||||
Self::Mmap(msg.into())
|
||||
}
|
||||
|
||||
/// Create an invalid input error
|
||||
pub fn invalid_input(msg: impl Into<String>) -> Self {
|
||||
Self::InvalidInput(msg.into())
|
||||
}
|
||||
|
||||
/// Create a generic error
|
||||
pub fn other(msg: impl Into<String>) -> Self {
|
||||
Self::Other(msg.into())
|
||||
}
|
||||
}
|
||||
582
vendor/ruvector/crates/ruvector-gnn/src/ewc.rs
vendored
Normal file
582
vendor/ruvector/crates/ruvector-gnn/src/ewc.rs
vendored
Normal file
@@ -0,0 +1,582 @@
|
||||
/// Elastic Weight Consolidation (EWC) for preventing catastrophic forgetting in GNNs
|
||||
///
|
||||
/// EWC adds a regularization term that penalizes changes to important weights,
|
||||
/// where importance is measured by the Fisher information matrix diagonal.
|
||||
///
|
||||
/// The EWC loss term is: L_EWC = λ/2 * Σ F_i * (θ_i - θ*_i)²
|
||||
/// where:
|
||||
/// - λ is the regularization strength
|
||||
/// - F_i is the Fisher information for weight i
|
||||
/// - θ_i is the current weight
|
||||
/// - θ*_i is the anchor weight from the previous task
|
||||
use std::f32;
|
||||
|
||||
/// Elastic Weight Consolidation implementation
|
||||
///
|
||||
/// Prevents catastrophic forgetting by penalizing changes to important weights
|
||||
/// learned from previous tasks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ElasticWeightConsolidation {
|
||||
/// Fisher information diagonal (importance of each weight)
|
||||
/// Higher values indicate more important weights
|
||||
fisher_diag: Vec<f32>,
|
||||
|
||||
/// Anchor weights (optimal weights from previous task)
|
||||
/// These are the weights we want to stay close to
|
||||
anchor_weights: Vec<f32>,
|
||||
|
||||
/// Regularization strength (λ)
|
||||
/// Controls how strongly we penalize deviations from anchor weights
|
||||
lambda: f32,
|
||||
|
||||
/// Whether EWC is active
|
||||
/// EWC is only active after consolidation has been called
|
||||
active: bool,
|
||||
}
|
||||
|
||||
impl ElasticWeightConsolidation {
|
||||
/// Create a new EWC instance with specified regularization strength
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `lambda` - Regularization strength (typically 10-10000)
|
||||
///
|
||||
/// # Returns
|
||||
/// A new inactive EWC instance
|
||||
pub fn new(lambda: f32) -> Self {
|
||||
assert!(lambda >= 0.0, "Lambda must be non-negative");
|
||||
|
||||
Self {
|
||||
fisher_diag: Vec::new(),
|
||||
anchor_weights: Vec::new(),
|
||||
lambda,
|
||||
active: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute Fisher information diagonal from gradients
|
||||
///
|
||||
/// The Fisher information measures the importance of each weight.
|
||||
/// It's approximated as the mean squared gradient over samples:
|
||||
/// F_i ≈ (1/N) * Σ (∂L/∂θ_i)²
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `gradients` - Slice of gradient vectors for each sample
|
||||
/// * `sample_count` - Number of samples (for normalization)
|
||||
pub fn compute_fisher(&mut self, gradients: &[&[f32]], sample_count: usize) {
|
||||
if gradients.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let num_weights = gradients[0].len();
|
||||
|
||||
// Always reset Fisher diagonal to zero before computing
|
||||
// (Fisher information should be computed fresh from current gradients)
|
||||
self.fisher_diag = vec![0.0; num_weights];
|
||||
|
||||
// Accumulate squared gradients
|
||||
for grad in gradients {
|
||||
assert_eq!(
|
||||
grad.len(),
|
||||
num_weights,
|
||||
"All gradient vectors must have the same length"
|
||||
);
|
||||
|
||||
for (i, &g) in grad.iter().enumerate() {
|
||||
self.fisher_diag[i] += g * g;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize by sample count
|
||||
let normalization = 1.0 / (sample_count as f32).max(1.0);
|
||||
for f in &mut self.fisher_diag {
|
||||
*f *= normalization;
|
||||
}
|
||||
}
|
||||
|
||||
/// Save current weights as anchor and activate EWC
|
||||
///
|
||||
/// This should be called after training on a task, before moving to the next task.
|
||||
/// It marks the current weights as important and activates the EWC penalty.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `weights` - Current model weights to save as anchor
|
||||
pub fn consolidate(&mut self, weights: &[f32]) {
|
||||
assert!(
|
||||
!self.fisher_diag.is_empty(),
|
||||
"Must compute Fisher information before consolidating"
|
||||
);
|
||||
assert_eq!(
|
||||
weights.len(),
|
||||
self.fisher_diag.len(),
|
||||
"Weight count must match Fisher information size"
|
||||
);
|
||||
|
||||
self.anchor_weights = weights.to_vec();
|
||||
self.active = true;
|
||||
}
|
||||
|
||||
/// Compute EWC penalty term
|
||||
///
|
||||
/// Returns: λ/2 * Σ F_i * (θ_i - θ*_i)²
|
||||
///
|
||||
/// This penalty is added to the loss function to discourage changes
|
||||
/// to important weights.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `weights` - Current model weights
|
||||
///
|
||||
/// # Returns
|
||||
/// The EWC penalty value (0.0 if not active)
|
||||
pub fn penalty(&self, weights: &[f32]) -> f32 {
|
||||
if !self.active {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
weights.len(),
|
||||
self.anchor_weights.len(),
|
||||
"Weight count must match anchor weights"
|
||||
);
|
||||
|
||||
let mut penalty = 0.0;
|
||||
|
||||
for i in 0..weights.len() {
|
||||
let diff = weights[i] - self.anchor_weights[i];
|
||||
penalty += self.fisher_diag[i] * diff * diff;
|
||||
}
|
||||
|
||||
// Multiply by λ/2
|
||||
penalty * self.lambda * 0.5
|
||||
}
|
||||
|
||||
/// Compute EWC gradient
|
||||
///
|
||||
/// Returns: λ * F_i * (θ_i - θ*_i) for each weight i
|
||||
///
|
||||
/// This gradient is added to the model gradients during training
|
||||
/// to push weights back toward their anchor values.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `weights` - Current model weights
|
||||
///
|
||||
/// # Returns
|
||||
/// Gradient vector (all zeros if not active)
|
||||
pub fn gradient(&self, weights: &[f32]) -> Vec<f32> {
|
||||
if !self.active {
|
||||
return vec![0.0; weights.len()];
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
weights.len(),
|
||||
self.anchor_weights.len(),
|
||||
"Weight count must match anchor weights"
|
||||
);
|
||||
|
||||
let mut grad = Vec::with_capacity(weights.len());
|
||||
|
||||
for i in 0..weights.len() {
|
||||
let diff = weights[i] - self.anchor_weights[i];
|
||||
grad.push(self.lambda * self.fisher_diag[i] * diff);
|
||||
}
|
||||
|
||||
grad
|
||||
}
|
||||
|
||||
/// Check if EWC is active
|
||||
///
|
||||
/// # Returns
|
||||
/// true if consolidate() has been called, false otherwise
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.active
|
||||
}
|
||||
|
||||
/// Get the regularization strength
|
||||
pub fn lambda(&self) -> f32 {
|
||||
self.lambda
|
||||
}
|
||||
|
||||
/// Update the regularization strength
|
||||
pub fn set_lambda(&mut self, lambda: f32) {
|
||||
assert!(lambda >= 0.0, "Lambda must be non-negative");
|
||||
self.lambda = lambda;
|
||||
}
|
||||
|
||||
/// Get the Fisher information diagonal
|
||||
pub fn fisher_diag(&self) -> &[f32] {
|
||||
&self.fisher_diag
|
||||
}
|
||||
|
||||
/// Get the anchor weights
|
||||
pub fn anchor_weights(&self) -> &[f32] {
|
||||
&self.anchor_weights
|
||||
}
|
||||
|
||||
/// Reset EWC to inactive state
|
||||
pub fn reset(&mut self) {
|
||||
self.fisher_diag.clear();
|
||||
self.anchor_weights.clear();
|
||||
self.active = false;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
let ewc = ElasticWeightConsolidation::new(1000.0);
|
||||
assert_eq!(ewc.lambda(), 1000.0);
|
||||
assert!(!ewc.is_active());
|
||||
assert!(ewc.fisher_diag().is_empty());
|
||||
assert!(ewc.anchor_weights().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Lambda must be non-negative")]
|
||||
fn test_new_negative_lambda() {
|
||||
ElasticWeightConsolidation::new(-1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_fisher_single_sample() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Single gradient: [1.0, 2.0, 3.0]
|
||||
let grad1 = vec![1.0, 2.0, 3.0];
|
||||
let gradients = vec![grad1.as_slice()];
|
||||
|
||||
ewc.compute_fisher(&gradients, 1);
|
||||
|
||||
// Fisher should be squared gradients
|
||||
assert_eq!(ewc.fisher_diag(), &[1.0, 4.0, 9.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_fisher_multiple_samples() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Two gradients
|
||||
let grad1 = vec![1.0, 2.0, 3.0];
|
||||
let grad2 = vec![2.0, 1.0, 1.0];
|
||||
let gradients = vec![grad1.as_slice(), grad2.as_slice()];
|
||||
|
||||
ewc.compute_fisher(&gradients, 2);
|
||||
|
||||
// Fisher should be mean of squared gradients
|
||||
// Position 0: (1² + 2²) / 2 = 2.5
|
||||
// Position 1: (2² + 1²) / 2 = 2.5
|
||||
// Position 2: (3² + 1²) / 2 = 5.0
|
||||
let expected = vec![2.5, 2.5, 5.0];
|
||||
assert_eq!(ewc.fisher_diag().len(), expected.len());
|
||||
for (actual, exp) in ewc.fisher_diag().iter().zip(expected.iter()) {
|
||||
assert!((actual - exp).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_fisher_accumulates() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// First computation
|
||||
let grad1 = vec![1.0, 2.0];
|
||||
ewc.compute_fisher(&[grad1.as_slice()], 1);
|
||||
assert_eq!(ewc.fisher_diag(), &[1.0, 4.0]);
|
||||
|
||||
// Second computation accumulates on top of first
|
||||
// When fisher_diag has same length, it's reset to zero first in compute_fisher
|
||||
// then accumulates: 0 + 2^2 = 4, 0 + 1^2 = 1
|
||||
// normalized by 1/1 = 4.0, 1.0
|
||||
let grad2 = vec![2.0, 1.0];
|
||||
ewc.compute_fisher(&[grad2.as_slice()], 1);
|
||||
// Fisher is reset and recomputed with new gradients
|
||||
assert_eq!(ewc.fisher_diag(), &[4.0, 1.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "All gradient vectors must have the same length")]
|
||||
fn test_compute_fisher_mismatched_sizes() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
let grad1 = vec![1.0, 2.0];
|
||||
let grad2 = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad1.as_slice(), grad2.as_slice()], 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_consolidate() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Setup Fisher information
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
// Consolidate weights
|
||||
let weights = vec![0.5, 1.0, 1.5];
|
||||
ewc.consolidate(&weights);
|
||||
|
||||
assert!(ewc.is_active());
|
||||
assert_eq!(ewc.anchor_weights(), &weights);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Must compute Fisher information before consolidating")]
|
||||
fn test_consolidate_without_fisher() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
let weights = vec![1.0, 2.0];
|
||||
ewc.consolidate(&weights);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Weight count must match Fisher information size")]
|
||||
fn test_consolidate_size_mismatch() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
let grad = vec![1.0, 2.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let weights = vec![1.0, 2.0, 3.0]; // Wrong size
|
||||
ewc.consolidate(&weights);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_penalty_inactive() {
|
||||
let ewc = ElasticWeightConsolidation::new(100.0);
|
||||
let weights = vec![1.0, 2.0, 3.0];
|
||||
|
||||
assert_eq!(ewc.penalty(&weights), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_penalty_no_deviation() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Setup
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let weights = vec![0.5, 1.0, 1.5];
|
||||
ewc.consolidate(&weights);
|
||||
|
||||
// Penalty should be 0 when weights match anchor
|
||||
assert_eq!(ewc.penalty(&weights), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_penalty_with_deviation() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Fisher diagonal: [1.0, 4.0, 9.0]
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
// Anchor weights: [0.0, 0.0, 0.0]
|
||||
let anchor = vec![0.0, 0.0, 0.0];
|
||||
ewc.consolidate(&anchor);
|
||||
|
||||
// Current weights: [1.0, 1.0, 1.0]
|
||||
let weights = vec![1.0, 1.0, 1.0];
|
||||
|
||||
// Penalty = λ/2 * Σ F_i * (w_i - w*_i)²
|
||||
// = 100/2 * (1.0 * 1² + 4.0 * 1² + 9.0 * 1²)
|
||||
// = 50 * 14 = 700
|
||||
let penalty = ewc.penalty(&weights);
|
||||
assert!((penalty - 700.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_penalty_increases_with_deviation() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
let grad = vec![1.0, 1.0, 1.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let anchor = vec![0.0, 0.0, 0.0];
|
||||
ewc.consolidate(&anchor);
|
||||
|
||||
// Small deviation
|
||||
let weights1 = vec![0.1, 0.1, 0.1];
|
||||
let penalty1 = ewc.penalty(&weights1);
|
||||
|
||||
// Larger deviation
|
||||
let weights2 = vec![0.5, 0.5, 0.5];
|
||||
let penalty2 = ewc.penalty(&weights2);
|
||||
|
||||
// Penalty should increase
|
||||
assert!(penalty2 > penalty1);
|
||||
|
||||
// Penalty should scale quadratically
|
||||
// (0.5/0.1)² = 25
|
||||
assert!((penalty2 / penalty1 - 25.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_inactive() {
|
||||
let ewc = ElasticWeightConsolidation::new(100.0);
|
||||
let weights = vec![1.0, 2.0, 3.0];
|
||||
|
||||
let grad = ewc.gradient(&weights);
|
||||
assert_eq!(grad, vec![0.0, 0.0, 0.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_no_deviation() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let weights = vec![0.5, 1.0, 1.5];
|
||||
ewc.consolidate(&weights);
|
||||
|
||||
// Gradient should be 0 when weights match anchor
|
||||
let grad = ewc.gradient(&weights);
|
||||
assert_eq!(grad, vec![0.0, 0.0, 0.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_points_toward_anchor() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Fisher diagonal: [1.0, 4.0, 9.0]
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
// Anchor at origin
|
||||
let anchor = vec![0.0, 0.0, 0.0];
|
||||
ewc.consolidate(&anchor);
|
||||
|
||||
// Weights moved positive
|
||||
let weights = vec![1.0, 1.0, 1.0];
|
||||
|
||||
// Gradient = λ * F_i * (w_i - w*_i)
|
||||
// = 100 * [1.0, 4.0, 9.0] * [1.0, 1.0, 1.0]
|
||||
// = [100, 400, 900]
|
||||
let grad = ewc.gradient(&weights);
|
||||
assert_eq!(grad.len(), 3);
|
||||
assert!((grad[0] - 100.0).abs() < 1e-4);
|
||||
assert!((grad[1] - 400.0).abs() < 1e-4);
|
||||
assert!((grad[2] - 900.0).abs() < 1e-4);
|
||||
|
||||
// Weights moved negative
|
||||
let weights = vec![-1.0, -1.0, -1.0];
|
||||
let grad = ewc.gradient(&weights);
|
||||
|
||||
// Gradient should point opposite direction (toward anchor)
|
||||
assert!(grad[0] < 0.0);
|
||||
assert!(grad[1] < 0.0);
|
||||
assert!(grad[2] < 0.0);
|
||||
assert!((grad[0] + 100.0).abs() < 1e-4);
|
||||
assert!((grad[1] + 400.0).abs() < 1e-4);
|
||||
assert!((grad[2] + 900.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_magnitude_scales_with_fisher() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Fisher with varying importance
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let anchor = vec![0.0, 0.0, 0.0];
|
||||
ewc.consolidate(&anchor);
|
||||
|
||||
let weights = vec![1.0, 1.0, 1.0];
|
||||
let grad = ewc.gradient(&weights);
|
||||
|
||||
// Gradient magnitude should increase with Fisher importance
|
||||
assert!(grad[0].abs() < grad[1].abs());
|
||||
assert!(grad[1].abs() < grad[2].abs());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lambda_scaling() {
|
||||
let mut ewc1 = ElasticWeightConsolidation::new(100.0);
|
||||
let mut ewc2 = ElasticWeightConsolidation::new(200.0);
|
||||
|
||||
// Same setup for both
|
||||
let grad = vec![1.0, 1.0, 1.0];
|
||||
ewc1.compute_fisher(&[grad.as_slice()], 1);
|
||||
ewc2.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let anchor = vec![0.0, 0.0, 0.0];
|
||||
ewc1.consolidate(&anchor);
|
||||
ewc2.consolidate(&anchor);
|
||||
|
||||
let weights = vec![1.0, 1.0, 1.0];
|
||||
|
||||
// Penalty and gradient should scale with lambda
|
||||
let penalty1 = ewc1.penalty(&weights);
|
||||
let penalty2 = ewc2.penalty(&weights);
|
||||
assert!((penalty2 / penalty1 - 2.0).abs() < 1e-4);
|
||||
|
||||
let grad1 = ewc1.gradient(&weights);
|
||||
let grad2 = ewc2.gradient(&weights);
|
||||
assert!((grad2[0] / grad1[0] - 2.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_lambda() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
assert_eq!(ewc.lambda(), 100.0);
|
||||
|
||||
ewc.set_lambda(500.0);
|
||||
assert_eq!(ewc.lambda(), 500.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Lambda must be non-negative")]
|
||||
fn test_set_lambda_negative() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
ewc.set_lambda(-10.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Setup active EWC
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let weights = vec![0.5, 1.0, 1.5];
|
||||
ewc.consolidate(&weights);
|
||||
|
||||
assert!(ewc.is_active());
|
||||
|
||||
// Reset
|
||||
ewc.reset();
|
||||
|
||||
assert!(!ewc.is_active());
|
||||
assert!(ewc.fisher_diag().is_empty());
|
||||
assert!(ewc.anchor_weights().is_empty());
|
||||
assert_eq!(ewc.lambda(), 100.0); // Lambda preserved
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sequential_task_learning() {
|
||||
// Simulate learning two tasks sequentially
|
||||
let mut ewc = ElasticWeightConsolidation::new(1000.0);
|
||||
|
||||
// Task 1: Learn weights [1.0, 2.0, 3.0]
|
||||
let task1_grad = vec![2.0, 1.0, 3.0];
|
||||
ewc.compute_fisher(&[task1_grad.as_slice()], 1);
|
||||
|
||||
let task1_weights = vec![1.0, 2.0, 3.0];
|
||||
ewc.consolidate(&task1_weights);
|
||||
|
||||
// Task 2: Try to learn very different weights
|
||||
let task2_weights = vec![5.0, 6.0, 7.0];
|
||||
|
||||
// EWC penalty should be significant
|
||||
let penalty = ewc.penalty(&task2_weights);
|
||||
assert!(penalty > 10000.0); // Large penalty for large deviation
|
||||
|
||||
// Gradient should point back toward task 1 weights
|
||||
let grad = ewc.gradient(&task2_weights);
|
||||
assert!(grad[0] > 0.0); // Push toward lower value
|
||||
assert!(grad[1] > 0.0);
|
||||
assert!(grad[2] > 0.0);
|
||||
}
|
||||
}
|
||||
546
vendor/ruvector/crates/ruvector-gnn/src/layer.rs
vendored
Normal file
546
vendor/ruvector/crates/ruvector-gnn/src/layer.rs
vendored
Normal file
@@ -0,0 +1,546 @@
|
||||
//! GNN Layer Implementation for HNSW Topology
|
||||
//!
|
||||
//! This module implements graph neural network layers that operate on HNSW graph structure,
|
||||
//! including attention mechanisms, normalization, and gated recurrent updates.
|
||||
|
||||
use crate::error::GnnError;
|
||||
use ndarray::{Array1, Array2, ArrayView1};
|
||||
use rand::Rng;
|
||||
use rand_distr::{Distribution, Normal};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Linear transformation layer (weight matrix multiplication)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Linear {
|
||||
weights: Array2<f32>,
|
||||
bias: Array1<f32>,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
/// Create a new linear layer with Xavier/Glorot initialization
|
||||
pub fn new(input_dim: usize, output_dim: usize) -> Self {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// Xavier initialization: scale = sqrt(2.0 / (input_dim + output_dim))
|
||||
let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
|
||||
let normal = Normal::new(0.0, scale as f64).unwrap();
|
||||
|
||||
let weights =
|
||||
Array2::from_shape_fn((output_dim, input_dim), |_| normal.sample(&mut rng) as f32);
|
||||
|
||||
let bias = Array1::zeros(output_dim);
|
||||
|
||||
Self { weights, bias }
|
||||
}
|
||||
|
||||
/// Forward pass: y = Wx + b
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
let x = ArrayView1::from(input);
|
||||
let output = self.weights.dot(&x) + &self.bias;
|
||||
output.to_vec()
|
||||
}
|
||||
|
||||
/// Get output dimension
|
||||
pub fn output_dim(&self) -> usize {
|
||||
self.weights.shape()[0]
|
||||
}
|
||||
}
|
||||
|
||||
/// Layer normalization
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LayerNorm {
|
||||
gamma: Array1<f32>,
|
||||
beta: Array1<f32>,
|
||||
eps: f32,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
/// Create a new layer normalization layer
|
||||
pub fn new(dim: usize, eps: f32) -> Self {
|
||||
Self {
|
||||
gamma: Array1::ones(dim),
|
||||
beta: Array1::zeros(dim),
|
||||
eps,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass: normalize and scale
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
let x = ArrayView1::from(input);
|
||||
|
||||
// Compute mean and variance
|
||||
let mean = x.mean().unwrap_or(0.0);
|
||||
let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
|
||||
|
||||
// Normalize
|
||||
let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt());
|
||||
|
||||
// Scale and shift
|
||||
let output = &self.gamma * &normalized + &self.beta;
|
||||
output.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-head attention mechanism
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MultiHeadAttention {
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
q_linear: Linear,
|
||||
k_linear: Linear,
|
||||
v_linear: Linear,
|
||||
out_linear: Linear,
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
/// Create a new multi-head attention layer
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `GnnError::LayerConfig` if `embed_dim` is not divisible by `num_heads`.
|
||||
pub fn new(embed_dim: usize, num_heads: usize) -> Result<Self, GnnError> {
|
||||
if embed_dim % num_heads != 0 {
|
||||
return Err(GnnError::layer_config(format!(
|
||||
"Embedding dimension ({}) must be divisible by number of heads ({})",
|
||||
embed_dim, num_heads
|
||||
)));
|
||||
}
|
||||
|
||||
let head_dim = embed_dim / num_heads;
|
||||
|
||||
Ok(Self {
|
||||
num_heads,
|
||||
head_dim,
|
||||
q_linear: Linear::new(embed_dim, embed_dim),
|
||||
k_linear: Linear::new(embed_dim, embed_dim),
|
||||
v_linear: Linear::new(embed_dim, embed_dim),
|
||||
out_linear: Linear::new(embed_dim, embed_dim),
|
||||
})
|
||||
}
|
||||
|
||||
/// Forward pass: compute multi-head attention
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `keys` - Key vectors from neighbors
|
||||
/// * `values` - Value vectors from neighbors
|
||||
///
|
||||
/// # Returns
|
||||
/// Attention-weighted output vector
|
||||
pub fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32> {
|
||||
if keys.is_empty() || values.is_empty() {
|
||||
return query.to_vec();
|
||||
}
|
||||
|
||||
// Project query, keys, and values
|
||||
let q = self.q_linear.forward(query);
|
||||
let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
|
||||
let v: Vec<Vec<f32>> = values.iter().map(|v| self.v_linear.forward(v)).collect();
|
||||
|
||||
// Reshape for multi-head attention
|
||||
let q_heads = self.split_heads(&q);
|
||||
let k_heads: Vec<Vec<Vec<f32>>> = k.iter().map(|k_vec| self.split_heads(k_vec)).collect();
|
||||
let v_heads: Vec<Vec<Vec<f32>>> = v.iter().map(|v_vec| self.split_heads(v_vec)).collect();
|
||||
|
||||
// Compute attention for each head
|
||||
let mut head_outputs = Vec::new();
|
||||
for h in 0..self.num_heads {
|
||||
let q_h = &q_heads[h];
|
||||
let k_h: Vec<&Vec<f32>> = k_heads.iter().map(|heads| &heads[h]).collect();
|
||||
let v_h: Vec<&Vec<f32>> = v_heads.iter().map(|heads| &heads[h]).collect();
|
||||
|
||||
let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h);
|
||||
head_outputs.push(head_output);
|
||||
}
|
||||
|
||||
// Concatenate heads
|
||||
let concat: Vec<f32> = head_outputs.into_iter().flatten().collect();
|
||||
|
||||
// Final linear projection
|
||||
self.out_linear.forward(&concat)
|
||||
}
|
||||
|
||||
/// Split vector into multiple heads
|
||||
fn split_heads(&self, x: &[f32]) -> Vec<Vec<f32>> {
|
||||
let mut heads = Vec::new();
|
||||
for h in 0..self.num_heads {
|
||||
let start = h * self.head_dim;
|
||||
let end = start + self.head_dim;
|
||||
heads.push(x[start..end].to_vec());
|
||||
}
|
||||
heads
|
||||
}
|
||||
|
||||
/// Scaled dot-product attention
|
||||
fn scaled_dot_product_attention(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&Vec<f32>],
|
||||
values: &[&Vec<f32>],
|
||||
) -> Vec<f32> {
|
||||
if keys.is_empty() {
|
||||
return query.to_vec();
|
||||
}
|
||||
|
||||
let scale = (self.head_dim as f32).sqrt();
|
||||
|
||||
// Compute attention scores
|
||||
let scores: Vec<f32> = keys
|
||||
.iter()
|
||||
.map(|k| {
|
||||
let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum();
|
||||
dot / scale
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Softmax with epsilon guard against division by zero
|
||||
let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
|
||||
let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
|
||||
let attention_weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
|
||||
|
||||
// Weighted sum of values
|
||||
let mut output = vec![0.0; self.head_dim];
|
||||
for (weight, value) in attention_weights.iter().zip(values.iter()) {
|
||||
for (out, &val) in output.iter_mut().zip(value.iter()) {
|
||||
*out += weight * val;
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
/// Gated Recurrent Unit (GRU) cell for state updates
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GRUCell {
|
||||
// Update gate
|
||||
w_z: Linear,
|
||||
u_z: Linear,
|
||||
|
||||
// Reset gate
|
||||
w_r: Linear,
|
||||
u_r: Linear,
|
||||
|
||||
// Candidate hidden state
|
||||
w_h: Linear,
|
||||
u_h: Linear,
|
||||
}
|
||||
|
||||
impl GRUCell {
|
||||
/// Create a new GRU cell
|
||||
pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
|
||||
Self {
|
||||
// Update gate
|
||||
w_z: Linear::new(input_dim, hidden_dim),
|
||||
u_z: Linear::new(hidden_dim, hidden_dim),
|
||||
|
||||
// Reset gate
|
||||
w_r: Linear::new(input_dim, hidden_dim),
|
||||
u_r: Linear::new(hidden_dim, hidden_dim),
|
||||
|
||||
// Candidate hidden state
|
||||
w_h: Linear::new(input_dim, hidden_dim),
|
||||
u_h: Linear::new(hidden_dim, hidden_dim),
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass: update hidden state
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input` - Current input
|
||||
/// * `hidden` - Previous hidden state
|
||||
///
|
||||
/// # Returns
|
||||
/// Updated hidden state
|
||||
pub fn forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
|
||||
// Update gate: z_t = sigmoid(W_z * x_t + U_z * h_{t-1})
|
||||
let z =
|
||||
self.sigmoid_vec(&self.add_vecs(&self.w_z.forward(input), &self.u_z.forward(hidden)));
|
||||
|
||||
// Reset gate: r_t = sigmoid(W_r * x_t + U_r * h_{t-1})
|
||||
let r =
|
||||
self.sigmoid_vec(&self.add_vecs(&self.w_r.forward(input), &self.u_r.forward(hidden)));
|
||||
|
||||
// Candidate hidden state: h_tilde = tanh(W_h * x_t + U_h * (r_t ⊙ h_{t-1}))
|
||||
let r_hidden = self.mul_vecs(&r, hidden);
|
||||
let h_tilde =
|
||||
self.tanh_vec(&self.add_vecs(&self.w_h.forward(input), &self.u_h.forward(&r_hidden)));
|
||||
|
||||
// Final hidden state: h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h_tilde
|
||||
let one_minus_z: Vec<f32> = z.iter().map(|&zval| 1.0 - zval).collect();
|
||||
let term1 = self.mul_vecs(&one_minus_z, hidden);
|
||||
let term2 = self.mul_vecs(&z, &h_tilde);
|
||||
|
||||
self.add_vecs(&term1, &term2)
|
||||
}
|
||||
|
||||
/// Sigmoid activation with numerical stability
|
||||
fn sigmoid(&self, x: f32) -> f32 {
|
||||
if x > 0.0 {
|
||||
1.0 / (1.0 + (-x).exp())
|
||||
} else {
|
||||
let ex = x.exp();
|
||||
ex / (1.0 + ex)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sigmoid for vectors
|
||||
fn sigmoid_vec(&self, v: &[f32]) -> Vec<f32> {
|
||||
v.iter().map(|&x| self.sigmoid(x)).collect()
|
||||
}
|
||||
|
||||
/// Tanh activation
|
||||
fn tanh(&self, x: f32) -> f32 {
|
||||
x.tanh()
|
||||
}
|
||||
|
||||
/// Tanh for vectors
|
||||
fn tanh_vec(&self, v: &[f32]) -> Vec<f32> {
|
||||
v.iter().map(|&x| self.tanh(x)).collect()
|
||||
}
|
||||
|
||||
/// Element-wise addition
|
||||
fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
|
||||
}
|
||||
|
||||
/// Element-wise multiplication
|
||||
fn mul_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Main GNN layer operating on HNSW topology
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RuvectorLayer {
|
||||
/// Message weight matrix
|
||||
w_msg: Linear,
|
||||
|
||||
/// Aggregation weight matrix
|
||||
w_agg: Linear,
|
||||
|
||||
/// GRU update cell
|
||||
w_update: GRUCell,
|
||||
|
||||
/// Multi-head attention
|
||||
attention: MultiHeadAttention,
|
||||
|
||||
/// Layer normalization
|
||||
norm: LayerNorm,
|
||||
|
||||
/// Dropout rate
|
||||
dropout: f32,
|
||||
}
|
||||
|
||||
impl RuvectorLayer {
|
||||
/// Create a new Ruvector GNN layer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_dim` - Dimension of input node embeddings
|
||||
/// * `hidden_dim` - Dimension of hidden representations
|
||||
/// * `heads` - Number of attention heads
|
||||
/// * `dropout` - Dropout rate (0.0 to 1.0)
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `GnnError::LayerConfig` if `dropout` is outside `[0.0, 1.0]` or
|
||||
/// if `hidden_dim` is not divisible by `heads`.
|
||||
pub fn new(
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
heads: usize,
|
||||
dropout: f32,
|
||||
) -> Result<Self, GnnError> {
|
||||
if !(0.0..=1.0).contains(&dropout) {
|
||||
return Err(GnnError::layer_config(format!(
|
||||
"Dropout must be between 0.0 and 1.0, got {}",
|
||||
dropout
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
w_msg: Linear::new(input_dim, hidden_dim),
|
||||
w_agg: Linear::new(hidden_dim, hidden_dim),
|
||||
w_update: GRUCell::new(hidden_dim, hidden_dim),
|
||||
attention: MultiHeadAttention::new(hidden_dim, heads)?,
|
||||
norm: LayerNorm::new(hidden_dim, 1e-5),
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Forward pass through the GNN layer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_embedding` - Current node's embedding
|
||||
/// * `neighbor_embeddings` - Embeddings of neighbor nodes
|
||||
/// * `edge_weights` - Weights of edges to neighbors (e.g., distances)
|
||||
///
|
||||
/// # Returns
|
||||
/// Updated node embedding
|
||||
pub fn forward(
|
||||
&self,
|
||||
node_embedding: &[f32],
|
||||
neighbor_embeddings: &[Vec<f32>],
|
||||
edge_weights: &[f32],
|
||||
) -> Vec<f32> {
|
||||
if neighbor_embeddings.is_empty() {
|
||||
// No neighbors: return normalized projection
|
||||
let projected = self.w_msg.forward(node_embedding);
|
||||
return self.norm.forward(&projected);
|
||||
}
|
||||
|
||||
// Step 1: Message passing - transform node and neighbor embeddings
|
||||
let node_msg = self.w_msg.forward(node_embedding);
|
||||
let neighbor_msgs: Vec<Vec<f32>> = neighbor_embeddings
|
||||
.iter()
|
||||
.map(|n| self.w_msg.forward(n))
|
||||
.collect();
|
||||
|
||||
// Step 2: Attention-based aggregation
|
||||
let attention_output = self
|
||||
.attention
|
||||
.forward(&node_msg, &neighbor_msgs, &neighbor_msgs);
|
||||
|
||||
// Step 3: Weighted aggregation using edge weights
|
||||
let weighted_msgs = self.aggregate_messages(&neighbor_msgs, edge_weights);
|
||||
|
||||
// Step 4: Combine attention and weighted aggregation
|
||||
let combined = self.add_vecs(&attention_output, &weighted_msgs);
|
||||
let aggregated = self.w_agg.forward(&combined);
|
||||
|
||||
// Step 5: GRU update
|
||||
let updated = self.w_update.forward(&aggregated, &node_msg);
|
||||
|
||||
// Step 6: Apply dropout (simplified - always apply scaling)
|
||||
let dropped = self.apply_dropout(&updated);
|
||||
|
||||
// Step 7: Layer normalization
|
||||
self.norm.forward(&dropped)
|
||||
}
|
||||
|
||||
/// Aggregate neighbor messages with edge weights
|
||||
fn aggregate_messages(&self, messages: &[Vec<f32>], weights: &[f32]) -> Vec<f32> {
|
||||
if messages.is_empty() || weights.is_empty() {
|
||||
return vec![0.0; self.w_msg.output_dim()];
|
||||
}
|
||||
|
||||
// Normalize weights to sum to 1
|
||||
let weight_sum: f32 = weights.iter().sum();
|
||||
let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
|
||||
weights.iter().map(|&w| w / weight_sum).collect()
|
||||
} else {
|
||||
vec![1.0 / weights.len() as f32; weights.len()]
|
||||
};
|
||||
|
||||
// Weighted sum
|
||||
let dim = messages[0].len();
|
||||
let mut aggregated = vec![0.0; dim];
|
||||
|
||||
for (msg, &weight) in messages.iter().zip(normalized_weights.iter()) {
|
||||
for (agg, &m) in aggregated.iter_mut().zip(msg.iter()) {
|
||||
*agg += weight * m;
|
||||
}
|
||||
}
|
||||
|
||||
aggregated
|
||||
}
|
||||
|
||||
/// Apply dropout (simplified version - just scales by (1-dropout))
|
||||
fn apply_dropout(&self, input: &[f32]) -> Vec<f32> {
|
||||
let scale = 1.0 - self.dropout;
|
||||
input.iter().map(|&x| x * scale).collect()
|
||||
}
|
||||
|
||||
/// Element-wise vector addition
|
||||
fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_linear_layer() {
|
||||
let linear = Linear::new(4, 2);
|
||||
let input = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let output = linear.forward(&input);
|
||||
assert_eq!(output.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layer_norm() {
|
||||
let norm = LayerNorm::new(4, 1e-5);
|
||||
let input = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let output = norm.forward(&input);
|
||||
|
||||
// Check that output has zero mean (approximately)
|
||||
let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
|
||||
assert!((mean).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multihead_attention() {
|
||||
let attention = MultiHeadAttention::new(8, 2).unwrap();
|
||||
let query = vec![0.5; 8];
|
||||
let keys = vec![vec![0.3; 8], vec![0.7; 8]];
|
||||
let values = vec![vec![0.2; 8], vec![0.8; 8]];
|
||||
|
||||
let output = attention.forward(&query, &keys, &values);
|
||||
assert_eq!(output.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multihead_attention_invalid_dims() {
|
||||
let result = MultiHeadAttention::new(10, 3);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("divisible"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gru_cell() {
|
||||
let gru = GRUCell::new(4, 8);
|
||||
let input = vec![1.0; 4];
|
||||
let hidden = vec![0.5; 8];
|
||||
|
||||
let new_hidden = gru.forward(&input, &hidden);
|
||||
assert_eq!(new_hidden.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ruvector_layer() {
|
||||
let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
|
||||
|
||||
let node = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let neighbors = vec![vec![0.5, 1.0, 1.5, 2.0], vec![2.0, 3.0, 4.0, 5.0]];
|
||||
let weights = vec![0.3, 0.7];
|
||||
|
||||
let output = layer.forward(&node, &neighbors, &weights);
|
||||
assert_eq!(output.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ruvector_layer_no_neighbors() {
|
||||
let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
|
||||
|
||||
let node = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let neighbors: Vec<Vec<f32>> = vec![];
|
||||
let weights: Vec<f32> = vec![];
|
||||
|
||||
let output = layer.forward(&node, &neighbors, &weights);
|
||||
assert_eq!(output.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ruvector_layer_invalid_dropout() {
|
||||
let result = RuvectorLayer::new(4, 8, 2, 1.5);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ruvector_layer_invalid_heads() {
|
||||
let result = RuvectorLayer::new(4, 7, 3, 0.1);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
92
vendor/ruvector/crates/ruvector-gnn/src/lib.rs
vendored
Normal file
92
vendor/ruvector/crates/ruvector-gnn/src/lib.rs
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
//! # RuVector GNN
|
||||
//!
|
||||
//! Graph Neural Network capabilities for RuVector, providing tensor operations,
|
||||
//! GNN layers, compression, and differentiable search.
|
||||
//!
|
||||
//! ## Forgetting Mitigation (Issue #17)
|
||||
//!
|
||||
//! This crate includes comprehensive forgetting mitigation for continual learning:
|
||||
//!
|
||||
//! - **Adam Optimizer**: Full implementation with momentum and bias correction
|
||||
//! - **Replay Buffer**: Experience replay with reservoir sampling for uniform coverage
|
||||
//! - **EWC (Elastic Weight Consolidation)**: Prevents catastrophic forgetting
|
||||
//! - **Learning Rate Scheduling**: Multiple strategies including warmup and plateau detection
|
||||
//!
|
||||
//! ### Usage Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_gnn::{
|
||||
//! training::{Optimizer, OptimizerType},
|
||||
//! replay::ReplayBuffer,
|
||||
//! ewc::ElasticWeightConsolidation,
|
||||
//! scheduler::{LearningRateScheduler, SchedulerType},
|
||||
//! };
|
||||
//!
|
||||
//! // Create Adam optimizer
|
||||
//! let mut optimizer = Optimizer::new(OptimizerType::Adam {
|
||||
//! learning_rate: 0.001,
|
||||
//! beta1: 0.9,
|
||||
//! beta2: 0.999,
|
||||
//! epsilon: 1e-8,
|
||||
//! });
|
||||
//!
|
||||
//! // Create replay buffer for experience replay
|
||||
//! let mut replay = ReplayBuffer::new(10000);
|
||||
//!
|
||||
//! // Create EWC for preventing forgetting
|
||||
//! let mut ewc = ElasticWeightConsolidation::new(0.4);
|
||||
//!
|
||||
//! // Create learning rate scheduler
|
||||
//! let mut scheduler = LearningRateScheduler::new(
|
||||
//! SchedulerType::CosineAnnealing { t_max: 100, eta_min: 1e-6 },
|
||||
//! 0.001
|
||||
//! );
|
||||
//! ```
|
||||
|
||||
#![warn(missing_docs)]
|
||||
#![deny(unsafe_op_in_unsafe_fn)]
|
||||
|
||||
pub mod compress;
|
||||
pub mod error;
|
||||
pub mod ewc;
|
||||
pub mod layer;
|
||||
pub mod query;
|
||||
pub mod replay;
|
||||
pub mod scheduler;
|
||||
pub mod search;
|
||||
pub mod tensor;
|
||||
pub mod training;
|
||||
|
||||
#[cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
|
||||
pub mod mmap;
|
||||
|
||||
#[cfg(all(feature = "cold-tier", not(target_arch = "wasm32")))]
|
||||
pub mod cold_tier;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use compress::{CompressedTensor, CompressionLevel, TensorCompress};
|
||||
pub use error::{GnnError, Result};
|
||||
pub use ewc::ElasticWeightConsolidation;
|
||||
pub use layer::RuvectorLayer;
|
||||
pub use query::{QueryMode, QueryResult, RuvectorQuery, SubGraph};
|
||||
pub use replay::{DistributionStats, ReplayBuffer, ReplayEntry};
|
||||
pub use scheduler::{LearningRateScheduler, SchedulerType};
|
||||
pub use search::{cosine_similarity, differentiable_search, hierarchical_forward};
|
||||
pub use training::{
|
||||
info_nce_loss, local_contrastive_loss, sgd_step, Loss, LossType, OnlineConfig, Optimizer,
|
||||
OptimizerType, TrainConfig,
|
||||
};
|
||||
|
||||
#[cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
|
||||
pub use mmap::{AtomicBitmap, MmapGradientAccumulator, MmapManager};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_basic() {
|
||||
// Basic smoke test to ensure the crate compiles
|
||||
assert!(true);
|
||||
}
|
||||
}
|
||||
939
vendor/ruvector/crates/ruvector-gnn/src/mmap.rs
vendored
Normal file
939
vendor/ruvector/crates/ruvector-gnn/src/mmap.rs
vendored
Normal file
@@ -0,0 +1,939 @@
|
||||
//! Memory-mapped embedding management for large-scale GNN training.
|
||||
//!
|
||||
//! This module provides efficient memory-mapped access to embeddings and gradients
|
||||
//! that don't fit in RAM. It includes:
|
||||
//! - `MmapManager`: Memory-mapped embedding storage with dirty tracking
|
||||
//! - `MmapGradientAccumulator`: Lock-free gradient accumulation
|
||||
//! - `AtomicBitmap`: Thread-safe bitmap for access/dirty tracking
|
||||
//!
|
||||
//! Only available on non-WASM targets.
|
||||
|
||||
#![cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
|
||||
|
||||
use crate::error::{GnnError, Result};
|
||||
use memmap2::{MmapMut, MmapOptions};
|
||||
use parking_lot::RwLock;
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::{self, Write};
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
|
||||
|
||||
/// Thread-safe bitmap using atomic operations.
|
||||
///
|
||||
/// Used for tracking which embeddings have been accessed or modified.
|
||||
/// Each bit represents one embedding node.
|
||||
#[derive(Debug)]
|
||||
pub struct AtomicBitmap {
|
||||
/// Array of 64-bit atomic integers, each storing 64 bits
|
||||
bits: Vec<AtomicU64>,
|
||||
/// Total number of bits (nodes)
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl AtomicBitmap {
|
||||
/// Create a new atomic bitmap with the specified capacity.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `size` - Number of bits to allocate
|
||||
pub fn new(size: usize) -> Self {
|
||||
let num_words = (size + 63) / 64;
|
||||
let bits = (0..num_words).map(|_| AtomicU64::new(0)).collect();
|
||||
|
||||
Self { bits, size }
|
||||
}
|
||||
|
||||
/// Set a bit to 1 (mark as accessed/dirty).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `index` - Bit index to set
|
||||
pub fn set(&self, index: usize) {
|
||||
if index >= self.size {
|
||||
return;
|
||||
}
|
||||
let word_idx = index / 64;
|
||||
let bit_idx = index % 64;
|
||||
self.bits[word_idx].fetch_or(1u64 << bit_idx, Ordering::Release);
|
||||
}
|
||||
|
||||
/// Clear a bit to 0 (mark as clean/not accessed).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `index` - Bit index to clear
|
||||
pub fn clear(&self, index: usize) {
|
||||
if index >= self.size {
|
||||
return;
|
||||
}
|
||||
let word_idx = index / 64;
|
||||
let bit_idx = index % 64;
|
||||
self.bits[word_idx].fetch_and(!(1u64 << bit_idx), Ordering::Release);
|
||||
}
|
||||
|
||||
/// Check if a bit is set.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `index` - Bit index to check
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the bit is set, `false` otherwise
|
||||
pub fn get(&self, index: usize) -> bool {
|
||||
if index >= self.size {
|
||||
return false;
|
||||
}
|
||||
let word_idx = index / 64;
|
||||
let bit_idx = index % 64;
|
||||
let word = self.bits[word_idx].load(Ordering::Acquire);
|
||||
(word & (1u64 << bit_idx)) != 0
|
||||
}
|
||||
|
||||
/// Clear all bits in the bitmap.
|
||||
pub fn clear_all(&self) {
|
||||
for word in &self.bits {
|
||||
word.store(0, Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all set bit indices (for finding dirty pages).
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of indices where bits are set
|
||||
pub fn get_set_indices(&self) -> Vec<usize> {
|
||||
let mut indices = Vec::new();
|
||||
for (word_idx, word) in self.bits.iter().enumerate() {
|
||||
let mut w = word.load(Ordering::Acquire);
|
||||
while w != 0 {
|
||||
let bit_idx = w.trailing_zeros() as usize;
|
||||
indices.push(word_idx * 64 + bit_idx);
|
||||
w &= w - 1; // Clear lowest set bit
|
||||
}
|
||||
}
|
||||
indices
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory-mapped embedding manager with dirty tracking and prefetching.
|
||||
///
|
||||
/// Manages large embedding matrices that may not fit in RAM using memory-mapped files.
|
||||
/// Tracks which embeddings have been accessed and modified for efficient I/O.
|
||||
#[derive(Debug)]
|
||||
pub struct MmapManager {
|
||||
/// The memory-mapped file
|
||||
file: File,
|
||||
/// Mutable memory mapping
|
||||
mmap: MmapMut,
|
||||
/// Operating system page size
|
||||
page_size: usize,
|
||||
/// Embedding dimension
|
||||
d_embed: usize,
|
||||
/// Bitmap tracking which embeddings have been accessed
|
||||
access_bitmap: AtomicBitmap,
|
||||
/// Bitmap tracking which embeddings have been modified
|
||||
dirty_bitmap: AtomicBitmap,
|
||||
/// Pin count for each page (prevents eviction)
|
||||
pin_count: Vec<AtomicU32>,
|
||||
/// Maximum number of nodes
|
||||
max_nodes: usize,
|
||||
}
|
||||
|
||||
impl MmapManager {
|
||||
/// Create a new memory-mapped embedding manager.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the memory-mapped file
|
||||
/// * `d_embed` - Embedding dimension
|
||||
/// * `max_nodes` - Maximum number of nodes to support
|
||||
///
|
||||
/// # Returns
|
||||
/// A new `MmapManager` instance
|
||||
pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
|
||||
// Calculate required file size
|
||||
let embedding_size = d_embed * std::mem::size_of::<f32>();
|
||||
let file_size = max_nodes * embedding_size;
|
||||
|
||||
// Create or open the file
|
||||
let file = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(true)
|
||||
.open(path)
|
||||
.map_err(|e| GnnError::mmap(format!("Failed to open mmap file: {}", e)))?;
|
||||
|
||||
// Set file size
|
||||
file.set_len(file_size as u64)
|
||||
.map_err(|e| GnnError::mmap(format!("Failed to set file size: {}", e)))?;
|
||||
|
||||
// Create memory mapping
|
||||
let mmap = unsafe {
|
||||
MmapOptions::new()
|
||||
.len(file_size)
|
||||
.map_mut(&file)
|
||||
.map_err(|e| GnnError::mmap(format!("Failed to create mmap: {}", e)))?
|
||||
};
|
||||
|
||||
// Get system page size
|
||||
let page_size = page_size::get();
|
||||
let num_pages = (file_size + page_size - 1) / page_size;
|
||||
|
||||
Ok(Self {
|
||||
file,
|
||||
mmap,
|
||||
page_size,
|
||||
d_embed,
|
||||
access_bitmap: AtomicBitmap::new(max_nodes),
|
||||
dirty_bitmap: AtomicBitmap::new(max_nodes),
|
||||
pin_count: (0..num_pages).map(|_| AtomicU32::new(0)).collect(),
|
||||
max_nodes,
|
||||
})
|
||||
}
|
||||
|
||||
/// Calculate the byte offset for a given node's embedding.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_id` - Node identifier
|
||||
///
|
||||
/// # Returns
|
||||
/// Byte offset in the memory-mapped file, or None if overflow would occur
|
||||
///
|
||||
/// # Security
|
||||
/// Uses checked arithmetic to prevent integer overflow attacks.
|
||||
#[inline]
|
||||
pub fn embedding_offset(&self, node_id: u64) -> Option<usize> {
|
||||
let node_idx = usize::try_from(node_id).ok()?;
|
||||
let elem_size = std::mem::size_of::<f32>();
|
||||
node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
|
||||
}
|
||||
|
||||
/// Validate that a node_id is within bounds.
|
||||
#[inline]
|
||||
fn validate_node_id(&self, node_id: u64) -> bool {
|
||||
(node_id as usize) < self.max_nodes
|
||||
}
|
||||
|
||||
/// Get a read-only reference to a node's embedding.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_id` - Node identifier
|
||||
///
|
||||
/// # Returns
|
||||
/// Slice containing the embedding vector
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if node_id is out of bounds or would cause overflow
|
||||
pub fn get_embedding(&self, node_id: u64) -> &[f32] {
|
||||
// Security: Validate bounds before any pointer arithmetic
|
||||
assert!(
|
||||
self.validate_node_id(node_id),
|
||||
"node_id {} out of bounds (max: {})",
|
||||
node_id,
|
||||
self.max_nodes
|
||||
);
|
||||
|
||||
let offset = self
|
||||
.embedding_offset(node_id)
|
||||
.expect("embedding offset calculation overflow");
|
||||
let end = offset
|
||||
.checked_add(
|
||||
self.d_embed
|
||||
.checked_mul(std::mem::size_of::<f32>())
|
||||
.unwrap(),
|
||||
)
|
||||
.expect("end offset overflow");
|
||||
assert!(
|
||||
end <= self.mmap.len(),
|
||||
"embedding extends beyond mmap bounds"
|
||||
);
|
||||
|
||||
// Mark as accessed
|
||||
self.access_bitmap.set(node_id as usize);
|
||||
|
||||
// Safety: We control the offset and know the data is properly aligned
|
||||
unsafe {
|
||||
let ptr = self.mmap.as_ptr().add(offset) as *const f32;
|
||||
std::slice::from_raw_parts(ptr, self.d_embed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a node's embedding data.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_id` - Node identifier
|
||||
/// * `data` - Embedding vector to write
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if node_id is out of bounds, data length doesn't match d_embed,
|
||||
/// or offset calculation would overflow.
|
||||
pub fn set_embedding(&mut self, node_id: u64, data: &[f32]) {
|
||||
// Security: Validate bounds first
|
||||
assert!(
|
||||
self.validate_node_id(node_id),
|
||||
"node_id {} out of bounds (max: {})",
|
||||
node_id,
|
||||
self.max_nodes
|
||||
);
|
||||
assert_eq!(
|
||||
data.len(),
|
||||
self.d_embed,
|
||||
"Embedding data length must match d_embed"
|
||||
);
|
||||
|
||||
let offset = self
|
||||
.embedding_offset(node_id)
|
||||
.expect("embedding offset calculation overflow");
|
||||
let end = offset
|
||||
.checked_add(data.len().checked_mul(std::mem::size_of::<f32>()).unwrap())
|
||||
.expect("end offset overflow");
|
||||
assert!(
|
||||
end <= self.mmap.len(),
|
||||
"embedding extends beyond mmap bounds"
|
||||
);
|
||||
|
||||
// Mark as accessed and dirty
|
||||
self.access_bitmap.set(node_id as usize);
|
||||
self.dirty_bitmap.set(node_id as usize);
|
||||
|
||||
// Safety: We control the offset and know the data is properly aligned
|
||||
unsafe {
|
||||
let ptr = self.mmap.as_mut_ptr().add(offset) as *mut f32;
|
||||
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, self.d_embed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Flush all dirty pages to disk.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(())` on success, error otherwise
|
||||
pub fn flush_dirty(&self) -> io::Result<()> {
|
||||
let dirty_nodes = self.dirty_bitmap.get_set_indices();
|
||||
|
||||
if dirty_nodes.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Flush the entire mmap for simplicity
|
||||
// In a production system, you might want to flush only dirty pages
|
||||
self.mmap.flush()?;
|
||||
|
||||
// Clear dirty bitmap after successful flush
|
||||
for &node_id in &dirty_nodes {
|
||||
self.dirty_bitmap.clear(node_id);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prefetch embeddings into memory for better cache locality.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_ids` - List of node IDs to prefetch
|
||||
pub fn prefetch(&self, node_ids: &[u64]) {
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
#[allow(unused_imports)]
|
||||
use std::os::unix::io::AsRawFd;
|
||||
|
||||
for &node_id in node_ids {
|
||||
// Skip invalid node IDs
|
||||
if !self.validate_node_id(node_id) {
|
||||
continue;
|
||||
}
|
||||
let offset = match self.embedding_offset(node_id) {
|
||||
Some(o) => o,
|
||||
None => continue,
|
||||
};
|
||||
let page_offset = (offset / self.page_size) * self.page_size;
|
||||
let length = self.d_embed * std::mem::size_of::<f32>();
|
||||
|
||||
unsafe {
|
||||
// Use madvise to hint the kernel to prefetch
|
||||
libc::madvise(
|
||||
self.mmap.as_ptr().add(page_offset) as *mut libc::c_void,
|
||||
length,
|
||||
libc::MADV_WILLNEED,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// On non-Linux platforms, just access the data to bring it into cache
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
{
|
||||
for &node_id in node_ids {
|
||||
if self.validate_node_id(node_id) {
|
||||
let _ = self.get_embedding(node_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the embedding dimension.
|
||||
pub fn d_embed(&self) -> usize {
|
||||
self.d_embed
|
||||
}
|
||||
|
||||
/// Get the maximum number of nodes.
|
||||
pub fn max_nodes(&self) -> usize {
|
||||
self.max_nodes
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory-mapped gradient accumulator with fine-grained locking.
|
||||
///
|
||||
/// Allows multiple threads to accumulate gradients concurrently with minimal contention.
|
||||
/// Uses reader-writer locks at a configurable granularity.
|
||||
pub struct MmapGradientAccumulator {
|
||||
/// Memory-mapped gradient storage (using UnsafeCell for interior mutability)
|
||||
grad_mmap: std::cell::UnsafeCell<MmapMut>,
|
||||
/// Number of nodes per lock (lock granularity)
|
||||
lock_granularity: usize,
|
||||
/// Reader-writer locks for gradient regions
|
||||
locks: Vec<RwLock<()>>,
|
||||
/// Number of nodes
|
||||
n_nodes: usize,
|
||||
/// Embedding dimension
|
||||
d_embed: usize,
|
||||
/// Gradient file
|
||||
_file: File,
|
||||
}
|
||||
|
||||
impl MmapGradientAccumulator {
|
||||
/// Create a new memory-mapped gradient accumulator.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the gradient file
|
||||
/// * `d_embed` - Embedding dimension
|
||||
/// * `max_nodes` - Maximum number of nodes
|
||||
///
|
||||
/// # Returns
|
||||
/// A new `MmapGradientAccumulator` instance
|
||||
pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
|
||||
// Calculate required file size
|
||||
let grad_size = d_embed * std::mem::size_of::<f32>();
|
||||
let file_size = max_nodes * grad_size;
|
||||
|
||||
// Create or open the file
|
||||
let file = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(true)
|
||||
.open(path)
|
||||
.map_err(|e| GnnError::mmap(format!("Failed to open gradient file: {}", e)))?;
|
||||
|
||||
// Set file size
|
||||
file.set_len(file_size as u64)
|
||||
.map_err(|e| GnnError::mmap(format!("Failed to set gradient file size: {}", e)))?;
|
||||
|
||||
// Create memory mapping
|
||||
let grad_mmap = unsafe {
|
||||
MmapOptions::new()
|
||||
.len(file_size)
|
||||
.map_mut(&file)
|
||||
.map_err(|e| GnnError::mmap(format!("Failed to create gradient mmap: {}", e)))?
|
||||
};
|
||||
|
||||
// Zero out the gradients
|
||||
for byte in grad_mmap.iter() {
|
||||
// This forces the pages to be allocated and zeroed
|
||||
let _ = byte;
|
||||
}
|
||||
|
||||
// Use a lock granularity of 64 nodes per lock for good parallelism
|
||||
let lock_granularity = 64;
|
||||
let num_locks = (max_nodes + lock_granularity - 1) / lock_granularity;
|
||||
let locks = (0..num_locks).map(|_| RwLock::new(())).collect();
|
||||
|
||||
Ok(Self {
|
||||
grad_mmap: std::cell::UnsafeCell::new(grad_mmap),
|
||||
lock_granularity,
|
||||
locks,
|
||||
n_nodes: max_nodes,
|
||||
d_embed,
|
||||
_file: file,
|
||||
})
|
||||
}
|
||||
|
||||
/// Calculate the byte offset for a node's gradient.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_id` - Node identifier
|
||||
///
|
||||
/// # Returns
|
||||
/// Byte offset in the gradient file, or None on overflow or out-of-bounds
|
||||
///
|
||||
/// # Security
|
||||
/// Uses checked arithmetic to prevent integer overflow (SEC-001).
|
||||
#[inline]
|
||||
pub fn grad_offset(&self, node_id: u64) -> Option<usize> {
|
||||
let node_idx = usize::try_from(node_id).ok()?;
|
||||
if node_idx >= self.n_nodes {
|
||||
return None;
|
||||
}
|
||||
let elem_size = std::mem::size_of::<f32>();
|
||||
node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
|
||||
}
|
||||
|
||||
/// Accumulate gradients for a specific node.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_id` - Node identifier
|
||||
/// * `grad` - Gradient vector to accumulate
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if grad length doesn't match d_embed
|
||||
pub fn accumulate(&self, node_id: u64, grad: &[f32]) {
|
||||
assert_eq!(
|
||||
grad.len(),
|
||||
self.d_embed,
|
||||
"Gradient length must match d_embed"
|
||||
);
|
||||
|
||||
let offset = self
|
||||
.grad_offset(node_id)
|
||||
.expect("node_id out of bounds or offset overflow");
|
||||
|
||||
let lock_idx = (node_id as usize) / self.lock_granularity;
|
||||
assert!(lock_idx < self.locks.len(), "lock index out of bounds");
|
||||
let _lock = self.locks[lock_idx].write();
|
||||
|
||||
// Safety: We validated node_id bounds and offset above, and hold the write lock
|
||||
unsafe {
|
||||
let mmap = &mut *self.grad_mmap.get();
|
||||
assert!(
|
||||
offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
|
||||
"gradient write would exceed mmap bounds"
|
||||
);
|
||||
let ptr = mmap.as_mut_ptr().add(offset) as *mut f32;
|
||||
let grad_slice = std::slice::from_raw_parts_mut(ptr, self.d_embed);
|
||||
|
||||
// Accumulate gradients
|
||||
for (g, &new_g) in grad_slice.iter_mut().zip(grad.iter()) {
|
||||
*g += new_g;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply accumulated gradients to embeddings and zero out gradients.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `learning_rate` - Learning rate for gradient descent
|
||||
/// * `embeddings` - Embedding manager to update
|
||||
pub fn apply(&mut self, learning_rate: f32, embeddings: &mut MmapManager) {
|
||||
assert_eq!(
|
||||
self.d_embed, embeddings.d_embed,
|
||||
"Gradient and embedding dimensions must match"
|
||||
);
|
||||
|
||||
// Process all nodes
|
||||
for node_id in 0..self.n_nodes.min(embeddings.max_nodes) {
|
||||
let grad = self.get_grad(node_id as u64);
|
||||
let embedding = embeddings.get_embedding(node_id as u64);
|
||||
|
||||
// Apply gradient descent: embedding -= learning_rate * grad
|
||||
let mut updated = vec![0.0f32; self.d_embed];
|
||||
for i in 0..self.d_embed {
|
||||
updated[i] = embedding[i] - learning_rate * grad[i];
|
||||
}
|
||||
|
||||
embeddings.set_embedding(node_id as u64, &updated);
|
||||
}
|
||||
|
||||
// Zero out gradients after applying
|
||||
self.zero_grad();
|
||||
}
|
||||
|
||||
/// Zero out all accumulated gradients.
|
||||
pub fn zero_grad(&mut self) {
|
||||
// Zero the entire gradient buffer
|
||||
unsafe {
|
||||
let mmap = &mut *self.grad_mmap.get();
|
||||
for byte in mmap.iter_mut() {
|
||||
*byte = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a read-only reference to a node's accumulated gradient.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_id` - Node identifier
|
||||
///
|
||||
/// # Returns
|
||||
/// Slice containing the gradient vector
|
||||
pub fn get_grad(&self, node_id: u64) -> &[f32] {
|
||||
let offset = self
|
||||
.grad_offset(node_id)
|
||||
.expect("node_id out of bounds or offset overflow");
|
||||
|
||||
let lock_idx = (node_id as usize) / self.lock_granularity;
|
||||
assert!(lock_idx < self.locks.len(), "lock index out of bounds");
|
||||
let _lock = self.locks[lock_idx].read();
|
||||
|
||||
// Safety: We validated node_id bounds and offset above, and hold the read lock
|
||||
unsafe {
|
||||
let mmap = &*self.grad_mmap.get();
|
||||
assert!(
|
||||
offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
|
||||
"gradient read would exceed mmap bounds"
|
||||
);
|
||||
let ptr = mmap.as_ptr().add(offset) as *const f32;
|
||||
std::slice::from_raw_parts(ptr, self.d_embed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the embedding dimension.
|
||||
pub fn d_embed(&self) -> usize {
|
||||
self.d_embed
|
||||
}
|
||||
|
||||
/// Get the number of nodes.
|
||||
pub fn n_nodes(&self) -> usize {
|
||||
self.n_nodes
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Drop to ensure proper cleanup
|
||||
impl Drop for MmapManager {
|
||||
fn drop(&mut self) {
|
||||
// Try to flush dirty pages before dropping
|
||||
let _ = self.flush_dirty();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MmapGradientAccumulator {
|
||||
fn drop(&mut self) {
|
||||
// Flush gradient data
|
||||
unsafe {
|
||||
let mmap = &mut *self.grad_mmap.get();
|
||||
let _ = mmap.flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Safety: MmapGradientAccumulator is safe to send between threads
|
||||
// because access is protected by RwLocks
|
||||
unsafe impl Send for MmapGradientAccumulator {}
|
||||
unsafe impl Sync for MmapGradientAccumulator {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_atomic_bitmap_basic() {
|
||||
let bitmap = AtomicBitmap::new(128);
|
||||
|
||||
assert!(!bitmap.get(0));
|
||||
assert!(!bitmap.get(127));
|
||||
|
||||
bitmap.set(0);
|
||||
bitmap.set(127);
|
||||
bitmap.set(64);
|
||||
|
||||
assert!(bitmap.get(0));
|
||||
assert!(bitmap.get(127));
|
||||
assert!(bitmap.get(64));
|
||||
assert!(!bitmap.get(1));
|
||||
|
||||
bitmap.clear(0);
|
||||
assert!(!bitmap.get(0));
|
||||
assert!(bitmap.get(127));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_bitmap_get_set_indices() {
|
||||
let bitmap = AtomicBitmap::new(256);
|
||||
|
||||
bitmap.set(0);
|
||||
bitmap.set(63);
|
||||
bitmap.set(64);
|
||||
bitmap.set(128);
|
||||
bitmap.set(255);
|
||||
|
||||
let mut indices = bitmap.get_set_indices();
|
||||
indices.sort();
|
||||
|
||||
assert_eq!(indices, vec![0, 63, 64, 128, 255]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_bitmap_clear_all() {
|
||||
let bitmap = AtomicBitmap::new(128);
|
||||
|
||||
bitmap.set(0);
|
||||
bitmap.set(64);
|
||||
bitmap.set(127);
|
||||
|
||||
assert!(bitmap.get(0));
|
||||
|
||||
bitmap.clear_all();
|
||||
|
||||
assert!(!bitmap.get(0));
|
||||
assert!(!bitmap.get(64));
|
||||
assert!(!bitmap.get(127));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmap_manager_creation() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("embeddings.bin");
|
||||
|
||||
let manager = MmapManager::new(&path, 128, 1000).unwrap();
|
||||
|
||||
assert_eq!(manager.d_embed(), 128);
|
||||
assert_eq!(manager.max_nodes(), 1000);
|
||||
assert!(path.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmap_manager_set_get_embedding() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("embeddings.bin");
|
||||
|
||||
let mut manager = MmapManager::new(&path, 64, 100).unwrap();
|
||||
|
||||
let embedding = vec![1.0f32; 64];
|
||||
manager.set_embedding(0, &embedding);
|
||||
|
||||
let retrieved = manager.get_embedding(0);
|
||||
assert_eq!(retrieved.len(), 64);
|
||||
assert_eq!(retrieved[0], 1.0);
|
||||
assert_eq!(retrieved[63], 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmap_manager_multiple_embeddings() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("embeddings.bin");
|
||||
|
||||
let mut manager = MmapManager::new(&path, 32, 100).unwrap();
|
||||
|
||||
for i in 0..10 {
|
||||
let embedding: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
|
||||
manager.set_embedding(i, &embedding);
|
||||
}
|
||||
|
||||
// Verify each embedding
|
||||
for i in 0..10 {
|
||||
let retrieved = manager.get_embedding(i);
|
||||
assert_eq!(retrieved.len(), 32);
|
||||
assert_eq!(retrieved[0], (i * 32) as f32);
|
||||
assert_eq!(retrieved[31], (i * 32 + 31) as f32);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmap_manager_dirty_tracking() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("embeddings.bin");
|
||||
|
||||
let mut manager = MmapManager::new(&path, 64, 100).unwrap();
|
||||
|
||||
let embedding = vec![2.0f32; 64];
|
||||
manager.set_embedding(5, &embedding);
|
||||
|
||||
// Should be marked as dirty
|
||||
assert!(manager.dirty_bitmap.get(5));
|
||||
|
||||
// Flush and check it's clean
|
||||
manager.flush_dirty().unwrap();
|
||||
assert!(!manager.dirty_bitmap.get(5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmap_manager_persistence() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("embeddings.bin");
|
||||
|
||||
{
|
||||
let mut manager = MmapManager::new(&path, 64, 100).unwrap();
|
||||
let embedding = vec![3.14f32; 64];
|
||||
manager.set_embedding(10, &embedding);
|
||||
manager.flush_dirty().unwrap();
|
||||
}
|
||||
|
||||
// Reopen and verify data persisted
|
||||
{
|
||||
let manager = MmapManager::new(&path, 64, 100).unwrap();
|
||||
let retrieved = manager.get_embedding(10);
|
||||
assert_eq!(retrieved[0], 3.14);
|
||||
assert_eq!(retrieved[63], 3.14);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_accumulator_creation() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("gradients.bin");
|
||||
|
||||
let accumulator = MmapGradientAccumulator::new(&path, 128, 1000).unwrap();
|
||||
|
||||
assert_eq!(accumulator.d_embed(), 128);
|
||||
assert_eq!(accumulator.n_nodes(), 1000);
|
||||
assert!(path.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_accumulator_accumulate() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("gradients.bin");
|
||||
|
||||
let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
|
||||
|
||||
let grad1 = vec![1.0f32; 64];
|
||||
let grad2 = vec![2.0f32; 64];
|
||||
|
||||
accumulator.accumulate(0, &grad1);
|
||||
accumulator.accumulate(0, &grad2);
|
||||
|
||||
let accumulated = accumulator.get_grad(0);
|
||||
assert_eq!(accumulated[0], 3.0);
|
||||
assert_eq!(accumulated[63], 3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_accumulator_zero_grad() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("gradients.bin");
|
||||
|
||||
let mut accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
|
||||
|
||||
let grad = vec![1.5f32; 64];
|
||||
accumulator.accumulate(0, &grad);
|
||||
|
||||
let accumulated = accumulator.get_grad(0);
|
||||
assert_eq!(accumulated[0], 1.5);
|
||||
|
||||
accumulator.zero_grad();
|
||||
|
||||
let zeroed = accumulator.get_grad(0);
|
||||
assert_eq!(zeroed[0], 0.0);
|
||||
assert_eq!(zeroed[63], 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_accumulator_apply() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let embed_path = temp_dir.path().join("embeddings.bin");
|
||||
let grad_path = temp_dir.path().join("gradients.bin");
|
||||
|
||||
let mut embeddings = MmapManager::new(&embed_path, 32, 100).unwrap();
|
||||
let mut accumulator = MmapGradientAccumulator::new(&grad_path, 32, 100).unwrap();
|
||||
|
||||
// Set initial embedding
|
||||
let initial = vec![10.0f32; 32];
|
||||
embeddings.set_embedding(0, &initial);
|
||||
|
||||
// Accumulate gradient
|
||||
let grad = vec![1.0f32; 32];
|
||||
accumulator.accumulate(0, &grad);
|
||||
|
||||
// Apply with learning rate 0.1
|
||||
accumulator.apply(0.1, &mut embeddings);
|
||||
|
||||
// Check updated embedding: 10.0 - 0.1 * 1.0 = 9.9
|
||||
let updated = embeddings.get_embedding(0);
|
||||
assert!((updated[0] - 9.9).abs() < 1e-6);
|
||||
|
||||
// Check gradients were zeroed
|
||||
let zeroed_grad = accumulator.get_grad(0);
|
||||
assert_eq!(zeroed_grad[0], 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_accumulator_concurrent_accumulation() {
|
||||
use std::thread;
|
||||
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("gradients.bin");
|
||||
|
||||
let accumulator =
|
||||
std::sync::Arc::new(MmapGradientAccumulator::new(&path, 64, 100).unwrap());
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn 10 threads, each accumulating 1.0 to node 0
|
||||
for _ in 0..10 {
|
||||
let acc = accumulator.clone();
|
||||
let handle = thread::spawn(move || {
|
||||
let grad = vec![1.0f32; 64];
|
||||
acc.accumulate(0, &grad);
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Should have accumulated 10.0
|
||||
let result = accumulator.get_grad(0);
|
||||
assert_eq!(result[0], 10.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_offset_calculation() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("embeddings.bin");
|
||||
|
||||
let manager = MmapManager::new(&path, 64, 100).unwrap();
|
||||
|
||||
assert_eq!(manager.embedding_offset(0), Some(0));
|
||||
assert_eq!(manager.embedding_offset(1), Some(64 * 4)); // 64 floats * 4 bytes
|
||||
assert_eq!(manager.embedding_offset(10), Some(64 * 4 * 10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_grad_offset_calculation() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("gradients.bin");
|
||||
|
||||
let accumulator = MmapGradientAccumulator::new(&path, 128, 100).unwrap();
|
||||
|
||||
assert_eq!(accumulator.grad_offset(0), Some(0));
|
||||
assert_eq!(accumulator.grad_offset(1), Some(128 * 4)); // 128 floats * 4 bytes
|
||||
assert_eq!(accumulator.grad_offset(5), Some(128 * 4 * 5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Embedding data length must match d_embed")]
|
||||
fn test_set_embedding_wrong_size() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("embeddings.bin");
|
||||
|
||||
let mut manager = MmapManager::new(&path, 64, 100).unwrap();
|
||||
let wrong_size = vec![1.0f32; 32]; // Should be 64
|
||||
manager.set_embedding(0, &wrong_size);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Gradient length must match d_embed")]
|
||||
fn test_accumulate_wrong_size() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("gradients.bin");
|
||||
|
||||
let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
|
||||
let wrong_size = vec![1.0f32; 32]; // Should be 64
|
||||
accumulator.accumulate(0, &wrong_size);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prefetch() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let path = temp_dir.path().join("embeddings.bin");
|
||||
|
||||
let mut manager = MmapManager::new(&path, 64, 100).unwrap();
|
||||
|
||||
// Set some embeddings
|
||||
for i in 0..10 {
|
||||
let embedding = vec![i as f32; 64];
|
||||
manager.set_embedding(i, &embedding);
|
||||
}
|
||||
|
||||
// Prefetch should not crash
|
||||
manager.prefetch(&[0, 1, 2, 3, 4]);
|
||||
|
||||
// Access should still work
|
||||
let retrieved = manager.get_embedding(2);
|
||||
assert_eq!(retrieved[0], 2.0);
|
||||
}
|
||||
}
|
||||
82
vendor/ruvector/crates/ruvector-gnn/src/mmap_fixed.rs
vendored
Normal file
82
vendor/ruvector/crates/ruvector-gnn/src/mmap_fixed.rs
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
//! Memory-mapped embedding management for large-scale GNN training.
|
||||
//!
|
||||
//! This module provides efficient memory-mapped access to embeddings and gradients
|
||||
//! that don't fit in RAM. It includes:
|
||||
//! - `MmapManager`: Memory-mapped embedding storage with dirty tracking
|
||||
//! - `MmapGradientAccumulator`: Lock-free gradient accumulation
|
||||
//! - `AtomicBitmap`: Thread-safe bitmap for access/dirty tracking
|
||||
//!
|
||||
//! Only available on non-WASM targets.
|
||||
|
||||
#![cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
|
||||
|
||||
use crate::error::{GnnError, Result};
|
||||
use std::cell::UnsafeCell;
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
|
||||
use parking_lot::RwLock;
|
||||
use memmap2::{MmapMut, MmapOptions};
|
||||
|
||||
/// Thread-safe bitmap using atomic operations.
|
||||
#[derive(Debug)]
|
||||
pub struct AtomicBitmap {
|
||||
bits: Vec<AtomicU64>,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl AtomicBitmap {
|
||||
pub fn new(size: usize) -> Self {
|
||||
let num_words = (size + 63) / 64;
|
||||
let bits = (0..num_words).map(|_| AtomicU64::new(0)).collect();
|
||||
Self { bits, size }
|
||||
}
|
||||
|
||||
pub fn set(&self, index: usize) {
|
||||
if index >= self.size {
|
||||
return;
|
||||
}
|
||||
let word_idx = index / 64;
|
||||
let bit_idx = index % 64;
|
||||
self.bits[word_idx].fetch_or(1u64 << bit_idx, Ordering::Release);
|
||||
}
|
||||
|
||||
pub fn clear(&self, index: usize) {
|
||||
if index >= self.size {
|
||||
return;
|
||||
}
|
||||
let word_idx = index / 64;
|
||||
let bit_idx = index % 64;
|
||||
self.bits[word_idx].fetch_and(!(1u64 << bit_idx), Ordering::Release);
|
||||
}
|
||||
|
||||
pub fn get(&self, index: usize) -> bool {
|
||||
if index >= self.size {
|
||||
return false;
|
||||
}
|
||||
let word_idx = index / 64;
|
||||
let bit_idx = index % 64;
|
||||
let word = self.bits[word_idx].load(Ordering::Acquire);
|
||||
(word & (1u64 << bit_idx)) != 0
|
||||
}
|
||||
|
||||
pub fn clear_all(&self) {
|
||||
for word in &self.bits {
|
||||
word.store(0, Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_set_indices(&self) -> Vec<usize> {
|
||||
let mut indices = Vec::new();
|
||||
for (word_idx, word) in self.bits.iter().enumerate() {
|
||||
let mut w = word.load(Ordering::Acquire);
|
||||
while w != 0 {
|
||||
let bit_idx = w.trailing_zeros() as usize;
|
||||
indices.push(word_idx * 64 + bit_idx);
|
||||
w &= w - 1;
|
||||
}
|
||||
}
|
||||
indices
|
||||
}
|
||||
}
|
||||
670
vendor/ruvector/crates/ruvector-gnn/src/query.rs
vendored
Normal file
670
vendor/ruvector/crates/ruvector-gnn/src/query.rs
vendored
Normal file
@@ -0,0 +1,670 @@
|
||||
//! Query API for RuVector GNN
|
||||
//!
|
||||
//! Provides high-level query interfaces for vector search, neural search,
|
||||
//! and subgraph extraction.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Query mode for different search strategies
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum QueryMode {
|
||||
/// Pure HNSW vector search
|
||||
VectorSearch,
|
||||
/// GNN-enhanced neural search
|
||||
NeuralSearch,
|
||||
/// Extract k-hop subgraph around results
|
||||
SubgraphExtraction,
|
||||
/// Differentiable search with soft attention
|
||||
DifferentiableSearch,
|
||||
}
|
||||
|
||||
/// Query configuration for RuVector searches
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RuvectorQuery {
|
||||
/// Query vector for similarity search
|
||||
pub vector: Option<Vec<f32>>,
|
||||
/// Text query (requires embedding model)
|
||||
pub text: Option<String>,
|
||||
/// Node ID for subgraph extraction
|
||||
pub node_id: Option<u64>,
|
||||
/// Search mode
|
||||
pub mode: QueryMode,
|
||||
/// Number of results to return
|
||||
pub k: usize,
|
||||
/// HNSW search parameter (exploration factor)
|
||||
pub ef: usize,
|
||||
/// GNN depth for neural search
|
||||
pub gnn_depth: usize,
|
||||
/// Temperature for differentiable search (higher = softer)
|
||||
pub temperature: f32,
|
||||
/// Whether to return attention weights
|
||||
pub return_attention: bool,
|
||||
}
|
||||
|
||||
impl Default for RuvectorQuery {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vector: None,
|
||||
text: None,
|
||||
node_id: None,
|
||||
mode: QueryMode::VectorSearch,
|
||||
k: 10,
|
||||
ef: 50,
|
||||
gnn_depth: 2,
|
||||
temperature: 1.0,
|
||||
return_attention: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RuvectorQuery {
|
||||
/// Create a basic vector search query
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `vector` - Query vector
|
||||
/// * `k` - Number of results to return
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ruvector_gnn::query::RuvectorQuery;
|
||||
///
|
||||
/// let query = RuvectorQuery::vector_search(vec![0.1, 0.2, 0.3], 10);
|
||||
/// assert_eq!(query.k, 10);
|
||||
/// ```
|
||||
pub fn vector_search(vector: Vec<f32>, k: usize) -> Self {
|
||||
Self {
|
||||
vector: Some(vector),
|
||||
mode: QueryMode::VectorSearch,
|
||||
k,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a GNN-enhanced neural search query
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `vector` - Query vector
|
||||
/// * `k` - Number of results to return
|
||||
/// * `gnn_depth` - Number of GNN layers to apply
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ruvector_gnn::query::RuvectorQuery;
|
||||
///
|
||||
/// let query = RuvectorQuery::neural_search(vec![0.1, 0.2, 0.3], 10, 3);
|
||||
/// assert_eq!(query.gnn_depth, 3);
|
||||
/// ```
|
||||
pub fn neural_search(vector: Vec<f32>, k: usize, gnn_depth: usize) -> Self {
|
||||
Self {
|
||||
vector: Some(vector),
|
||||
mode: QueryMode::NeuralSearch,
|
||||
k,
|
||||
gnn_depth,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a subgraph extraction query
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `vector` - Query vector
|
||||
/// * `k` - Number of nodes in subgraph
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ruvector_gnn::query::RuvectorQuery;
|
||||
///
|
||||
/// let query = RuvectorQuery::subgraph_search(vec![0.1, 0.2, 0.3], 20);
|
||||
/// assert_eq!(query.k, 20);
|
||||
/// ```
|
||||
pub fn subgraph_search(vector: Vec<f32>, k: usize) -> Self {
|
||||
Self {
|
||||
vector: Some(vector),
|
||||
mode: QueryMode::SubgraphExtraction,
|
||||
k,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a differentiable search query with temperature
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `vector` - Query vector
|
||||
/// * `k` - Number of results
|
||||
/// * `temperature` - Softmax temperature (higher = softer distribution)
|
||||
pub fn differentiable_search(vector: Vec<f32>, k: usize, temperature: f32) -> Self {
|
||||
Self {
|
||||
vector: Some(vector),
|
||||
mode: QueryMode::DifferentiableSearch,
|
||||
k,
|
||||
temperature,
|
||||
return_attention: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Set text query (requires embedding model)
|
||||
pub fn with_text(mut self, text: String) -> Self {
|
||||
self.text = Some(text);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set node ID for centered queries
|
||||
pub fn with_node(mut self, node_id: u64) -> Self {
|
||||
self.node_id = Some(node_id);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set EF parameter for HNSW search
|
||||
pub fn with_ef(mut self, ef: usize) -> Self {
|
||||
self.ef = ef;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable attention weight return
|
||||
pub fn with_attention(mut self) -> Self {
|
||||
self.return_attention = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Subgraph representation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct SubGraph {
|
||||
/// Node IDs in the subgraph
|
||||
pub nodes: Vec<u64>,
|
||||
/// Edges as (from, to, weight) tuples
|
||||
pub edges: Vec<(u64, u64, f32)>,
|
||||
}
|
||||
|
||||
impl SubGraph {
|
||||
/// Create a new empty subgraph
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
nodes: Vec::new(),
|
||||
edges: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create subgraph with nodes and edges
|
||||
pub fn with_edges(nodes: Vec<u64>, edges: Vec<(u64, u64, f32)>) -> Self {
|
||||
Self { nodes, edges }
|
||||
}
|
||||
|
||||
/// Get number of nodes
|
||||
pub fn node_count(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// Get number of edges
|
||||
pub fn edge_count(&self) -> usize {
|
||||
self.edges.len()
|
||||
}
|
||||
|
||||
/// Check if subgraph contains a node
|
||||
pub fn contains_node(&self, node_id: u64) -> bool {
|
||||
self.nodes.contains(&node_id)
|
||||
}
|
||||
|
||||
/// Get average edge weight
|
||||
pub fn average_edge_weight(&self) -> f32 {
|
||||
if self.edges.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let sum: f32 = self.edges.iter().map(|(_, _, w)| w).sum();
|
||||
sum / self.edges.len() as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SubGraph {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Query result with nodes, scores, and optional metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QueryResult {
|
||||
/// Matched node IDs
|
||||
pub nodes: Vec<u64>,
|
||||
/// Similarity scores (higher = more similar)
|
||||
pub scores: Vec<f32>,
|
||||
/// Optional node embeddings after GNN processing
|
||||
pub embeddings: Option<Vec<Vec<f32>>>,
|
||||
/// Optional attention weights from differentiable search
|
||||
pub attention_weights: Option<Vec<Vec<f32>>>,
|
||||
/// Optional subgraph extraction
|
||||
pub subgraph: Option<SubGraph>,
|
||||
/// Query latency in milliseconds
|
||||
pub latency_ms: u64,
|
||||
}
|
||||
|
||||
impl QueryResult {
|
||||
/// Create a new empty query result
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
nodes: Vec::new(),
|
||||
scores: Vec::new(),
|
||||
embeddings: None,
|
||||
attention_weights: None,
|
||||
subgraph: None,
|
||||
latency_ms: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create query result with nodes and scores
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `nodes` - Node IDs
|
||||
/// * `scores` - Similarity scores
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ruvector_gnn::query::QueryResult;
|
||||
///
|
||||
/// let result = QueryResult::with_nodes(vec![1, 2, 3], vec![0.9, 0.8, 0.7]);
|
||||
/// assert_eq!(result.nodes.len(), 3);
|
||||
/// ```
|
||||
pub fn with_nodes(nodes: Vec<u64>, scores: Vec<f32>) -> Self {
|
||||
Self {
|
||||
nodes,
|
||||
scores,
|
||||
embeddings: None,
|
||||
attention_weights: None,
|
||||
subgraph: None,
|
||||
latency_ms: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add embeddings to the result
|
||||
pub fn with_embeddings(mut self, embeddings: Vec<Vec<f32>>) -> Self {
|
||||
self.embeddings = Some(embeddings);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add attention weights to the result
|
||||
pub fn with_attention(mut self, attention: Vec<Vec<f32>>) -> Self {
|
||||
self.attention_weights = Some(attention);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add subgraph to the result
|
||||
pub fn with_subgraph(mut self, subgraph: SubGraph) -> Self {
|
||||
self.subgraph = Some(subgraph);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set query latency
|
||||
pub fn with_latency(mut self, latency_ms: u64) -> Self {
|
||||
self.latency_ms = latency_ms;
|
||||
self
|
||||
}
|
||||
|
||||
/// Get number of results
|
||||
pub fn len(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// Check if result is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.nodes.is_empty()
|
||||
}
|
||||
|
||||
/// Get top-k results
|
||||
pub fn top_k(&self, k: usize) -> Self {
|
||||
let k = k.min(self.nodes.len());
|
||||
Self {
|
||||
nodes: self.nodes[..k].to_vec(),
|
||||
scores: self.scores[..k].to_vec(),
|
||||
embeddings: self.embeddings.as_ref().map(|e| e[..k].to_vec()),
|
||||
attention_weights: self.attention_weights.as_ref().map(|a| a[..k].to_vec()),
|
||||
subgraph: self.subgraph.clone(),
|
||||
latency_ms: self.latency_ms,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the best result (highest score)
|
||||
pub fn best(&self) -> Option<(u64, f32)> {
|
||||
if self.nodes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some((self.nodes[0], self.scores[0]))
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter results by minimum score
|
||||
pub fn filter_by_score(mut self, min_score: f32) -> Self {
|
||||
let mut filtered_nodes = Vec::new();
|
||||
let mut filtered_scores = Vec::new();
|
||||
let mut filtered_embeddings = Vec::new();
|
||||
let mut filtered_attention = Vec::new();
|
||||
|
||||
for i in 0..self.nodes.len() {
|
||||
if self.scores[i] >= min_score {
|
||||
filtered_nodes.push(self.nodes[i]);
|
||||
filtered_scores.push(self.scores[i]);
|
||||
|
||||
if let Some(ref emb) = self.embeddings {
|
||||
filtered_embeddings.push(emb[i].clone());
|
||||
}
|
||||
|
||||
if let Some(ref att) = self.attention_weights {
|
||||
filtered_attention.push(att[i].clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.nodes = filtered_nodes;
|
||||
self.scores = filtered_scores;
|
||||
|
||||
if !filtered_embeddings.is_empty() {
|
||||
self.embeddings = Some(filtered_embeddings);
|
||||
}
|
||||
|
||||
if !filtered_attention.is_empty() {
|
||||
self.attention_weights = Some(filtered_attention);
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QueryResult {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_query_mode_serialization() {
|
||||
let mode = QueryMode::NeuralSearch;
|
||||
let json = serde_json::to_string(&mode).unwrap();
|
||||
let deserialized: QueryMode = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(mode, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ruvector_query_default() {
|
||||
let query = RuvectorQuery::default();
|
||||
assert_eq!(query.k, 10);
|
||||
assert_eq!(query.ef, 50);
|
||||
assert_eq!(query.gnn_depth, 2);
|
||||
assert_eq!(query.temperature, 1.0);
|
||||
assert_eq!(query.mode, QueryMode::VectorSearch);
|
||||
assert!(!query.return_attention);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_search_query() {
|
||||
let vector = vec![0.1, 0.2, 0.3, 0.4];
|
||||
let query = RuvectorQuery::vector_search(vector.clone(), 5);
|
||||
|
||||
assert_eq!(query.vector, Some(vector));
|
||||
assert_eq!(query.k, 5);
|
||||
assert_eq!(query.mode, QueryMode::VectorSearch);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neural_search_query() {
|
||||
let vector = vec![0.1, 0.2, 0.3];
|
||||
let query = RuvectorQuery::neural_search(vector.clone(), 10, 3);
|
||||
|
||||
assert_eq!(query.vector, Some(vector));
|
||||
assert_eq!(query.k, 10);
|
||||
assert_eq!(query.gnn_depth, 3);
|
||||
assert_eq!(query.mode, QueryMode::NeuralSearch);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subgraph_search_query() {
|
||||
let vector = vec![0.5, 0.5];
|
||||
let query = RuvectorQuery::subgraph_search(vector.clone(), 20);
|
||||
|
||||
assert_eq!(query.vector, Some(vector));
|
||||
assert_eq!(query.k, 20);
|
||||
assert_eq!(query.mode, QueryMode::SubgraphExtraction);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_differentiable_search_query() {
|
||||
let vector = vec![0.3, 0.4, 0.5];
|
||||
let query = RuvectorQuery::differentiable_search(vector.clone(), 15, 0.5);
|
||||
|
||||
assert_eq!(query.vector, Some(vector));
|
||||
assert_eq!(query.k, 15);
|
||||
assert_eq!(query.temperature, 0.5);
|
||||
assert_eq!(query.mode, QueryMode::DifferentiableSearch);
|
||||
assert!(query.return_attention);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_builder_pattern() {
|
||||
let query = RuvectorQuery::vector_search(vec![0.1, 0.2], 5)
|
||||
.with_text("hello world".to_string())
|
||||
.with_node(42)
|
||||
.with_ef(100)
|
||||
.with_attention();
|
||||
|
||||
assert_eq!(query.text, Some("hello world".to_string()));
|
||||
assert_eq!(query.node_id, Some(42));
|
||||
assert_eq!(query.ef, 100);
|
||||
assert!(query.return_attention);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subgraph_new() {
|
||||
let subgraph = SubGraph::new();
|
||||
assert_eq!(subgraph.node_count(), 0);
|
||||
assert_eq!(subgraph.edge_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subgraph_with_edges() {
|
||||
let nodes = vec![1, 2, 3];
|
||||
let edges = vec![(1, 2, 0.8), (2, 3, 0.6), (1, 3, 0.5)];
|
||||
let subgraph = SubGraph::with_edges(nodes.clone(), edges.clone());
|
||||
|
||||
assert_eq!(subgraph.nodes, nodes);
|
||||
assert_eq!(subgraph.edges, edges);
|
||||
assert_eq!(subgraph.node_count(), 3);
|
||||
assert_eq!(subgraph.edge_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subgraph_contains_node() {
|
||||
let nodes = vec![1, 2, 3];
|
||||
let subgraph = SubGraph::with_edges(nodes, vec![]);
|
||||
|
||||
assert!(subgraph.contains_node(1));
|
||||
assert!(subgraph.contains_node(2));
|
||||
assert!(subgraph.contains_node(3));
|
||||
assert!(!subgraph.contains_node(4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subgraph_average_edge_weight() {
|
||||
let edges = vec![(1, 2, 0.8), (2, 3, 0.6), (1, 3, 0.4)];
|
||||
let subgraph = SubGraph::with_edges(vec![1, 2, 3], edges);
|
||||
|
||||
let avg = subgraph.average_edge_weight();
|
||||
assert!((avg - 0.6).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subgraph_empty_average() {
|
||||
let subgraph = SubGraph::new();
|
||||
assert_eq!(subgraph.average_edge_weight(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_new() {
|
||||
let result = QueryResult::new();
|
||||
assert!(result.is_empty());
|
||||
assert_eq!(result.len(), 0);
|
||||
assert_eq!(result.latency_ms, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_with_nodes() {
|
||||
let nodes = vec![1, 2, 3];
|
||||
let scores = vec![0.9, 0.8, 0.7];
|
||||
let result = QueryResult::with_nodes(nodes.clone(), scores.clone());
|
||||
|
||||
assert_eq!(result.nodes, nodes);
|
||||
assert_eq!(result.scores, scores);
|
||||
assert_eq!(result.len(), 3);
|
||||
assert!(!result.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_builder_pattern() {
|
||||
let embeddings = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
|
||||
let attention = vec![vec![0.5, 0.5], vec![0.6, 0.4]];
|
||||
let subgraph = SubGraph::with_edges(vec![1, 2], vec![(1, 2, 0.8)]);
|
||||
|
||||
let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8])
|
||||
.with_embeddings(embeddings.clone())
|
||||
.with_attention(attention.clone())
|
||||
.with_subgraph(subgraph.clone())
|
||||
.with_latency(100);
|
||||
|
||||
assert_eq!(result.embeddings, Some(embeddings));
|
||||
assert_eq!(result.attention_weights, Some(attention));
|
||||
assert_eq!(result.subgraph, Some(subgraph));
|
||||
assert_eq!(result.latency_ms, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_top_k() {
|
||||
let nodes = vec![1, 2, 3, 4, 5];
|
||||
let scores = vec![0.9, 0.8, 0.7, 0.6, 0.5];
|
||||
let result = QueryResult::with_nodes(nodes, scores);
|
||||
|
||||
let top_3 = result.top_k(3);
|
||||
assert_eq!(top_3.len(), 3);
|
||||
assert_eq!(top_3.nodes, vec![1, 2, 3]);
|
||||
assert_eq!(top_3.scores, vec![0.9, 0.8, 0.7]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_top_k_overflow() {
|
||||
let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8]);
|
||||
let top_10 = result.top_k(10);
|
||||
assert_eq!(top_10.len(), 2); // Should only return available results
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_best() {
|
||||
let result = QueryResult::with_nodes(vec![1, 2, 3], vec![0.9, 0.8, 0.7]);
|
||||
let best = result.best();
|
||||
assert_eq!(best, Some((1, 0.9)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_best_empty() {
|
||||
let result = QueryResult::new();
|
||||
assert_eq!(result.best(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_filter_by_score() {
|
||||
let nodes = vec![1, 2, 3, 4, 5];
|
||||
let scores = vec![0.9, 0.8, 0.7, 0.6, 0.5];
|
||||
let result = QueryResult::with_nodes(nodes, scores);
|
||||
|
||||
let filtered = result.filter_by_score(0.7);
|
||||
assert_eq!(filtered.len(), 3);
|
||||
assert_eq!(filtered.nodes, vec![1, 2, 3]);
|
||||
assert_eq!(filtered.scores, vec![0.9, 0.8, 0.7]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_filter_with_embeddings() {
|
||||
let nodes = vec![1, 2, 3];
|
||||
let scores = vec![0.9, 0.6, 0.8];
|
||||
let embeddings = vec![vec![0.1], vec![0.2], vec![0.3]];
|
||||
|
||||
let result = QueryResult::with_nodes(nodes, scores).with_embeddings(embeddings);
|
||||
|
||||
let filtered = result.filter_by_score(0.7);
|
||||
assert_eq!(filtered.len(), 2);
|
||||
assert_eq!(filtered.nodes, vec![1, 3]);
|
||||
assert_eq!(filtered.embeddings, Some(vec![vec![0.1], vec![0.3]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_filter_with_attention() {
|
||||
let nodes = vec![1, 2, 3];
|
||||
let scores = vec![0.9, 0.5, 0.8];
|
||||
let attention = vec![vec![0.5, 0.5], vec![0.6, 0.4], vec![0.7, 0.3]];
|
||||
|
||||
let result = QueryResult::with_nodes(nodes, scores).with_attention(attention);
|
||||
|
||||
let filtered = result.filter_by_score(0.75);
|
||||
assert_eq!(filtered.len(), 2);
|
||||
assert_eq!(filtered.nodes, vec![1, 3]);
|
||||
assert_eq!(
|
||||
filtered.attention_weights,
|
||||
Some(vec![vec![0.5, 0.5], vec![0.7, 0.3]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_serialization() {
|
||||
let query = RuvectorQuery::neural_search(vec![0.1, 0.2], 5, 2);
|
||||
let json = serde_json::to_string(&query).unwrap();
|
||||
let deserialized: RuvectorQuery = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.k, query.k);
|
||||
assert_eq!(deserialized.gnn_depth, query.gnn_depth);
|
||||
assert_eq!(deserialized.mode, query.mode);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_result_serialization() {
|
||||
let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8]).with_latency(50);
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
let deserialized: QueryResult = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.nodes, result.nodes);
|
||||
assert_eq!(deserialized.scores, result.scores);
|
||||
assert_eq!(deserialized.latency_ms, result.latency_ms);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subgraph_serialization() {
|
||||
let subgraph = SubGraph::with_edges(vec![1, 2, 3], vec![(1, 2, 0.8), (2, 3, 0.6)]);
|
||||
|
||||
let json = serde_json::to_string(&subgraph).unwrap();
|
||||
let deserialized: SubGraph = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.nodes, subgraph.nodes);
|
||||
assert_eq!(deserialized.edges, subgraph.edges);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_case_empty_filter() {
|
||||
let result = QueryResult::with_nodes(vec![1, 2], vec![0.5, 0.4]);
|
||||
let filtered = result.filter_by_score(0.9);
|
||||
|
||||
assert!(filtered.is_empty());
|
||||
assert_eq!(filtered.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_mode_variants() {
|
||||
// Test all query mode variants
|
||||
assert_eq!(QueryMode::VectorSearch, QueryMode::VectorSearch);
|
||||
assert_ne!(QueryMode::VectorSearch, QueryMode::NeuralSearch);
|
||||
assert_ne!(QueryMode::NeuralSearch, QueryMode::SubgraphExtraction);
|
||||
assert_ne!(
|
||||
QueryMode::SubgraphExtraction,
|
||||
QueryMode::DifferentiableSearch
|
||||
);
|
||||
}
|
||||
}
|
||||
502
vendor/ruvector/crates/ruvector-gnn/src/replay.rs
vendored
Normal file
502
vendor/ruvector/crates/ruvector-gnn/src/replay.rs
vendored
Normal file
@@ -0,0 +1,502 @@
|
||||
//! Experience Replay Buffer for GNN Training
|
||||
//!
|
||||
//! This module implements an experience replay buffer to mitigate catastrophic forgetting
|
||||
//! during continual learning. The buffer stores past training samples and supports:
|
||||
//! - Reservoir sampling for uniform distribution over time
|
||||
//! - Batch sampling for training
|
||||
//! - Distribution shift detection
|
||||
|
||||
use rand::Rng;
|
||||
use std::collections::VecDeque;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// A single entry in the replay buffer
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReplayEntry {
|
||||
/// Query vector used for training
|
||||
pub query: Vec<f32>,
|
||||
/// IDs of positive nodes for this query
|
||||
pub positive_ids: Vec<usize>,
|
||||
/// Timestamp when this entry was added (milliseconds since epoch)
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
impl ReplayEntry {
|
||||
/// Create a new replay entry with current timestamp
|
||||
pub fn new(query: Vec<f32>, positive_ids: Vec<usize>) -> Self {
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64;
|
||||
|
||||
Self {
|
||||
query,
|
||||
positive_ids,
|
||||
timestamp,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics for tracking distribution characteristics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DistributionStats {
|
||||
/// Running mean of query vectors
|
||||
pub mean: Vec<f32>,
|
||||
/// Running variance of query vectors
|
||||
pub variance: Vec<f32>,
|
||||
/// Number of samples used to compute statistics
|
||||
pub count: usize,
|
||||
}
|
||||
|
||||
impl DistributionStats {
|
||||
/// Create new distribution statistics
|
||||
pub fn new(dimension: usize) -> Self {
|
||||
Self {
|
||||
mean: vec![0.0; dimension],
|
||||
variance: vec![0.0; dimension],
|
||||
count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update statistics with a new sample using Welford's online algorithm
|
||||
pub fn update(&mut self, sample: &[f32]) {
|
||||
if self.mean.is_empty() && !sample.is_empty() {
|
||||
self.mean = vec![0.0; sample.len()];
|
||||
self.variance = vec![0.0; sample.len()];
|
||||
}
|
||||
|
||||
if self.mean.len() != sample.len() {
|
||||
return; // Dimension mismatch, skip update
|
||||
}
|
||||
|
||||
self.count += 1;
|
||||
let count = self.count as f32;
|
||||
|
||||
for i in 0..sample.len() {
|
||||
let delta = sample[i] - self.mean[i];
|
||||
self.mean[i] += delta / count;
|
||||
let delta2 = sample[i] - self.mean[i];
|
||||
self.variance[i] += delta * delta2;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute standard deviation from variance
|
||||
pub fn std_dev(&self) -> Vec<f32> {
|
||||
if self.count <= 1 {
|
||||
return vec![0.0; self.variance.len()];
|
||||
}
|
||||
|
||||
self.variance
|
||||
.iter()
|
||||
.map(|&v| (v / (self.count - 1) as f32).sqrt())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Reset statistics
|
||||
pub fn reset(&mut self) {
|
||||
let dim = self.mean.len();
|
||||
self.mean = vec![0.0; dim];
|
||||
self.variance = vec![0.0; dim];
|
||||
self.count = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Experience Replay Buffer for storing and sampling past training examples
|
||||
pub struct ReplayBuffer {
|
||||
/// Circular buffer of replay entries
|
||||
queries: VecDeque<ReplayEntry>,
|
||||
/// Maximum capacity of the buffer
|
||||
capacity: usize,
|
||||
/// Total number of samples seen (including evicted ones)
|
||||
total_seen: usize,
|
||||
/// Statistics of the overall distribution
|
||||
distribution_stats: DistributionStats,
|
||||
}
|
||||
|
||||
impl ReplayBuffer {
|
||||
/// Create a new replay buffer with specified capacity
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `capacity` - Maximum number of entries to store
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
queries: VecDeque::with_capacity(capacity),
|
||||
capacity,
|
||||
total_seen: 0,
|
||||
distribution_stats: DistributionStats::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a new entry to the buffer using reservoir sampling
|
||||
///
|
||||
/// Reservoir sampling ensures uniform distribution over all samples seen,
|
||||
/// even as old samples are evicted due to capacity constraints.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `positive_ids` - IDs of positive nodes for this query
|
||||
pub fn add(&mut self, query: &[f32], positive_ids: &[usize]) {
|
||||
let entry = ReplayEntry::new(query.to_vec(), positive_ids.to_vec());
|
||||
|
||||
self.total_seen += 1;
|
||||
|
||||
// Update distribution statistics
|
||||
self.distribution_stats.update(query);
|
||||
|
||||
// If buffer is not full, just add the entry
|
||||
if self.queries.len() < self.capacity {
|
||||
self.queries.push_back(entry);
|
||||
return;
|
||||
}
|
||||
|
||||
// Reservoir sampling: replace a random entry with probability capacity/total_seen
|
||||
let mut rng = rand::thread_rng();
|
||||
let random_index = rng.gen_range(0..self.total_seen);
|
||||
|
||||
if random_index < self.capacity {
|
||||
self.queries[random_index] = entry;
|
||||
}
|
||||
}
|
||||
|
||||
/// Sample a batch of entries uniformly at random
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `batch_size` - Number of entries to sample
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of references to sampled entries (may be smaller than batch_size if buffer is small)
|
||||
pub fn sample(&self, batch_size: usize) -> Vec<&ReplayEntry> {
|
||||
if self.queries.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let actual_batch_size = batch_size.min(self.queries.len());
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut indices: Vec<usize> = (0..self.queries.len()).collect();
|
||||
|
||||
// Fisher-Yates shuffle for first batch_size elements
|
||||
for i in 0..actual_batch_size {
|
||||
let j = rng.gen_range(i..indices.len());
|
||||
indices.swap(i, j);
|
||||
}
|
||||
|
||||
indices[..actual_batch_size]
|
||||
.iter()
|
||||
.map(|&idx| &self.queries[idx])
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Detect distribution shift between recent samples and overall distribution
|
||||
///
|
||||
/// Uses Kullback-Leibler divergence approximation based on mean and variance changes.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `recent_window` - Number of most recent samples to compare
|
||||
///
|
||||
/// # Returns
|
||||
/// Shift score (higher values indicate more significant distribution shift)
|
||||
/// Returns 0.0 if insufficient data
|
||||
pub fn detect_distribution_shift(&self, recent_window: usize) -> f32 {
|
||||
if self.queries.len() < recent_window || recent_window == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Compute statistics for recent window
|
||||
let mut recent_stats = DistributionStats::new(self.distribution_stats.mean.len());
|
||||
|
||||
let start_idx = self.queries.len().saturating_sub(recent_window);
|
||||
for entry in self.queries.iter().skip(start_idx) {
|
||||
recent_stats.update(&entry.query);
|
||||
}
|
||||
|
||||
// Compute shift using normalized mean difference
|
||||
let overall_mean = &self.distribution_stats.mean;
|
||||
let recent_mean = &recent_stats.mean;
|
||||
|
||||
if overall_mean.is_empty() || recent_mean.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let overall_std = self.distribution_stats.std_dev();
|
||||
let mut shift_sum = 0.0;
|
||||
let mut count = 0;
|
||||
|
||||
for i in 0..overall_mean.len() {
|
||||
if overall_std[i] > 1e-8 {
|
||||
let diff = (recent_mean[i] - overall_mean[i]).abs();
|
||||
shift_sum += diff / overall_std[i];
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
shift_sum / count as f32
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of entries currently in the buffer
|
||||
pub fn len(&self) -> usize {
|
||||
self.queries.len()
|
||||
}
|
||||
|
||||
/// Check if the buffer is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.queries.is_empty()
|
||||
}
|
||||
|
||||
/// Get the total capacity of the buffer
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.capacity
|
||||
}
|
||||
|
||||
/// Get the total number of samples seen (including evicted ones)
|
||||
pub fn total_seen(&self) -> usize {
|
||||
self.total_seen
|
||||
}
|
||||
|
||||
/// Get a reference to the distribution statistics
|
||||
pub fn distribution_stats(&self) -> &DistributionStats {
|
||||
&self.distribution_stats
|
||||
}
|
||||
|
||||
/// Clear all entries from the buffer
|
||||
pub fn clear(&mut self) {
|
||||
self.queries.clear();
|
||||
self.total_seen = 0;
|
||||
self.distribution_stats.reset();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_replay_buffer_basic() {
|
||||
let mut buffer = ReplayBuffer::new(10);
|
||||
assert_eq!(buffer.len(), 0);
|
||||
assert!(buffer.is_empty());
|
||||
assert_eq!(buffer.capacity(), 10);
|
||||
|
||||
buffer.add(&[1.0, 2.0, 3.0], &[0, 1]);
|
||||
assert_eq!(buffer.len(), 1);
|
||||
assert!(!buffer.is_empty());
|
||||
|
||||
buffer.add(&[4.0, 5.0, 6.0], &[2, 3]);
|
||||
assert_eq!(buffer.len(), 2);
|
||||
assert_eq!(buffer.total_seen(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_buffer_capacity() {
|
||||
let mut buffer = ReplayBuffer::new(3);
|
||||
|
||||
// Add entries up to capacity
|
||||
for i in 0..3 {
|
||||
buffer.add(&[i as f32], &[i]);
|
||||
}
|
||||
assert_eq!(buffer.len(), 3);
|
||||
|
||||
// Adding more should maintain capacity through reservoir sampling
|
||||
for i in 3..10 {
|
||||
buffer.add(&[i as f32], &[i]);
|
||||
}
|
||||
assert_eq!(buffer.len(), 3);
|
||||
assert_eq!(buffer.total_seen(), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_empty_buffer() {
|
||||
let buffer = ReplayBuffer::new(10);
|
||||
let samples = buffer.sample(5);
|
||||
assert!(samples.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_basic() {
|
||||
let mut buffer = ReplayBuffer::new(10);
|
||||
|
||||
for i in 0..5 {
|
||||
buffer.add(&[i as f32], &[i]);
|
||||
}
|
||||
|
||||
let samples = buffer.sample(3);
|
||||
assert_eq!(samples.len(), 3);
|
||||
|
||||
// Check that samples are from the buffer
|
||||
for sample in samples {
|
||||
assert!(sample.query[0] >= 0.0 && sample.query[0] < 5.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_larger_than_buffer() {
|
||||
let mut buffer = ReplayBuffer::new(10);
|
||||
|
||||
buffer.add(&[1.0], &[0]);
|
||||
buffer.add(&[2.0], &[1]);
|
||||
|
||||
let samples = buffer.sample(5);
|
||||
assert_eq!(samples.len(), 2); // Can only return what's available
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distribution_stats_update() {
|
||||
let mut stats = DistributionStats::new(2);
|
||||
|
||||
stats.update(&[1.0, 2.0]);
|
||||
assert_eq!(stats.count, 1);
|
||||
assert_eq!(stats.mean, vec![1.0, 2.0]);
|
||||
|
||||
stats.update(&[3.0, 4.0]);
|
||||
assert_eq!(stats.count, 2);
|
||||
assert_eq!(stats.mean, vec![2.0, 3.0]);
|
||||
|
||||
stats.update(&[2.0, 3.0]);
|
||||
assert_eq!(stats.count, 3);
|
||||
assert_eq!(stats.mean, vec![2.0, 3.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distribution_stats_std_dev() {
|
||||
let mut stats = DistributionStats::new(2);
|
||||
|
||||
stats.update(&[1.0, 1.0]);
|
||||
stats.update(&[3.0, 3.0]);
|
||||
stats.update(&[5.0, 5.0]);
|
||||
|
||||
let std_dev = stats.std_dev();
|
||||
// Expected std dev for [1, 3, 5] is 2.0
|
||||
assert!((std_dev[0] - 2.0).abs() < 0.01);
|
||||
assert!((std_dev[1] - 2.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_distribution_shift_no_shift() {
|
||||
let mut buffer = ReplayBuffer::new(100);
|
||||
|
||||
// Add samples from the same distribution
|
||||
for _ in 0..50 {
|
||||
buffer.add(&[1.0, 2.0, 3.0], &[0]);
|
||||
}
|
||||
|
||||
let shift = buffer.detect_distribution_shift(10);
|
||||
assert!(shift < 0.1); // Should be very low
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_distribution_shift_with_shift() {
|
||||
let mut buffer = ReplayBuffer::new(100);
|
||||
|
||||
// Add samples from one distribution
|
||||
for _ in 0..40 {
|
||||
buffer.add(&[1.0, 2.0, 3.0], &[0]);
|
||||
}
|
||||
|
||||
// Add samples from a different distribution
|
||||
for _ in 0..10 {
|
||||
buffer.add(&[5.0, 6.0, 7.0], &[1]);
|
||||
}
|
||||
|
||||
let shift = buffer.detect_distribution_shift(10);
|
||||
assert!(shift > 0.5); // Should detect significant shift
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_distribution_shift_insufficient_data() {
|
||||
let mut buffer = ReplayBuffer::new(100);
|
||||
|
||||
buffer.add(&[1.0, 2.0], &[0]);
|
||||
|
||||
let shift = buffer.detect_distribution_shift(10);
|
||||
assert_eq!(shift, 0.0); // Not enough data
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
let mut buffer = ReplayBuffer::new(10);
|
||||
|
||||
for i in 0..5 {
|
||||
buffer.add(&[i as f32], &[i]);
|
||||
}
|
||||
|
||||
assert_eq!(buffer.len(), 5);
|
||||
assert_eq!(buffer.total_seen(), 5);
|
||||
|
||||
buffer.clear();
|
||||
assert_eq!(buffer.len(), 0);
|
||||
assert_eq!(buffer.total_seen(), 0);
|
||||
assert!(buffer.is_empty());
|
||||
assert_eq!(buffer.distribution_stats().count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_entry_creation() {
|
||||
let entry = ReplayEntry::new(vec![1.0, 2.0, 3.0], vec![0, 1, 2]);
|
||||
|
||||
assert_eq!(entry.query, vec![1.0, 2.0, 3.0]);
|
||||
assert_eq!(entry.positive_ids, vec![0, 1, 2]);
|
||||
assert!(entry.timestamp > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reservoir_sampling_distribution() {
|
||||
let mut buffer = ReplayBuffer::new(10);
|
||||
|
||||
// Add 100 entries (much more than capacity)
|
||||
for i in 0..100 {
|
||||
buffer.add(&[i as f32], &[i]);
|
||||
}
|
||||
|
||||
assert_eq!(buffer.len(), 10);
|
||||
assert_eq!(buffer.total_seen(), 100);
|
||||
|
||||
// Sample multiple times and verify we get different samples
|
||||
let samples1 = buffer.sample(5);
|
||||
let samples2 = buffer.sample(5);
|
||||
|
||||
assert_eq!(samples1.len(), 5);
|
||||
assert_eq!(samples2.len(), 5);
|
||||
|
||||
// Check that samples come from the full range (not just recent entries)
|
||||
let sample_batch = buffer.sample(10);
|
||||
let values: Vec<f32> = sample_batch.iter().map(|e| e.query[0]).collect();
|
||||
|
||||
// With reservoir sampling, we should have some diversity in values
|
||||
let unique_values: std::collections::HashSet<_> =
|
||||
values.iter().map(|&v| v as i32).collect();
|
||||
assert!(unique_values.len() > 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch_handling() {
|
||||
let mut buffer = ReplayBuffer::new(10);
|
||||
|
||||
buffer.add(&[1.0, 2.0], &[0]);
|
||||
|
||||
// This should not panic, just be handled gracefully
|
||||
// The implementation will initialize stats on first add
|
||||
assert_eq!(buffer.len(), 1);
|
||||
assert_eq!(buffer.distribution_stats().mean.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_uniqueness() {
|
||||
let mut buffer = ReplayBuffer::new(5);
|
||||
|
||||
for i in 0..5 {
|
||||
buffer.add(&[i as f32], &[i]);
|
||||
}
|
||||
|
||||
// Sample all entries
|
||||
let samples = buffer.sample(5);
|
||||
let values: Vec<f32> = samples.iter().map(|e| e.query[0]).collect();
|
||||
|
||||
// All samples should be unique (no duplicates in a single batch)
|
||||
let unique_values: std::collections::HashSet<_> =
|
||||
values.iter().map(|&v| v as i32).collect();
|
||||
assert_eq!(unique_values.len(), 5);
|
||||
}
|
||||
}
|
||||
531
vendor/ruvector/crates/ruvector-gnn/src/scheduler.rs
vendored
Normal file
531
vendor/ruvector/crates/ruvector-gnn/src/scheduler.rs
vendored
Normal file
@@ -0,0 +1,531 @@
|
||||
//! Learning rate scheduling for Graph Neural Networks
|
||||
//!
|
||||
//! Provides various learning rate scheduling strategies to prevent catastrophic
|
||||
//! forgetting and optimize training dynamics in continual learning scenarios.
|
||||
|
||||
use std::f32::consts::PI;
|
||||
|
||||
/// Learning rate scheduling strategies
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SchedulerType {
|
||||
/// Constant learning rate throughout training
|
||||
Constant,
|
||||
|
||||
/// Step decay: multiply learning rate by gamma every step_size epochs
|
||||
/// Formula: lr = base_lr * gamma^(epoch / step_size)
|
||||
StepDecay { step_size: usize, gamma: f32 },
|
||||
|
||||
/// Exponential decay: multiply learning rate by gamma each epoch
|
||||
/// Formula: lr = base_lr * gamma^epoch
|
||||
Exponential { gamma: f32 },
|
||||
|
||||
/// Cosine annealing with warm restarts
|
||||
/// Formula: lr = eta_min + 0.5 * (base_lr - eta_min) * (1 + cos(pi * (epoch % t_max) / t_max))
|
||||
CosineAnnealing { t_max: usize, eta_min: f32 },
|
||||
|
||||
/// Warmup phase followed by linear decay
|
||||
/// Linearly increases lr from 0 to base_lr over warmup_steps,
|
||||
/// then linearly decreases to 0 over remaining steps
|
||||
WarmupLinear {
|
||||
warmup_steps: usize,
|
||||
total_steps: usize,
|
||||
},
|
||||
|
||||
/// Reduce learning rate when a metric plateaus
|
||||
/// Useful for online learning scenarios
|
||||
ReduceOnPlateau {
|
||||
factor: f32,
|
||||
patience: usize,
|
||||
min_lr: f32,
|
||||
},
|
||||
}
|
||||
|
||||
/// Learning rate scheduler for GNN training
|
||||
///
|
||||
/// Implements various scheduling strategies to control learning rate
|
||||
/// during training, helping prevent catastrophic forgetting and
|
||||
/// improve convergence.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LearningRateScheduler {
|
||||
scheduler_type: SchedulerType,
|
||||
base_lr: f32,
|
||||
current_lr: f32,
|
||||
step_count: usize,
|
||||
best_metric: f32,
|
||||
patience_counter: usize,
|
||||
}
|
||||
|
||||
impl LearningRateScheduler {
|
||||
/// Creates a new learning rate scheduler
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `scheduler_type` - The scheduling strategy to use
|
||||
/// * `base_lr` - The initial/base learning rate
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ruvector_gnn::scheduler::{LearningRateScheduler, SchedulerType};
|
||||
///
|
||||
/// let scheduler = LearningRateScheduler::new(
|
||||
/// SchedulerType::StepDecay { step_size: 10, gamma: 0.9 },
|
||||
/// 0.001
|
||||
/// );
|
||||
/// ```
|
||||
pub fn new(scheduler_type: SchedulerType, base_lr: f32) -> Self {
|
||||
Self {
|
||||
scheduler_type,
|
||||
base_lr,
|
||||
current_lr: base_lr,
|
||||
step_count: 0,
|
||||
best_metric: f32::INFINITY,
|
||||
patience_counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Advances the scheduler by one step and returns the new learning rate
|
||||
///
|
||||
/// For most schedulers, this should be called once per epoch.
|
||||
/// For ReduceOnPlateau, use `step_with_metric` instead.
|
||||
///
|
||||
/// # Returns
|
||||
/// The updated learning rate
|
||||
pub fn step(&mut self) -> f32 {
|
||||
self.step_count += 1;
|
||||
self.current_lr = self.calculate_lr();
|
||||
self.current_lr
|
||||
}
|
||||
|
||||
/// Advances the scheduler with a metric value (for ReduceOnPlateau)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `metric` - The metric value to monitor (e.g., validation loss)
|
||||
///
|
||||
/// # Returns
|
||||
/// The updated learning rate
|
||||
pub fn step_with_metric(&mut self, metric: f32) -> f32 {
|
||||
self.step_count += 1;
|
||||
|
||||
match &self.scheduler_type {
|
||||
SchedulerType::ReduceOnPlateau {
|
||||
factor,
|
||||
patience,
|
||||
min_lr,
|
||||
} => {
|
||||
// Check if metric improved
|
||||
if metric < self.best_metric - 1e-8 {
|
||||
self.best_metric = metric;
|
||||
self.patience_counter = 0;
|
||||
} else {
|
||||
self.patience_counter += 1;
|
||||
|
||||
// Reduce learning rate if patience exceeded
|
||||
if self.patience_counter >= *patience {
|
||||
self.current_lr = (self.current_lr * factor).max(*min_lr);
|
||||
self.patience_counter = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// For non-plateau schedulers, just use step()
|
||||
self.current_lr = self.calculate_lr();
|
||||
}
|
||||
}
|
||||
|
||||
self.current_lr
|
||||
}
|
||||
|
||||
/// Gets the current learning rate without advancing the scheduler
|
||||
pub fn get_lr(&self) -> f32 {
|
||||
self.current_lr
|
||||
}
|
||||
|
||||
/// Resets the scheduler to its initial state
|
||||
pub fn reset(&mut self) {
|
||||
self.current_lr = self.base_lr;
|
||||
self.step_count = 0;
|
||||
self.best_metric = f32::INFINITY;
|
||||
self.patience_counter = 0;
|
||||
}
|
||||
|
||||
/// Calculates the learning rate based on the current step and scheduler type
|
||||
fn calculate_lr(&self) -> f32 {
|
||||
match &self.scheduler_type {
|
||||
SchedulerType::Constant => self.base_lr,
|
||||
|
||||
SchedulerType::StepDecay { step_size, gamma } => {
|
||||
let decay_factor = (*gamma).powi((self.step_count / step_size) as i32);
|
||||
self.base_lr * decay_factor
|
||||
}
|
||||
|
||||
SchedulerType::Exponential { gamma } => {
|
||||
let decay_factor = (*gamma).powi(self.step_count as i32);
|
||||
self.base_lr * decay_factor
|
||||
}
|
||||
|
||||
SchedulerType::CosineAnnealing { t_max, eta_min } => {
|
||||
let cycle_step = self.step_count % t_max;
|
||||
let cos_term = (PI * cycle_step as f32 / *t_max as f32).cos();
|
||||
eta_min + 0.5 * (self.base_lr - eta_min) * (1.0 + cos_term)
|
||||
}
|
||||
|
||||
SchedulerType::WarmupLinear {
|
||||
warmup_steps,
|
||||
total_steps,
|
||||
} => {
|
||||
if self.step_count < *warmup_steps {
|
||||
// Warmup phase: linear increase
|
||||
self.base_lr * (self.step_count as f32 / *warmup_steps as f32)
|
||||
} else if self.step_count < *total_steps {
|
||||
// Decay phase: linear decrease
|
||||
let remaining_steps = *total_steps - self.step_count;
|
||||
let total_decay_steps = *total_steps - *warmup_steps;
|
||||
self.base_lr * (remaining_steps as f32 / total_decay_steps as f32)
|
||||
} else {
|
||||
// After total_steps, keep at 0
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
SchedulerType::ReduceOnPlateau { .. } => {
|
||||
// For plateau scheduler, lr is updated in step_with_metric
|
||||
self.current_lr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const EPSILON: f32 = 1e-6;
|
||||
|
||||
fn assert_close(a: f32, b: f32, msg: &str) {
|
||||
assert!((a - b).abs() < EPSILON, "{}: {} != {}", msg, a, b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_scheduler() {
|
||||
let mut scheduler = LearningRateScheduler::new(SchedulerType::Constant, 0.01);
|
||||
|
||||
assert_close(scheduler.get_lr(), 0.01, "Initial LR");
|
||||
|
||||
for i in 1..=10 {
|
||||
let lr = scheduler.step();
|
||||
assert_close(lr, 0.01, &format!("Step {} LR", i));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_step_decay() {
|
||||
let mut scheduler = LearningRateScheduler::new(
|
||||
SchedulerType::StepDecay {
|
||||
step_size: 5,
|
||||
gamma: 0.5,
|
||||
},
|
||||
0.1,
|
||||
);
|
||||
|
||||
assert_close(scheduler.get_lr(), 0.1, "Initial LR");
|
||||
|
||||
// Steps 1-4: no decay
|
||||
for i in 1..=4 {
|
||||
let lr = scheduler.step();
|
||||
assert_close(lr, 0.1, &format!("Step {} LR", i));
|
||||
}
|
||||
|
||||
// Step 5: first decay (0.1 * 0.5)
|
||||
let lr = scheduler.step();
|
||||
assert_close(lr, 0.05, "Step 5 LR (first decay)");
|
||||
|
||||
// Steps 6-9: maintain decayed rate
|
||||
for i in 6..=9 {
|
||||
let lr = scheduler.step();
|
||||
assert_close(lr, 0.05, &format!("Step {} LR", i));
|
||||
}
|
||||
|
||||
// Step 10: second decay (0.1 * 0.5^2)
|
||||
let lr = scheduler.step();
|
||||
assert_close(lr, 0.025, "Step 10 LR (second decay)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exponential_decay() {
|
||||
let mut scheduler =
|
||||
LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.9 }, 0.1);
|
||||
|
||||
assert_close(scheduler.get_lr(), 0.1, "Initial LR");
|
||||
|
||||
let expected_lrs = vec![
|
||||
0.1 * 0.9, // Step 1
|
||||
0.1 * 0.81, // Step 2 (0.9^2)
|
||||
0.1 * 0.729, // Step 3 (0.9^3)
|
||||
];
|
||||
|
||||
for (i, expected) in expected_lrs.iter().enumerate() {
|
||||
let lr = scheduler.step();
|
||||
assert_close(lr, *expected, &format!("Step {} LR", i + 1));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_annealing() {
|
||||
let mut scheduler = LearningRateScheduler::new(
|
||||
SchedulerType::CosineAnnealing {
|
||||
t_max: 10,
|
||||
eta_min: 0.0,
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
|
||||
assert_close(scheduler.get_lr(), 1.0, "Initial LR");
|
||||
|
||||
// Cosine annealing formula: lr = eta_min + 0.5 * (base_lr - eta_min) * (1 + cos(pi * cycle_step / t_max))
|
||||
// cycle_step = step_count % t_max
|
||||
// At step 5: cycle_step = 5, cos(pi * 5/10) = cos(pi/2) = 0, lr = 0 + 0.5 * 1 * (1 + 0) = 0.5
|
||||
// At step 10: cycle_step = 0 (wrapped), cos(0) = 1, lr = 0 + 0.5 * 1 * (1 + 1) = 1.0 (restart)
|
||||
|
||||
for _ in 1..=5 {
|
||||
scheduler.step();
|
||||
}
|
||||
assert_close(scheduler.get_lr(), 0.5, "Mid-cycle LR (step 5)");
|
||||
|
||||
// At step 9: cycle_step = 9, cos(pi * 9/10) ≈ -0.951, lr ≈ 0.025
|
||||
for _ in 6..=9 {
|
||||
scheduler.step();
|
||||
}
|
||||
let lr_step9 = scheduler.get_lr();
|
||||
assert!(
|
||||
lr_step9 < 0.1,
|
||||
"Near end of cycle LR (step 9) should be small: {}",
|
||||
lr_step9
|
||||
);
|
||||
|
||||
// At step 10: warm restart (cycle_step = 0), LR goes back to base
|
||||
scheduler.step();
|
||||
assert_close(
|
||||
scheduler.get_lr(),
|
||||
1.0,
|
||||
"Restart at step 10 (cycle_step = 0)",
|
||||
);
|
||||
|
||||
// Continue new cycle
|
||||
scheduler.step();
|
||||
assert!(
|
||||
scheduler.get_lr() < 1.0,
|
||||
"Step 11 should be less than base LR"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_warmup_linear() {
|
||||
let mut scheduler = LearningRateScheduler::new(
|
||||
SchedulerType::WarmupLinear {
|
||||
warmup_steps: 5,
|
||||
total_steps: 10,
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
|
||||
assert_close(scheduler.get_lr(), 1.0, "Initial LR");
|
||||
|
||||
// Warmup phase: linear increase
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 0.2, "Step 1 (warmup)");
|
||||
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 0.4, "Step 2 (warmup)");
|
||||
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 0.6, "Step 3 (warmup)");
|
||||
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 0.8, "Step 4 (warmup)");
|
||||
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 1.0, "Step 5 (warmup end)");
|
||||
|
||||
// Decay phase: linear decrease
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 0.8, "Step 6 (decay)");
|
||||
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 0.6, "Step 7 (decay)");
|
||||
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 0.4, "Step 8 (decay)");
|
||||
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 0.2, "Step 9 (decay)");
|
||||
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 0.0, "Step 10 (decay end)");
|
||||
|
||||
// After total_steps
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 0.0, "Step 11 (after total)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduce_on_plateau() {
|
||||
let mut scheduler = LearningRateScheduler::new(
|
||||
SchedulerType::ReduceOnPlateau {
|
||||
factor: 0.5,
|
||||
patience: 3,
|
||||
min_lr: 0.0001,
|
||||
},
|
||||
0.01,
|
||||
);
|
||||
|
||||
assert_close(scheduler.get_lr(), 0.01, "Initial LR");
|
||||
|
||||
// Improving metrics: no reduction (sets best_metric, resets patience)
|
||||
scheduler.step_with_metric(1.0);
|
||||
assert_close(
|
||||
scheduler.get_lr(),
|
||||
0.01,
|
||||
"Step 1 (first metric, sets baseline)",
|
||||
);
|
||||
|
||||
scheduler.step_with_metric(0.9);
|
||||
assert_close(scheduler.get_lr(), 0.01, "Step 2 (improving)");
|
||||
|
||||
// Plateau: metric not improving (patience counter: 1, 2, 3)
|
||||
scheduler.step_with_metric(0.91);
|
||||
assert_close(scheduler.get_lr(), 0.01, "Step 3 (plateau 1)");
|
||||
|
||||
scheduler.step_with_metric(0.92);
|
||||
assert_close(scheduler.get_lr(), 0.01, "Step 4 (plateau 2)");
|
||||
|
||||
// patience=3 means after 3 non-improvements, reduce LR
|
||||
// Step 5 is the 3rd non-improvement, so LR gets reduced
|
||||
scheduler.step_with_metric(0.93);
|
||||
assert_close(
|
||||
scheduler.get_lr(),
|
||||
0.005,
|
||||
"Step 5 (patience exceeded, reduced)",
|
||||
);
|
||||
|
||||
// Counter is reset after reduction, so we need 3 more non-improvements
|
||||
scheduler.step_with_metric(0.94); // plateau 1 after reset
|
||||
assert_close(scheduler.get_lr(), 0.005, "Step 6 (plateau 1 after reset)");
|
||||
|
||||
scheduler.step_with_metric(0.95); // plateau 2
|
||||
assert_close(scheduler.get_lr(), 0.005, "Step 7 (plateau 2)");
|
||||
|
||||
scheduler.step_with_metric(0.96); // plateau 3 - triggers reduction
|
||||
assert_close(scheduler.get_lr(), 0.0025, "Step 8 (reduced again)");
|
||||
|
||||
// Test min_lr floor
|
||||
for _ in 0..20 {
|
||||
scheduler.step_with_metric(1.0);
|
||||
}
|
||||
assert!(
|
||||
scheduler.get_lr() >= 0.0001,
|
||||
"LR should not go below min_lr"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scheduler_reset() {
|
||||
let mut scheduler =
|
||||
LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.9 }, 0.1);
|
||||
|
||||
// Run for several steps
|
||||
for _ in 0..5 {
|
||||
scheduler.step();
|
||||
}
|
||||
assert!(scheduler.get_lr() < 0.1, "LR should have decayed");
|
||||
|
||||
// Reset and verify
|
||||
scheduler.reset();
|
||||
assert_close(scheduler.get_lr(), 0.1, "Reset LR");
|
||||
assert_eq!(scheduler.step_count, 0, "Reset step count");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scheduler_cloning() {
|
||||
let scheduler1 = LearningRateScheduler::new(
|
||||
SchedulerType::StepDecay {
|
||||
step_size: 10,
|
||||
gamma: 0.5,
|
||||
},
|
||||
0.01,
|
||||
);
|
||||
|
||||
let mut scheduler2 = scheduler1.clone();
|
||||
|
||||
// Advance clone
|
||||
scheduler2.step();
|
||||
|
||||
// Original should be unchanged
|
||||
assert_close(scheduler1.get_lr(), 0.01, "Original LR");
|
||||
assert_close(scheduler2.get_lr(), 0.01, "Clone LR after step");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_scheduler_types() {
|
||||
let schedulers = vec![
|
||||
(SchedulerType::Constant, 0.01),
|
||||
(
|
||||
SchedulerType::StepDecay {
|
||||
step_size: 5,
|
||||
gamma: 0.9,
|
||||
},
|
||||
0.01,
|
||||
),
|
||||
(SchedulerType::Exponential { gamma: 0.95 }, 0.01),
|
||||
(
|
||||
SchedulerType::CosineAnnealing {
|
||||
t_max: 10,
|
||||
eta_min: 0.001,
|
||||
},
|
||||
0.01,
|
||||
),
|
||||
(
|
||||
SchedulerType::WarmupLinear {
|
||||
warmup_steps: 5,
|
||||
total_steps: 20,
|
||||
},
|
||||
0.01,
|
||||
),
|
||||
(
|
||||
SchedulerType::ReduceOnPlateau {
|
||||
factor: 0.5,
|
||||
patience: 5,
|
||||
min_lr: 0.0001,
|
||||
},
|
||||
0.01,
|
||||
),
|
||||
];
|
||||
|
||||
for (sched_type, base_lr) in schedulers {
|
||||
let mut scheduler = LearningRateScheduler::new(sched_type, base_lr);
|
||||
|
||||
// All schedulers should start at base_lr
|
||||
assert_close(scheduler.get_lr(), base_lr, "Initial LR for scheduler type");
|
||||
|
||||
// All schedulers should be able to step
|
||||
let _ = scheduler.step();
|
||||
assert!(scheduler.get_lr() >= 0.0, "LR should be non-negative");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_cases() {
|
||||
// Zero learning rate
|
||||
let mut scheduler = LearningRateScheduler::new(SchedulerType::Constant, 0.0);
|
||||
assert_close(scheduler.get_lr(), 0.0, "Zero LR");
|
||||
scheduler.step();
|
||||
assert_close(scheduler.get_lr(), 0.0, "Zero LR after step");
|
||||
|
||||
// Very small gamma
|
||||
let mut scheduler =
|
||||
LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.1 }, 1.0);
|
||||
for _ in 0..10 {
|
||||
scheduler.step();
|
||||
}
|
||||
assert!(scheduler.get_lr() > 0.0, "LR should remain positive");
|
||||
assert!(scheduler.get_lr() < 1e-8, "LR should be very small");
|
||||
}
|
||||
}
|
||||
247
vendor/ruvector/crates/ruvector-gnn/src/search.rs
vendored
Normal file
247
vendor/ruvector/crates/ruvector-gnn/src/search.rs
vendored
Normal file
@@ -0,0 +1,247 @@
|
||||
use crate::layer::RuvectorLayer;
|
||||
|
||||
/// Compute cosine similarity between two vectors with improved precision
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
|
||||
|
||||
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
|
||||
// Use f64 accumulator for better precision in norm computation
|
||||
let norm_a: f32 = (a
|
||||
.iter()
|
||||
.map(|&x| (x as f64) * (x as f64))
|
||||
.sum::<f64>()
|
||||
.sqrt()) as f32;
|
||||
let norm_b: f32 = (b
|
||||
.iter()
|
||||
.map(|&x| (x as f64) * (x as f64))
|
||||
.sum::<f64>()
|
||||
.sqrt()) as f32;
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
dot_product / (norm_a * norm_b)
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply softmax with temperature scaling
|
||||
fn softmax(values: &[f32], temperature: f32) -> Vec<f32> {
|
||||
if values.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
// Scale by temperature and subtract max for numerical stability
|
||||
let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_values: Vec<f32> = values
|
||||
.iter()
|
||||
.map(|&x| ((x - max_val) / temperature).exp())
|
||||
.collect();
|
||||
|
||||
let sum: f32 = exp_values.iter().sum::<f32>().max(1e-10);
|
||||
|
||||
exp_values.iter().map(|&x| x / sum).collect()
|
||||
}
|
||||
|
||||
/// Differentiable search using soft attention mechanism
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - The query vector
|
||||
/// * `candidate_embeddings` - List of candidate embedding vectors
|
||||
/// * `k` - Number of top results to return
|
||||
/// * `temperature` - Temperature for softmax (lower = sharper, higher = smoother)
|
||||
///
|
||||
/// # Returns
|
||||
/// * Tuple of (indices, soft_weights) for top-k candidates
|
||||
pub fn differentiable_search(
|
||||
query: &[f32],
|
||||
candidate_embeddings: &[Vec<f32>],
|
||||
k: usize,
|
||||
temperature: f32,
|
||||
) -> (Vec<usize>, Vec<f32>) {
|
||||
if candidate_embeddings.is_empty() {
|
||||
return (Vec::new(), Vec::new());
|
||||
}
|
||||
|
||||
let k = k.min(candidate_embeddings.len());
|
||||
|
||||
// 1. Compute similarities using cosine similarity
|
||||
let similarities: Vec<f32> = candidate_embeddings
|
||||
.iter()
|
||||
.map(|embedding| cosine_similarity(query, embedding))
|
||||
.collect();
|
||||
|
||||
// 2. Apply softmax with temperature to get soft weights
|
||||
let soft_weights = softmax(&similarities, temperature);
|
||||
|
||||
// 3. Get top-k indices by sorting similarities
|
||||
let mut indexed_weights: Vec<(usize, f32)> = soft_weights
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &w)| (i, w))
|
||||
.collect();
|
||||
|
||||
// Sort by weight descending
|
||||
indexed_weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Take top-k
|
||||
let top_k: Vec<(usize, f32)> = indexed_weights.into_iter().take(k).collect();
|
||||
|
||||
let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
|
||||
let weights: Vec<f32> = top_k.iter().map(|&(_, w)| w).collect();
|
||||
|
||||
(indices, weights)
|
||||
}
|
||||
|
||||
/// Hierarchical forward pass through GNN layers
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - The query vector
|
||||
/// * `layer_embeddings` - Embeddings organized by layer (outer vec = layers, inner vec = nodes per layer)
|
||||
/// * `gnn_layers` - The GNN layers to process through
|
||||
///
|
||||
/// # Returns
|
||||
/// * Final embedding after hierarchical processing
|
||||
pub fn hierarchical_forward(
|
||||
query: &[f32],
|
||||
layer_embeddings: &[Vec<Vec<f32>>],
|
||||
gnn_layers: &[RuvectorLayer],
|
||||
) -> Vec<f32> {
|
||||
if layer_embeddings.is_empty() || gnn_layers.is_empty() {
|
||||
return query.to_vec();
|
||||
}
|
||||
|
||||
let mut current_embedding = query.to_vec();
|
||||
|
||||
// Process through each layer from top to bottom
|
||||
for (layer_idx, (embeddings, gnn_layer)) in
|
||||
layer_embeddings.iter().zip(gnn_layers.iter()).enumerate()
|
||||
{
|
||||
if embeddings.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find most relevant nodes at this layer using differentiable search
|
||||
let (top_indices, weights) = differentiable_search(
|
||||
¤t_embedding,
|
||||
embeddings,
|
||||
5.min(embeddings.len()), // Top-5 or all if less
|
||||
1.0, // Default temperature
|
||||
);
|
||||
|
||||
// Aggregate embeddings from top nodes using soft weights
|
||||
let mut aggregated = vec![0.0; current_embedding.len()];
|
||||
for (&idx, &weight) in top_indices.iter().zip(weights.iter()) {
|
||||
for (i, &val) in embeddings[idx].iter().enumerate() {
|
||||
if i < aggregated.len() {
|
||||
aggregated[i] += weight * val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Combine with current embedding
|
||||
let combined: Vec<f32> = current_embedding
|
||||
.iter()
|
||||
.zip(&aggregated)
|
||||
.map(|(curr, agg)| (curr + agg) / 2.0)
|
||||
.collect();
|
||||
|
||||
// Apply GNN layer transformation
|
||||
// Extract neighbor embeddings and compute edge weights
|
||||
let neighbor_embs: Vec<Vec<f32>> = top_indices
|
||||
.iter()
|
||||
.map(|&idx| embeddings[idx].clone())
|
||||
.collect();
|
||||
|
||||
let edge_weights_vec: Vec<f32> = weights.clone();
|
||||
|
||||
current_embedding = gnn_layer.forward(&combined, &neighbor_embs, &edge_weights_vec);
|
||||
}
|
||||
|
||||
current_embedding
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
|
||||
|
||||
let c = vec![1.0, 0.0, 0.0];
|
||||
let d = vec![0.0, 1.0, 0.0];
|
||||
assert!((cosine_similarity(&c, &d) - 0.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax() {
|
||||
let values = vec![1.0, 2.0, 3.0];
|
||||
let result = softmax(&values, 1.0);
|
||||
|
||||
// Sum should be 1.0
|
||||
let sum: f32 = result.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-6);
|
||||
|
||||
// Higher values should have higher probabilities
|
||||
assert!(result[2] > result[1]);
|
||||
assert!(result[1] > result[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_with_temperature() {
|
||||
let values = vec![1.0, 2.0, 3.0];
|
||||
|
||||
// Lower temperature = sharper distribution
|
||||
let sharp = softmax(&values, 0.1);
|
||||
let smooth = softmax(&values, 10.0);
|
||||
|
||||
// Sharp should have more weight on max
|
||||
assert!(sharp[2] > smooth[2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_differentiable_search() {
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let candidates = vec![
|
||||
vec![1.0, 0.0, 0.0], // Perfect match
|
||||
vec![0.9, 0.1, 0.0], // Close match
|
||||
vec![0.0, 1.0, 0.0], // Orthogonal
|
||||
];
|
||||
|
||||
let (indices, weights) = differentiable_search(&query, &candidates, 2, 1.0);
|
||||
|
||||
assert_eq!(indices.len(), 2);
|
||||
assert_eq!(weights.len(), 2);
|
||||
|
||||
// First result should be the perfect match
|
||||
assert_eq!(indices[0], 0);
|
||||
|
||||
// Weights should sum to less than or equal to 1.0 (since we took top-k)
|
||||
let sum: f32 = weights.iter().sum();
|
||||
assert!(sum <= 1.0 + 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hierarchical_forward() {
|
||||
// Use consistent dimensions throughout
|
||||
let query = vec![1.0, 0.0];
|
||||
|
||||
// Layer embeddings should match the output dimensions of each layer
|
||||
let layer_embeddings = vec![
|
||||
// First layer: embeddings are 2-dimensional (match query)
|
||||
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
|
||||
];
|
||||
|
||||
// Single GNN layer that maintains dimension
|
||||
let gnn_layers = vec![
|
||||
RuvectorLayer::new(2, 2, 1, 0.0).unwrap(), // input_dim, hidden_dim, heads, dropout
|
||||
];
|
||||
|
||||
let result = hierarchical_forward(&query, &layer_embeddings, &gnn_layers);
|
||||
|
||||
assert_eq!(result.len(), 2); // Should match hidden_dim of last layer
|
||||
}
|
||||
}
|
||||
789
vendor/ruvector/crates/ruvector-gnn/src/tensor.rs
vendored
Normal file
789
vendor/ruvector/crates/ruvector-gnn/src/tensor.rs
vendored
Normal file
@@ -0,0 +1,789 @@
|
||||
//! Tensor operations for GNN computations.
|
||||
//!
|
||||
//! Provides efficient tensor operations including:
|
||||
//! - Matrix multiplication
|
||||
//! - Element-wise operations
|
||||
//! - Activation functions
|
||||
//! - Weight initialization
|
||||
//! - Normalization
|
||||
|
||||
use crate::error::{GnnError, Result};
|
||||
use rand::Rng;
|
||||
use rand_distr::{Distribution, Normal, Uniform};
|
||||
|
||||
/// Basic tensor operations for GNN computations
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Tensor {
|
||||
/// Flattened tensor data
|
||||
pub data: Vec<f32>,
|
||||
/// Shape of the tensor (dimensions)
|
||||
pub shape: Vec<usize>,
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Create a new tensor from data and shape
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - Flattened tensor data
|
||||
/// * `shape` - Dimensions of the tensor
|
||||
///
|
||||
/// # Returns
|
||||
/// A new `Tensor` instance
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `GnnError::InvalidShape` if data length doesn't match shape
|
||||
pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Result<Self> {
|
||||
let expected_len: usize = shape.iter().product();
|
||||
if data.len() != expected_len {
|
||||
return Err(GnnError::invalid_shape(format!(
|
||||
"Data length {} doesn't match shape {:?} (expected {})",
|
||||
data.len(),
|
||||
shape,
|
||||
expected_len
|
||||
)));
|
||||
}
|
||||
Ok(Self { data, shape })
|
||||
}
|
||||
|
||||
/// Create a zero-filled tensor with the given shape
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `shape` - Dimensions of the tensor
|
||||
///
|
||||
/// # Returns
|
||||
/// A new zero-filled `Tensor`
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `GnnError::InvalidShape` if shape is empty or contains zero
|
||||
pub fn zeros(shape: &[usize]) -> Result<Self> {
|
||||
if shape.is_empty() || shape.iter().any(|&d| d == 0) {
|
||||
return Err(GnnError::invalid_shape(format!(
|
||||
"Invalid shape: {:?}",
|
||||
shape
|
||||
)));
|
||||
}
|
||||
let size: usize = shape.iter().product();
|
||||
Ok(Self {
|
||||
data: vec![0.0; size],
|
||||
shape: shape.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a 1D tensor from a vector
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - Vector data
|
||||
///
|
||||
/// # Returns
|
||||
/// A new 1D `Tensor`
|
||||
pub fn from_vec(data: Vec<f32>) -> Self {
|
||||
let len = data.len();
|
||||
Self {
|
||||
data,
|
||||
shape: vec![len],
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute dot product with another tensor (both must be 1D)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `other` - Another tensor to compute dot product with
|
||||
///
|
||||
/// # Returns
|
||||
/// The dot product as a scalar
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `GnnError::DimensionMismatch` if tensors are not 1D or have different lengths
|
||||
pub fn dot(&self, other: &Tensor) -> Result<f32> {
|
||||
if self.shape.len() != 1 || other.shape.len() != 1 {
|
||||
return Err(GnnError::dimension_mismatch(
|
||||
"1D tensors",
|
||||
format!("{}D and {}D", self.shape.len(), other.shape.len()),
|
||||
));
|
||||
}
|
||||
if self.shape[0] != other.shape[0] {
|
||||
return Err(GnnError::dimension_mismatch(
|
||||
format!("length {}", self.shape[0]),
|
||||
format!("length {}", other.shape[0]),
|
||||
));
|
||||
}
|
||||
|
||||
let result = self
|
||||
.data
|
||||
.iter()
|
||||
.zip(other.data.iter())
|
||||
.map(|(a, b)| a * b)
|
||||
.sum();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Matrix multiplication
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `other` - Another tensor to multiply with
|
||||
///
|
||||
/// # Returns
|
||||
/// The result of matrix multiplication
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `GnnError::DimensionMismatch` if dimensions are incompatible
|
||||
pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
|
||||
// Support 1D x 1D (dot product), 2D x 1D, 2D x 2D
|
||||
match (self.shape.len(), other.shape.len()) {
|
||||
(1, 1) => {
|
||||
let dot = self.dot(other)?;
|
||||
Ok(Tensor::from_vec(vec![dot]))
|
||||
}
|
||||
(2, 1) => {
|
||||
// Matrix-vector multiplication
|
||||
let m = self.shape[0];
|
||||
let n = self.shape[1];
|
||||
if n != other.shape[0] {
|
||||
return Err(GnnError::dimension_mismatch(
|
||||
format!("{}x{}", m, n),
|
||||
format!("vector of length {}", other.shape[0]),
|
||||
));
|
||||
}
|
||||
|
||||
let mut result = vec![0.0; m];
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
result[i] += self.data[i * n + j] * other.data[j];
|
||||
}
|
||||
}
|
||||
Ok(Tensor::from_vec(result))
|
||||
}
|
||||
(2, 2) => {
|
||||
// Matrix-matrix multiplication
|
||||
let m = self.shape[0];
|
||||
let n = self.shape[1];
|
||||
let p = other.shape[1];
|
||||
|
||||
if n != other.shape[0] {
|
||||
return Err(GnnError::dimension_mismatch(
|
||||
format!("{}x{}", m, n),
|
||||
format!("{}x{}", other.shape[0], p),
|
||||
));
|
||||
}
|
||||
|
||||
let mut result = vec![0.0; m * p];
|
||||
for i in 0..m {
|
||||
for j in 0..p {
|
||||
for k in 0..n {
|
||||
result[i * p + j] += self.data[i * n + k] * other.data[k * p + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
Tensor::new(result, vec![m, p])
|
||||
}
|
||||
_ => Err(GnnError::dimension_mismatch(
|
||||
"1D or 2D tensors",
|
||||
format!("{}D and {}D", self.shape.len(), other.shape.len()),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Element-wise addition
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `other` - Another tensor to add
|
||||
///
|
||||
/// # Returns
|
||||
/// The sum of the two tensors
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `GnnError::DimensionMismatch` if shapes don't match
|
||||
pub fn add(&self, other: &Tensor) -> Result<Tensor> {
|
||||
if self.shape != other.shape {
|
||||
return Err(GnnError::dimension_mismatch(
|
||||
format!("{:?}", self.shape),
|
||||
format!("{:?}", other.shape),
|
||||
));
|
||||
}
|
||||
|
||||
let result: Vec<f32> = self
|
||||
.data
|
||||
.iter()
|
||||
.zip(other.data.iter())
|
||||
.map(|(a, b)| a + b)
|
||||
.collect();
|
||||
|
||||
Tensor::new(result, self.shape.clone())
|
||||
}
|
||||
|
||||
/// Scalar multiplication
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `scalar` - Scalar value to multiply by
|
||||
///
|
||||
/// # Returns
|
||||
/// A new tensor with all elements scaled
|
||||
pub fn scale(&self, scalar: f32) -> Tensor {
|
||||
let result: Vec<f32> = self.data.iter().map(|&x| x * scalar).collect();
|
||||
Tensor {
|
||||
data: result,
|
||||
shape: self.shape.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// ReLU activation function (max(0, x))
|
||||
///
|
||||
/// # Returns
|
||||
/// A new tensor with ReLU applied element-wise
|
||||
pub fn relu(&self) -> Tensor {
|
||||
let result: Vec<f32> = self.data.iter().map(|&x| x.max(0.0)).collect();
|
||||
Tensor {
|
||||
data: result,
|
||||
shape: self.shape.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Sigmoid activation function (1 / (1 + e^(-x))) with numerical stability
|
||||
///
|
||||
/// # Returns
|
||||
/// A new tensor with sigmoid applied element-wise
|
||||
pub fn sigmoid(&self) -> Tensor {
|
||||
let result: Vec<f32> = self
|
||||
.data
|
||||
.iter()
|
||||
.map(|&x| {
|
||||
if x > 0.0 {
|
||||
1.0 / (1.0 + (-x).exp())
|
||||
} else {
|
||||
let ex = x.exp();
|
||||
ex / (1.0 + ex)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
Tensor {
|
||||
data: result,
|
||||
shape: self.shape.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Tanh activation function
|
||||
///
|
||||
/// # Returns
|
||||
/// A new tensor with tanh applied element-wise
|
||||
pub fn tanh(&self) -> Tensor {
|
||||
let result: Vec<f32> = self.data.iter().map(|&x| x.tanh()).collect();
|
||||
Tensor {
|
||||
data: result,
|
||||
shape: self.shape.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute L2 norm (Euclidean norm) with improved precision
|
||||
///
|
||||
/// # Returns
|
||||
/// The L2 norm of the tensor
|
||||
pub fn l2_norm(&self) -> f32 {
|
||||
// Use f64 accumulator for better numerical precision
|
||||
let sum_squares: f64 = self.data.iter().map(|&x| (x as f64) * (x as f64)).sum();
|
||||
(sum_squares.sqrt()) as f32
|
||||
}
|
||||
|
||||
/// Normalize the tensor to unit L2 norm
|
||||
///
|
||||
/// # Returns
|
||||
/// A normalized tensor
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `GnnError::InvalidInput` if norm is zero
|
||||
pub fn normalize(&self) -> Result<Tensor> {
|
||||
let norm = self.l2_norm();
|
||||
if norm == 0.0 {
|
||||
return Err(GnnError::invalid_input(
|
||||
"Cannot normalize zero vector".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(self.scale(1.0 / norm))
|
||||
}
|
||||
|
||||
/// Get a slice view of the tensor data
|
||||
///
|
||||
/// # Returns
|
||||
/// A slice reference to the underlying data
|
||||
pub fn as_slice(&self) -> &[f32] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Consume the tensor and return the underlying vector
|
||||
///
|
||||
/// # Returns
|
||||
/// The vector containing the tensor data
|
||||
pub fn into_vec(self) -> Vec<f32> {
|
||||
self.data
|
||||
}
|
||||
|
||||
/// Get the number of elements in the tensor
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Check if the tensor is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Xavier/Glorot initialization for neural network weights
|
||||
///
|
||||
/// Samples from uniform distribution U(-a, a) where a = sqrt(6 / (fan_in + fan_out))
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `fan_in` - Number of input units
|
||||
/// * `fan_out` - Number of output units
|
||||
///
|
||||
/// # Returns
|
||||
/// A vector of initialized weights
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if fan_in or fan_out is 0
|
||||
pub fn xavier_init(fan_in: usize, fan_out: usize) -> Vec<f32> {
|
||||
assert!(
|
||||
fan_in > 0 && fan_out > 0,
|
||||
"fan_in and fan_out must be positive"
|
||||
);
|
||||
|
||||
let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
|
||||
let uniform = Uniform::new(-limit, limit);
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
(0..fan_in * fan_out)
|
||||
.map(|_| uniform.sample(&mut rng))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// He initialization for ReLU networks
|
||||
///
|
||||
/// Samples from normal distribution N(0, sqrt(2 / fan_in))
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `fan_in` - Number of input units
|
||||
///
|
||||
/// # Returns
|
||||
/// A vector of initialized weights
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if fan_in is 0
|
||||
pub fn he_init(fan_in: usize) -> Vec<f32> {
|
||||
assert!(fan_in > 0, "fan_in must be positive");
|
||||
|
||||
let std_dev = (2.0 / fan_in as f32).sqrt();
|
||||
let normal = Normal::new(0.0, std_dev).expect("Invalid normal distribution parameters");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
(0..fan_in).map(|_| normal.sample(&mut rng)).collect()
|
||||
}
|
||||
|
||||
/// Element-wise (Hadamard) product
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `a` - First vector
|
||||
/// * `b` - Second vector
|
||||
///
|
||||
/// # Returns
|
||||
/// Element-wise product of the two vectors
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if vectors have different lengths
|
||||
pub fn hadamard_product(a: &[f32], b: &[f32]) -> Vec<f32> {
|
||||
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
|
||||
}
|
||||
|
||||
/// Element-wise vector addition
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `a` - First vector
|
||||
/// * `b` - Second vector
|
||||
///
|
||||
/// # Returns
|
||||
/// Element-wise sum of the two vectors
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if vectors have different lengths
|
||||
pub fn vector_add(a: &[f32], b: &[f32]) -> Vec<f32> {
|
||||
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
|
||||
}
|
||||
|
||||
/// Scalar multiplication of a vector
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `v` - Input vector
|
||||
/// * `scalar` - Scalar multiplier
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector with all elements multiplied by scalar
|
||||
pub fn vector_scale(v: &[f32], scalar: f32) -> Vec<f32> {
|
||||
v.iter().map(|&x| x * scalar).collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const EPSILON: f32 = 1e-6;
|
||||
|
||||
fn assert_vec_approx_eq(a: &[f32], b: &[f32], epsilon: f32) {
|
||||
assert_eq!(a.len(), b.len(), "Vectors have different lengths");
|
||||
for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
|
||||
assert!(
|
||||
(x - y).abs() < epsilon,
|
||||
"Values at index {} differ: {} vs {} (diff: {})",
|
||||
i,
|
||||
x,
|
||||
y,
|
||||
(x - y).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_new() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let tensor = Tensor::new(data.clone(), vec![2, 2]).unwrap();
|
||||
assert_eq!(tensor.data, data);
|
||||
assert_eq!(tensor.shape, vec![2, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_new_invalid_shape() {
|
||||
let data = vec![1.0, 2.0, 3.0];
|
||||
let result = Tensor::new(data, vec![2, 2]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_zeros() {
|
||||
let tensor = Tensor::zeros(&[3, 2]).unwrap();
|
||||
assert_eq!(tensor.data, vec![0.0; 6]);
|
||||
assert_eq!(tensor.shape, vec![3, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_zeros_invalid_shape() {
|
||||
let result = Tensor::zeros(&[0, 2]);
|
||||
assert!(result.is_err());
|
||||
|
||||
let result = Tensor::zeros(&[]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_from_vec() {
|
||||
let data = vec![1.0, 2.0, 3.0];
|
||||
let tensor = Tensor::from_vec(data.clone());
|
||||
assert_eq!(tensor.data, data);
|
||||
assert_eq!(tensor.shape, vec![3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product() {
|
||||
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
|
||||
let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
|
||||
let result = a.dot(&b).unwrap();
|
||||
assert!((result - 32.0).abs() < EPSILON); // 1*4 + 2*5 + 3*6 = 32
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_dimension_mismatch() {
|
||||
let a = Tensor::from_vec(vec![1.0, 2.0]);
|
||||
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
|
||||
let result = a.dot(&b);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_1d() {
|
||||
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
|
||||
let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
|
||||
let result = a.matmul(&b).unwrap();
|
||||
assert_eq!(result.shape, vec![1]);
|
||||
assert!((result.data[0] - 32.0).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_2d_1d() {
|
||||
// Matrix-vector multiplication
|
||||
let mat = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
|
||||
let vec = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
|
||||
let result = mat.matmul(&vec).unwrap();
|
||||
|
||||
assert_eq!(result.shape, vec![2]);
|
||||
// [1,2,3] * [1,2,3]' = 14
|
||||
// [4,5,6] * [1,2,3]' = 32
|
||||
assert_vec_approx_eq(&result.data, &[14.0, 32.0], EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_2d_2d() {
|
||||
// Matrix-matrix multiplication
|
||||
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
|
||||
let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
|
||||
let result = a.matmul(&b).unwrap();
|
||||
|
||||
assert_eq!(result.shape, vec![2, 2]);
|
||||
// [[1,2], [3,4]] * [[5,6], [7,8]] = [[19,22], [43,50]]
|
||||
assert_vec_approx_eq(&result.data, &[19.0, 22.0, 43.0, 50.0], EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_dimension_mismatch() {
|
||||
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
|
||||
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
|
||||
let result = a.matmul(&b);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add() {
|
||||
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
|
||||
let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
|
||||
let result = a.add(&b).unwrap();
|
||||
assert_eq!(result.data, vec![5.0, 7.0, 9.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_dimension_mismatch() {
|
||||
let a = Tensor::from_vec(vec![1.0, 2.0]);
|
||||
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
|
||||
let result = a.add(&b);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scale() {
|
||||
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
|
||||
let result = tensor.scale(2.0);
|
||||
assert_eq!(result.data, vec![2.0, 4.0, 6.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu() {
|
||||
let tensor = Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0]);
|
||||
let result = tensor.relu();
|
||||
assert_eq!(result.data, vec![0.0, 0.0, 1.0, 2.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sigmoid() {
|
||||
let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
|
||||
let result = tensor.sigmoid();
|
||||
|
||||
assert!((result.data[0] - 0.5).abs() < EPSILON);
|
||||
assert!((result.data[1] - 0.7310586).abs() < EPSILON);
|
||||
assert!((result.data[2] - 0.26894143).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tanh() {
|
||||
let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
|
||||
let result = tensor.tanh();
|
||||
|
||||
assert!((result.data[0] - 0.0).abs() < EPSILON);
|
||||
assert!((result.data[1] - 0.7615942).abs() < EPSILON);
|
||||
assert!((result.data[2] - (-0.7615942)).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_l2_norm() {
|
||||
let tensor = Tensor::from_vec(vec![3.0, 4.0]);
|
||||
let norm = tensor.l2_norm();
|
||||
assert!((norm - 5.0).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize() {
|
||||
let tensor = Tensor::from_vec(vec![3.0, 4.0]);
|
||||
let result = tensor.normalize().unwrap();
|
||||
assert_vec_approx_eq(&result.data, &[0.6, 0.8], EPSILON);
|
||||
assert!((result.l2_norm() - 1.0).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_zero_vector() {
|
||||
let tensor = Tensor::from_vec(vec![0.0, 0.0]);
|
||||
let result = tensor.normalize();
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_as_slice() {
|
||||
let data = vec![1.0, 2.0, 3.0];
|
||||
let tensor = Tensor::from_vec(data.clone());
|
||||
assert_eq!(tensor.as_slice(), &data[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_into_vec() {
|
||||
let data = vec![1.0, 2.0, 3.0];
|
||||
let tensor = Tensor::from_vec(data.clone());
|
||||
assert_eq!(tensor.into_vec(), data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_len() {
|
||||
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
|
||||
assert_eq!(tensor.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_empty() {
|
||||
let tensor = Tensor::from_vec(vec![]);
|
||||
assert!(tensor.is_empty());
|
||||
|
||||
let tensor = Tensor::from_vec(vec![1.0]);
|
||||
assert!(!tensor.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xavier_init() {
|
||||
let weights = xavier_init(100, 50);
|
||||
assert_eq!(weights.len(), 5000);
|
||||
|
||||
// Check that values are in expected range
|
||||
let limit = (6.0 / 150.0_f32).sqrt();
|
||||
for &w in &weights {
|
||||
assert!(w >= -limit && w <= limit);
|
||||
}
|
||||
|
||||
// Check distribution properties
|
||||
let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
|
||||
assert!(mean.abs() < 0.1); // Mean should be close to 0
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "fan_in and fan_out must be positive")]
|
||||
fn test_xavier_init_zero_fan() {
|
||||
xavier_init(0, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_he_init() {
|
||||
let weights = he_init(100);
|
||||
assert_eq!(weights.len(), 100);
|
||||
|
||||
// Check distribution properties
|
||||
let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
|
||||
assert!(mean.abs() < 0.2); // Mean should be close to 0
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "fan_in must be positive")]
|
||||
fn test_he_init_zero_fan() {
|
||||
he_init(0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hadamard_product() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![4.0, 5.0, 6.0];
|
||||
let result = hadamard_product(&a, &b);
|
||||
assert_eq!(result, vec![4.0, 10.0, 18.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Vectors must have the same length")]
|
||||
fn test_hadamard_product_length_mismatch() {
|
||||
let a = vec![1.0, 2.0];
|
||||
let b = vec![1.0, 2.0, 3.0];
|
||||
hadamard_product(&a, &b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_add() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![4.0, 5.0, 6.0];
|
||||
let result = vector_add(&a, &b);
|
||||
assert_eq!(result, vec![5.0, 7.0, 9.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Vectors must have the same length")]
|
||||
fn test_vector_add_length_mismatch() {
|
||||
let a = vec![1.0, 2.0];
|
||||
let b = vec![1.0, 2.0, 3.0];
|
||||
vector_add(&a, &b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_scale() {
|
||||
let v = vec![1.0, 2.0, 3.0];
|
||||
let result = vector_scale(&v, 2.5);
|
||||
assert_vec_approx_eq(&result, &[2.5, 5.0, 7.5], EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complex_operations() {
|
||||
// Test chaining operations
|
||||
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
|
||||
let b = Tensor::from_vec(vec![0.5, 1.0, 1.5]);
|
||||
|
||||
let sum = a.add(&b).unwrap();
|
||||
let scaled = sum.scale(2.0);
|
||||
let activated = scaled.relu();
|
||||
let normalized = activated.normalize().unwrap();
|
||||
|
||||
assert!((normalized.l2_norm() - 1.0).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_case_single_element() {
|
||||
let tensor = Tensor::from_vec(vec![5.0]);
|
||||
assert_eq!(tensor.len(), 1);
|
||||
assert_eq!(tensor.l2_norm(), 5.0);
|
||||
|
||||
let normalized = tensor.normalize().unwrap();
|
||||
assert_vec_approx_eq(&normalized.data, &[1.0], EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_case_negative_values() {
|
||||
let tensor = Tensor::from_vec(vec![-3.0, -4.0]);
|
||||
assert!((tensor.l2_norm() - 5.0).abs() < EPSILON);
|
||||
|
||||
let relu_result = tensor.relu();
|
||||
assert_eq!(relu_result.data, vec![0.0, 0.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_matrix_multiplication() {
|
||||
// 10x10 matrix multiplication
|
||||
let size = 10;
|
||||
let a_data: Vec<f32> = (0..size * size).map(|i| i as f32).collect();
|
||||
let b_data: Vec<f32> = (0..size * size).map(|i| (i % 2) as f32).collect();
|
||||
|
||||
let a = Tensor::new(a_data, vec![size, size]).unwrap();
|
||||
let b = Tensor::new(b_data, vec![size, size]).unwrap();
|
||||
|
||||
let result = a.matmul(&b).unwrap();
|
||||
assert_eq!(result.shape, vec![size, size]);
|
||||
assert_eq!(result.len(), size * size);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_activation_functions_range() {
|
||||
let tensor = Tensor::from_vec(vec![-10.0, -1.0, 0.0, 1.0, 10.0]);
|
||||
|
||||
// Sigmoid should be in (0, 1)
|
||||
let sigmoid = tensor.sigmoid();
|
||||
for &val in &sigmoid.data {
|
||||
assert!(val > 0.0 && val < 1.0);
|
||||
}
|
||||
|
||||
// Tanh should be in [-1, 1]
|
||||
let tanh = tensor.tanh();
|
||||
for &val in &tanh.data {
|
||||
assert!(val >= -1.0 && val <= 1.0);
|
||||
}
|
||||
|
||||
// ReLU should be non-negative
|
||||
let relu = tensor.relu();
|
||||
for &val in &relu.data {
|
||||
assert!(val >= 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
1367
vendor/ruvector/crates/ruvector-gnn/src/training.rs
vendored
Normal file
1367
vendor/ruvector/crates/ruvector-gnn/src/training.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user