git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
1099 lines
37 KiB
Markdown
1099 lines
37 KiB
Markdown
# GNN-Guided HNSW Routing - Implementation Plan
|
||
|
||
## Overview
|
||
|
||
### Problem Statement
|
||
|
||
Current HNSW (Hierarchical Navigable Small World) graph search uses a greedy routing strategy that selects the nearest neighbor at each step. This approach is locally optimal but often misses globally better paths, resulting in:
|
||
|
||
- Suboptimal query performance (increased distance computations)
|
||
- Redundant edge traversals in dense regions
|
||
- Poor scaling with graph size (20-40% performance degradation at 10M+ vectors)
|
||
- Inability to learn from query patterns
|
||
|
||
### Proposed Solution
|
||
|
||
Replace greedy HNSW routing with a learned GNN-based routing policy that:
|
||
|
||
1. **Path Learning**: Train on successful search trajectories to learn optimal routing decisions
|
||
2. **Context-Aware Selection**: Use graph structure + query context to predict best next hops
|
||
3. **Multi-Hop Reasoning**: Consider k-step lookahead instead of greedy single-step
|
||
4. **Adaptive Routing**: Adjust routing strategy based on query characteristics
|
||
|
||
The GNN will output edge selection probabilities for each node during search, replacing the greedy nearest-neighbor heuristic.
|
||
|
||
### Expected Benefits
|
||
|
||
**Quantified Performance Improvements:**
|
||
- **+25% QPS** (Queries Per Second) through reduced search iterations
|
||
- **-30% distance computations** via smarter edge selection
|
||
- **-15% average hop count** to reach target nodes
|
||
- **+18% recall@10** for challenging queries (edge cases, dense clusters)
|
||
|
||
**Qualitative Benefits:**
|
||
- Learns from query distribution patterns
|
||
- Adapts to graph topology changes
|
||
- Handles multi-modal embeddings better
|
||
- Reduces tail latencies (P99 improvement)
|
||
|
||
## Technical Design
|
||
|
||
### Architecture Diagram (ASCII Art)
|
||
|
||
```
|
||
┌─────────────────────────────────────────────────────────────────┐
|
||
│ GNN-Guided HNSW Search │
|
||
└─────────────────────────────────────────────────────────────────┘
|
||
|
||
Query Vector (q)
|
||
│
|
||
▼
|
||
┌─────────────────────────────────────────────────────────────────┐
|
||
│ Entry Point Selection (standard HNSW top layer) │
|
||
└─────────────────────────────────────────────────────────────────┘
|
||
│
|
||
▼
|
||
┌─────────────────────────────────────────────────────────────────┐
|
||
│ Layer L → 0 Search Loop │
|
||
│ ┌───────────────────────────────────────────────────────────┐ │
|
||
│ │ Current Node (c) │ │
|
||
│ │ │ │ │
|
||
│ │ ▼ │ │
|
||
│ │ ┌──────────────────────────────────────────────┐ │ │
|
||
│ │ │ GNN Edge Scorer │ │ │
|
||
│ │ │ ┌────────────────────────────────────────┐ │ │ │
|
||
│ │ │ │ Input: [node_feat, query, edge_feat] │ │ │ │
|
||
│ │ │ │ Graph Context: k-hop neighborhood │ │ │ │
|
||
│ │ │ │ Attention Layer: Multi-head GAT │ │ │ │
|
||
│ │ │ │ Output: Edge selection probabilities │ │ │ │
|
||
│ │ │ └────────────────────────────────────────┘ │ │ │
|
||
│ │ └──────────────────────────────────────────────┘ │ │
|
||
│ │ │ │ │
|
||
│ │ ▼ │ │
|
||
│ │ ┌──────────────────────────────────────────────┐ │ │
|
||
│ │ │ Edge Selection Strategy │ │ │
|
||
│ │ │ - Top-k by GNN score │ │ │
|
||
│ │ │ - Temperature-based sampling (exploration) │ │ │
|
||
│ │ │ - Hybrid: GNN score * distance heuristic │ │ │
|
||
│ │ └──────────────────────────────────────────────┘ │ │
|
||
│ │ │ │ │
|
||
│ │ ▼ │ │
|
||
│ │ Candidate Neighbors (N) │ │
|
||
│ │ │ │ │
|
||
│ │ ▼ │ │
|
||
│ │ Update best candidates, move to next node │ │
|
||
│ └───────────────────────────────────────────────────────────┘ │
|
||
└─────────────────────────────────────────────────────────────────┘
|
||
│
|
||
▼
|
||
Return Top-K Results
|
||
|
||
┌─────────────────────────────────────────────────────────────────┐
|
||
│ Training Pipeline │
|
||
└─────────────────────────────────────────────────────────────────┘
|
||
|
||
Query Workload
|
||
│
|
||
▼
|
||
┌─────────────────────────────────────────────────────────────────┐
|
||
│ Path Collector (on standard HNSW) │
|
||
│ - Record: (query, node_seq, edges_taken, final_results) │
|
||
│ - Label: edges_on_optimal_path = 1, others = 0 │
|
||
└─────────────────────────────────────────────────────────────────┘
|
||
│
|
||
▼
|
||
┌─────────────────────────────────────────────────────────────────┐
|
||
│ Offline Training (PyTorch/candle) │
|
||
│ - Loss: BCE(GNN_edge_score, optimal_edge_label) │
|
||
│ - Optimizer: AdamW with lr=1e-3 │
|
||
│ - Batch: 256 query trajectories │
|
||
└─────────────────────────────────────────────────────────────────┘
|
||
│
|
||
▼
|
||
Export ONNX Model → Load in ruvector-gnn (Rust)
|
||
```
|
||
|
||
### Core Data Structures (Rust)
|
||
|
||
```rust
|
||
// File: crates/ruvector-gnn/src/routing/mod.rs
|
||
|
||
use ndarray::{Array1, Array2};
|
||
use ort::{Session, Value};
|
||
|
||
/// GNN-guided routing policy for HNSW search
|
||
pub struct GnnRoutingPolicy {
|
||
/// ONNX runtime session for GNN inference
|
||
session: Session,
|
||
|
||
/// Feature extractor for nodes and edges
|
||
feature_extractor: FeatureExtractor,
|
||
|
||
/// Configuration for routing behavior
|
||
config: RoutingConfig,
|
||
|
||
/// Performance metrics
|
||
metrics: RoutingMetrics,
|
||
}
|
||
|
||
/// Configuration for GNN routing
|
||
#[derive(Debug, Clone)]
|
||
pub struct RoutingConfig {
|
||
/// Number of top edges to consider per node
|
||
pub top_k_edges: usize,
|
||
|
||
/// Temperature for edge selection sampling (0.0 = greedy)
|
||
pub temperature: f32,
|
||
|
||
/// Hybrid weight: α * gnn_score + (1-α) * distance_score
|
||
pub hybrid_alpha: f32,
|
||
|
||
/// Maximum GNN inference batch size
|
||
pub inference_batch_size: usize,
|
||
|
||
/// Enable/disable GNN routing (fallback to greedy)
|
||
pub enabled: bool,
|
||
|
||
/// K-hop neighborhood size for graph context
|
||
pub context_hops: usize,
|
||
}
|
||
|
||
/// Feature extraction for nodes and edges
|
||
pub struct FeatureExtractor {
|
||
/// Dimensionality of node features
|
||
node_dim: usize,
|
||
|
||
/// Dimensionality of edge features
|
||
edge_dim: usize,
|
||
|
||
/// Cache for computed features
|
||
cache: FeatureCache,
|
||
}
|
||
|
||
/// Features for a single node in the graph
|
||
#[derive(Debug, Clone)]
|
||
pub struct NodeFeatures {
|
||
/// Node embedding vector
|
||
pub embedding: Array1<f32>,
|
||
|
||
/// Degree (number of neighbors)
|
||
pub degree: usize,
|
||
|
||
/// Layer in HNSW hierarchy
|
||
pub layer: usize,
|
||
|
||
/// Clustering coefficient
|
||
pub clustering_coef: f32,
|
||
|
||
/// Distance to query (dynamic)
|
||
pub query_distance: f32,
|
||
}
|
||
|
||
/// Features for an edge in the graph
|
||
#[derive(Debug, Clone)]
|
||
pub struct EdgeFeatures {
|
||
/// Euclidean distance between connected nodes
|
||
pub distance: f32,
|
||
|
||
/// Angular similarity (cosine)
|
||
pub angular_similarity: f32,
|
||
|
||
/// Edge betweenness (precomputed)
|
||
pub betweenness: f32,
|
||
|
||
/// Whether edge crosses layers
|
||
pub cross_layer: bool,
|
||
}
|
||
|
||
/// GNN inference result for edge selection
|
||
#[derive(Debug)]
|
||
pub struct EdgeScore {
|
||
/// Target node ID
|
||
pub target_node: u32,
|
||
|
||
/// GNN-predicted score [0, 1]
|
||
pub gnn_score: f32,
|
||
|
||
/// Distance-based heuristic score
|
||
pub distance_score: f32,
|
||
|
||
/// Final combined score
|
||
pub combined_score: f32,
|
||
}
|
||
|
||
/// Performance tracking for routing
|
||
#[derive(Debug, Default)]
|
||
pub struct RoutingMetrics {
|
||
/// Total GNN inference calls
|
||
pub total_inferences: u64,
|
||
|
||
/// Average inference latency (microseconds)
|
||
pub avg_inference_us: f64,
|
||
|
||
/// Total distance computations
|
||
pub distance_computations: u64,
|
||
|
||
/// Average hops per query
|
||
pub avg_hops: f64,
|
||
|
||
/// Cache hit rate for features
|
||
pub feature_cache_hit_rate: f64,
|
||
}
|
||
|
||
/// Training data collection for offline learning
|
||
pub struct PathTrajectory {
|
||
/// Query vector
|
||
pub query: Vec<f32>,
|
||
|
||
/// Sequence of nodes visited
|
||
pub node_sequence: Vec<u32>,
|
||
|
||
/// Edges taken at each step
|
||
pub edges_taken: Vec<(u32, u32)>,
|
||
|
||
/// All candidate edges at each step
|
||
pub candidate_edges: Vec<Vec<(u32, u32)>>,
|
||
|
||
/// Final k-NN results
|
||
pub results: Vec<(u32, f32)>,
|
||
}
|
||
```
|
||
|
||
### Key Algorithms (Pseudocode)
|
||
|
||
#### 1. GNN-Guided Search Algorithm
|
||
|
||
```python
|
||
function gnn_guided_search(query: Vector, graph: HNSWGraph, k: int) -> List[Result]:
|
||
"""
|
||
HNSW search with GNN-guided routing instead of greedy selection.
|
||
"""
|
||
# Initialize from top layer entry point
|
||
current_nodes = {graph.entry_point}
|
||
layer = graph.max_layer
|
||
|
||
# Descend through layers
|
||
while layer >= 0:
|
||
# Find best candidates at this layer using GNN
|
||
candidates = priority_queue()
|
||
visited = set()
|
||
|
||
for node in current_nodes:
|
||
# Get neighbors at this layer
|
||
neighbors = graph.get_neighbors(node, layer)
|
||
|
||
# Extract features for GNN
|
||
node_features = extract_node_features(node, query, graph)
|
||
edge_features = [extract_edge_features(node, neighbor, graph)
|
||
for neighbor in neighbors]
|
||
|
||
# GNN inference: score all edges from current node
|
||
edge_scores = gnn_model.score_edges(
|
||
node_features,
|
||
edge_features,
|
||
query
|
||
)
|
||
|
||
# Select edges based on GNN scores (not greedy distance)
|
||
selected = select_edges_by_gnn_score(
|
||
neighbors,
|
||
edge_scores,
|
||
config.top_k_edges,
|
||
config.temperature
|
||
)
|
||
|
||
for neighbor in selected:
|
||
if neighbor not in visited:
|
||
distance = compute_distance(query, graph.get_vector(neighbor))
|
||
candidates.push(neighbor, distance)
|
||
visited.add(neighbor)
|
||
|
||
# Move to best candidates for next iteration
|
||
current_nodes = candidates.top(config.beam_width)
|
||
layer -= 1
|
||
|
||
# Return top-k results from layer 0
|
||
return candidates.top(k)
|
||
|
||
|
||
function select_edges_by_gnn_score(neighbors, scores, top_k, temperature):
|
||
"""
|
||
Select edges based on GNN scores with optional exploration.
|
||
|
||
Strategies:
|
||
- temperature = 0: greedy top-k
|
||
- temperature > 0: sampling from softmax distribution
|
||
- hybrid mode: combine GNN score with distance heuristic
|
||
"""
|
||
if temperature == 0:
|
||
# Greedy: select top-k by GNN score
|
||
return top_k_by_score(neighbors, scores)
|
||
else:
|
||
# Sampling: use temperature-scaled softmax
|
||
probs = softmax(scores / temperature)
|
||
return sample_without_replacement(neighbors, probs, top_k)
|
||
|
||
|
||
function extract_node_features(node, query, graph):
|
||
"""
|
||
Extract node-level features for GNN input.
|
||
"""
|
||
return NodeFeatures(
|
||
embedding=graph.get_vector(node),
|
||
degree=graph.get_degree(node),
|
||
layer=graph.get_layer(node),
|
||
clustering_coef=graph.get_clustering_coefficient(node),
|
||
query_distance=distance(query, graph.get_vector(node))
|
||
)
|
||
|
||
|
||
function extract_edge_features(source, target, graph):
|
||
"""
|
||
Extract edge-level features for GNN input.
|
||
"""
|
||
source_vec = graph.get_vector(source)
|
||
target_vec = graph.get_vector(target)
|
||
|
||
return EdgeFeatures(
|
||
distance=euclidean_distance(source_vec, target_vec),
|
||
angular_similarity=cosine_similarity(source_vec, target_vec),
|
||
betweenness=graph.get_edge_betweenness(source, target),
|
||
cross_layer=(graph.get_layer(source) != graph.get_layer(target))
|
||
)
|
||
```
|
||
|
||
#### 2. Offline Training Pipeline
|
||
|
||
```python
|
||
function collect_training_data(graph, query_workload, n_samples):
|
||
"""
|
||
Collect path trajectories from standard HNSW for training.
|
||
"""
|
||
trajectories = []
|
||
|
||
for query in query_workload.sample(n_samples):
|
||
# Run standard greedy HNSW search with full logging
|
||
path = instrumented_hnsw_search(query, graph)
|
||
|
||
# Label edges: 1 if on optimal path, 0 otherwise
|
||
optimal_edges = set(path.edges_taken)
|
||
|
||
# For each node in path, get all candidate edges
|
||
for step in path.node_sequence:
|
||
node = step.node
|
||
neighbors = graph.get_neighbors(node, step.layer)
|
||
|
||
# Create training examples
|
||
for neighbor in neighbors:
|
||
edge = (node, neighbor)
|
||
label = 1.0 if edge in optimal_edges else 0.0
|
||
|
||
node_feat = extract_node_features(node, query, graph)
|
||
edge_feat = extract_edge_features(node, neighbor, graph)
|
||
|
||
trajectories.append({
|
||
'node_features': node_feat,
|
||
'edge_features': edge_feat,
|
||
'query': query,
|
||
'label': label,
|
||
'distance_to_query': distance(query, graph.get_vector(neighbor))
|
||
})
|
||
|
||
return trajectories
|
||
|
||
|
||
function train_gnn_routing_model(trajectories, config):
|
||
"""
|
||
Train GNN model to predict edge selection probabilities.
|
||
|
||
Architecture: Graph Attention Network (GAT)
|
||
- 3 attention layers with 4 heads each
|
||
- Hidden dim: 128
|
||
- Edge features concatenated with node features
|
||
- Output: single logit per edge (probability of selection)
|
||
"""
|
||
model = GAT(
|
||
node_dim=config.node_feature_dim,
|
||
edge_dim=config.edge_feature_dim,
|
||
hidden_dim=128,
|
||
num_layers=3,
|
||
num_heads=4,
|
||
output_dim=1
|
||
)
|
||
|
||
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
||
loss_fn = BCEWithLogitsLoss()
|
||
|
||
for epoch in range(config.num_epochs):
|
||
for batch in DataLoader(trajectories, batch_size=256):
|
||
# Forward pass
|
||
edge_logits = model(
|
||
batch.node_features,
|
||
batch.edge_features,
|
||
batch.query
|
||
)
|
||
|
||
# Binary cross-entropy loss
|
||
loss = loss_fn(edge_logits, batch.labels)
|
||
|
||
# Add distance-aware regularization
|
||
# Encourage model to respect distance heuristic
|
||
distance_scores = 1.0 / (1.0 + batch.distance_to_query)
|
||
consistency_loss = mse_loss(sigmoid(edge_logits), distance_scores)
|
||
|
||
total_loss = loss + 0.1 * consistency_loss
|
||
|
||
# Backward pass
|
||
optimizer.zero_grad()
|
||
total_loss.backward()
|
||
optimizer.step()
|
||
|
||
return model
|
||
|
||
|
||
function export_to_onnx(model, output_path):
|
||
"""
|
||
Export trained PyTorch model to ONNX for Rust inference.
|
||
"""
|
||
dummy_input = {
|
||
'node_features': torch.randn(1, node_dim),
|
||
'edge_features': torch.randn(10, edge_dim), # up to 10 neighbors
|
||
'query': torch.randn(1, embedding_dim)
|
||
}
|
||
|
||
torch.onnx.export(
|
||
model,
|
||
dummy_input,
|
||
output_path,
|
||
input_names=['node_features', 'edge_features', 'query'],
|
||
output_names=['edge_scores'],
|
||
dynamic_axes={
|
||
'edge_features': {0: 'num_edges'},
|
||
'edge_scores': {0: 'num_edges'}
|
||
},
|
||
opset_version=14
|
||
)
|
||
```
|
||
|
||
### API Design (Function Signatures)
|
||
|
||
```rust
|
||
// File: crates/ruvector-gnn/src/routing/mod.rs
|
||
|
||
impl GnnRoutingPolicy {
|
||
/// Create a new GNN routing policy from an ONNX model file
|
||
pub fn from_onnx(
|
||
model_path: impl AsRef<Path>,
|
||
config: RoutingConfig,
|
||
) -> Result<Self, GnnError>;
|
||
|
||
/// Score edges from a given node during HNSW search
|
||
///
|
||
/// # Arguments
|
||
/// * `current_node` - The node we're currently at
|
||
/// * `candidate_neighbors` - Potential next hops
|
||
/// * `query` - The query vector
|
||
/// * `graph` - Reference to HNSW graph for feature extraction
|
||
///
|
||
/// # Returns
|
||
/// Vector of `EdgeScore` sorted by combined_score (descending)
|
||
pub fn score_edges(
|
||
&mut self,
|
||
current_node: u32,
|
||
candidate_neighbors: &[u32],
|
||
query: &[f32],
|
||
graph: &HnswGraph,
|
||
) -> Result<Vec<EdgeScore>, GnnError>;
|
||
|
||
/// Select top-k edges based on GNN scores
|
||
pub fn select_top_k(
|
||
&self,
|
||
edge_scores: &[EdgeScore],
|
||
k: usize,
|
||
) -> Vec<u32>;
|
||
|
||
/// Get current routing metrics
|
||
pub fn metrics(&self) -> &RoutingMetrics;
|
||
|
||
/// Reset metrics counters
|
||
pub fn reset_metrics(&mut self);
|
||
|
||
/// Update configuration at runtime
|
||
pub fn update_config(&mut self, config: RoutingConfig);
|
||
}
|
||
|
||
impl FeatureExtractor {
|
||
/// Create a new feature extractor
|
||
pub fn new(node_dim: usize, edge_dim: usize) -> Self;
|
||
|
||
/// Extract node features for GNN input
|
||
pub fn extract_node_features(
|
||
&self,
|
||
node_id: u32,
|
||
query: &[f32],
|
||
graph: &HnswGraph,
|
||
) -> Result<NodeFeatures, GnnError>;
|
||
|
||
/// Extract edge features for GNN input
|
||
pub fn extract_edge_features(
|
||
&self,
|
||
source: u32,
|
||
target: u32,
|
||
graph: &HnswGraph,
|
||
) -> Result<EdgeFeatures, GnnError>;
|
||
|
||
/// Batch extract features for multiple edges
|
||
pub fn batch_extract_edge_features(
|
||
&self,
|
||
edges: &[(u32, u32)],
|
||
graph: &HnswGraph,
|
||
) -> Result<Vec<EdgeFeatures>, GnnError>;
|
||
|
||
/// Clear feature cache
|
||
pub fn clear_cache(&mut self);
|
||
}
|
||
|
||
// Integration with existing HNSW implementation
|
||
// File: crates/ruvector-core/src/index/hnsw.rs
|
||
|
||
impl HnswIndex {
|
||
/// Enable GNN-guided routing
|
||
pub fn set_gnn_routing(
|
||
&mut self,
|
||
policy: GnnRoutingPolicy,
|
||
) -> Result<(), HnswError>;
|
||
|
||
/// Disable GNN routing (fallback to greedy)
|
||
pub fn disable_gnn_routing(&mut self);
|
||
|
||
/// Get routing performance metrics
|
||
pub fn routing_metrics(&self) -> Option<&RoutingMetrics>;
|
||
}
|
||
|
||
// Training utilities
|
||
// File: crates/ruvector-gnn/src/routing/training.rs
|
||
|
||
/// Collect path trajectories from HNSW search for training
|
||
pub fn collect_training_trajectories(
|
||
graph: &HnswGraph,
|
||
queries: &[Vec<f32>],
|
||
output_path: impl AsRef<Path>,
|
||
) -> Result<usize, GnnError>;
|
||
|
||
/// Validate ONNX model compatibility
|
||
pub fn validate_onnx_model(
|
||
model_path: impl AsRef<Path>,
|
||
) -> Result<ModelInfo, GnnError>;
|
||
|
||
#[derive(Debug)]
|
||
pub struct ModelInfo {
|
||
pub input_dims: Vec<(String, Vec<i64>)>,
|
||
pub output_dims: Vec<(String, Vec<i64>)>,
|
||
pub opset_version: i64,
|
||
}
|
||
```
|
||
|
||
## Integration Points
|
||
|
||
### Affected Crates/Modules
|
||
|
||
1. **`ruvector-gnn`** (Primary)
|
||
- New module: `src/routing/mod.rs` - GNN routing policy
|
||
- New module: `src/routing/features.rs` - Feature extraction
|
||
- New module: `src/routing/training.rs` - Training utilities
|
||
- Modified: `src/lib.rs` - Export routing types
|
||
|
||
2. **`ruvector-core`** (Integration)
|
||
- Modified: `src/index/hnsw.rs` - Integrate GNN routing into search
|
||
- Modified: `src/index/mod.rs` - Add routing configuration
|
||
- New: `src/index/hnsw_gnn.rs` - GNN-specific HNSW extensions
|
||
|
||
3. **`ruvector-api`** (Configuration)
|
||
- Modified: `src/config.rs` - Add GNN routing config options
|
||
- Modified: `src/index_manager.rs` - Support GNN model loading
|
||
|
||
4. **`ruvector-bindings`** (Exposure)
|
||
- Modified: `python/src/lib.rs` - Expose routing config to Python
|
||
- Modified: `nodejs/src/lib.rs` - Expose routing config to Node.js
|
||
|
||
### New Modules to Create
|
||
|
||
```
|
||
crates/ruvector-gnn/
|
||
├── src/
|
||
│ ├── routing/
|
||
│ │ ├── mod.rs # Core routing policy
|
||
│ │ ├── features.rs # Feature extraction
|
||
│ │ ├── training.rs # Training data collection
|
||
│ │ ├── cache.rs # Feature caching
|
||
│ │ └── metrics.rs # Performance tracking
|
||
│ └── models/
|
||
│ └── routing_gnn.onnx # Pre-trained model (optional)
|
||
|
||
examples/
|
||
├── gnn_routing/
|
||
│ ├── train_routing_model.py # Python training script
|
||
│ ├── evaluate_routing.rs # Rust evaluation benchmark
|
||
│ └── README.md # Usage guide
|
||
```
|
||
|
||
### Dependencies on Other Features
|
||
|
||
**Independent** - Can be implemented standalone
|
||
|
||
**Synergies with:**
|
||
- **Incremental Graph Learning (Feature 2)**: Cached node features can be reused
|
||
- **Neuro-Symbolic Query (Feature 3)**: GNN routing can incorporate symbolic constraints
|
||
- **Existing Attention Mechanisms**: Reuse attention layers from Issue #38
|
||
|
||
**External Dependencies:**
|
||
- `ort` (ONNX Runtime) - Already in use for GNN inference
|
||
- `ndarray` - Already in use for tensor operations
|
||
- `parking_lot` - For feature cache concurrency
|
||
|
||
## Regression Prevention
|
||
|
||
### What Existing Functionality Could Break
|
||
|
||
1. **HNSW Search Correctness**
|
||
- Risk: GNN routing might skip true nearest neighbors
|
||
- Impact: Degraded recall, incorrect results
|
||
|
||
2. **Performance Degradation**
|
||
- Risk: GNN inference overhead exceeds routing savings
|
||
- Impact: Lower QPS than baseline greedy search
|
||
|
||
3. **Memory Usage**
|
||
- Risk: Feature caching and GNN model consume excessive RAM
|
||
- Impact: OOM on large graphs
|
||
|
||
4. **Thread Safety**
|
||
- Risk: Feature cache race conditions in concurrent queries
|
||
- Impact: Corrupted features, crashes
|
||
|
||
5. **Build/Deployment**
|
||
- Risk: ONNX model path resolution failures
|
||
- Impact: Runtime errors, inability to use feature
|
||
|
||
### Test Cases to Prevent Regressions
|
||
|
||
```rust
|
||
// File: crates/ruvector-gnn/tests/routing_regression_tests.rs
|
||
|
||
#[test]
|
||
fn test_gnn_routing_recall_matches_greedy() {
|
||
// GNN routing must achieve ≥95% of greedy baseline recall
|
||
let graph = build_test_hnsw(10_000, 512);
|
||
let queries = generate_test_queries(1000);
|
||
|
||
// Baseline: greedy search
|
||
graph.disable_gnn_routing();
|
||
let greedy_results = run_search_batch(&graph, &queries, k=10);
|
||
|
||
// GNN routing
|
||
graph.set_gnn_routing(load_test_model());
|
||
let gnn_results = run_search_batch(&graph, &queries, k=10);
|
||
|
||
let recall = compute_recall(&greedy_results, &gnn_results);
|
||
assert!(recall >= 0.95, "GNN recall: {}, expected ≥0.95", recall);
|
||
}
|
||
|
||
#[test]
|
||
fn test_gnn_routing_performance_improvement() {
|
||
// GNN routing must achieve ≥10% QPS improvement
|
||
let graph = build_test_hnsw(100_000, 512);
|
||
let queries = generate_test_queries(10_000);
|
||
|
||
// Baseline
|
||
graph.disable_gnn_routing();
|
||
let greedy_qps = benchmark_qps(&graph, &queries);
|
||
|
||
// GNN
|
||
graph.set_gnn_routing(load_test_model());
|
||
let gnn_qps = benchmark_qps(&graph, &queries);
|
||
|
||
let improvement = (gnn_qps - greedy_qps) / greedy_qps;
|
||
assert!(improvement >= 0.10, "QPS improvement: {:.2}%, expected ≥10%", improvement * 100.0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_gnn_routing_distance_computation_reduction() {
|
||
// Must reduce distance computations by ≥20%
|
||
let graph = build_test_hnsw(50_000, 512);
|
||
let queries = generate_test_queries(1000);
|
||
|
||
graph.disable_gnn_routing();
|
||
graph.reset_metrics();
|
||
run_search_batch(&graph, &queries, k=10);
|
||
let greedy_dists = graph.metrics().distance_computations;
|
||
|
||
graph.set_gnn_routing(load_test_model());
|
||
graph.reset_metrics();
|
||
run_search_batch(&graph, &queries, k=10);
|
||
let gnn_dists = graph.metrics().distance_computations;
|
||
|
||
let reduction = (greedy_dists - gnn_dists) as f64 / greedy_dists as f64;
|
||
assert!(reduction >= 0.20, "Distance reduction: {:.2}%, expected ≥20%", reduction * 100.0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_feature_cache_thread_safety() {
|
||
// Concurrent queries must not corrupt feature cache
|
||
let graph = Arc::new(build_test_hnsw(10_000, 512));
|
||
graph.set_gnn_routing(load_test_model());
|
||
|
||
let handles: Vec<_> = (0..16)
|
||
.map(|_| {
|
||
let g = Arc::clone(&graph);
|
||
thread::spawn(move || {
|
||
let queries = generate_test_queries(100);
|
||
run_search_batch(&g, &queries, k=10)
|
||
})
|
||
})
|
||
.collect();
|
||
|
||
let results: Vec<_> = handles.into_iter()
|
||
.map(|h| h.join().unwrap())
|
||
.collect();
|
||
|
||
// All results should be valid (no panics/corruptions)
|
||
for result_set in results {
|
||
assert!(validate_results(&result_set));
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_graceful_fallback_on_gnn_error() {
|
||
// If GNN fails, must fallback to greedy without crashing
|
||
let graph = build_test_hnsw(1000, 512);
|
||
|
||
// Inject faulty GNN model
|
||
graph.set_gnn_routing(create_faulty_model());
|
||
|
||
let queries = generate_test_queries(100);
|
||
let results = run_search_batch(&graph, &queries, k=10);
|
||
|
||
// Should get valid results (from fallback)
|
||
assert_eq!(results.len(), 100);
|
||
assert!(graph.routing_metrics().unwrap().fallback_count > 0);
|
||
}
|
||
```
|
||
|
||
### Backward Compatibility Strategy
|
||
|
||
1. **Default Disabled**
|
||
- GNN routing is opt-in via configuration
|
||
- Existing deployments unaffected unless explicitly enabled
|
||
|
||
2. **Configuration Migration**
|
||
```yaml
|
||
# Old config (still works)
|
||
hnsw:
|
||
ef_construction: 200
|
||
M: 16
|
||
|
||
# New config (optional)
|
||
hnsw:
|
||
ef_construction: 200
|
||
M: 16
|
||
gnn_routing:
|
||
enabled: false # Default: disabled
|
||
model_path: "./models/routing_gnn.onnx"
|
||
top_k_edges: 5
|
||
temperature: 0.0
|
||
hybrid_alpha: 0.8
|
||
```
|
||
|
||
3. **Feature Flags**
|
||
```rust
|
||
#[cfg(feature = "gnn-routing")]
|
||
pub mod routing;
|
||
```
|
||
- Can be compiled out if not needed
|
||
- Reduces binary size and dependencies
|
||
|
||
4. **Versioned Model Format**
|
||
- ONNX models include version metadata
|
||
- Runtime checks for compatibility
|
||
- Graceful degradation on version mismatch
|
||
|
||
## Implementation Phases
|
||
|
||
### Phase 1: Core Implementation (Week 1-2)
|
||
|
||
**Goal**: Working GNN routing with ONNX inference
|
||
|
||
**Tasks**:
|
||
1. Implement `FeatureExtractor` for nodes and edges
|
||
2. Implement `GnnRoutingPolicy` with ONNX runtime
|
||
3. Add basic edge scoring logic
|
||
4. Unit tests for feature extraction
|
||
5. Unit tests for ONNX inference
|
||
|
||
**Deliverables**:
|
||
- `ruvector-gnn/src/routing/mod.rs`
|
||
- `ruvector-gnn/src/routing/features.rs`
|
||
- Passing unit tests
|
||
- Example ONNX model (mock, not trained)
|
||
|
||
**Success Criteria**:
|
||
- GNN can score edges without crashing
|
||
- Feature extraction produces valid tensors
|
||
- ONNX model loads and runs inference
|
||
|
||
### Phase 2: Integration (Week 2-3)
|
||
|
||
**Goal**: Integrate GNN routing into HNSW search
|
||
|
||
**Tasks**:
|
||
1. Modify `HnswIndex` to support GNN routing
|
||
2. Implement routing selection strategies (greedy, sampling, hybrid)
|
||
3. Add performance metrics tracking
|
||
4. Add feature caching for performance
|
||
5. Integration tests with real HNSW graphs
|
||
|
||
**Deliverables**:
|
||
- Modified `ruvector-core/src/index/hnsw.rs`
|
||
- Working end-to-end search with GNN
|
||
- Performance benchmarks vs baseline
|
||
- Feature cache implementation
|
||
|
||
**Success Criteria**:
|
||
- GNN routing produces correct k-NN results
|
||
- No crashes or panics in concurrent scenarios
|
||
- Metrics collection working
|
||
|
||
### Phase 3: Optimization (Week 3-4)
|
||
|
||
**Goal**: Achieve +25% QPS, -30% distance computations
|
||
|
||
**Tasks**:
|
||
1. Profile GNN inference overhead
|
||
2. Optimize feature extraction (batching, caching)
|
||
3. Tune hybrid_alpha and temperature parameters
|
||
4. Implement batch inference for multiple edges
|
||
5. Add SIMD optimizations where applicable
|
||
6. Train actual GNN model on real query workload
|
||
|
||
**Deliverables**:
|
||
- Trained ONNX model with documented performance
|
||
- Python training script (`examples/gnn_routing/train_routing_model.py`)
|
||
- Performance tuning guide
|
||
- Optimized feature cache
|
||
|
||
**Success Criteria**:
|
||
- +25% QPS improvement on benchmark dataset
|
||
- -30% reduction in distance computations
|
||
- <2ms average GNN inference latency per query
|
||
- >80% feature cache hit rate
|
||
|
||
### Phase 4: Production Hardening (Week 4-5)
|
||
|
||
**Goal**: Production-ready feature with safety guarantees
|
||
|
||
**Tasks**:
|
||
1. Add comprehensive error handling
|
||
2. Implement graceful fallback to greedy on GNN errors
|
||
3. Add configuration validation
|
||
4. Write regression tests (prevent regressions)
|
||
5. Write documentation and examples
|
||
6. Add telemetry/observability hooks
|
||
7. Performance benchmarks on large-scale datasets (10M+ vectors)
|
||
|
||
**Deliverables**:
|
||
- Full regression test suite
|
||
- User documentation
|
||
- Performance benchmark report
|
||
- Example configurations
|
||
- Migration guide
|
||
|
||
**Success Criteria**:
|
||
- All regression tests passing
|
||
- Zero crashes in stress tests
|
||
- Documentation complete
|
||
- Ready for alpha release
|
||
|
||
## Success Metrics
|
||
|
||
### Performance Benchmarks
|
||
|
||
**Primary Metrics** (Must Achieve):
|
||
|
||
| Metric | Baseline (Greedy) | Target (GNN) | Measurement |
|
||
|--------|-------------------|--------------|-------------|
|
||
| QPS (1M vectors) | 10,000 | 12,500 (+25%) | queries/second @ 16 threads |
|
||
| Distance Computations | 150/query | 105/query (-30%) | average per query |
|
||
| Average Hops | 12.5 | 10.6 (-15%) | hops to reach target |
|
||
| P99 Latency | 15ms | 12ms (-20%) | 99th percentile query time |
|
||
|
||
**Secondary Metrics** (Nice to Have):
|
||
|
||
| Metric | Baseline | Target | Measurement |
|
||
|--------|----------|--------|-------------|
|
||
| Feature Cache Hit Rate | N/A | >80% | cache hits / total accesses |
|
||
| GNN Inference Time | N/A | <2ms | average per query |
|
||
| Memory Overhead | N/A | <5% | additional RAM for GNN + cache |
|
||
| Recall@10 | 0.95 | 0.96 (+1pp) | fraction of true neighbors found |
|
||
|
||
### Accuracy Metrics
|
||
|
||
**Recall Preservation**:
|
||
- GNN routing must achieve ≥95% of greedy baseline recall
|
||
- No degradation on edge-case queries (dense clusters, outliers)
|
||
|
||
**Path Optimality**:
|
||
- GNN paths should be ≤5% longer than oracle optimal paths
|
||
- Measured by comparing against brute-force ground truth
|
||
|
||
**Failure Rate**:
|
||
- Graceful fallback to greedy on <1% of queries
|
||
- Zero crashes or incorrect results
|
||
|
||
### Memory/Latency Targets
|
||
|
||
**Memory**:
|
||
- GNN model size: <50MB (ONNX file)
|
||
- Feature cache: <100MB per 1M vectors
|
||
- Total overhead: <5% of base HNSW index size
|
||
|
||
**Latency**:
|
||
- GNN inference: <2ms average, <5ms P99
|
||
- Feature extraction: <0.5ms per node
|
||
- Total query latency: <15ms P99 (vs 15ms baseline)
|
||
|
||
**Throughput**:
|
||
- Concurrent queries: 16+ threads with linear scaling
|
||
- Batch inference: 10+ edges per batch for efficiency
|
||
|
||
## Risks and Mitigations
|
||
|
||
### Technical Risks
|
||
|
||
**Risk 1: GNN Inference Overhead Exceeds Routing Savings**
|
||
|
||
*Probability: Medium | Impact: High*
|
||
|
||
**Description**: If GNN model is too complex, inference time could negate benefits of reduced hops.
|
||
|
||
**Mitigation**:
|
||
- Profile GNN inference early in Phase 1
|
||
- Set hard latency budget (<2ms per query)
|
||
- Use lightweight GNN architecture (3-layer GAT, not deep networks)
|
||
- Batch inference across multiple edges
|
||
- Implement feature caching to avoid recomputation
|
||
- Add fallback to greedy if inference exceeds budget
|
||
|
||
**Contingency**: If overhead too high, switch to simpler models (MLP instead of GNN) or hybrid mode (GNN only for hard queries).
|
||
|
||
---
|
||
|
||
**Risk 2: Training Data Scarcity**
|
||
|
||
*Probability: Medium | Impact: Medium*
|
||
|
||
**Description**: May not have enough diverse queries to train robust GNN model.
|
||
|
||
**Mitigation**:
|
||
- Use query augmentation (add noise, rotations)
|
||
- Pretrain on synthetic queries (random vectors)
|
||
- Fine-tune on actual workload
|
||
- Support transfer learning from similar datasets
|
||
- Provide pre-trained baseline model
|
||
|
||
**Contingency**: Start with simple heuristic-based routing (e.g., distance + degree) and upgrade to GNN later.
|
||
|
||
---
|
||
|
||
**Risk 3: Model Generalization Failures**
|
||
|
||
*Probability: Low | Impact: High*
|
||
|
||
**Description**: GNN trained on one dataset might not generalize to different embedding distributions.
|
||
|
||
**Mitigation**:
|
||
- Train on diverse datasets (text, images, multi-modal)
|
||
- Use domain-agnostic features (degree, distance, structure)
|
||
- Add online learning to adapt to new query patterns
|
||
- Provide model retraining tools
|
||
- Extensive evaluation on held-out datasets
|
||
|
||
**Contingency**: Support per-index model training for critical use cases.
|
||
|
||
---
|
||
|
||
**Risk 4: Feature Cache Memory Bloat**
|
||
|
||
*Probability: Low | Impact: Medium*
|
||
|
||
**Description**: Caching node/edge features could consume excessive memory on large graphs.
|
||
|
||
**Mitigation**:
|
||
- Use LRU eviction policy (keep only recent features)
|
||
- Set cache size limits (e.g., max 100MB)
|
||
- Make caching optional (can disable for low-memory environments)
|
||
- Use compressed feature representations
|
||
- Profile memory usage in Phase 3
|
||
|
||
**Contingency**: Disable feature caching by default, enable only for latency-critical workloads.
|
||
|
||
---
|
||
|
||
**Risk 5: ONNX Compatibility Issues**
|
||
|
||
*Probability: Low | Impact: Medium*
|
||
|
||
**Description**: ONNX runtime might not support specific GNN operations or have platform issues.
|
||
|
||
**Mitigation**:
|
||
- Use only standard ONNX ops (opset 14+)
|
||
- Test on multiple platforms (Linux, macOS, Windows)
|
||
- Provide model validation tool to check compatibility
|
||
- Fallback to pure Rust inference if ONNX unavailable
|
||
|
||
**Contingency**: Implement lightweight Rust-native GNN inference as fallback.
|
||
|
||
---
|
||
|
||
**Risk 6: Regression in Recall**
|
||
|
||
*Probability: Medium | Impact: Critical*
|
||
|
||
**Description**: GNN routing might skip true nearest neighbors, degrading result quality.
|
||
|
||
**Mitigation**:
|
||
- Extensive recall testing in Phase 2
|
||
- Set minimum recall threshold (≥95% of baseline)
|
||
- Add recall monitoring in production
|
||
- Use hybrid mode (GNN + distance heuristic) for safety
|
||
- Comprehensive regression test suite
|
||
|
||
**Contingency**: If recall drops, increase `hybrid_alpha` to rely more on distance heuristic, or disable GNN routing entirely.
|
||
|
||
---
|
||
|
||
### Summary Risk Matrix
|
||
|
||
| Risk | Probability | Impact | Mitigation Priority |
|
||
|------|-------------|--------|---------------------|
|
||
| GNN inference overhead | Medium | High | **HIGH** - Profile early |
|
||
| Training data scarcity | Medium | Medium | Medium - Augmentation |
|
||
| Model generalization | Low | High | Medium - Diverse training |
|
||
| Feature cache bloat | Low | Medium | Low - Monitor in Phase 3 |
|
||
| ONNX compatibility | Low | Medium | Low - Validation tools |
|
||
| Recall regression | Medium | Critical | **HIGH** - Regression tests |
|
||
|
||
---
|
||
|
||
## Next Steps
|
||
|
||
1. **Prototype Phase 1**: Build minimal GNN routing with mock model (1 week)
|
||
2. **Collect Training Data**: Run 100K queries on existing HNSW, log trajectories (3 days)
|
||
3. **Train Initial Model**: Use collected data to train baseline GAT model (2 days)
|
||
4. **Integration Testing**: Plug GNN into HNSW, measure initial performance (1 week)
|
||
5. **Iterate**: Optimize based on profiling results (ongoing)
|
||
|
||
**Key Decision Points**:
|
||
- After Phase 1: Is GNN inference fast enough? (<5ms target)
|
||
- After Phase 2: Does GNN improve QPS? (>10% required to continue)
|
||
- After Phase 3: Does GNN meet all success metrics? (Go/No-Go for Phase 4)
|