git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
281 lines
8.2 KiB
Markdown
281 lines
8.2 KiB
Markdown
# GNN Layers Implementation Summary
|
|
|
|
## Overview
|
|
|
|
Complete implementation of Graph Neural Network (GNN) layers for the ruvector-postgres PostgreSQL extension. This module enables efficient graph learning directly on relational data.
|
|
|
|
## Module Structure
|
|
|
|
```
|
|
src/gnn/
|
|
├── mod.rs # Module exports and organization
|
|
├── message_passing.rs # Core message passing framework
|
|
├── aggregators.rs # Neighbor message aggregation functions
|
|
├── gcn.rs # Graph Convolutional Network layer
|
|
├── graphsage.rs # GraphSAGE with neighbor sampling
|
|
└── operators.rs # PostgreSQL operator functions
|
|
```
|
|
|
|
## Core Components
|
|
|
|
### 1. Message Passing Framework (`message_passing.rs`)
|
|
|
|
**MessagePassing Trait**:
|
|
- `message()` - Compute messages from neighbors
|
|
- `aggregate()` - Combine messages from all neighbors
|
|
- `update()` - Update node representations
|
|
|
|
**Key Functions**:
|
|
- `build_adjacency_list(edge_index, num_nodes)` - Build graph adjacency structure
|
|
- `propagate(node_features, edge_index, layer)` - Standard message passing
|
|
- `propagate_weighted(...)` - Weighted message passing with edge weights
|
|
|
|
**Features**:
|
|
- Parallel node processing with Rayon
|
|
- Support for disconnected nodes
|
|
- Edge weight handling
|
|
- Efficient adjacency list representation
|
|
|
|
### 2. Aggregation Functions (`aggregators.rs`)
|
|
|
|
**AggregationMethod Enum**:
|
|
- `Sum` - Sum all neighbor messages
|
|
- `Mean` - Average all neighbor messages
|
|
- `Max` - Element-wise maximum of messages
|
|
|
|
**Functions**:
|
|
- `sum_aggregate(messages)` - Sum aggregation
|
|
- `mean_aggregate(messages)` - Mean aggregation
|
|
- `max_aggregate(messages)` - Max aggregation
|
|
- `weighted_aggregate(messages, weights, method)` - Weighted aggregation
|
|
|
|
**Performance**:
|
|
- Parallel aggregation using Rayon
|
|
- Zero-copy operations where possible
|
|
- Efficient memory layout
|
|
|
|
### 3. Graph Convolutional Network (`gcn.rs`)
|
|
|
|
**GCNLayer Structure**:
|
|
```rust
|
|
pub struct GCNLayer {
|
|
pub in_features: usize,
|
|
pub out_features: usize,
|
|
pub weights: Vec<Vec<f32>>,
|
|
pub bias: Option<Vec<f32>>,
|
|
pub normalize: bool,
|
|
}
|
|
```
|
|
|
|
**Key Methods**:
|
|
- `new(in_features, out_features)` - Create layer with Xavier initialization
|
|
- `linear_transform(features)` - Apply weight matrix
|
|
- `forward(x, edge_index, edge_weights)` - Full forward pass with ReLU
|
|
- `compute_norm_factor(degree)` - Degree normalization
|
|
|
|
**Features**:
|
|
- Degree normalization for stable gradients
|
|
- Optional bias terms
|
|
- ReLU activation
|
|
- Edge weight support
|
|
|
|
### 4. GraphSAGE Layer (`graphsage.rs`)
|
|
|
|
**GraphSAGELayer Structure**:
|
|
```rust
|
|
pub struct GraphSAGELayer {
|
|
pub in_features: usize,
|
|
pub out_features: usize,
|
|
pub neighbor_weights: Vec<Vec<f32>>,
|
|
pub self_weights: Vec<Vec<f32>>,
|
|
pub aggregator: SAGEAggregator,
|
|
pub num_samples: usize,
|
|
pub normalize: bool,
|
|
}
|
|
```
|
|
|
|
**SAGEAggregator Types**:
|
|
- `Mean` - Mean aggregator
|
|
- `MaxPool` - Max pooling aggregator
|
|
- `LSTM` - LSTM aggregator (simplified)
|
|
|
|
**Key Methods**:
|
|
- `sample_neighbors(neighbors, k)` - Uniform neighbor sampling
|
|
- `forward_with_sampling(x, edge_index, num_samples)` - Forward with sampling
|
|
- `forward(x, edge_index)` - Standard forward pass
|
|
|
|
**Features**:
|
|
- Neighbor sampling for scalability
|
|
- Separate weight matrices for neighbors and self
|
|
- L2 normalization of outputs
|
|
- Multiple aggregator types
|
|
|
|
### 5. PostgreSQL Operators (`operators.rs`)
|
|
|
|
**SQL Functions**:
|
|
|
|
1. **`ruvector_gcn_forward(embeddings, src, dst, weights, out_dim)`**
|
|
- Apply GCN layer to node embeddings
|
|
- Returns: Updated embeddings after GCN
|
|
|
|
2. **`ruvector_gnn_aggregate(messages, method)`**
|
|
- Aggregate neighbor messages
|
|
- Methods: 'sum', 'mean', 'max'
|
|
- Returns: Aggregated message vector
|
|
|
|
3. **`ruvector_message_pass(node_table, edge_table, embedding_col, hops, layer_type)`**
|
|
- Multi-hop message passing
|
|
- Layer types: 'gcn', 'sage'
|
|
- Returns: Query description
|
|
|
|
4. **`ruvector_graphsage_forward(embeddings, src, dst, out_dim, num_samples)`**
|
|
- Apply GraphSAGE with neighbor sampling
|
|
- Returns: Updated embeddings after GraphSAGE
|
|
|
|
5. **`ruvector_gnn_batch_forward(embeddings_batch, edge_indices, graph_sizes, layer_type, out_dim)`**
|
|
- Batch processing for multiple graphs
|
|
- Supports 'gcn' and 'sage' layers
|
|
- Returns: Batch of updated embeddings
|
|
|
|
## Usage Examples
|
|
|
|
### Basic GCN Example
|
|
|
|
```sql
|
|
-- Apply GCN forward pass
|
|
SELECT ruvector_gcn_forward(
|
|
ARRAY[ARRAY[1.0, 2.0], ARRAY[3.0, 4.0], ARRAY[5.0, 6.0]]::FLOAT[][], -- embeddings
|
|
ARRAY[0, 1, 2]::INT[], -- source nodes
|
|
ARRAY[1, 2, 0]::INT[], -- target nodes
|
|
NULL, -- edge weights
|
|
8 -- output dimension
|
|
);
|
|
```
|
|
|
|
### Aggregation Example
|
|
|
|
```sql
|
|
-- Aggregate neighbor messages using mean
|
|
SELECT ruvector_gnn_aggregate(
|
|
ARRAY[ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]]::FLOAT[][],
|
|
'mean'
|
|
);
|
|
-- Returns: [2.0, 3.0]
|
|
```
|
|
|
|
### GraphSAGE Example
|
|
|
|
```sql
|
|
-- Apply GraphSAGE with neighbor sampling
|
|
SELECT ruvector_graphsage_forward(
|
|
node_embeddings,
|
|
edge_sources,
|
|
edge_targets,
|
|
64, -- output dimension
|
|
10 -- sample 10 neighbors per node
|
|
)
|
|
FROM graph_data;
|
|
```
|
|
|
|
## Performance Characteristics
|
|
|
|
### Parallelization
|
|
- **Node-level parallelism**: All nodes processed in parallel using Rayon
|
|
- **Aggregation parallelism**: Vector operations parallelized
|
|
- **Batch processing**: Multiple graphs processed independently
|
|
|
|
### Memory Efficiency
|
|
- **Adjacency lists**: HashMap-based for sparse graphs
|
|
- **Zero-copy**: Minimal data copying during aggregation
|
|
- **Streaming**: Process nodes without materializing full graph
|
|
|
|
### Scalability
|
|
- **GraphSAGE sampling**: O(k) neighbors instead of O(degree)
|
|
- **Sparse graphs**: Efficient for large, sparse graphs
|
|
- **Batch support**: Process multiple graphs simultaneously
|
|
|
|
## Testing
|
|
|
|
### Unit Tests
|
|
All modules include comprehensive `#[test]` tests:
|
|
- Message passing correctness
|
|
- Aggregation functions
|
|
- Layer forward passes
|
|
- Neighbor sampling
|
|
- Edge cases (empty graphs, disconnected nodes)
|
|
|
|
### PostgreSQL Tests
|
|
Extensive `#[pg_test]` tests in `operators.rs`:
|
|
- SQL function correctness
|
|
- Empty input handling
|
|
- Weighted edges
|
|
- Batch processing
|
|
|
|
### Test Coverage
|
|
- ✅ Message passing framework
|
|
- ✅ All aggregation methods
|
|
- ✅ GCN layer operations
|
|
- ✅ GraphSAGE with sampling
|
|
- ✅ PostgreSQL operators
|
|
- ✅ Edge cases and error handling
|
|
|
|
## Integration
|
|
|
|
The GNN module is integrated into the main extension via `src/lib.rs`:
|
|
|
|
```rust
|
|
pub mod gnn;
|
|
```
|
|
|
|
All operator functions are automatically registered with PostgreSQL via pgrx macros.
|
|
|
|
## Design Decisions
|
|
|
|
1. **Trait-Based Architecture**: MessagePassing trait enables extensibility
|
|
2. **Parallel-First**: Rayon used throughout for parallelism
|
|
3. **Type Safety**: Strong typing prevents runtime errors
|
|
4. **PostgreSQL Native**: Deep integration with PostgreSQL types
|
|
5. **Testability**: Comprehensive test coverage at all levels
|
|
|
|
## Future Enhancements
|
|
|
|
Potential improvements:
|
|
1. GPU acceleration via CUDA
|
|
2. Additional GNN layers (GAT, GIN, etc.)
|
|
3. Dynamic graph support
|
|
4. Graph pooling operations
|
|
5. Mini-batch training support
|
|
6. Gradient computation for training
|
|
|
|
## Dependencies
|
|
|
|
- `pgrx` - PostgreSQL extension framework
|
|
- `rayon` - Data parallelism
|
|
- `rand` - Random neighbor sampling
|
|
- `serde_json` - JSON serialization (for results)
|
|
|
|
## Files Summary
|
|
|
|
| File | Lines | Description |
|
|
|------|-------|-------------|
|
|
| `mod.rs` | ~40 | Module exports and organization |
|
|
| `message_passing.rs` | ~250 | Core message passing framework |
|
|
| `aggregators.rs` | ~200 | Aggregation functions |
|
|
| `gcn.rs` | ~280 | GCN layer implementation |
|
|
| `graphsage.rs` | ~330 | GraphSAGE layer with sampling |
|
|
| `operators.rs` | ~400 | PostgreSQL operator functions |
|
|
| **Total** | **~1,500** | Complete GNN implementation |
|
|
|
|
## References
|
|
|
|
1. Kipf & Welling (2016) - "Semi-Supervised Classification with Graph Convolutional Networks"
|
|
2. Hamilton et al. (2017) - "Inductive Representation Learning on Large Graphs"
|
|
3. PostgreSQL Extension Development Guide
|
|
4. pgrx Documentation
|
|
|
|
---
|
|
|
|
**Implementation Status**: ✅ Complete
|
|
|
|
All components implemented, tested, and integrated into ruvector-postgres extension.
|