Files
wifi-densepose/vendor/ruvector/docs/gnn/gnn-layer-implementation.md

172 lines
5.1 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Ruvector GNN Layer Implementation
## Overview
Implemented a complete Graph Neural Network (GNN) layer for Ruvector that operates on HNSW topology, providing message passing, attention mechanisms, and recurrent state updates.
## Location
**Implementation:** `/home/user/ruvector/crates/ruvector-gnn/src/layer.rs`
## Components Implemented
### 1. Linear Layer
- **Purpose:** Weight matrix multiplication for transformations
- **Initialization:** Xavier/Glorot initialization for stable gradients
- **API:**
```rust
Linear::new(input_dim: usize, output_dim: usize) -> Self
forward(&self, input: &[f32]) -> Vec<f32>
```
### 2. Layer Normalization
- **Purpose:** Normalize activations for stable training
- **Features:** Learnable scale (gamma) and shift (beta) parameters
- **API:**
```rust
LayerNorm::new(dim: usize, eps: f32) -> Self
forward(&self, input: &[f32]) -> Vec<f32>
```
### 3. Multi-Head Attention
- **Purpose:** Attention-based neighbor aggregation
- **Features:**
- Separate Q, K, V projections
- Scaled dot-product attention
- Multi-head parallelization
- **API:**
```rust
MultiHeadAttention::new(embed_dim: usize, num_heads: usize) -> Self
forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32>
```
### 4. GRU Cell (Gated Recurrent Unit)
- **Purpose:** State updates with gating mechanisms
- **Features:**
- Update gate: Controls how much of new information to accept
- Reset gate: Controls how much of past information to forget
- Candidate state: Proposes new hidden state
- **API:**
```rust
GRUCell::new(input_dim: usize, hidden_dim: usize) -> Self
forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32>
```
### 5. RuvectorLayer (Main GNN Layer)
- **Purpose:** Complete GNN layer combining all components
- **Architecture:**
1. Message passing through linear transformations
2. Attention-based neighbor aggregation
3. Weighted message aggregation using edge weights
4. GRU-based state update
5. Dropout regularization
6. Layer normalization
- **API:**
```rust
RuvectorLayer::new(
input_dim: usize,
hidden_dim: usize,
heads: usize,
dropout: f32
) -> Self
forward(
&self,
node_embedding: &[f32],
neighbor_embeddings: &[Vec<f32>],
edge_weights: &[f32]
) -> Vec<f32>
```
## Usage Example
```rust
use ruvector_gnn::RuvectorLayer;
// Create GNN layer: 128-dim input -> 256-dim hidden, 4 attention heads, 10% dropout
let layer = RuvectorLayer::new(128, 256, 4, 0.1);
// Node and neighbor embeddings
let node = vec![0.5; 128];
let neighbors = vec![
vec![0.3; 128],
vec![0.7; 128],
];
let edge_weights = vec![0.8, 0.6]; // e.g., inverse distances
// Forward pass
let updated_embedding = layer.forward(&node, &neighbors, &edge_weights);
// Output: 256-dimensional embedding
```
## Key Features
1. **HNSW-Aware:** Designed to operate on HNSW graph topology
2. **Message Passing:** Transforms and aggregates neighbor information
3. **Attention Mechanism:** Learns importance of different neighbors
4. **Edge Weights:** Incorporates graph structure (e.g., distances)
5. **State Updates:** GRU cells maintain and update node states
6. **Normalization:** Layer norm for training stability
7. **Regularization:** Dropout to prevent overfitting
## Mathematical Operations
### Forward Pass Flow:
```
1. node_msg = W_msg × node_embedding
2. neighbor_msgs = [W_msg × neighbor_i for all neighbors]
3. attention_out = MultiHeadAttention(node_msg, neighbor_msgs)
4. weighted_msgs = Σ(weight_i × neighbor_msg_i) / Σ(weights)
5. combined = attention_out + weighted_msgs
6. aggregated = W_agg × combined
7. updated = GRU(aggregated, node_msg)
8. dropped = Dropout(updated)
9. output = LayerNorm(dropped)
```
## Testing
All components include comprehensive unit tests:
- ✓ Linear layer transformation
- ✓ Layer normalization (zero mean check)
- ✓ Multi-head attention with multiple neighbors
- ✓ GRU state updates
- ✓ RuvectorLayer with neighbors
- ✓ RuvectorLayer without neighbors (edge case)
**Test Results:** All 6 layer tests passing
## Integration
The layer integrates with existing ruvector-gnn components:
- Used in `search.rs` for hierarchical forward passes
- Compatible with HNSW topology from `ruvector-core`
- Supports differentiable search operations
## Dependencies
- **ndarray:** Matrix operations and linear algebra
- **rand/rand_distr:** Weight initialization
- **serde:** Serialization support
## Performance Considerations
1. **Xavier Initialization:** Helps gradient flow during training
2. **Batch Operations:** Uses ndarray for efficient matrix ops
3. **Attention Caching:** Could be added for repeated queries
4. **Edge Weight Normalization:** Ensures stable aggregation
## Future Enhancements
1. Actual dropout sampling (current: deterministic scaling)
2. Gradient computation for training
3. Batch processing support
4. GPU acceleration via specialized backends
5. Additional aggregation schemes (mean, max, sum)
---
**Status:** ✅ Implemented and tested successfully
**Build:** ✅ Compiles without errors (warnings: documentation only)
**Tests:** ✅ 26/26 tests passing