Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
280
crates/ruvector-postgres/docs/GNN_IMPLEMENTATION_SUMMARY.md
Normal file
280
crates/ruvector-postgres/docs/GNN_IMPLEMENTATION_SUMMARY.md
Normal file
@@ -0,0 +1,280 @@
|
||||
# GNN Layers Implementation Summary
|
||||
|
||||
## Overview
|
||||
|
||||
Complete implementation of Graph Neural Network (GNN) layers for the ruvector-postgres PostgreSQL extension. This module enables efficient graph learning directly on relational data.
|
||||
|
||||
## Module Structure
|
||||
|
||||
```
|
||||
src/gnn/
|
||||
├── mod.rs # Module exports and organization
|
||||
├── message_passing.rs # Core message passing framework
|
||||
├── aggregators.rs # Neighbor message aggregation functions
|
||||
├── gcn.rs # Graph Convolutional Network layer
|
||||
├── graphsage.rs # GraphSAGE with neighbor sampling
|
||||
└── operators.rs # PostgreSQL operator functions
|
||||
```
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. Message Passing Framework (`message_passing.rs`)
|
||||
|
||||
**MessagePassing Trait**:
|
||||
- `message()` - Compute messages from neighbors
|
||||
- `aggregate()` - Combine messages from all neighbors
|
||||
- `update()` - Update node representations
|
||||
|
||||
**Key Functions**:
|
||||
- `build_adjacency_list(edge_index, num_nodes)` - Build graph adjacency structure
|
||||
- `propagate(node_features, edge_index, layer)` - Standard message passing
|
||||
- `propagate_weighted(...)` - Weighted message passing with edge weights
|
||||
|
||||
**Features**:
|
||||
- Parallel node processing with Rayon
|
||||
- Support for disconnected nodes
|
||||
- Edge weight handling
|
||||
- Efficient adjacency list representation
|
||||
|
||||
### 2. Aggregation Functions (`aggregators.rs`)
|
||||
|
||||
**AggregationMethod Enum**:
|
||||
- `Sum` - Sum all neighbor messages
|
||||
- `Mean` - Average all neighbor messages
|
||||
- `Max` - Element-wise maximum of messages
|
||||
|
||||
**Functions**:
|
||||
- `sum_aggregate(messages)` - Sum aggregation
|
||||
- `mean_aggregate(messages)` - Mean aggregation
|
||||
- `max_aggregate(messages)` - Max aggregation
|
||||
- `weighted_aggregate(messages, weights, method)` - Weighted aggregation
|
||||
|
||||
**Performance**:
|
||||
- Parallel aggregation using Rayon
|
||||
- Zero-copy operations where possible
|
||||
- Efficient memory layout
|
||||
|
||||
### 3. Graph Convolutional Network (`gcn.rs`)
|
||||
|
||||
**GCNLayer Structure**:
|
||||
```rust
|
||||
pub struct GCNLayer {
|
||||
pub in_features: usize,
|
||||
pub out_features: usize,
|
||||
pub weights: Vec<Vec<f32>>,
|
||||
pub bias: Option<Vec<f32>>,
|
||||
pub normalize: bool,
|
||||
}
|
||||
```
|
||||
|
||||
**Key Methods**:
|
||||
- `new(in_features, out_features)` - Create layer with Xavier initialization
|
||||
- `linear_transform(features)` - Apply weight matrix
|
||||
- `forward(x, edge_index, edge_weights)` - Full forward pass with ReLU
|
||||
- `compute_norm_factor(degree)` - Degree normalization
|
||||
|
||||
**Features**:
|
||||
- Degree normalization for stable gradients
|
||||
- Optional bias terms
|
||||
- ReLU activation
|
||||
- Edge weight support
|
||||
|
||||
### 4. GraphSAGE Layer (`graphsage.rs`)
|
||||
|
||||
**GraphSAGELayer Structure**:
|
||||
```rust
|
||||
pub struct GraphSAGELayer {
|
||||
pub in_features: usize,
|
||||
pub out_features: usize,
|
||||
pub neighbor_weights: Vec<Vec<f32>>,
|
||||
pub self_weights: Vec<Vec<f32>>,
|
||||
pub aggregator: SAGEAggregator,
|
||||
pub num_samples: usize,
|
||||
pub normalize: bool,
|
||||
}
|
||||
```
|
||||
|
||||
**SAGEAggregator Types**:
|
||||
- `Mean` - Mean aggregator
|
||||
- `MaxPool` - Max pooling aggregator
|
||||
- `LSTM` - LSTM aggregator (simplified)
|
||||
|
||||
**Key Methods**:
|
||||
- `sample_neighbors(neighbors, k)` - Uniform neighbor sampling
|
||||
- `forward_with_sampling(x, edge_index, num_samples)` - Forward with sampling
|
||||
- `forward(x, edge_index)` - Standard forward pass
|
||||
|
||||
**Features**:
|
||||
- Neighbor sampling for scalability
|
||||
- Separate weight matrices for neighbors and self
|
||||
- L2 normalization of outputs
|
||||
- Multiple aggregator types
|
||||
|
||||
### 5. PostgreSQL Operators (`operators.rs`)
|
||||
|
||||
**SQL Functions**:
|
||||
|
||||
1. **`ruvector_gcn_forward(embeddings, src, dst, weights, out_dim)`**
|
||||
- Apply GCN layer to node embeddings
|
||||
- Returns: Updated embeddings after GCN
|
||||
|
||||
2. **`ruvector_gnn_aggregate(messages, method)`**
|
||||
- Aggregate neighbor messages
|
||||
- Methods: 'sum', 'mean', 'max'
|
||||
- Returns: Aggregated message vector
|
||||
|
||||
3. **`ruvector_message_pass(node_table, edge_table, embedding_col, hops, layer_type)`**
|
||||
- Multi-hop message passing
|
||||
- Layer types: 'gcn', 'sage'
|
||||
- Returns: Query description
|
||||
|
||||
4. **`ruvector_graphsage_forward(embeddings, src, dst, out_dim, num_samples)`**
|
||||
- Apply GraphSAGE with neighbor sampling
|
||||
- Returns: Updated embeddings after GraphSAGE
|
||||
|
||||
5. **`ruvector_gnn_batch_forward(embeddings_batch, edge_indices, graph_sizes, layer_type, out_dim)`**
|
||||
- Batch processing for multiple graphs
|
||||
- Supports 'gcn' and 'sage' layers
|
||||
- Returns: Batch of updated embeddings
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic GCN Example
|
||||
|
||||
```sql
|
||||
-- Apply GCN forward pass
|
||||
SELECT ruvector_gcn_forward(
|
||||
ARRAY[ARRAY[1.0, 2.0], ARRAY[3.0, 4.0], ARRAY[5.0, 6.0]]::FLOAT[][], -- embeddings
|
||||
ARRAY[0, 1, 2]::INT[], -- source nodes
|
||||
ARRAY[1, 2, 0]::INT[], -- target nodes
|
||||
NULL, -- edge weights
|
||||
8 -- output dimension
|
||||
);
|
||||
```
|
||||
|
||||
### Aggregation Example
|
||||
|
||||
```sql
|
||||
-- Aggregate neighbor messages using mean
|
||||
SELECT ruvector_gnn_aggregate(
|
||||
ARRAY[ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]]::FLOAT[][],
|
||||
'mean'
|
||||
);
|
||||
-- Returns: [2.0, 3.0]
|
||||
```
|
||||
|
||||
### GraphSAGE Example
|
||||
|
||||
```sql
|
||||
-- Apply GraphSAGE with neighbor sampling
|
||||
SELECT ruvector_graphsage_forward(
|
||||
node_embeddings,
|
||||
edge_sources,
|
||||
edge_targets,
|
||||
64, -- output dimension
|
||||
10 -- sample 10 neighbors per node
|
||||
)
|
||||
FROM graph_data;
|
||||
```
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
### Parallelization
|
||||
- **Node-level parallelism**: All nodes processed in parallel using Rayon
|
||||
- **Aggregation parallelism**: Vector operations parallelized
|
||||
- **Batch processing**: Multiple graphs processed independently
|
||||
|
||||
### Memory Efficiency
|
||||
- **Adjacency lists**: HashMap-based for sparse graphs
|
||||
- **Zero-copy**: Minimal data copying during aggregation
|
||||
- **Streaming**: Process nodes without materializing full graph
|
||||
|
||||
### Scalability
|
||||
- **GraphSAGE sampling**: O(k) neighbors instead of O(degree)
|
||||
- **Sparse graphs**: Efficient for large, sparse graphs
|
||||
- **Batch support**: Process multiple graphs simultaneously
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Tests
|
||||
All modules include comprehensive `#[test]` tests:
|
||||
- Message passing correctness
|
||||
- Aggregation functions
|
||||
- Layer forward passes
|
||||
- Neighbor sampling
|
||||
- Edge cases (empty graphs, disconnected nodes)
|
||||
|
||||
### PostgreSQL Tests
|
||||
Extensive `#[pg_test]` tests in `operators.rs`:
|
||||
- SQL function correctness
|
||||
- Empty input handling
|
||||
- Weighted edges
|
||||
- Batch processing
|
||||
|
||||
### Test Coverage
|
||||
- ✅ Message passing framework
|
||||
- ✅ All aggregation methods
|
||||
- ✅ GCN layer operations
|
||||
- ✅ GraphSAGE with sampling
|
||||
- ✅ PostgreSQL operators
|
||||
- ✅ Edge cases and error handling
|
||||
|
||||
## Integration
|
||||
|
||||
The GNN module is integrated into the main extension via `src/lib.rs`:
|
||||
|
||||
```rust
|
||||
pub mod gnn;
|
||||
```
|
||||
|
||||
All operator functions are automatically registered with PostgreSQL via pgrx macros.
|
||||
|
||||
## Design Decisions
|
||||
|
||||
1. **Trait-Based Architecture**: MessagePassing trait enables extensibility
|
||||
2. **Parallel-First**: Rayon used throughout for parallelism
|
||||
3. **Type Safety**: Strong typing prevents runtime errors
|
||||
4. **PostgreSQL Native**: Deep integration with PostgreSQL types
|
||||
5. **Testability**: Comprehensive test coverage at all levels
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements:
|
||||
1. GPU acceleration via CUDA
|
||||
2. Additional GNN layers (GAT, GIN, etc.)
|
||||
3. Dynamic graph support
|
||||
4. Graph pooling operations
|
||||
5. Mini-batch training support
|
||||
6. Gradient computation for training
|
||||
|
||||
## Dependencies
|
||||
|
||||
- `pgrx` - PostgreSQL extension framework
|
||||
- `rayon` - Data parallelism
|
||||
- `rand` - Random neighbor sampling
|
||||
- `serde_json` - JSON serialization (for results)
|
||||
|
||||
## Files Summary
|
||||
|
||||
| File | Lines | Description |
|
||||
|------|-------|-------------|
|
||||
| `mod.rs` | ~40 | Module exports and organization |
|
||||
| `message_passing.rs` | ~250 | Core message passing framework |
|
||||
| `aggregators.rs` | ~200 | Aggregation functions |
|
||||
| `gcn.rs` | ~280 | GCN layer implementation |
|
||||
| `graphsage.rs` | ~330 | GraphSAGE layer with sampling |
|
||||
| `operators.rs` | ~400 | PostgreSQL operator functions |
|
||||
| **Total** | **~1,500** | Complete GNN implementation |
|
||||
|
||||
## References
|
||||
|
||||
1. Kipf & Welling (2016) - "Semi-Supervised Classification with Graph Convolutional Networks"
|
||||
2. Hamilton et al. (2017) - "Inductive Representation Learning on Large Graphs"
|
||||
3. PostgreSQL Extension Development Guide
|
||||
4. pgrx Documentation
|
||||
|
||||
---
|
||||
|
||||
**Implementation Status**: ✅ Complete
|
||||
|
||||
All components implemented, tested, and integrated into ruvector-postgres extension.
|
||||
Reference in New Issue
Block a user