git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
8.2 KiB
GNN Layers Implementation Summary
Overview
Complete implementation of Graph Neural Network (GNN) layers for the ruvector-postgres PostgreSQL extension. This module enables efficient graph learning directly on relational data.
Module Structure
src/gnn/
├── mod.rs # Module exports and organization
├── message_passing.rs # Core message passing framework
├── aggregators.rs # Neighbor message aggregation functions
├── gcn.rs # Graph Convolutional Network layer
├── graphsage.rs # GraphSAGE with neighbor sampling
└── operators.rs # PostgreSQL operator functions
Core Components
1. Message Passing Framework (message_passing.rs)
MessagePassing Trait:
message()- Compute messages from neighborsaggregate()- Combine messages from all neighborsupdate()- Update node representations
Key Functions:
build_adjacency_list(edge_index, num_nodes)- Build graph adjacency structurepropagate(node_features, edge_index, layer)- Standard message passingpropagate_weighted(...)- Weighted message passing with edge weights
Features:
- Parallel node processing with Rayon
- Support for disconnected nodes
- Edge weight handling
- Efficient adjacency list representation
2. Aggregation Functions (aggregators.rs)
AggregationMethod Enum:
Sum- Sum all neighbor messagesMean- Average all neighbor messagesMax- Element-wise maximum of messages
Functions:
sum_aggregate(messages)- Sum aggregationmean_aggregate(messages)- Mean aggregationmax_aggregate(messages)- Max aggregationweighted_aggregate(messages, weights, method)- Weighted aggregation
Performance:
- Parallel aggregation using Rayon
- Zero-copy operations where possible
- Efficient memory layout
3. Graph Convolutional Network (gcn.rs)
GCNLayer Structure:
pub struct GCNLayer {
pub in_features: usize,
pub out_features: usize,
pub weights: Vec<Vec<f32>>,
pub bias: Option<Vec<f32>>,
pub normalize: bool,
}
Key Methods:
new(in_features, out_features)- Create layer with Xavier initializationlinear_transform(features)- Apply weight matrixforward(x, edge_index, edge_weights)- Full forward pass with ReLUcompute_norm_factor(degree)- Degree normalization
Features:
- Degree normalization for stable gradients
- Optional bias terms
- ReLU activation
- Edge weight support
4. GraphSAGE Layer (graphsage.rs)
GraphSAGELayer Structure:
pub struct GraphSAGELayer {
pub in_features: usize,
pub out_features: usize,
pub neighbor_weights: Vec<Vec<f32>>,
pub self_weights: Vec<Vec<f32>>,
pub aggregator: SAGEAggregator,
pub num_samples: usize,
pub normalize: bool,
}
SAGEAggregator Types:
Mean- Mean aggregatorMaxPool- Max pooling aggregatorLSTM- LSTM aggregator (simplified)
Key Methods:
sample_neighbors(neighbors, k)- Uniform neighbor samplingforward_with_sampling(x, edge_index, num_samples)- Forward with samplingforward(x, edge_index)- Standard forward pass
Features:
- Neighbor sampling for scalability
- Separate weight matrices for neighbors and self
- L2 normalization of outputs
- Multiple aggregator types
5. PostgreSQL Operators (operators.rs)
SQL Functions:
-
ruvector_gcn_forward(embeddings, src, dst, weights, out_dim)- Apply GCN layer to node embeddings
- Returns: Updated embeddings after GCN
-
ruvector_gnn_aggregate(messages, method)- Aggregate neighbor messages
- Methods: 'sum', 'mean', 'max'
- Returns: Aggregated message vector
-
ruvector_message_pass(node_table, edge_table, embedding_col, hops, layer_type)- Multi-hop message passing
- Layer types: 'gcn', 'sage'
- Returns: Query description
-
ruvector_graphsage_forward(embeddings, src, dst, out_dim, num_samples)- Apply GraphSAGE with neighbor sampling
- Returns: Updated embeddings after GraphSAGE
-
ruvector_gnn_batch_forward(embeddings_batch, edge_indices, graph_sizes, layer_type, out_dim)- Batch processing for multiple graphs
- Supports 'gcn' and 'sage' layers
- Returns: Batch of updated embeddings
Usage Examples
Basic GCN Example
-- Apply GCN forward pass
SELECT ruvector_gcn_forward(
ARRAY[ARRAY[1.0, 2.0], ARRAY[3.0, 4.0], ARRAY[5.0, 6.0]]::FLOAT[][], -- embeddings
ARRAY[0, 1, 2]::INT[], -- source nodes
ARRAY[1, 2, 0]::INT[], -- target nodes
NULL, -- edge weights
8 -- output dimension
);
Aggregation Example
-- Aggregate neighbor messages using mean
SELECT ruvector_gnn_aggregate(
ARRAY[ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]]::FLOAT[][],
'mean'
);
-- Returns: [2.0, 3.0]
GraphSAGE Example
-- Apply GraphSAGE with neighbor sampling
SELECT ruvector_graphsage_forward(
node_embeddings,
edge_sources,
edge_targets,
64, -- output dimension
10 -- sample 10 neighbors per node
)
FROM graph_data;
Performance Characteristics
Parallelization
- Node-level parallelism: All nodes processed in parallel using Rayon
- Aggregation parallelism: Vector operations parallelized
- Batch processing: Multiple graphs processed independently
Memory Efficiency
- Adjacency lists: HashMap-based for sparse graphs
- Zero-copy: Minimal data copying during aggregation
- Streaming: Process nodes without materializing full graph
Scalability
- GraphSAGE sampling: O(k) neighbors instead of O(degree)
- Sparse graphs: Efficient for large, sparse graphs
- Batch support: Process multiple graphs simultaneously
Testing
Unit Tests
All modules include comprehensive #[test] tests:
- Message passing correctness
- Aggregation functions
- Layer forward passes
- Neighbor sampling
- Edge cases (empty graphs, disconnected nodes)
PostgreSQL Tests
Extensive #[pg_test] tests in operators.rs:
- SQL function correctness
- Empty input handling
- Weighted edges
- Batch processing
Test Coverage
- ✅ Message passing framework
- ✅ All aggregation methods
- ✅ GCN layer operations
- ✅ GraphSAGE with sampling
- ✅ PostgreSQL operators
- ✅ Edge cases and error handling
Integration
The GNN module is integrated into the main extension via src/lib.rs:
pub mod gnn;
All operator functions are automatically registered with PostgreSQL via pgrx macros.
Design Decisions
- Trait-Based Architecture: MessagePassing trait enables extensibility
- Parallel-First: Rayon used throughout for parallelism
- Type Safety: Strong typing prevents runtime errors
- PostgreSQL Native: Deep integration with PostgreSQL types
- Testability: Comprehensive test coverage at all levels
Future Enhancements
Potential improvements:
- GPU acceleration via CUDA
- Additional GNN layers (GAT, GIN, etc.)
- Dynamic graph support
- Graph pooling operations
- Mini-batch training support
- Gradient computation for training
Dependencies
pgrx- PostgreSQL extension frameworkrayon- Data parallelismrand- Random neighbor samplingserde_json- JSON serialization (for results)
Files Summary
| File | Lines | Description |
|---|---|---|
mod.rs |
~40 | Module exports and organization |
message_passing.rs |
~250 | Core message passing framework |
aggregators.rs |
~200 | Aggregation functions |
gcn.rs |
~280 | GCN layer implementation |
graphsage.rs |
~330 | GraphSAGE layer with sampling |
operators.rs |
~400 | PostgreSQL operator functions |
| Total | ~1,500 | Complete GNN implementation |
References
- Kipf & Welling (2016) - "Semi-Supervised Classification with Graph Convolutional Networks"
- Hamilton et al. (2017) - "Inductive Representation Learning on Large Graphs"
- PostgreSQL Extension Development Guide
- pgrx Documentation
Implementation Status: ✅ Complete
All components implemented, tested, and integrated into ruvector-postgres extension.