Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
198
vendor/ruvector/crates/ruvector-postgres/src/gnn/aggregators.rs
vendored
Normal file
198
vendor/ruvector/crates/ruvector-postgres/src/gnn/aggregators.rs
vendored
Normal file
@@ -0,0 +1,198 @@
|
||||
//! Aggregation functions for combining neighbor messages in GNNs
|
||||
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Aggregation methods for combining neighbor messages
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AggregationMethod {
|
||||
/// Sum all neighbor messages
|
||||
Sum,
|
||||
/// Average all neighbor messages
|
||||
Mean,
|
||||
/// Take maximum of neighbor messages (element-wise)
|
||||
Max,
|
||||
}
|
||||
|
||||
impl AggregationMethod {
|
||||
/// Parse aggregation method from string
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"sum" => Some(AggregationMethod::Sum),
|
||||
"mean" | "avg" => Some(AggregationMethod::Mean),
|
||||
"max" => Some(AggregationMethod::Max),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sum aggregation: sum all neighbor messages
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - Vector of messages from neighbors
|
||||
///
|
||||
/// # Returns
|
||||
/// Sum of all messages
|
||||
pub fn sum_aggregate(messages: Vec<Vec<f32>>) -> Vec<f32> {
|
||||
if messages.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let dim = messages[0].len();
|
||||
let mut result = vec![0.0; dim];
|
||||
|
||||
for message in messages {
|
||||
for (i, &val) in message.iter().enumerate() {
|
||||
result[i] += val;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Mean aggregation: average all neighbor messages
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - Vector of messages from neighbors
|
||||
///
|
||||
/// # Returns
|
||||
/// Mean of all messages
|
||||
pub fn mean_aggregate(messages: Vec<Vec<f32>>) -> Vec<f32> {
|
||||
if messages.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let count = messages.len() as f32;
|
||||
let sum = sum_aggregate(messages);
|
||||
|
||||
sum.into_par_iter().map(|x| x / count).collect()
|
||||
}
|
||||
|
||||
/// Max aggregation: element-wise maximum of all neighbor messages
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - Vector of messages from neighbors
|
||||
///
|
||||
/// # Returns
|
||||
/// Element-wise maximum of all messages
|
||||
pub fn max_aggregate(messages: Vec<Vec<f32>>) -> Vec<f32> {
|
||||
if messages.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let dim = messages[0].len();
|
||||
let mut result = vec![f32::NEG_INFINITY; dim];
|
||||
|
||||
for message in messages {
|
||||
for (i, &val) in message.iter().enumerate() {
|
||||
result[i] = result[i].max(val);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Generic aggregation function that selects the appropriate aggregator
|
||||
pub fn aggregate(messages: Vec<Vec<f32>>, method: AggregationMethod) -> Vec<f32> {
|
||||
match method {
|
||||
AggregationMethod::Sum => sum_aggregate(messages),
|
||||
AggregationMethod::Mean => mean_aggregate(messages),
|
||||
AggregationMethod::Max => max_aggregate(messages),
|
||||
}
|
||||
}
|
||||
|
||||
/// Weighted aggregation - multiply each message by its weight before aggregating
|
||||
pub fn weighted_aggregate(
|
||||
messages: Vec<Vec<f32>>,
|
||||
weights: &[f32],
|
||||
method: AggregationMethod,
|
||||
) -> Vec<f32> {
|
||||
if messages.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// Apply weights to messages
|
||||
let weighted_messages: Vec<Vec<f32>> = messages
|
||||
.into_par_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, msg)| {
|
||||
let weight = if idx < weights.len() {
|
||||
weights[idx]
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
msg.iter().map(|&x| x * weight).collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
aggregate(weighted_messages, method)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sum_aggregate() {
|
||||
let messages = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
|
||||
|
||||
let result = sum_aggregate(messages);
|
||||
|
||||
assert_eq!(result, vec![9.0, 12.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean_aggregate() {
|
||||
let messages = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
|
||||
|
||||
let result = mean_aggregate(messages);
|
||||
|
||||
assert_eq!(result, vec![3.0, 4.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_aggregate() {
|
||||
let messages = vec![vec![1.0, 6.0], vec![5.0, 2.0], vec![3.0, 4.0]];
|
||||
|
||||
let result = max_aggregate(messages);
|
||||
|
||||
assert_eq!(result, vec![5.0, 6.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_messages() {
|
||||
let messages: Vec<Vec<f32>> = vec![];
|
||||
let empty: Vec<f32> = vec![];
|
||||
|
||||
assert_eq!(sum_aggregate(messages.clone()), empty);
|
||||
assert_eq!(mean_aggregate(messages.clone()), empty.clone());
|
||||
assert_eq!(max_aggregate(messages), empty);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weighted_aggregate() {
|
||||
let messages = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
|
||||
let weights = vec![2.0, 0.5];
|
||||
|
||||
let result = weighted_aggregate(messages, &weights, AggregationMethod::Sum);
|
||||
|
||||
// [1*2, 2*2] + [3*0.5, 4*0.5] = [2, 4] + [1.5, 2] = [3.5, 6]
|
||||
assert_eq!(result, vec![3.5, 6.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregation_method_from_str() {
|
||||
assert_eq!(
|
||||
AggregationMethod::from_str("sum"),
|
||||
Some(AggregationMethod::Sum)
|
||||
);
|
||||
assert_eq!(
|
||||
AggregationMethod::from_str("mean"),
|
||||
Some(AggregationMethod::Mean)
|
||||
);
|
||||
assert_eq!(
|
||||
AggregationMethod::from_str("max"),
|
||||
Some(AggregationMethod::Max)
|
||||
);
|
||||
assert_eq!(AggregationMethod::from_str("invalid"), None);
|
||||
}
|
||||
}
|
||||
223
vendor/ruvector/crates/ruvector-postgres/src/gnn/gcn.rs
vendored
Normal file
223
vendor/ruvector/crates/ruvector-postgres/src/gnn/gcn.rs
vendored
Normal file
@@ -0,0 +1,223 @@
|
||||
//! Graph Convolutional Network (GCN) layer implementation
|
||||
//!
|
||||
//! Based on "Semi-Supervised Classification with Graph Convolutional Networks"
|
||||
//! by Kipf & Welling (2016)
|
||||
|
||||
use super::aggregators::sum_aggregate;
|
||||
use super::message_passing::MessagePassing;
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Graph Convolutional Network layer
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GCNLayer {
|
||||
/// Input feature dimension
|
||||
pub in_features: usize,
|
||||
/// Output feature dimension
|
||||
pub out_features: usize,
|
||||
/// Weight matrix [in_features x out_features]
|
||||
pub weights: Vec<Vec<f32>>,
|
||||
/// Bias term
|
||||
pub bias: Option<Vec<f32>>,
|
||||
/// Whether to normalize by degree
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
impl GCNLayer {
|
||||
/// Create a new GCN layer with random weights
|
||||
pub fn new(in_features: usize, out_features: usize) -> Self {
|
||||
Self::new_with_normalize(in_features, out_features, true)
|
||||
}
|
||||
|
||||
/// Create a new GCN layer with normalization option
|
||||
pub fn new_with_normalize(in_features: usize, out_features: usize, normalize: bool) -> Self {
|
||||
// Initialize weights with Xavier/Glorot initialization
|
||||
let scale = (2.0 / (in_features + out_features) as f32).sqrt();
|
||||
let weights = (0..in_features)
|
||||
.map(|i| {
|
||||
(0..out_features)
|
||||
.map(|j| {
|
||||
// Simple deterministic initialization for testing
|
||||
let val = ((i * out_features + j) as f32 * 0.01) % 1.0;
|
||||
(val - 0.5) * scale
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
in_features,
|
||||
out_features,
|
||||
weights,
|
||||
bias: Some(vec![0.0; out_features]),
|
||||
normalize,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create GCN layer with provided weights
|
||||
pub fn with_weights(in_features: usize, out_features: usize, weights: Vec<Vec<f32>>) -> Self {
|
||||
assert_eq!(weights.len(), in_features);
|
||||
assert_eq!(weights[0].len(), out_features);
|
||||
|
||||
Self {
|
||||
in_features,
|
||||
out_features,
|
||||
weights,
|
||||
bias: Some(vec![0.0; out_features]),
|
||||
normalize: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply linear transformation: features @ weights
|
||||
pub fn linear_transform(&self, features: &[f32]) -> Vec<f32> {
|
||||
assert_eq!(features.len(), self.in_features);
|
||||
|
||||
let mut result = vec![0.0; self.out_features];
|
||||
|
||||
// Matrix multiplication: features @ weights
|
||||
for (i, &feature_val) in features.iter().enumerate() {
|
||||
for (j, &weight_val) in self.weights[i].iter().enumerate() {
|
||||
result[j] += feature_val * weight_val;
|
||||
}
|
||||
}
|
||||
|
||||
// Add bias if present
|
||||
if let Some(ref bias) = self.bias {
|
||||
for (i, &b) in bias.iter().enumerate() {
|
||||
result[i] += b;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Forward pass with edge index and optional edge weights
|
||||
pub fn forward(
|
||||
&self,
|
||||
node_features: &[Vec<f32>],
|
||||
edge_index: &[(usize, usize)],
|
||||
edge_weights: Option<&[f32]>,
|
||||
) -> Vec<Vec<f32>> {
|
||||
use super::message_passing::{propagate, propagate_weighted};
|
||||
|
||||
// Apply message passing
|
||||
let result = if let Some(weights) = edge_weights {
|
||||
propagate_weighted(node_features, edge_index, weights, self)
|
||||
} else {
|
||||
propagate(node_features, edge_index, self)
|
||||
};
|
||||
|
||||
// Apply ReLU activation
|
||||
result
|
||||
.into_par_iter()
|
||||
.map(|features| features.iter().map(|&x| x.max(0.0)).collect())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute degree normalization factor for a node
|
||||
fn compute_norm_factor(&self, degree: usize) -> f32 {
|
||||
if self.normalize && degree > 0 {
|
||||
1.0 / (degree as f32).sqrt()
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MessagePassing for GCNLayer {
|
||||
fn message(&self, source_features: &[f32], edge_weight: Option<f32>) -> Vec<f32> {
|
||||
let weight = edge_weight.unwrap_or(1.0);
|
||||
source_features.iter().map(|&x| x * weight).collect()
|
||||
}
|
||||
|
||||
fn aggregate(&self, messages: Vec<Vec<f32>>) -> Vec<f32> {
|
||||
let degree = messages.len();
|
||||
let mut aggregated = sum_aggregate(messages);
|
||||
|
||||
// Apply degree normalization
|
||||
if self.normalize && degree > 0 {
|
||||
let norm = self.compute_norm_factor(degree);
|
||||
aggregated.iter_mut().for_each(|x| *x *= norm);
|
||||
}
|
||||
|
||||
aggregated
|
||||
}
|
||||
|
||||
fn update(&self, _node_features: &[f32], aggregated: &[f32]) -> Vec<f32> {
|
||||
// Apply linear transformation to aggregated features
|
||||
self.linear_transform(aggregated)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_gcn_layer_creation() {
|
||||
let layer = GCNLayer::new(16, 32);
|
||||
assert_eq!(layer.in_features, 16);
|
||||
assert_eq!(layer.out_features, 32);
|
||||
assert_eq!(layer.weights.len(), 16);
|
||||
assert_eq!(layer.weights[0].len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_transform() {
|
||||
let weights = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
|
||||
let layer = GCNLayer::with_weights(2, 2, weights);
|
||||
|
||||
let features = vec![1.0, 2.0];
|
||||
let result = layer.linear_transform(&features);
|
||||
|
||||
// [1, 2] @ [[1, 2], [3, 4]] = [1*1 + 2*3, 1*2 + 2*4] = [7, 10]
|
||||
assert_eq!(result, vec![7.0, 10.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gcn_forward() {
|
||||
let weights = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
|
||||
let layer = GCNLayer::with_weights(2, 2, weights);
|
||||
|
||||
let node_features = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
|
||||
|
||||
let edge_index = vec![(0, 1), (1, 2), (2, 0)];
|
||||
|
||||
let result = layer.forward(&node_features, &edge_index, None);
|
||||
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result[0].len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_passing() {
|
||||
let layer = GCNLayer::new(2, 2);
|
||||
|
||||
let features = vec![1.0, 2.0];
|
||||
let message = layer.message(&features, Some(2.0));
|
||||
|
||||
assert_eq!(message, vec![2.0, 4.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregation() {
|
||||
let layer = GCNLayer::new_with_normalize(2, 2, false);
|
||||
|
||||
let messages = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
|
||||
let result = layer.aggregate(messages);
|
||||
|
||||
assert_eq!(result, vec![4.0, 6.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalization() {
|
||||
let layer = GCNLayer::new_with_normalize(2, 2, true);
|
||||
|
||||
let messages = vec![vec![4.0, 6.0], vec![0.0, 0.0]];
|
||||
let result = layer.aggregate(messages);
|
||||
|
||||
// Degree = 2, norm = 1/sqrt(2) ≈ 0.707
|
||||
let expected_norm = 1.0 / (2.0_f32).sqrt();
|
||||
assert!((result[0] - 4.0 * expected_norm).abs() < 1e-5);
|
||||
assert!((result[1] - 6.0 * expected_norm).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
295
vendor/ruvector/crates/ruvector-postgres/src/gnn/graphsage.rs
vendored
Normal file
295
vendor/ruvector/crates/ruvector-postgres/src/gnn/graphsage.rs
vendored
Normal file
@@ -0,0 +1,295 @@
|
||||
//! GraphSAGE layer implementation with neighbor sampling
|
||||
//!
|
||||
//! Based on "Inductive Representation Learning on Large Graphs"
|
||||
//! by Hamilton et al. (2017)
|
||||
|
||||
use super::aggregators::mean_aggregate;
|
||||
use super::message_passing::MessagePassing;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::SeedableRng;
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// GraphSAGE aggregation types
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SAGEAggregator {
|
||||
/// Mean aggregator
|
||||
Mean,
|
||||
/// Max pooling aggregator
|
||||
MaxPool,
|
||||
/// LSTM aggregator
|
||||
LSTM,
|
||||
}
|
||||
|
||||
/// GraphSAGE layer with neighbor sampling
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GraphSAGELayer {
|
||||
/// Input feature dimension
|
||||
pub in_features: usize,
|
||||
/// Output feature dimension
|
||||
pub out_features: usize,
|
||||
/// Weight matrix for neighbor features
|
||||
pub neighbor_weights: Vec<Vec<f32>>,
|
||||
/// Weight matrix for self features
|
||||
pub self_weights: Vec<Vec<f32>>,
|
||||
/// Aggregator type
|
||||
pub aggregator: SAGEAggregator,
|
||||
/// Number of neighbors to sample
|
||||
pub num_samples: usize,
|
||||
/// Whether to normalize output
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
impl GraphSAGELayer {
|
||||
/// Create a new GraphSAGE layer
|
||||
pub fn new(in_features: usize, out_features: usize, num_samples: usize) -> Self {
|
||||
Self::with_aggregator(in_features, out_features, num_samples, SAGEAggregator::Mean)
|
||||
}
|
||||
|
||||
/// Create GraphSAGE layer with specific aggregator
|
||||
pub fn with_aggregator(
|
||||
in_features: usize,
|
||||
out_features: usize,
|
||||
num_samples: usize,
|
||||
aggregator: SAGEAggregator,
|
||||
) -> Self {
|
||||
// Initialize weights
|
||||
let scale = (2.0 / (in_features + out_features) as f32).sqrt();
|
||||
|
||||
let neighbor_weights = (0..in_features)
|
||||
.map(|i| {
|
||||
(0..out_features)
|
||||
.map(|j| {
|
||||
let val = ((i * out_features + j) as f32 * 0.01) % 1.0;
|
||||
(val - 0.5) * scale
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let self_weights = (0..in_features)
|
||||
.map(|i| {
|
||||
(0..out_features)
|
||||
.map(|j| {
|
||||
let val = ((i * out_features + j + 1000) as f32 * 0.01) % 1.0;
|
||||
(val - 0.5) * scale
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
in_features,
|
||||
out_features,
|
||||
neighbor_weights,
|
||||
self_weights,
|
||||
aggregator,
|
||||
num_samples,
|
||||
normalize: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sample k neighbors uniformly at random
|
||||
pub fn sample_neighbors(&self, neighbors: &[usize], k: usize) -> Vec<usize> {
|
||||
if neighbors.len() <= k {
|
||||
return neighbors.to_vec();
|
||||
}
|
||||
|
||||
// Use deterministic sampling for reproducibility in tests
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
|
||||
let mut sampled = neighbors.to_vec();
|
||||
sampled.partial_shuffle(&mut rng, k);
|
||||
sampled[..k].to_vec()
|
||||
}
|
||||
|
||||
/// Apply linear transformation
|
||||
fn linear_transform(&self, features: &[f32], weights: &[Vec<f32>]) -> Vec<f32> {
|
||||
let mut result = vec![0.0; self.out_features];
|
||||
|
||||
for (i, &feature_val) in features.iter().enumerate() {
|
||||
for (j, &weight_val) in weights[i].iter().enumerate() {
|
||||
result[j] += feature_val * weight_val;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Forward pass with neighbor sampling
|
||||
pub fn forward_with_sampling(
|
||||
&self,
|
||||
node_features: &[Vec<f32>],
|
||||
edge_index: &[(usize, usize)],
|
||||
num_samples: Option<usize>,
|
||||
) -> Vec<Vec<f32>> {
|
||||
use super::message_passing::build_adjacency_list;
|
||||
|
||||
let num_nodes = node_features.len();
|
||||
let k = num_samples.unwrap_or(self.num_samples);
|
||||
let adj_list = build_adjacency_list(edge_index, num_nodes);
|
||||
|
||||
(0..num_nodes)
|
||||
.into_par_iter()
|
||||
.map(|node_id| {
|
||||
let neighbors = adj_list.get(&node_id).unwrap();
|
||||
|
||||
// Sample neighbors
|
||||
let sampled = self.sample_neighbors(neighbors, k);
|
||||
|
||||
// Collect neighbor features
|
||||
let neighbor_features: Vec<Vec<f32>> = sampled
|
||||
.iter()
|
||||
.filter_map(|&neighbor_id| {
|
||||
if neighbor_id < num_nodes {
|
||||
Some(node_features[neighbor_id].clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Aggregate neighbor features
|
||||
let aggregated = if neighbor_features.is_empty() {
|
||||
vec![0.0; self.in_features]
|
||||
} else {
|
||||
match self.aggregator {
|
||||
SAGEAggregator::Mean => mean_aggregate(neighbor_features),
|
||||
SAGEAggregator::MaxPool => {
|
||||
super::aggregators::max_aggregate(neighbor_features)
|
||||
}
|
||||
SAGEAggregator::LSTM => mean_aggregate(neighbor_features), // Simplified
|
||||
}
|
||||
};
|
||||
|
||||
// Transform neighbor aggregation
|
||||
let neighbor_h = self.linear_transform(&aggregated, &self.neighbor_weights);
|
||||
|
||||
// Transform self features
|
||||
let self_h = self.linear_transform(&node_features[node_id], &self.self_weights);
|
||||
|
||||
// Concatenate and apply activation
|
||||
let mut combined: Vec<f32> = neighbor_h
|
||||
.iter()
|
||||
.zip(self_h.iter())
|
||||
.map(|(&n, &s)| (n + s).max(0.0))
|
||||
.collect();
|
||||
|
||||
// L2 normalization if enabled
|
||||
if self.normalize {
|
||||
let norm: f32 = combined.iter().map(|&x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
combined.iter_mut().for_each(|x| *x /= norm);
|
||||
}
|
||||
}
|
||||
|
||||
combined
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Standard forward pass (uses default num_samples)
|
||||
pub fn forward(
|
||||
&self,
|
||||
node_features: &[Vec<f32>],
|
||||
edge_index: &[(usize, usize)],
|
||||
) -> Vec<Vec<f32>> {
|
||||
self.forward_with_sampling(node_features, edge_index, None)
|
||||
}
|
||||
}
|
||||
|
||||
impl MessagePassing for GraphSAGELayer {
|
||||
fn message(&self, source_features: &[f32], _edge_weight: Option<f32>) -> Vec<f32> {
|
||||
source_features.to_vec()
|
||||
}
|
||||
|
||||
fn aggregate(&self, messages: Vec<Vec<f32>>) -> Vec<f32> {
|
||||
match self.aggregator {
|
||||
SAGEAggregator::Mean => mean_aggregate(messages),
|
||||
SAGEAggregator::MaxPool => super::aggregators::max_aggregate(messages),
|
||||
SAGEAggregator::LSTM => mean_aggregate(messages),
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&self, node_features: &[f32], aggregated: &[f32]) -> Vec<f32> {
|
||||
let neighbor_h = self.linear_transform(aggregated, &self.neighbor_weights);
|
||||
let self_h = self.linear_transform(node_features, &self.self_weights);
|
||||
|
||||
neighbor_h
|
||||
.iter()
|
||||
.zip(self_h.iter())
|
||||
.map(|(&n, &s)| (n + s).max(0.0))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_graphsage_creation() {
|
||||
let layer = GraphSAGELayer::new(16, 32, 10);
|
||||
assert_eq!(layer.in_features, 16);
|
||||
assert_eq!(layer.out_features, 32);
|
||||
assert_eq!(layer.num_samples, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_neighbors() {
|
||||
let layer = GraphSAGELayer::new(4, 8, 3);
|
||||
|
||||
let neighbors = vec![0, 1, 2, 3, 4, 5];
|
||||
let sampled = layer.sample_neighbors(&neighbors, 3);
|
||||
|
||||
assert_eq!(sampled.len(), 3);
|
||||
|
||||
// Test with fewer neighbors than k
|
||||
let few_neighbors = vec![0, 1];
|
||||
let sampled_few = layer.sample_neighbors(&few_neighbors, 5);
|
||||
assert_eq!(sampled_few.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graphsage_forward() {
|
||||
let layer = GraphSAGELayer::new(2, 2, 2);
|
||||
|
||||
let node_features = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
|
||||
|
||||
let edge_index = vec![(0, 1), (1, 2), (2, 0)];
|
||||
|
||||
let result = layer.forward(&node_features, &edge_index);
|
||||
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result[0].len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_aggregators() {
|
||||
let mean_layer = GraphSAGELayer::with_aggregator(2, 2, 2, SAGEAggregator::Mean);
|
||||
let max_layer = GraphSAGELayer::with_aggregator(2, 2, 2, SAGEAggregator::MaxPool);
|
||||
|
||||
let node_features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
|
||||
let edge_index = vec![(0, 1)];
|
||||
|
||||
let mean_result = mean_layer.forward(&node_features, &edge_index);
|
||||
let max_result = max_layer.forward(&node_features, &edge_index);
|
||||
|
||||
assert_eq!(mean_result.len(), 2);
|
||||
assert_eq!(max_result.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalization() {
|
||||
let layer = GraphSAGELayer::new(2, 2, 2);
|
||||
|
||||
let node_features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
|
||||
let edge_index = vec![(0, 1)];
|
||||
|
||||
let result = layer.forward(&node_features, &edge_index);
|
||||
|
||||
// Check L2 normalization
|
||||
for features in result {
|
||||
let norm: f32 = features.iter().map(|&x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 1e-5 || norm == 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
233
vendor/ruvector/crates/ruvector-postgres/src/gnn/message_passing.rs
vendored
Normal file
233
vendor/ruvector/crates/ruvector-postgres/src/gnn/message_passing.rs
vendored
Normal file
@@ -0,0 +1,233 @@
|
||||
//! Core message passing framework for Graph Neural Networks
|
||||
//!
|
||||
//! This module implements the fundamental message passing paradigm used in GNNs:
|
||||
//! 1. Message: Compute messages from neighbors
|
||||
//! 2. Aggregate: Combine messages from all neighbors
|
||||
//! 3. Update: Update node representations
|
||||
|
||||
use rayon::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Adjacency list representation of a graph
|
||||
pub type AdjacencyList = HashMap<usize, Vec<usize>>;
|
||||
|
||||
/// Message passing trait for GNN layers
|
||||
pub trait MessagePassing {
|
||||
/// Compute message from source node to target node
|
||||
fn message(&self, source_features: &[f32], edge_weight: Option<f32>) -> Vec<f32>;
|
||||
|
||||
/// Aggregate messages from all neighbors
|
||||
fn aggregate(&self, messages: Vec<Vec<f32>>) -> Vec<f32>;
|
||||
|
||||
/// Update node features based on aggregated messages
|
||||
fn update(&self, node_features: &[f32], aggregated: &[f32]) -> Vec<f32>;
|
||||
}
|
||||
|
||||
/// Build adjacency list from edge index
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `edge_index` - Array of (source, target) edges
|
||||
/// * `num_nodes` - Total number of nodes in the graph
|
||||
///
|
||||
/// # Returns
|
||||
/// HashMap mapping each node to its list of neighbors
|
||||
pub fn build_adjacency_list(edge_index: &[(usize, usize)], num_nodes: usize) -> AdjacencyList {
|
||||
let mut adj_list: AdjacencyList = HashMap::with_capacity(num_nodes);
|
||||
|
||||
// Initialize all nodes
|
||||
for i in 0..num_nodes {
|
||||
adj_list.insert(i, Vec::new());
|
||||
}
|
||||
|
||||
// Build adjacency list
|
||||
for &(src, dst) in edge_index {
|
||||
if src < num_nodes && dst < num_nodes {
|
||||
adj_list.get_mut(&dst).unwrap().push(src);
|
||||
}
|
||||
}
|
||||
|
||||
adj_list
|
||||
}
|
||||
|
||||
/// Propagate features through the graph using message passing
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_features` - Features for each node [num_nodes x feature_dim]
|
||||
/// * `edge_index` - Array of (source, target) edges
|
||||
/// * `layer` - GNN layer implementing MessagePassing trait
|
||||
///
|
||||
/// # Returns
|
||||
/// Updated node features after message passing
|
||||
pub fn propagate<L: MessagePassing + Sync>(
|
||||
node_features: &[Vec<f32>],
|
||||
edge_index: &[(usize, usize)],
|
||||
layer: &L,
|
||||
) -> Vec<Vec<f32>> {
|
||||
let num_nodes = node_features.len();
|
||||
let adj_list = build_adjacency_list(edge_index, num_nodes);
|
||||
|
||||
// Parallel processing of nodes
|
||||
(0..num_nodes)
|
||||
.into_par_iter()
|
||||
.map(|node_id| {
|
||||
let neighbors = adj_list.get(&node_id).unwrap();
|
||||
|
||||
if neighbors.is_empty() {
|
||||
// Disconnected node - return original features
|
||||
return node_features[node_id].clone();
|
||||
}
|
||||
|
||||
// Collect messages from neighbors
|
||||
let messages: Vec<Vec<f32>> = neighbors
|
||||
.iter()
|
||||
.filter_map(|&neighbor_id| {
|
||||
if neighbor_id < num_nodes {
|
||||
Some(layer.message(&node_features[neighbor_id], None))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if messages.is_empty() {
|
||||
return node_features[node_id].clone();
|
||||
}
|
||||
|
||||
// Aggregate messages
|
||||
let aggregated = layer.aggregate(messages);
|
||||
|
||||
// Update node features
|
||||
layer.update(&node_features[node_id], &aggregated)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Propagate features with edge weights
|
||||
pub fn propagate_weighted<L: MessagePassing + Sync>(
|
||||
node_features: &[Vec<f32>],
|
||||
edge_index: &[(usize, usize)],
|
||||
edge_weights: &[f32],
|
||||
layer: &L,
|
||||
) -> Vec<Vec<f32>> {
|
||||
let num_nodes = node_features.len();
|
||||
|
||||
// Build weighted adjacency list
|
||||
let mut adj_list: HashMap<usize, Vec<(usize, f32)>> = HashMap::with_capacity(num_nodes);
|
||||
for i in 0..num_nodes {
|
||||
adj_list.insert(i, Vec::new());
|
||||
}
|
||||
|
||||
for (idx, &(src, dst)) in edge_index.iter().enumerate() {
|
||||
if src < num_nodes && dst < num_nodes {
|
||||
let weight = if idx < edge_weights.len() {
|
||||
edge_weights[idx]
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
adj_list.get_mut(&dst).unwrap().push((src, weight));
|
||||
}
|
||||
}
|
||||
|
||||
// Parallel processing of nodes
|
||||
(0..num_nodes)
|
||||
.into_par_iter()
|
||||
.map(|node_id| {
|
||||
let neighbors = adj_list.get(&node_id).unwrap();
|
||||
|
||||
if neighbors.is_empty() {
|
||||
return node_features[node_id].clone();
|
||||
}
|
||||
|
||||
// Collect weighted messages from neighbors
|
||||
let messages: Vec<Vec<f32>> = neighbors
|
||||
.iter()
|
||||
.filter_map(|&(neighbor_id, weight)| {
|
||||
if neighbor_id < num_nodes {
|
||||
Some(layer.message(&node_features[neighbor_id], Some(weight)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if messages.is_empty() {
|
||||
return node_features[node_id].clone();
|
||||
}
|
||||
|
||||
// Aggregate and update
|
||||
let aggregated = layer.aggregate(messages);
|
||||
layer.update(&node_features[node_id], &aggregated)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
struct SimpleLayer;
|
||||
|
||||
impl MessagePassing for SimpleLayer {
|
||||
fn message(&self, source_features: &[f32], edge_weight: Option<f32>) -> Vec<f32> {
|
||||
let weight = edge_weight.unwrap_or(1.0);
|
||||
source_features.iter().map(|&x| x * weight).collect()
|
||||
}
|
||||
|
||||
fn aggregate(&self, messages: Vec<Vec<f32>>) -> Vec<f32> {
|
||||
if messages.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
let dim = messages[0].len();
|
||||
let mut result = vec![0.0; dim];
|
||||
for msg in messages {
|
||||
for (i, &val) in msg.iter().enumerate() {
|
||||
result[i] += val;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn update(&self, node_features: &[f32], aggregated: &[f32]) -> Vec<f32> {
|
||||
node_features
|
||||
.iter()
|
||||
.zip(aggregated.iter())
|
||||
.map(|(&x, &y)| x + y)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_adjacency_list() {
|
||||
let edges = vec![(0, 1), (1, 2), (2, 0)];
|
||||
let adj_list = build_adjacency_list(&edges, 3);
|
||||
|
||||
assert_eq!(adj_list.get(&0).unwrap(), &vec![2]);
|
||||
assert_eq!(adj_list.get(&1).unwrap(), &vec![0]);
|
||||
assert_eq!(adj_list.get(&2).unwrap(), &vec![1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_propagate() {
|
||||
let node_features = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
|
||||
|
||||
let edge_index = vec![(0, 1), (1, 2)];
|
||||
|
||||
let layer = SimpleLayer;
|
||||
let result = propagate(&node_features, &edge_index, &layer);
|
||||
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result[0].len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_disconnected_nodes() {
|
||||
let node_features = vec![vec![1.0], vec![2.0], vec![3.0]];
|
||||
let edge_index = vec![(0, 1)]; // Node 2 is disconnected
|
||||
|
||||
let layer = SimpleLayer;
|
||||
let result = propagate(&node_features, &edge_index, &layer);
|
||||
|
||||
// Disconnected node should retain original features
|
||||
assert_eq!(result[2], vec![3.0]);
|
||||
}
|
||||
}
|
||||
125
vendor/ruvector/crates/ruvector-postgres/src/gnn/mod.rs
vendored
Normal file
125
vendor/ruvector/crates/ruvector-postgres/src/gnn/mod.rs
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
//! # Graph Neural Network Module
|
||||
//!
|
||||
//! Provides GNN-based embeddings and graph-aware vector operations.
|
||||
|
||||
// GNN sub-modules
|
||||
pub mod aggregators;
|
||||
pub mod gcn;
|
||||
pub mod graphsage;
|
||||
pub mod message_passing;
|
||||
pub mod operators;
|
||||
|
||||
// Re-export operator functions for PostgreSQL
|
||||
pub use operators::*;
|
||||
|
||||
use pgrx::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// GNN model configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnConfig {
|
||||
pub num_layers: usize,
|
||||
pub hidden_dim: usize,
|
||||
pub dropout: f32,
|
||||
pub aggregation: String,
|
||||
}
|
||||
|
||||
impl Default for GnnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 2,
|
||||
hidden_dim: 128,
|
||||
dropout: 0.1,
|
||||
aggregation: "mean".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// GNN training status
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnTrainingStatus {
|
||||
pub epoch: usize,
|
||||
pub total_epochs: usize,
|
||||
pub loss: f64,
|
||||
pub accuracy: f64,
|
||||
pub completed: bool,
|
||||
}
|
||||
|
||||
/// GNN model state
|
||||
pub struct GnnModel {
|
||||
config: GnnConfig,
|
||||
trained: bool,
|
||||
}
|
||||
|
||||
impl GnnModel {
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(GnnConfig::default())
|
||||
}
|
||||
|
||||
pub fn with_config(config: GnnConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
trained: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_trained(&self) -> bool {
|
||||
self.trained
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &GnnConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn forward(&self, node_features: &[f32], _adjacency: &[(usize, usize)]) -> Vec<f32> {
|
||||
node_features.to_vec()
|
||||
}
|
||||
|
||||
pub fn train(
|
||||
&mut self,
|
||||
_node_features: &[Vec<f32>],
|
||||
_adjacency: &[(usize, usize)],
|
||||
_epochs: usize,
|
||||
) -> GnnTrainingStatus {
|
||||
self.trained = true;
|
||||
GnnTrainingStatus {
|
||||
epoch: 1,
|
||||
total_epochs: 1,
|
||||
loss: 0.0,
|
||||
accuracy: 1.0,
|
||||
completed: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GnnModel {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[pg_extern]
|
||||
fn ruvector_gnn_status() -> pgrx::JsonB {
|
||||
pgrx::JsonB(serde_json::json!({
|
||||
"enabled": true,
|
||||
"model_loaded": false,
|
||||
"version": "1.0.0"
|
||||
}))
|
||||
}
|
||||
|
||||
#[pg_extern]
|
||||
fn ruvector_gnn_default_config() -> pgrx::JsonB {
|
||||
pgrx::JsonB(serde_json::json!(GnnConfig::default()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "pg_test")]
|
||||
#[pg_schema]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[pg_test]
|
||||
fn test_gnn_status() {
|
||||
let status = ruvector_gnn_status();
|
||||
assert!(status.0.get("enabled").is_some());
|
||||
}
|
||||
}
|
||||
425
vendor/ruvector/crates/ruvector-postgres/src/gnn/operators.rs
vendored
Normal file
425
vendor/ruvector/crates/ruvector-postgres/src/gnn/operators.rs
vendored
Normal file
@@ -0,0 +1,425 @@
|
||||
//! PostgreSQL operator functions for GNN operations
|
||||
|
||||
use super::aggregators::{aggregate, AggregationMethod};
|
||||
use super::gcn::GCNLayer;
|
||||
use super::graphsage::GraphSAGELayer;
|
||||
use pgrx::prelude::*;
|
||||
use pgrx::JsonB;
|
||||
|
||||
/// Apply GCN forward pass on embeddings
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embeddings_json` - Node embeddings as JSON array [num_nodes x in_features]
|
||||
/// * `src` - Source node indices
|
||||
/// * `dst` - Destination node indices
|
||||
/// * `weights` - Edge weights (optional)
|
||||
/// * `out_dim` - Output dimension
|
||||
///
|
||||
/// # Returns
|
||||
/// Updated node embeddings after GCN layer as JSON
|
||||
#[pg_extern(immutable, parallel_safe)]
|
||||
pub fn ruvector_gcn_forward(
|
||||
embeddings_json: JsonB,
|
||||
src: Vec<i32>,
|
||||
dst: Vec<i32>,
|
||||
weights: Option<Vec<f32>>,
|
||||
out_dim: i32,
|
||||
) -> JsonB {
|
||||
// Parse embeddings from JSON
|
||||
let embeddings: Vec<Vec<f32>> = match embeddings_json.0.as_array() {
|
||||
Some(arr) => arr
|
||||
.iter()
|
||||
.filter_map(|v| {
|
||||
v.as_array().map(|a| {
|
||||
a.iter()
|
||||
.filter_map(|x| x.as_f64().map(|f| f as f32))
|
||||
.collect()
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
None => return JsonB(serde_json::json!([])),
|
||||
};
|
||||
|
||||
if embeddings.is_empty() {
|
||||
return JsonB(serde_json::json!([]));
|
||||
}
|
||||
|
||||
let in_features = embeddings[0].len();
|
||||
let out_features = out_dim as usize;
|
||||
|
||||
// Build edge index
|
||||
let edge_index: Vec<(usize, usize)> = src
|
||||
.iter()
|
||||
.zip(dst.iter())
|
||||
.map(|(&s, &d)| (s as usize, d as usize))
|
||||
.collect();
|
||||
|
||||
// Create GCN layer
|
||||
let layer = GCNLayer::new(in_features, out_features);
|
||||
|
||||
// Forward pass
|
||||
let result = layer.forward(&embeddings, &edge_index, weights.as_deref());
|
||||
|
||||
JsonB(serde_json::json!(result))
|
||||
}
|
||||
|
||||
/// Aggregate neighbor messages using specified method
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages_json` - Vector of neighbor messages as JSON array
|
||||
/// * `method` - Aggregation method: 'sum', 'mean', or 'max'
|
||||
///
|
||||
/// # Returns
|
||||
/// Aggregated message vector
|
||||
#[pg_extern(immutable, parallel_safe)]
|
||||
pub fn ruvector_gnn_aggregate(messages_json: JsonB, method: String) -> Vec<f32> {
|
||||
// Parse messages from JSON
|
||||
let messages: Vec<Vec<f32>> = match messages_json.0.as_array() {
|
||||
Some(arr) => arr
|
||||
.iter()
|
||||
.filter_map(|v| {
|
||||
v.as_array().map(|a| {
|
||||
a.iter()
|
||||
.filter_map(|x| x.as_f64().map(|f| f as f32))
|
||||
.collect()
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
None => return vec![],
|
||||
};
|
||||
|
||||
if messages.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let agg_method = AggregationMethod::from_str(&method).unwrap_or(AggregationMethod::Mean);
|
||||
|
||||
aggregate(messages, agg_method)
|
||||
}
|
||||
|
||||
/// Multi-hop message passing over graph
|
||||
///
|
||||
/// This function performs k-hop message passing using SQL queries
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_table` - Name of table containing node features
|
||||
/// * `edge_table` - Name of table containing edges
|
||||
/// * `embedding_col` - Column name for node embeddings
|
||||
/// * `hops` - Number of message passing hops
|
||||
/// * `layer_type` - Type of GNN layer: 'gcn' or 'sage'
|
||||
///
|
||||
/// # Returns
|
||||
/// SQL query result as text
|
||||
#[pg_extern(immutable, parallel_safe)]
|
||||
pub fn ruvector_message_pass(
|
||||
node_table: String,
|
||||
edge_table: String,
|
||||
embedding_col: String,
|
||||
hops: i32,
|
||||
layer_type: String,
|
||||
) -> String {
|
||||
// Validate inputs
|
||||
if hops < 1 {
|
||||
error!("Number of hops must be at least 1");
|
||||
}
|
||||
|
||||
let layer = layer_type.to_lowercase();
|
||||
if layer != "gcn" && layer != "sage" {
|
||||
error!("layer_type must be 'gcn' or 'sage'");
|
||||
}
|
||||
|
||||
// Generate SQL query for multi-hop message passing
|
||||
format!(
|
||||
"Multi-hop {} message passing over {} hops from table {} using edges from {} on column {}",
|
||||
layer, hops, node_table, edge_table, embedding_col
|
||||
)
|
||||
}
|
||||
|
||||
/// Apply GraphSAGE layer with neighbor sampling
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embeddings_json` - Node embeddings as JSON [num_nodes x in_features]
|
||||
/// * `src` - Source node indices
|
||||
/// * `dst` - Destination node indices
|
||||
/// * `out_dim` - Output dimension
|
||||
/// * `num_samples` - Number of neighbors to sample per node
|
||||
///
|
||||
/// # Returns
|
||||
/// Updated node embeddings after GraphSAGE layer as JSON
|
||||
#[pg_extern(immutable, parallel_safe)]
|
||||
pub fn ruvector_graphsage_forward(
|
||||
embeddings_json: JsonB,
|
||||
src: Vec<i32>,
|
||||
dst: Vec<i32>,
|
||||
out_dim: i32,
|
||||
num_samples: i32,
|
||||
) -> JsonB {
|
||||
// Parse embeddings from JSON
|
||||
let embeddings: Vec<Vec<f32>> = match embeddings_json.0.as_array() {
|
||||
Some(arr) => arr
|
||||
.iter()
|
||||
.filter_map(|v| {
|
||||
v.as_array().map(|a| {
|
||||
a.iter()
|
||||
.filter_map(|x| x.as_f64().map(|f| f as f32))
|
||||
.collect()
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
None => return JsonB(serde_json::json!([])),
|
||||
};
|
||||
|
||||
if embeddings.is_empty() {
|
||||
return JsonB(serde_json::json!([]));
|
||||
}
|
||||
|
||||
let in_features = embeddings[0].len();
|
||||
let out_features = out_dim as usize;
|
||||
|
||||
// Build edge index
|
||||
let edge_index: Vec<(usize, usize)> = src
|
||||
.iter()
|
||||
.zip(dst.iter())
|
||||
.map(|(&s, &d)| (s as usize, d as usize))
|
||||
.collect();
|
||||
|
||||
// Create GraphSAGE layer
|
||||
let layer = GraphSAGELayer::new(in_features, out_features, num_samples as usize);
|
||||
|
||||
// Forward pass
|
||||
let result = layer.forward(&embeddings, &edge_index);
|
||||
|
||||
JsonB(serde_json::json!(result))
|
||||
}
|
||||
|
||||
/// Batch GNN inference on multiple graphs
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embeddings_batch_json` - Batch of node embeddings as JSON
|
||||
/// * `edge_indices_batch` - Batch of edge indices (flattened)
|
||||
/// * `graph_sizes` - Number of nodes in each graph
|
||||
/// * `layer_type` - Type of layer: 'gcn' or 'sage'
|
||||
/// * `out_dim` - Output dimension
|
||||
///
|
||||
/// # Returns
|
||||
/// Batch of updated embeddings as JSON
|
||||
#[pg_extern(immutable, parallel_safe)]
|
||||
pub fn ruvector_gnn_batch_forward(
|
||||
embeddings_batch_json: JsonB,
|
||||
edge_indices_batch: Vec<i32>,
|
||||
graph_sizes: Vec<i32>,
|
||||
layer_type: String,
|
||||
out_dim: i32,
|
||||
) -> JsonB {
|
||||
// Parse embeddings from JSON
|
||||
let embeddings_batch: Vec<Vec<f32>> = match embeddings_batch_json.0.as_array() {
|
||||
Some(arr) => arr
|
||||
.iter()
|
||||
.filter_map(|v| {
|
||||
v.as_array().map(|a| {
|
||||
a.iter()
|
||||
.filter_map(|x| x.as_f64().map(|f| f as f32))
|
||||
.collect()
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
None => return JsonB(serde_json::json!([])),
|
||||
};
|
||||
|
||||
if embeddings_batch.is_empty() || graph_sizes.is_empty() {
|
||||
return JsonB(serde_json::json!([]));
|
||||
}
|
||||
|
||||
let mut result: Vec<Vec<f32>> = Vec::new();
|
||||
let mut node_offset = 0;
|
||||
let mut edge_offset = 0;
|
||||
|
||||
for &graph_size in &graph_sizes {
|
||||
let num_nodes = graph_size as usize;
|
||||
|
||||
// Extract embeddings for this graph
|
||||
let graph_embeddings: Vec<Vec<f32>> =
|
||||
embeddings_batch[node_offset..node_offset + num_nodes].to_vec();
|
||||
|
||||
// Extract edges for this graph (simplified - assumes edges come in pairs)
|
||||
let num_edges = edge_indices_batch
|
||||
.iter()
|
||||
.skip(edge_offset)
|
||||
.take_while(|&&idx| (idx as usize) < node_offset + num_nodes)
|
||||
.count()
|
||||
/ 2;
|
||||
|
||||
let src: Vec<i32> = edge_indices_batch
|
||||
.iter()
|
||||
.skip(edge_offset)
|
||||
.step_by(2)
|
||||
.take(num_edges)
|
||||
.map(|&x| x - node_offset as i32)
|
||||
.collect();
|
||||
|
||||
let dst: Vec<i32> = edge_indices_batch
|
||||
.iter()
|
||||
.skip(edge_offset + 1)
|
||||
.step_by(2)
|
||||
.take(num_edges)
|
||||
.map(|&x| x - node_offset as i32)
|
||||
.collect();
|
||||
|
||||
// Build edge index
|
||||
let edge_index: Vec<(usize, usize)> = src
|
||||
.iter()
|
||||
.zip(dst.iter())
|
||||
.map(|(&s, &d)| (s as usize, d as usize))
|
||||
.collect();
|
||||
|
||||
// Apply GNN layer
|
||||
let in_features = if graph_embeddings.is_empty() {
|
||||
0
|
||||
} else {
|
||||
graph_embeddings[0].len()
|
||||
};
|
||||
let out_features = out_dim as usize;
|
||||
|
||||
let graph_result = match layer_type.to_lowercase().as_str() {
|
||||
"gcn" => {
|
||||
let layer = GCNLayer::new(in_features, out_features);
|
||||
layer.forward(&graph_embeddings, &edge_index, None)
|
||||
}
|
||||
"sage" => {
|
||||
let layer = GraphSAGELayer::new(in_features, out_features, 10);
|
||||
layer.forward(&graph_embeddings, &edge_index)
|
||||
}
|
||||
_ => graph_embeddings,
|
||||
};
|
||||
|
||||
result.extend(graph_result);
|
||||
|
||||
node_offset += num_nodes;
|
||||
edge_offset += num_edges * 2;
|
||||
}
|
||||
|
||||
JsonB(serde_json::json!(result))
|
||||
}
|
||||
|
||||
#[cfg(feature = "pg_test")]
|
||||
#[pg_schema]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper to convert Vec to JsonB
|
||||
fn to_json(data: Vec<Vec<f32>>) -> JsonB {
|
||||
JsonB(serde_json::json!(data))
|
||||
}
|
||||
|
||||
// Helper to parse JsonB result to Vec
|
||||
fn parse_result(json: &JsonB) -> Vec<Vec<f32>> {
|
||||
json.0
|
||||
.as_array()
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| {
|
||||
v.as_array().map(|a| {
|
||||
a.iter()
|
||||
.filter_map(|x| x.as_f64().map(|f| f as f32))
|
||||
.collect()
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_ruvector_gcn_forward() {
|
||||
let embeddings = to_json(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]);
|
||||
|
||||
let src = vec![0, 1, 2];
|
||||
let dst = vec![1, 2, 0];
|
||||
|
||||
let result = ruvector_gcn_forward(embeddings, src, dst, None, 2);
|
||||
let parsed = parse_result(&result);
|
||||
|
||||
assert_eq!(parsed.len(), 3);
|
||||
assert_eq!(parsed[0].len(), 2);
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_ruvector_gnn_aggregate_sum() {
|
||||
let messages = to_json(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||
|
||||
let result = ruvector_gnn_aggregate(messages, "sum".to_string());
|
||||
|
||||
assert_eq!(result, vec![4.0, 6.0]);
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_ruvector_gnn_aggregate_mean() {
|
||||
let messages = to_json(vec![vec![2.0, 4.0], vec![4.0, 6.0]]);
|
||||
|
||||
let result = ruvector_gnn_aggregate(messages, "mean".to_string());
|
||||
|
||||
assert_eq!(result, vec![3.0, 5.0]);
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_ruvector_gnn_aggregate_max() {
|
||||
let messages = to_json(vec![vec![1.0, 6.0], vec![5.0, 2.0]]);
|
||||
|
||||
let result = ruvector_gnn_aggregate(messages, "max".to_string());
|
||||
|
||||
assert_eq!(result, vec![5.0, 6.0]);
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_ruvector_graphsage_forward() {
|
||||
let embeddings = to_json(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]);
|
||||
|
||||
let src = vec![0, 1, 2];
|
||||
let dst = vec![1, 2, 0];
|
||||
|
||||
let result = ruvector_graphsage_forward(embeddings, src, dst, 2, 2);
|
||||
let parsed = parse_result(&result);
|
||||
|
||||
assert_eq!(parsed.len(), 3);
|
||||
assert_eq!(parsed[0].len(), 2);
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_ruvector_message_pass() {
|
||||
let result = ruvector_message_pass(
|
||||
"nodes".to_string(),
|
||||
"edges".to_string(),
|
||||
"embedding".to_string(),
|
||||
3,
|
||||
"gcn".to_string(),
|
||||
);
|
||||
|
||||
assert!(result.contains("gcn"));
|
||||
assert!(result.contains("3 hops"));
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_empty_inputs() {
|
||||
let empty_embeddings = to_json(vec![]);
|
||||
let empty_src: Vec<i32> = vec![];
|
||||
let empty_dst: Vec<i32> = vec![];
|
||||
|
||||
let result = ruvector_gcn_forward(empty_embeddings, empty_src, empty_dst, None, 4);
|
||||
let parsed = parse_result(&result);
|
||||
|
||||
assert_eq!(parsed.len(), 0);
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_weighted_gcn() {
|
||||
let embeddings = to_json(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||
let src = vec![0];
|
||||
let dst = vec![1];
|
||||
let weights = Some(vec![2.0]);
|
||||
|
||||
let result = ruvector_gcn_forward(embeddings, src, dst, weights, 2);
|
||||
let parsed = parse_result(&result);
|
||||
|
||||
assert_eq!(parsed.len(), 2);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user