34 KiB
Optimization Strategies: Latent Space ↔ Graph Reality
Executive Summary
This document explores optimization strategies for training GNNs that effectively bridge latent space and graph topology. We examine loss functions, training procedures, regularization techniques, and multi-objective optimization methods specific to the graph learning domain.
Core Challenge: Jointly optimize for graph structure preservation, downstream task performance, and latent space quality
1. Loss Function Taxonomy
1.1 Classification of Graph Learning Losses
Graph Learning Losses
│
├─ Reconstruction Losses
│ ├─ Link Prediction
│ ├─ Node Feature Reconstruction
│ └─ Graph Structure Reconstruction
│
├─ Contrastive Losses
│ ├─ InfoNCE (current in RuVector)
│ ├─ Local Contrastive (current)
│ ├─ Triplet Loss
│ └─ Deep Graph Infomax
│
├─ Regularization Losses
│ ├─ Spectral (Laplacian)
│ ├─ Sparsity
│ ├─ EWC (current in RuVector)
│ └─ Embedding Normalization
│
├─ Task-Specific Losses
│ ├─ Node Classification (Cross-Entropy)
│ ├─ Link Prediction (BCE)
│ └─ Graph Classification
│
└─ Geometric Losses
├─ Distance Preservation
├─ Angle Preservation
└─ Curvature-Based
2. Reconstruction Losses
2.1 Link Prediction Loss
Goal: Predict edge existence from latent embeddings
Current RuVector Approach (Implicit in search.rs):
// Cosine similarity for neighbor selection
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32
Binary Cross-Entropy Link Loss:
pub fn link_prediction_loss(
embeddings: &[Vec<f32>],
positive_edges: &[(usize, usize)],
negative_edges: &[(usize, usize)],
) -> f32 {
let mut loss = 0.0;
// Positive edges: should have high similarity
for &(i, j) in positive_edges {
let sim = cosine_similarity(&embeddings[i], &embeddings[j]);
let prob = sigmoid(sim);
loss -= (prob + 1e-10).ln(); // -log P(edge exists)
}
// Negative edges: should have low similarity
for &(i, j) in negative_edges {
let sim = cosine_similarity(&embeddings[i], &embeddings[j]);
let prob = sigmoid(sim);
loss -= (1.0 - prob + 1e-10).ln(); // -log P(edge doesn't exist)
}
loss / (positive_edges.len() + negative_edges.len()) as f32
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
Negative Sampling Strategies:
1. Random Negatives
fn sample_random_negatives(
num_nodes: usize,
positive_edges: &[(usize, usize)],
ratio: usize, // Negatives per positive
) -> Vec<(usize, usize)> {
let mut negatives = Vec::new();
let positive_set: HashSet<_> = positive_edges.iter().collect();
while negatives.len() < positive_edges.len() * ratio {
let i = rand::random::<usize>() % num_nodes;
let j = rand::random::<usize>() % num_nodes;
if i != j && !positive_set.contains(&(i, j)) {
negatives.push((i, j));
}
}
negatives
}
2. Hard Negatives (Distance-Based)
fn sample_hard_negatives_distance(
embeddings: &[Vec<f32>],
positive_edges: &[(usize, usize)],
k_hops: usize, // Graph distance threshold
graph: &Graph,
) -> Vec<(usize, usize)> {
let mut hard_negatives = Vec::new();
for &(i, _) in positive_edges {
// Find nodes that are:
// 1. Latent-close (high embedding similarity)
// 2. Graph-far (> k_hops away)
let candidates: Vec<_> = (0..embeddings.len())
.filter(|&j| {
let dist = graph.shortest_path_distance(i, j);
dist > k_hops
})
.map(|j| (j, cosine_similarity(&embeddings[i], &embeddings[j])))
.collect();
// Sort by similarity (most similar = hardest negative)
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
if let Some((j, _)) = candidates.first() {
hard_negatives.push((i, *j));
}
}
hard_negatives
}
3. Degree-Corrected Negatives
// Sample negatives with similar degree distribution to positives
fn sample_degree_corrected_negatives(
degrees: &[usize],
positive_edges: &[(usize, usize)],
tolerance: f32,
) -> Vec<(usize, usize)> {
let mut negatives = Vec::new();
for &(i, j) in positive_edges {
let target_deg_i = degrees[i] as f32;
let target_deg_j = degrees[j] as f32;
// Sample i', j' with similar degrees
let candidates: Vec<_> = (0..degrees.len())
.filter(|&k| {
let deg_k = degrees[k] as f32;
(deg_k - target_deg_i).abs() / target_deg_i < tolerance
})
.collect();
if candidates.len() >= 2 {
let idx1 = rand::random::<usize>() % candidates.len();
let idx2 = rand::random::<usize>() % candidates.len();
if idx1 != idx2 {
negatives.push((candidates[idx1], candidates[idx2]));
}
}
}
negatives
}
2.2 Node Feature Reconstruction Loss
Goal: Reconstruct original node features from embeddings (autoencoder)
pub struct GraphAutoencoder {
encoder: RuvectorLayer,
decoder: Linear,
}
impl GraphAutoencoder {
fn reconstruction_loss(
&self,
node_features: &[Vec<f32>],
neighbor_structure: &[Vec<Vec<f32>>],
) -> f32 {
let mut total_loss = 0.0;
for (i, features) in node_features.iter().enumerate() {
// Encode
let embedding = self.encoder.forward(
features,
&neighbor_structure[i],
&vec![1.0; neighbor_structure[i].len()],
);
// Decode
let reconstructed = self.decoder.forward(&embedding);
// MSE loss
let mse = features.iter()
.zip(reconstructed.iter())
.map(|(f, r)| (f - r).powi(2))
.sum::<f32>();
total_loss += mse;
}
total_loss / node_features.len() as f32
}
}
Variants:
- MSE: Mean Squared Error (continuous features)
- BCE: Binary Cross-Entropy (binary features)
- Categorical CE: For discrete features
2.3 Graph Structure Reconstruction
Adjacency Matrix Reconstruction:
pub fn adjacency_reconstruction_loss(
embeddings: &[Vec<f32>],
true_adjacency: &SparseMatrix,
) -> f32 {
let n = embeddings.len();
let mut loss = 0.0;
// Predicted adjacency: A'[i,j] = σ(h_i^T h_j)
for i in 0..n {
for j in i+1..n {
let score = embeddings[i].iter()
.zip(embeddings[j].iter())
.map(|(a, b)| a * b)
.sum::<f32>();
let pred_adj = sigmoid(score);
let true_adj = if true_adjacency.has_edge(i, j) { 1.0 } else { 0.0 };
// Binary cross-entropy
if true_adj == 1.0 {
loss -= (pred_adj + 1e-10).ln();
} else {
loss -= (1.0 - pred_adj + 1e-10).ln();
}
}
}
loss / (n * (n-1) / 2) as f32
}
Sparse Variant (For large graphs):
// Only compute loss on edges + sampled non-edges
pub fn sparse_adjacency_loss(
embeddings: &[Vec<f32>],
edges: &[(usize, usize)],
num_negative_samples: usize,
) -> f32 {
let mut loss = 0.0;
// Positive samples
for &(i, j) in edges {
let score = dot_product(&embeddings[i], &embeddings[j]);
loss -= sigmoid(score).ln();
}
// Negative samples
let negatives = sample_random_negatives(embeddings.len(), edges, num_negative_samples);
for (i, j) in negatives {
let score = dot_product(&embeddings[i], &embeddings[j]);
loss -= (1.0 - sigmoid(score) + 1e-10).ln();
}
loss / (edges.len() + num_negative_samples) as f32
}
3. Contrastive Losses (Current in RuVector)
3.1 InfoNCE Loss (Deep Dive)
Current Implementation (training.rs:362-411):
pub fn info_nce_loss(
anchor: &[f32],
positives: &[&[f32]],
negatives: &[&[f32]],
temperature: f32,
) -> f32
Mathematical Properties:
1. Temperature Scaling:
τ → 0: Hard selection (argmax-like)
τ → ∞: Uniform (all samples weighted equally)
τ = 0.07: Standard for vision (SimCLR)
τ = 0.1-0.5: Common for graphs
Temperature Effect:
fn analyze_temperature_effect() {
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
let negative = vec![0.0, 1.0, 0.0];
for temp in [0.01, 0.07, 0.1, 0.5, 1.0] {
let loss = info_nce_loss(&anchor, &[&positive], &[&negative], temp);
println!("Temperature {}: Loss = {}", temp, loss);
}
}
2. Gradient Analysis:
∂L_InfoNCE / ∂h_v ∝ (h_+ - Σ_i w_i h_i)
where w_i = exp(sim(h_v, h_i) / τ) / Z
Interpretation: Gradient pulls anchor toward positive, pushes away from weighted average of negatives
3.2 Local Contrastive Loss (Graph-Specific)
Current Implementation (training.rs:444-462):
pub fn local_contrastive_loss(
node_embedding: &[f32],
neighbor_embeddings: &[Vec<f32>],
non_neighbor_embeddings: &[Vec<f32>],
temperature: f32,
) -> f32
Enhancement: Multi-Hop Contrastive:
pub fn multi_hop_contrastive_loss(
node_embedding: &[f32],
k_hop_neighbors: &HashMap<usize, Vec<Vec<f32>>>, // k -> neighbors at distance k
non_neighbors: &[Vec<f32>],
temperature: f32,
hop_weights: &[f32], // Weight for each hop distance
) -> f32 {
let mut total_loss = 0.0;
for (k, neighbors) in k_hop_neighbors {
if neighbors.is_empty() {
continue;
}
let positives: Vec<&[f32]> = neighbors.iter().map(|n| n.as_slice()).collect();
let negatives: Vec<&[f32]> = non_neighbors.iter().map(|n| n.as_slice()).collect();
let loss = info_nce_loss(node_embedding, &positives, &negatives, temperature);
// Weight by hop distance (closer = more important)
total_loss += hop_weights[*k] * loss;
}
total_loss
}
3.3 Triplet Loss
Alternative to InfoNCE:
L_triplet = max(0, ||h_v - h_+||² - ||h_v - h_-||² + margin)
Implementation:
pub fn triplet_loss(
anchor: &[f32],
positive: &[f32],
negative: &[f32],
margin: f32,
) -> f32 {
let pos_dist = l2_distance_squared(anchor, positive);
let neg_dist = l2_distance_squared(anchor, negative);
(pos_dist - neg_dist + margin).max(0.0)
}
pub fn batch_triplet_loss(
anchors: &[Vec<f32>],
positives: &[Vec<f32>],
negatives: &[Vec<f32>],
margin: f32,
) -> f32 {
anchors.iter()
.zip(positives.iter())
.zip(negatives.iter())
.map(|((a, p), n)| triplet_loss(a, p, n, margin))
.sum::<f32>() / anchors.len() as f32
}
Hard Triplet Mining:
fn mine_hard_triplets(
embeddings: &[Vec<f32>],
edges: &[(usize, usize)],
k: usize, // Top-k hardest
) -> Vec<(usize, usize, usize)> { // (anchor, positive, negative)
let mut triplets = Vec::new();
for &(i, j) in edges {
// j is positive for i
// Find hardest negative: closest non-neighbor
let mut candidates: Vec<(usize, f32)> = (0..embeddings.len())
.filter(|&k| k != i && k != j && !edges.contains(&(i, k)))
.map(|k| (k, l2_distance_squared(&embeddings[i], &embeddings[k])))
.collect();
// Sort by distance (ascending = closest = hardest)
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
for (neg_idx, _) in candidates.iter().take(k) {
triplets.push((i, j, *neg_idx));
}
}
triplets
}
4. Regularization Losses
4.1 Spectral Regularization (Laplacian)
Goal: Smooth embeddings along graph structure
Laplacian Smoothness:
L_Laplacian = Tr(H^T L H) = Σ_{(i,j) ∈ E} ||h_i - h_j||²
Implementation:
pub fn laplacian_regularization(
embeddings: &[Vec<f32>],
edges: &[(usize, usize)],
edge_weights: Option<&[f32]>, // Optional weighted graph
) -> f32 {
let mut loss = 0.0;
for (idx, &(i, j)) in edges.iter().enumerate() {
let diff = subtract(&embeddings[i], &embeddings[j]);
let norm_sq = l2_norm_squared(&diff);
let weight = edge_weights.map(|w| w[idx]).unwrap_or(1.0);
loss += weight * norm_sq;
}
loss / edges.len() as f32
}
Normalized Laplacian:
pub fn normalized_laplacian_regularization(
embeddings: &[Vec<f32>],
edges: &[(usize, usize)],
degrees: &[usize],
) -> f32 {
let mut loss = 0.0;
for &(i, j) in edges {
let diff = subtract(&embeddings[i], &embeddings[j]);
let norm_sq = l2_norm_squared(&diff);
// Normalize by sqrt of degrees
let normalization = (degrees[i] as f32 * degrees[j] as f32).sqrt();
loss += norm_sq / normalization;
}
loss / edges.len() as f32
}
4.2 Elastic Weight Consolidation (EWC) - Current
Implementation (ewc.rs:1-584):
pub struct ElasticWeightConsolidation {
fisher_diag: Vec<f32>, // Importance weights
anchor_weights: Vec<f32>, // Previous task optimal weights
lambda: f32, // Regularization strength
active: bool,
}
Loss Term:
L_EWC = (λ/2) Σ_i F_i (θ_i - θ*_i)²
Enhanced EWC for GNNs:
pub struct GNNElasticWeightConsolidation {
ewc_per_layer: Vec<ElasticWeightConsolidation>,
layer_importance: Vec<f32>, // Different λ per layer
}
impl GNNElasticWeightConsolidation {
fn compute_total_penalty(&self, current_weights: &[Vec<f32>]) -> f32 {
self.ewc_per_layer.iter()
.zip(self.layer_importance.iter())
.zip(current_weights.iter())
.map(|((ewc, &importance), weights)| {
importance * ewc.penalty(weights)
})
.sum()
}
fn adaptive_lambda_by_importance(&mut self) {
// Adjust lambda based on Fisher information magnitude
for ewc in &mut self.ewc_per_layer {
let avg_fisher = ewc.fisher_diag().iter().sum::<f32>()
/ ewc.fisher_diag().len() as f32;
// Higher Fisher = more important = higher lambda
ewc.set_lambda(1000.0 * avg_fisher);
}
}
}
4.3 Embedding Normalization
Unit Sphere Constraint:
||h_v|| = 1 for all v
Soft Constraint (Regularization):
pub fn embedding_norm_regularization(
embeddings: &[Vec<f32>],
target_norm: f32,
) -> f32 {
embeddings.iter()
.map(|h| {
let norm = l2_norm(h);
(norm - target_norm).powi(2)
})
.sum::<f32>() / embeddings.len() as f32
}
Hard Constraint (Projection):
pub fn normalize_embeddings(embeddings: &mut [Vec<f32>], target_norm: f32) {
for h in embeddings.iter_mut() {
let norm = l2_norm(h);
if norm > 1e-10 {
let scale = target_norm / norm;
for x in h.iter_mut() {
*x *= scale;
}
}
}
}
Benefits:
- Cosine similarity becomes dot product
- Prevents embedding collapse
- Bounded optimization landscape
4.4 Orthogonality Regularization
Goal: Encourage diverse, non-redundant features
pub fn orthogonality_regularization(
embeddings: &[Vec<f32>],
) -> f32 {
// H^T H should be close to identity
let n = embeddings.len();
let d = embeddings[0].len();
let mut gram_matrix = vec![vec![0.0; n]; n];
for i in 0..n {
for j in i..n {
let dot = dot_product(&embeddings[i], &embeddings[j]);
gram_matrix[i][j] = dot;
gram_matrix[j][i] = dot;
}
}
// Measure deviation from identity
let mut loss = 0.0;
for i in 0..n {
for j in 0..n {
let target = if i == j { 1.0 } else { 0.0 };
loss += (gram_matrix[i][j] - target).powi(2);
}
}
loss / (n * n) as f32
}
5. Multi-Objective Optimization
5.1 Weighted Combination
Total Loss:
L_total = λ_task L_task
+ λ_contrast L_contrast
+ λ_recon L_recon
+ λ_spectral L_spectral
+ λ_ewc L_ewc
Implementation:
pub struct MultiObjectiveLoss {
lambda_task: f32,
lambda_contrast: f32,
lambda_reconstruction: f32,
lambda_spectral: f32,
lambda_ewc: f32,
}
impl MultiObjectiveLoss {
fn compute_total_loss(
&self,
task_loss: f32,
contrastive_loss: f32,
reconstruction_loss: f32,
spectral_loss: f32,
ewc_penalty: f32,
) -> f32 {
self.lambda_task * task_loss
+ self.lambda_contrast * contrastive_loss
+ self.lambda_reconstruction * reconstruction_loss
+ self.lambda_spectral * spectral_loss
+ self.lambda_ewc * ewc_penalty
}
// Dynamic weight adjustment based on loss magnitudes
fn balance_weights(&mut self, losses: &LossComponents) {
// Normalize so all losses contribute roughly equally initially
let total_weighted = self.lambda_task * losses.task
+ self.lambda_contrast * losses.contrastive
+ self.lambda_reconstruction * losses.reconstruction
+ self.lambda_spectral * losses.spectral
+ self.lambda_ewc * losses.ewc;
// Adjust lambdas to balance contributions
let target_contribution = total_weighted / 5.0;
self.lambda_task *= target_contribution / (self.lambda_task * losses.task).max(1e-10);
self.lambda_contrast *= target_contribution / (self.lambda_contrast * losses.contrastive).max(1e-10);
// ... etc
}
}
5.2 Curriculum Learning
Idea: Schedule loss weights over training
Example Schedule:
Early training (epochs 0-10):
- High λ_reconstruction (learn basic features)
- Low λ_contrast (don't force structure yet)
Mid training (epochs 10-50):
- Increase λ_contrast (encode graph topology)
- Introduce λ_spectral (smooth embeddings)
Late training (epochs 50+):
- High λ_task (optimize for downstream)
- Introduce λ_ewc (if continual learning)
Implementation:
pub struct CurriculumSchedule {
current_epoch: usize,
schedules: HashMap<String, Box<dyn Fn(usize) -> f32>>,
}
impl CurriculumSchedule {
fn new() -> Self {
let mut schedules: HashMap<String, Box<dyn Fn(usize) -> f32>> = HashMap::new();
// Reconstruction: start high, decrease
schedules.insert(
"reconstruction".to_string(),
Box::new(|epoch| {
1.0 * (-(epoch as f32) / 50.0).exp()
})
);
// Contrastive: start low, increase, plateau
schedules.insert(
"contrastive".to_string(),
Box::new(|epoch| {
if epoch < 10 {
0.1 + 0.9 * (epoch as f32 / 10.0)
} else {
1.0
}
})
);
// Task: start low, ramp up late
schedules.insert(
"task".to_string(),
Box::new(|epoch| {
if epoch < 50 {
0.1
} else {
0.1 + 0.9 * ((epoch - 50) as f32 / 50.0).min(1.0)
}
})
);
Self {
current_epoch: 0,
schedules,
}
}
fn get_weight(&self, loss_name: &str) -> f32 {
self.schedules.get(loss_name)
.map(|f| f(self.current_epoch))
.unwrap_or(1.0)
}
fn step(&mut self) {
self.current_epoch += 1;
}
}
5.3 Gradient Surgery (Conflicting Objectives)
Problem: Gradients from different losses may conflict
PCGrad (Projected Conflicting Gradients):
pub fn project_conflicting_gradients(
grad_task: &[f32],
grad_contrast: &[f32],
) -> (Vec<f32>, Vec<f32>) {
let dot = dot_product(grad_task, grad_contrast);
if dot < 0.0 {
// Gradients conflict, project
let grad_contrast_norm_sq = l2_norm_squared(grad_contrast);
// Project grad_task away from grad_contrast
let projection_scale = dot / grad_contrast_norm_sq;
let grad_task_projected: Vec<f32> = grad_task.iter()
.zip(grad_contrast.iter())
.map(|(>, &gc)| gt - projection_scale * gc)
.collect();
(grad_task_projected, grad_contrast.to_vec())
} else {
// No conflict
(grad_task.to_vec(), grad_contrast.to_vec())
}
}
6. Training Strategies
6.1 Online Learning (Current in RuVector)
Config (training.rs:313-328):
pub struct OnlineConfig {
local_steps: usize,
propagate_updates: bool,
}
Enhanced Online Training:
pub struct OnlineGNNTrainer {
model: RuvectorLayer,
optimizer: Optimizer,
config: OnlineConfig,
replay_buffer: ReplayBuffer,
}
impl OnlineGNNTrainer {
fn update_on_new_node(
&mut self,
new_node_features: &[f32],
neighbors: &[Vec<f32>],
labels: Option<usize>,
) {
// 1. Forward pass with current model
let embedding = self.model.forward(new_node_features, neighbors, &[]);
// 2. Compute loss
let loss = if let Some(label) = labels {
// Supervised: classification loss
self.classification_loss(&embedding, label)
} else {
// Unsupervised: contrastive loss
self.contrastive_loss(&embedding, neighbors)
};
// 3. Local optimization steps
for _ in 0..self.config.local_steps {
let grads = self.compute_gradients(loss);
self.optimizer.step(&mut self.model.weights, &grads);
}
// 4. Store in replay buffer (prevent catastrophic forgetting)
self.replay_buffer.add(new_node_features, neighbors, labels);
// 5. Periodic replay
if self.replay_buffer.should_replay() {
self.replay_past_experiences();
}
}
fn replay_past_experiences(&mut self) {
let samples = self.replay_buffer.sample(32);
for (features, neighbors, label) in samples {
// Re-train on past data
self.update_on_new_node(&features, &neighbors, label);
}
}
}
6.2 Batch Training with Graph Sampling
Challenge: Full-batch training on large graphs is expensive
Solution: Sample subgraphs for each batch
Neighbor Sampling:
pub fn sample_neighbors(
node: usize,
all_neighbors: &[Vec<usize>],
sample_size: usize,
) -> Vec<usize> {
let neighbors = &all_neighbors[node];
if neighbors.len() <= sample_size {
return neighbors.clone();
}
// Uniform sampling
let mut sampled = Vec::new();
let mut rng = rand::thread_rng();
while sampled.len() < sample_size {
let idx = rng.gen_range(0..neighbors.len());
if !sampled.contains(&neighbors[idx]) {
sampled.push(neighbors[idx]);
}
}
sampled
}
Layer-wise Sampling (GraphSAINT):
pub fn layer_wise_sampling(
root_nodes: &[usize],
all_neighbors: &[Vec<usize>],
num_layers: usize,
sample_sizes: &[usize], // Sample size per layer
) -> Vec<Vec<Vec<usize>>> { // [layer][node][neighbors]
let mut sampled_neighborhoods = vec![Vec::new(); num_layers];
let mut current_frontier = root_nodes.to_vec();
for layer in 0..num_layers {
let mut next_frontier = Vec::new();
for &node in ¤t_frontier {
let neighbors = sample_neighbors(node, all_neighbors, sample_sizes[layer]);
sampled_neighborhoods[layer].push(neighbors.clone());
next_frontier.extend(neighbors);
}
current_frontier = next_frontier;
}
sampled_neighborhoods
}
6.3 Meta-Learning for Few-Shot Graph Learning
MAML (Model-Agnostic Meta-Learning) for Graphs:
pub struct GraphMAML {
model: RuvectorLayer,
meta_lr: f32,
inner_lr: f32,
inner_steps: usize,
}
impl GraphMAML {
fn meta_train_step(
&mut self,
tasks: &[GraphTask], // Multiple graph learning tasks
) -> f32 {
let mut meta_gradients = vec![0.0; self.model.num_parameters()];
for task in tasks {
// 1. Clone model for inner loop
let mut task_model = self.model.clone();
// 2. Inner loop: adapt to this task
for _ in 0..self.inner_steps {
let loss = self.compute_task_loss(&task_model, &task.support_set);
let grads = self.compute_gradients(loss);
// Update task model
task_model.update_parameters(&grads, self.inner_lr);
}
// 3. Compute loss on query set with adapted model
let query_loss = self.compute_task_loss(&task_model, &task.query_set);
let query_grads = self.compute_gradients(query_loss);
// 4. Accumulate meta-gradients
for (i, &grad) in query_grads.iter().enumerate() {
meta_gradients[i] += grad;
}
}
// 5. Meta-update
let avg_meta_grads: Vec<f32> = meta_gradients.iter()
.map(|&g| g / tasks.len() as f32)
.collect();
self.model.update_parameters(&avg_meta_grads, self.meta_lr);
// Return average meta-loss
tasks.iter()
.map(|task| self.compute_task_loss(&self.model, &task.query_set))
.sum::<f32>() / tasks.len() as f32
}
}
7. Hyperparameter Optimization
7.1 Key Hyperparameters for GNNs
Architecture:
num_layers: 2-8 typicalhidden_dim: 64-512num_heads: 1-8 (attention)dropout: 0.0-0.5
Training:
learning_rate: 1e-4 to 1e-2temperature: 0.01-1.0 (contrastive)lambda_*: Loss weights (0.0-10.0)batch_size: 32-512
HNSW-Specific:
M: Neighbors per layer (16-64)ef_construction: Search depth (100-500)num_hnsw_layers: Typically log(N)
7.2 Grid Search
pub fn grid_search_hyperparameters(
param_grid: &HashMap<String, Vec<f32>>,
validation_set: &Dataset,
) -> HashMap<String, f32> {
let mut best_params = HashMap::new();
let mut best_score = f32::NEG_INFINITY;
// Generate all combinations
let combinations = generate_combinations(param_grid);
for params in combinations {
// Train model with these params
let model = train_model_with_params(¶ms, validation_set);
// Evaluate
let score = evaluate_model(&model, validation_set);
if score > best_score {
best_score = score;
best_params = params.clone();
}
}
best_params
}
7.3 Bayesian Optimization
Use Gaussian Process to model hyperparameter → performance:
pub struct BayesianHyperparamOptimizer {
gp: GaussianProcess,
acquisition_fn: AcquisitionFunction,
evaluated_points: Vec<(HashMap<String, f32>, f32)>,
}
impl BayesianHyperparamOptimizer {
fn suggest_next_params(&self) -> HashMap<String, f32> {
// Maximize acquisition function (e.g., Expected Improvement)
self.acquisition_fn.maximize(&self.gp)
}
fn observe(&mut self, params: HashMap<String, f32>, score: f32) {
self.evaluated_points.push((params.clone(), score));
self.gp.update(&self.evaluated_points);
}
fn optimize(&mut self, num_iterations: usize) -> HashMap<String, f32> {
for _ in 0..num_iterations {
let params = self.suggest_next_params();
let score = train_and_evaluate(¶ms);
self.observe(params, score);
}
// Return best observed
self.evaluated_points.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(p, _)| p.clone())
.unwrap()
}
}
8. Distributed and Efficient Training
8.1 Data Parallelism
Distribute batches across GPUs/nodes:
pub fn distributed_training(
model: &RuvectorLayer,
dataset: &GraphDataset,
num_workers: usize,
) {
// 1. Replicate model to all workers
let models: Vec<_> = (0..num_workers)
.map(|_| model.clone())
.collect();
// 2. Split data
let batches_per_worker = dataset.len() / num_workers;
// 3. Train in parallel
let gradients: Vec<_> = (0..num_workers)
.into_par_iter()
.map(|worker_id| {
let start = worker_id * batches_per_worker;
let end = start + batches_per_worker;
let local_data = &dataset[start..end];
// Local forward + backward
train_on_subset(&models[worker_id], local_data)
})
.collect();
// 4. Aggregate gradients (AllReduce)
let avg_gradients = average_gradients(&gradients);
// 5. Update global model
model.apply_gradients(&avg_gradients);
}
8.2 Model Parallelism (Large Graphs)
Partition graph across devices:
pub struct DistributedGraph {
partitions: Vec<GraphPartition>,
partition_mapping: HashMap<usize, usize>, // node_id -> partition_id
}
impl DistributedGraph {
fn forward_distributed(
&self,
node_id: usize,
models: &[RuvectorLayer],
) -> Vec<f32> {
let partition_id = self.partition_mapping[&node_id];
let partition = &self.partitions[partition_id];
// Get local neighbors
let local_neighbors = partition.get_neighbors(node_id);
// Get remote neighbors (cross-partition edges)
let remote_neighbors = partition.get_remote_neighbors(node_id);
// Fetch remote embeddings via communication
let remote_embeddings = self.fetch_remote_embeddings(&remote_neighbors);
// Combine local and remote
let all_neighbors = [&local_neighbors[..], &remote_embeddings[..]].concat();
// Forward pass on local model
models[partition_id].forward(
&partition.get_features(node_id),
&all_neighbors,
&vec![1.0; all_neighbors.len()],
)
}
}
9. Implementation Recommendations for RuVector
9.1 Immediate Enhancements (Week 1-2)
1. Hard Negative Sampling
// Modify local_contrastive_loss to use hard negatives
let hard_negatives = sample_hard_negatives_distance(
node_embedding,
all_embeddings,
neighbor_indices,
k_negatives,
);
2. Spectral Regularization
// Add to total loss
let spectral_loss = laplacian_regularization(embeddings, edges, None);
total_loss += lambda_spectral * spectral_loss;
3. Dynamic Loss Weight Balancing
// Automatically balance loss contributions
let mut loss_config = MultiObjectiveLoss::new();
loss_config.balance_weights(¤t_losses);
9.2 Short-Term (Month 1)
4. Curriculum Learning Schedule
let mut curriculum = CurriculumSchedule::new();
for epoch in 0..num_epochs {
let lambda_contrast = curriculum.get_weight("contrastive");
let lambda_recon = curriculum.get_weight("reconstruction");
// Train with scheduled weights
train_epoch(lambda_contrast, lambda_recon);
curriculum.step();
}
5. Online Learning with Replay
let mut trainer = OnlineGNNTrainer::new(model, replay_buffer_size=1000);
for new_node in streaming_data {
trainer.update_on_new_node(&new_node.features, &new_node.neighbors, None);
}
9.3 Medium-Term (Quarter 1)
6. Meta-Learning for Few-Shot
let mut maml = GraphMAML::new(meta_lr=0.001, inner_lr=0.01);
for epoch in 0..meta_epochs {
let tasks = sample_tasks(task_distribution, k_shot=5);
maml.meta_train_step(&tasks);
}
7. Distributed Training
let distributed_graph = DistributedGraph::partition(graph, num_partitions=4);
let models = replicate_model(base_model, num_partitions);
distributed_training(&distributed_graph, &models);
10. Benchmarking and Evaluation
10.1 Loss Tracking
pub struct LossTracker {
history: HashMap<String, Vec<f32>>,
}
impl LossTracker {
fn log(&mut self, loss_name: &str, value: f32) {
self.history.entry(loss_name.to_string())
.or_insert_with(Vec::new)
.push(value);
}
fn plot_losses(&self) {
// Visualization: loss curves over training
for (name, values) in &self.history {
println!("{}: {:?}", name, values);
}
}
fn detect_overfitting(&self) -> bool {
// Compare train vs. validation loss
let train_loss = self.history.get("train").unwrap();
let val_loss = self.history.get("validation").unwrap();
// If validation loss increasing while train decreasing
train_loss.last().unwrap() < train_loss[0]
&& val_loss.last().unwrap() > val_loss[val_loss.len() / 2]
}
}
10.2 Metrics
1. Latent Space Quality:
- Embedding norm distribution
- Pairwise distance distribution
- Nearest neighbor recall
2. Graph Preservation:
- Link prediction AUC
- Triangle closure accuracy
- Community detection modularity
3. Downstream Performance:
- Node classification accuracy
- Graph classification accuracy
- Search quality (Recall@K)
References
Papers
- Contrastive Learning: Chen et al. (2020) - SimCLR, Oord et al. (2018) - CPC
- Meta-Learning: Finn et al. (2017) - MAML
- EWC: Kirkpatrick et al. (2017) - Overcoming Catastrophic Forgetting
- Curriculum: Bengio et al. (2009) - Curriculum Learning
- Hard Negatives: Schroff et al. (2015) - FaceNet (Triplet Loss)
- Spectral: Belkin & Niyogi (2003) - Laplacian Eigenmaps
- Multi-Objective: Yu et al. (2020) - Gradient Surgery
RuVector Code
crates/ruvector-gnn/src/training.rs- Current losses and optimizerscrates/ruvector-gnn/src/ewc.rs- Elastic Weight Consolidation
Document Version: 1.0 Last Updated: 2025-11-30 Author: RuVector Research Team