5.1 KiB
5.1 KiB
Ruvector GNN Layer Implementation
Overview
Implemented a complete Graph Neural Network (GNN) layer for Ruvector that operates on HNSW topology, providing message passing, attention mechanisms, and recurrent state updates.
Location
Implementation: /home/user/ruvector/crates/ruvector-gnn/src/layer.rs
Components Implemented
1. Linear Layer
- Purpose: Weight matrix multiplication for transformations
- Initialization: Xavier/Glorot initialization for stable gradients
- API:
Linear::new(input_dim: usize, output_dim: usize) -> Self forward(&self, input: &[f32]) -> Vec<f32>
2. Layer Normalization
- Purpose: Normalize activations for stable training
- Features: Learnable scale (gamma) and shift (beta) parameters
- API:
LayerNorm::new(dim: usize, eps: f32) -> Self forward(&self, input: &[f32]) -> Vec<f32>
3. Multi-Head Attention
- Purpose: Attention-based neighbor aggregation
- Features:
- Separate Q, K, V projections
- Scaled dot-product attention
- Multi-head parallelization
- API:
MultiHeadAttention::new(embed_dim: usize, num_heads: usize) -> Self forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32>
4. GRU Cell (Gated Recurrent Unit)
- Purpose: State updates with gating mechanisms
- Features:
- Update gate: Controls how much of new information to accept
- Reset gate: Controls how much of past information to forget
- Candidate state: Proposes new hidden state
- API:
GRUCell::new(input_dim: usize, hidden_dim: usize) -> Self forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32>
5. RuvectorLayer (Main GNN Layer)
- Purpose: Complete GNN layer combining all components
- Architecture:
- Message passing through linear transformations
- Attention-based neighbor aggregation
- Weighted message aggregation using edge weights
- GRU-based state update
- Dropout regularization
- Layer normalization
- API:
RuvectorLayer::new( input_dim: usize, hidden_dim: usize, heads: usize, dropout: f32 ) -> Self forward( &self, node_embedding: &[f32], neighbor_embeddings: &[Vec<f32>], edge_weights: &[f32] ) -> Vec<f32>
Usage Example
use ruvector_gnn::RuvectorLayer;
// Create GNN layer: 128-dim input -> 256-dim hidden, 4 attention heads, 10% dropout
let layer = RuvectorLayer::new(128, 256, 4, 0.1);
// Node and neighbor embeddings
let node = vec![0.5; 128];
let neighbors = vec![
vec![0.3; 128],
vec![0.7; 128],
];
let edge_weights = vec![0.8, 0.6]; // e.g., inverse distances
// Forward pass
let updated_embedding = layer.forward(&node, &neighbors, &edge_weights);
// Output: 256-dimensional embedding
Key Features
- HNSW-Aware: Designed to operate on HNSW graph topology
- Message Passing: Transforms and aggregates neighbor information
- Attention Mechanism: Learns importance of different neighbors
- Edge Weights: Incorporates graph structure (e.g., distances)
- State Updates: GRU cells maintain and update node states
- Normalization: Layer norm for training stability
- Regularization: Dropout to prevent overfitting
Mathematical Operations
Forward Pass Flow:
1. node_msg = W_msg × node_embedding
2. neighbor_msgs = [W_msg × neighbor_i for all neighbors]
3. attention_out = MultiHeadAttention(node_msg, neighbor_msgs)
4. weighted_msgs = Σ(weight_i × neighbor_msg_i) / Σ(weights)
5. combined = attention_out + weighted_msgs
6. aggregated = W_agg × combined
7. updated = GRU(aggregated, node_msg)
8. dropped = Dropout(updated)
9. output = LayerNorm(dropped)
Testing
All components include comprehensive unit tests:
- ✓ Linear layer transformation
- ✓ Layer normalization (zero mean check)
- ✓ Multi-head attention with multiple neighbors
- ✓ GRU state updates
- ✓ RuvectorLayer with neighbors
- ✓ RuvectorLayer without neighbors (edge case)
Test Results: All 6 layer tests passing
Integration
The layer integrates with existing ruvector-gnn components:
- Used in
search.rsfor hierarchical forward passes - Compatible with HNSW topology from
ruvector-core - Supports differentiable search operations
Dependencies
- ndarray: Matrix operations and linear algebra
- rand/rand_distr: Weight initialization
- serde: Serialization support
Performance Considerations
- Xavier Initialization: Helps gradient flow during training
- Batch Operations: Uses ndarray for efficient matrix ops
- Attention Caching: Could be added for repeated queries
- Edge Weight Normalization: Ensures stable aggregation
Future Enhancements
- Actual dropout sampling (current: deterministic scaling)
- Gradient computation for training
- Batch processing support
- GPU acceleration via specialized backends
- Additional aggregation schemes (mean, max, sum)
Status: ✅ Implemented and tested successfully Build: ✅ Compiles without errors (warnings: documentation only) Tests: ✅ 26/26 tests passing