Files
wifi-densepose/vendor/ruvector/docs/gnn/gnn-layer-implementation.md

5.1 KiB
Raw Permalink Blame History

Ruvector GNN Layer Implementation

Overview

Implemented a complete Graph Neural Network (GNN) layer for Ruvector that operates on HNSW topology, providing message passing, attention mechanisms, and recurrent state updates.

Location

Implementation: /home/user/ruvector/crates/ruvector-gnn/src/layer.rs

Components Implemented

1. Linear Layer

  • Purpose: Weight matrix multiplication for transformations
  • Initialization: Xavier/Glorot initialization for stable gradients
  • API:
    Linear::new(input_dim: usize, output_dim: usize) -> Self
    forward(&self, input: &[f32]) -> Vec<f32>
    

2. Layer Normalization

  • Purpose: Normalize activations for stable training
  • Features: Learnable scale (gamma) and shift (beta) parameters
  • API:
    LayerNorm::new(dim: usize, eps: f32) -> Self
    forward(&self, input: &[f32]) -> Vec<f32>
    

3. Multi-Head Attention

  • Purpose: Attention-based neighbor aggregation
  • Features:
    • Separate Q, K, V projections
    • Scaled dot-product attention
    • Multi-head parallelization
  • API:
    MultiHeadAttention::new(embed_dim: usize, num_heads: usize) -> Self
    forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32>
    

4. GRU Cell (Gated Recurrent Unit)

  • Purpose: State updates with gating mechanisms
  • Features:
    • Update gate: Controls how much of new information to accept
    • Reset gate: Controls how much of past information to forget
    • Candidate state: Proposes new hidden state
  • API:
    GRUCell::new(input_dim: usize, hidden_dim: usize) -> Self
    forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32>
    

5. RuvectorLayer (Main GNN Layer)

  • Purpose: Complete GNN layer combining all components
  • Architecture:
    1. Message passing through linear transformations
    2. Attention-based neighbor aggregation
    3. Weighted message aggregation using edge weights
    4. GRU-based state update
    5. Dropout regularization
    6. Layer normalization
  • API:
    RuvectorLayer::new(
        input_dim: usize,
        hidden_dim: usize, 
        heads: usize,
        dropout: f32
    ) -> Self
    
    forward(
        &self,
        node_embedding: &[f32],
        neighbor_embeddings: &[Vec<f32>],
        edge_weights: &[f32]
    ) -> Vec<f32>
    

Usage Example

use ruvector_gnn::RuvectorLayer;

// Create GNN layer: 128-dim input -> 256-dim hidden, 4 attention heads, 10% dropout
let layer = RuvectorLayer::new(128, 256, 4, 0.1);

// Node and neighbor embeddings
let node = vec![0.5; 128];
let neighbors = vec![
    vec![0.3; 128],
    vec![0.7; 128],
];
let edge_weights = vec![0.8, 0.6]; // e.g., inverse distances

// Forward pass
let updated_embedding = layer.forward(&node, &neighbors, &edge_weights);
// Output: 256-dimensional embedding

Key Features

  1. HNSW-Aware: Designed to operate on HNSW graph topology
  2. Message Passing: Transforms and aggregates neighbor information
  3. Attention Mechanism: Learns importance of different neighbors
  4. Edge Weights: Incorporates graph structure (e.g., distances)
  5. State Updates: GRU cells maintain and update node states
  6. Normalization: Layer norm for training stability
  7. Regularization: Dropout to prevent overfitting

Mathematical Operations

Forward Pass Flow:

1. node_msg = W_msg × node_embedding
2. neighbor_msgs = [W_msg × neighbor_i for all neighbors]
3. attention_out = MultiHeadAttention(node_msg, neighbor_msgs)
4. weighted_msgs = Σ(weight_i × neighbor_msg_i) / Σ(weights)
5. combined = attention_out + weighted_msgs
6. aggregated = W_agg × combined
7. updated = GRU(aggregated, node_msg)
8. dropped = Dropout(updated)
9. output = LayerNorm(dropped)

Testing

All components include comprehensive unit tests:

  • ✓ Linear layer transformation
  • ✓ Layer normalization (zero mean check)
  • ✓ Multi-head attention with multiple neighbors
  • ✓ GRU state updates
  • ✓ RuvectorLayer with neighbors
  • ✓ RuvectorLayer without neighbors (edge case)

Test Results: All 6 layer tests passing

Integration

The layer integrates with existing ruvector-gnn components:

  • Used in search.rs for hierarchical forward passes
  • Compatible with HNSW topology from ruvector-core
  • Supports differentiable search operations

Dependencies

  • ndarray: Matrix operations and linear algebra
  • rand/rand_distr: Weight initialization
  • serde: Serialization support

Performance Considerations

  1. Xavier Initialization: Helps gradient flow during training
  2. Batch Operations: Uses ndarray for efficient matrix ops
  3. Attention Caching: Could be added for repeated queries
  4. Edge Weight Normalization: Ensures stable aggregation

Future Enhancements

  1. Actual dropout sampling (current: deterministic scaling)
  2. Gradient computation for training
  3. Batch processing support
  4. GPU acceleration via specialized backends
  5. Additional aggregation schemes (mean, max, sum)

Status: Implemented and tested successfully Build: Compiles without errors (warnings: documentation only) Tests: 26/26 tests passing