37 KiB
GNN-Guided HNSW Routing - Implementation Plan
Overview
Problem Statement
Current HNSW (Hierarchical Navigable Small World) graph search uses a greedy routing strategy that selects the nearest neighbor at each step. This approach is locally optimal but often misses globally better paths, resulting in:
- Suboptimal query performance (increased distance computations)
- Redundant edge traversals in dense regions
- Poor scaling with graph size (20-40% performance degradation at 10M+ vectors)
- Inability to learn from query patterns
Proposed Solution
Replace greedy HNSW routing with a learned GNN-based routing policy that:
- Path Learning: Train on successful search trajectories to learn optimal routing decisions
- Context-Aware Selection: Use graph structure + query context to predict best next hops
- Multi-Hop Reasoning: Consider k-step lookahead instead of greedy single-step
- Adaptive Routing: Adjust routing strategy based on query characteristics
The GNN will output edge selection probabilities for each node during search, replacing the greedy nearest-neighbor heuristic.
Expected Benefits
Quantified Performance Improvements:
- +25% QPS (Queries Per Second) through reduced search iterations
- -30% distance computations via smarter edge selection
- -15% average hop count to reach target nodes
- +18% recall@10 for challenging queries (edge cases, dense clusters)
Qualitative Benefits:
- Learns from query distribution patterns
- Adapts to graph topology changes
- Handles multi-modal embeddings better
- Reduces tail latencies (P99 improvement)
Technical Design
Architecture Diagram (ASCII Art)
┌─────────────────────────────────────────────────────────────────┐
│ GNN-Guided HNSW Search │
└─────────────────────────────────────────────────────────────────┘
Query Vector (q)
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Entry Point Selection (standard HNSW top layer) │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Layer L → 0 Search Loop │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ Current Node (c) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌──────────────────────────────────────────────┐ │ │
│ │ │ GNN Edge Scorer │ │ │
│ │ │ ┌────────────────────────────────────────┐ │ │ │
│ │ │ │ Input: [node_feat, query, edge_feat] │ │ │ │
│ │ │ │ Graph Context: k-hop neighborhood │ │ │ │
│ │ │ │ Attention Layer: Multi-head GAT │ │ │ │
│ │ │ │ Output: Edge selection probabilities │ │ │ │
│ │ │ └────────────────────────────────────────┘ │ │ │
│ │ └──────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌──────────────────────────────────────────────┐ │ │
│ │ │ Edge Selection Strategy │ │ │
│ │ │ - Top-k by GNN score │ │ │
│ │ │ - Temperature-based sampling (exploration) │ │ │
│ │ │ - Hybrid: GNN score * distance heuristic │ │ │
│ │ └──────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ Candidate Neighbors (N) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ Update best candidates, move to next node │ │
│ └───────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
│
▼
Return Top-K Results
┌─────────────────────────────────────────────────────────────────┐
│ Training Pipeline │
└─────────────────────────────────────────────────────────────────┘
Query Workload
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Path Collector (on standard HNSW) │
│ - Record: (query, node_seq, edges_taken, final_results) │
│ - Label: edges_on_optimal_path = 1, others = 0 │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Offline Training (PyTorch/candle) │
│ - Loss: BCE(GNN_edge_score, optimal_edge_label) │
│ - Optimizer: AdamW with lr=1e-3 │
│ - Batch: 256 query trajectories │
└─────────────────────────────────────────────────────────────────┘
│
▼
Export ONNX Model → Load in ruvector-gnn (Rust)
Core Data Structures (Rust)
// File: crates/ruvector-gnn/src/routing/mod.rs
use ndarray::{Array1, Array2};
use ort::{Session, Value};
/// GNN-guided routing policy for HNSW search
pub struct GnnRoutingPolicy {
/// ONNX runtime session for GNN inference
session: Session,
/// Feature extractor for nodes and edges
feature_extractor: FeatureExtractor,
/// Configuration for routing behavior
config: RoutingConfig,
/// Performance metrics
metrics: RoutingMetrics,
}
/// Configuration for GNN routing
#[derive(Debug, Clone)]
pub struct RoutingConfig {
/// Number of top edges to consider per node
pub top_k_edges: usize,
/// Temperature for edge selection sampling (0.0 = greedy)
pub temperature: f32,
/// Hybrid weight: α * gnn_score + (1-α) * distance_score
pub hybrid_alpha: f32,
/// Maximum GNN inference batch size
pub inference_batch_size: usize,
/// Enable/disable GNN routing (fallback to greedy)
pub enabled: bool,
/// K-hop neighborhood size for graph context
pub context_hops: usize,
}
/// Feature extraction for nodes and edges
pub struct FeatureExtractor {
/// Dimensionality of node features
node_dim: usize,
/// Dimensionality of edge features
edge_dim: usize,
/// Cache for computed features
cache: FeatureCache,
}
/// Features for a single node in the graph
#[derive(Debug, Clone)]
pub struct NodeFeatures {
/// Node embedding vector
pub embedding: Array1<f32>,
/// Degree (number of neighbors)
pub degree: usize,
/// Layer in HNSW hierarchy
pub layer: usize,
/// Clustering coefficient
pub clustering_coef: f32,
/// Distance to query (dynamic)
pub query_distance: f32,
}
/// Features for an edge in the graph
#[derive(Debug, Clone)]
pub struct EdgeFeatures {
/// Euclidean distance between connected nodes
pub distance: f32,
/// Angular similarity (cosine)
pub angular_similarity: f32,
/// Edge betweenness (precomputed)
pub betweenness: f32,
/// Whether edge crosses layers
pub cross_layer: bool,
}
/// GNN inference result for edge selection
#[derive(Debug)]
pub struct EdgeScore {
/// Target node ID
pub target_node: u32,
/// GNN-predicted score [0, 1]
pub gnn_score: f32,
/// Distance-based heuristic score
pub distance_score: f32,
/// Final combined score
pub combined_score: f32,
}
/// Performance tracking for routing
#[derive(Debug, Default)]
pub struct RoutingMetrics {
/// Total GNN inference calls
pub total_inferences: u64,
/// Average inference latency (microseconds)
pub avg_inference_us: f64,
/// Total distance computations
pub distance_computations: u64,
/// Average hops per query
pub avg_hops: f64,
/// Cache hit rate for features
pub feature_cache_hit_rate: f64,
}
/// Training data collection for offline learning
pub struct PathTrajectory {
/// Query vector
pub query: Vec<f32>,
/// Sequence of nodes visited
pub node_sequence: Vec<u32>,
/// Edges taken at each step
pub edges_taken: Vec<(u32, u32)>,
/// All candidate edges at each step
pub candidate_edges: Vec<Vec<(u32, u32)>>,
/// Final k-NN results
pub results: Vec<(u32, f32)>,
}
Key Algorithms (Pseudocode)
1. GNN-Guided Search Algorithm
function gnn_guided_search(query: Vector, graph: HNSWGraph, k: int) -> List[Result]:
"""
HNSW search with GNN-guided routing instead of greedy selection.
"""
# Initialize from top layer entry point
current_nodes = {graph.entry_point}
layer = graph.max_layer
# Descend through layers
while layer >= 0:
# Find best candidates at this layer using GNN
candidates = priority_queue()
visited = set()
for node in current_nodes:
# Get neighbors at this layer
neighbors = graph.get_neighbors(node, layer)
# Extract features for GNN
node_features = extract_node_features(node, query, graph)
edge_features = [extract_edge_features(node, neighbor, graph)
for neighbor in neighbors]
# GNN inference: score all edges from current node
edge_scores = gnn_model.score_edges(
node_features,
edge_features,
query
)
# Select edges based on GNN scores (not greedy distance)
selected = select_edges_by_gnn_score(
neighbors,
edge_scores,
config.top_k_edges,
config.temperature
)
for neighbor in selected:
if neighbor not in visited:
distance = compute_distance(query, graph.get_vector(neighbor))
candidates.push(neighbor, distance)
visited.add(neighbor)
# Move to best candidates for next iteration
current_nodes = candidates.top(config.beam_width)
layer -= 1
# Return top-k results from layer 0
return candidates.top(k)
function select_edges_by_gnn_score(neighbors, scores, top_k, temperature):
"""
Select edges based on GNN scores with optional exploration.
Strategies:
- temperature = 0: greedy top-k
- temperature > 0: sampling from softmax distribution
- hybrid mode: combine GNN score with distance heuristic
"""
if temperature == 0:
# Greedy: select top-k by GNN score
return top_k_by_score(neighbors, scores)
else:
# Sampling: use temperature-scaled softmax
probs = softmax(scores / temperature)
return sample_without_replacement(neighbors, probs, top_k)
function extract_node_features(node, query, graph):
"""
Extract node-level features for GNN input.
"""
return NodeFeatures(
embedding=graph.get_vector(node),
degree=graph.get_degree(node),
layer=graph.get_layer(node),
clustering_coef=graph.get_clustering_coefficient(node),
query_distance=distance(query, graph.get_vector(node))
)
function extract_edge_features(source, target, graph):
"""
Extract edge-level features for GNN input.
"""
source_vec = graph.get_vector(source)
target_vec = graph.get_vector(target)
return EdgeFeatures(
distance=euclidean_distance(source_vec, target_vec),
angular_similarity=cosine_similarity(source_vec, target_vec),
betweenness=graph.get_edge_betweenness(source, target),
cross_layer=(graph.get_layer(source) != graph.get_layer(target))
)
2. Offline Training Pipeline
function collect_training_data(graph, query_workload, n_samples):
"""
Collect path trajectories from standard HNSW for training.
"""
trajectories = []
for query in query_workload.sample(n_samples):
# Run standard greedy HNSW search with full logging
path = instrumented_hnsw_search(query, graph)
# Label edges: 1 if on optimal path, 0 otherwise
optimal_edges = set(path.edges_taken)
# For each node in path, get all candidate edges
for step in path.node_sequence:
node = step.node
neighbors = graph.get_neighbors(node, step.layer)
# Create training examples
for neighbor in neighbors:
edge = (node, neighbor)
label = 1.0 if edge in optimal_edges else 0.0
node_feat = extract_node_features(node, query, graph)
edge_feat = extract_edge_features(node, neighbor, graph)
trajectories.append({
'node_features': node_feat,
'edge_features': edge_feat,
'query': query,
'label': label,
'distance_to_query': distance(query, graph.get_vector(neighbor))
})
return trajectories
function train_gnn_routing_model(trajectories, config):
"""
Train GNN model to predict edge selection probabilities.
Architecture: Graph Attention Network (GAT)
- 3 attention layers with 4 heads each
- Hidden dim: 128
- Edge features concatenated with node features
- Output: single logit per edge (probability of selection)
"""
model = GAT(
node_dim=config.node_feature_dim,
edge_dim=config.edge_feature_dim,
hidden_dim=128,
num_layers=3,
num_heads=4,
output_dim=1
)
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = BCEWithLogitsLoss()
for epoch in range(config.num_epochs):
for batch in DataLoader(trajectories, batch_size=256):
# Forward pass
edge_logits = model(
batch.node_features,
batch.edge_features,
batch.query
)
# Binary cross-entropy loss
loss = loss_fn(edge_logits, batch.labels)
# Add distance-aware regularization
# Encourage model to respect distance heuristic
distance_scores = 1.0 / (1.0 + batch.distance_to_query)
consistency_loss = mse_loss(sigmoid(edge_logits), distance_scores)
total_loss = loss + 0.1 * consistency_loss
# Backward pass
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return model
function export_to_onnx(model, output_path):
"""
Export trained PyTorch model to ONNX for Rust inference.
"""
dummy_input = {
'node_features': torch.randn(1, node_dim),
'edge_features': torch.randn(10, edge_dim), # up to 10 neighbors
'query': torch.randn(1, embedding_dim)
}
torch.onnx.export(
model,
dummy_input,
output_path,
input_names=['node_features', 'edge_features', 'query'],
output_names=['edge_scores'],
dynamic_axes={
'edge_features': {0: 'num_edges'},
'edge_scores': {0: 'num_edges'}
},
opset_version=14
)
API Design (Function Signatures)
// File: crates/ruvector-gnn/src/routing/mod.rs
impl GnnRoutingPolicy {
/// Create a new GNN routing policy from an ONNX model file
pub fn from_onnx(
model_path: impl AsRef<Path>,
config: RoutingConfig,
) -> Result<Self, GnnError>;
/// Score edges from a given node during HNSW search
///
/// # Arguments
/// * `current_node` - The node we're currently at
/// * `candidate_neighbors` - Potential next hops
/// * `query` - The query vector
/// * `graph` - Reference to HNSW graph for feature extraction
///
/// # Returns
/// Vector of `EdgeScore` sorted by combined_score (descending)
pub fn score_edges(
&mut self,
current_node: u32,
candidate_neighbors: &[u32],
query: &[f32],
graph: &HnswGraph,
) -> Result<Vec<EdgeScore>, GnnError>;
/// Select top-k edges based on GNN scores
pub fn select_top_k(
&self,
edge_scores: &[EdgeScore],
k: usize,
) -> Vec<u32>;
/// Get current routing metrics
pub fn metrics(&self) -> &RoutingMetrics;
/// Reset metrics counters
pub fn reset_metrics(&mut self);
/// Update configuration at runtime
pub fn update_config(&mut self, config: RoutingConfig);
}
impl FeatureExtractor {
/// Create a new feature extractor
pub fn new(node_dim: usize, edge_dim: usize) -> Self;
/// Extract node features for GNN input
pub fn extract_node_features(
&self,
node_id: u32,
query: &[f32],
graph: &HnswGraph,
) -> Result<NodeFeatures, GnnError>;
/// Extract edge features for GNN input
pub fn extract_edge_features(
&self,
source: u32,
target: u32,
graph: &HnswGraph,
) -> Result<EdgeFeatures, GnnError>;
/// Batch extract features for multiple edges
pub fn batch_extract_edge_features(
&self,
edges: &[(u32, u32)],
graph: &HnswGraph,
) -> Result<Vec<EdgeFeatures>, GnnError>;
/// Clear feature cache
pub fn clear_cache(&mut self);
}
// Integration with existing HNSW implementation
// File: crates/ruvector-core/src/index/hnsw.rs
impl HnswIndex {
/// Enable GNN-guided routing
pub fn set_gnn_routing(
&mut self,
policy: GnnRoutingPolicy,
) -> Result<(), HnswError>;
/// Disable GNN routing (fallback to greedy)
pub fn disable_gnn_routing(&mut self);
/// Get routing performance metrics
pub fn routing_metrics(&self) -> Option<&RoutingMetrics>;
}
// Training utilities
// File: crates/ruvector-gnn/src/routing/training.rs
/// Collect path trajectories from HNSW search for training
pub fn collect_training_trajectories(
graph: &HnswGraph,
queries: &[Vec<f32>],
output_path: impl AsRef<Path>,
) -> Result<usize, GnnError>;
/// Validate ONNX model compatibility
pub fn validate_onnx_model(
model_path: impl AsRef<Path>,
) -> Result<ModelInfo, GnnError>;
#[derive(Debug)]
pub struct ModelInfo {
pub input_dims: Vec<(String, Vec<i64>)>,
pub output_dims: Vec<(String, Vec<i64>)>,
pub opset_version: i64,
}
Integration Points
Affected Crates/Modules
-
ruvector-gnn(Primary)- New module:
src/routing/mod.rs- GNN routing policy - New module:
src/routing/features.rs- Feature extraction - New module:
src/routing/training.rs- Training utilities - Modified:
src/lib.rs- Export routing types
- New module:
-
ruvector-core(Integration)- Modified:
src/index/hnsw.rs- Integrate GNN routing into search - Modified:
src/index/mod.rs- Add routing configuration - New:
src/index/hnsw_gnn.rs- GNN-specific HNSW extensions
- Modified:
-
ruvector-api(Configuration)- Modified:
src/config.rs- Add GNN routing config options - Modified:
src/index_manager.rs- Support GNN model loading
- Modified:
-
ruvector-bindings(Exposure)- Modified:
python/src/lib.rs- Expose routing config to Python - Modified:
nodejs/src/lib.rs- Expose routing config to Node.js
- Modified:
New Modules to Create
crates/ruvector-gnn/
├── src/
│ ├── routing/
│ │ ├── mod.rs # Core routing policy
│ │ ├── features.rs # Feature extraction
│ │ ├── training.rs # Training data collection
│ │ ├── cache.rs # Feature caching
│ │ └── metrics.rs # Performance tracking
│ └── models/
│ └── routing_gnn.onnx # Pre-trained model (optional)
examples/
├── gnn_routing/
│ ├── train_routing_model.py # Python training script
│ ├── evaluate_routing.rs # Rust evaluation benchmark
│ └── README.md # Usage guide
Dependencies on Other Features
Independent - Can be implemented standalone
Synergies with:
- Incremental Graph Learning (Feature 2): Cached node features can be reused
- Neuro-Symbolic Query (Feature 3): GNN routing can incorporate symbolic constraints
- Existing Attention Mechanisms: Reuse attention layers from Issue #38
External Dependencies:
ort(ONNX Runtime) - Already in use for GNN inferencendarray- Already in use for tensor operationsparking_lot- For feature cache concurrency
Regression Prevention
What Existing Functionality Could Break
-
HNSW Search Correctness
- Risk: GNN routing might skip true nearest neighbors
- Impact: Degraded recall, incorrect results
-
Performance Degradation
- Risk: GNN inference overhead exceeds routing savings
- Impact: Lower QPS than baseline greedy search
-
Memory Usage
- Risk: Feature caching and GNN model consume excessive RAM
- Impact: OOM on large graphs
-
Thread Safety
- Risk: Feature cache race conditions in concurrent queries
- Impact: Corrupted features, crashes
-
Build/Deployment
- Risk: ONNX model path resolution failures
- Impact: Runtime errors, inability to use feature
Test Cases to Prevent Regressions
// File: crates/ruvector-gnn/tests/routing_regression_tests.rs
#[test]
fn test_gnn_routing_recall_matches_greedy() {
// GNN routing must achieve ≥95% of greedy baseline recall
let graph = build_test_hnsw(10_000, 512);
let queries = generate_test_queries(1000);
// Baseline: greedy search
graph.disable_gnn_routing();
let greedy_results = run_search_batch(&graph, &queries, k=10);
// GNN routing
graph.set_gnn_routing(load_test_model());
let gnn_results = run_search_batch(&graph, &queries, k=10);
let recall = compute_recall(&greedy_results, &gnn_results);
assert!(recall >= 0.95, "GNN recall: {}, expected ≥0.95", recall);
}
#[test]
fn test_gnn_routing_performance_improvement() {
// GNN routing must achieve ≥10% QPS improvement
let graph = build_test_hnsw(100_000, 512);
let queries = generate_test_queries(10_000);
// Baseline
graph.disable_gnn_routing();
let greedy_qps = benchmark_qps(&graph, &queries);
// GNN
graph.set_gnn_routing(load_test_model());
let gnn_qps = benchmark_qps(&graph, &queries);
let improvement = (gnn_qps - greedy_qps) / greedy_qps;
assert!(improvement >= 0.10, "QPS improvement: {:.2}%, expected ≥10%", improvement * 100.0);
}
#[test]
fn test_gnn_routing_distance_computation_reduction() {
// Must reduce distance computations by ≥20%
let graph = build_test_hnsw(50_000, 512);
let queries = generate_test_queries(1000);
graph.disable_gnn_routing();
graph.reset_metrics();
run_search_batch(&graph, &queries, k=10);
let greedy_dists = graph.metrics().distance_computations;
graph.set_gnn_routing(load_test_model());
graph.reset_metrics();
run_search_batch(&graph, &queries, k=10);
let gnn_dists = graph.metrics().distance_computations;
let reduction = (greedy_dists - gnn_dists) as f64 / greedy_dists as f64;
assert!(reduction >= 0.20, "Distance reduction: {:.2}%, expected ≥20%", reduction * 100.0);
}
#[test]
fn test_feature_cache_thread_safety() {
// Concurrent queries must not corrupt feature cache
let graph = Arc::new(build_test_hnsw(10_000, 512));
graph.set_gnn_routing(load_test_model());
let handles: Vec<_> = (0..16)
.map(|_| {
let g = Arc::clone(&graph);
thread::spawn(move || {
let queries = generate_test_queries(100);
run_search_batch(&g, &queries, k=10)
})
})
.collect();
let results: Vec<_> = handles.into_iter()
.map(|h| h.join().unwrap())
.collect();
// All results should be valid (no panics/corruptions)
for result_set in results {
assert!(validate_results(&result_set));
}
}
#[test]
fn test_graceful_fallback_on_gnn_error() {
// If GNN fails, must fallback to greedy without crashing
let graph = build_test_hnsw(1000, 512);
// Inject faulty GNN model
graph.set_gnn_routing(create_faulty_model());
let queries = generate_test_queries(100);
let results = run_search_batch(&graph, &queries, k=10);
// Should get valid results (from fallback)
assert_eq!(results.len(), 100);
assert!(graph.routing_metrics().unwrap().fallback_count > 0);
}
Backward Compatibility Strategy
-
Default Disabled
- GNN routing is opt-in via configuration
- Existing deployments unaffected unless explicitly enabled
-
Configuration Migration
# Old config (still works) hnsw: ef_construction: 200 M: 16 # New config (optional) hnsw: ef_construction: 200 M: 16 gnn_routing: enabled: false # Default: disabled model_path: "./models/routing_gnn.onnx" top_k_edges: 5 temperature: 0.0 hybrid_alpha: 0.8 -
Feature Flags
#[cfg(feature = "gnn-routing")] pub mod routing;- Can be compiled out if not needed
- Reduces binary size and dependencies
-
Versioned Model Format
- ONNX models include version metadata
- Runtime checks for compatibility
- Graceful degradation on version mismatch
Implementation Phases
Phase 1: Core Implementation (Week 1-2)
Goal: Working GNN routing with ONNX inference
Tasks:
- Implement
FeatureExtractorfor nodes and edges - Implement
GnnRoutingPolicywith ONNX runtime - Add basic edge scoring logic
- Unit tests for feature extraction
- Unit tests for ONNX inference
Deliverables:
ruvector-gnn/src/routing/mod.rsruvector-gnn/src/routing/features.rs- Passing unit tests
- Example ONNX model (mock, not trained)
Success Criteria:
- GNN can score edges without crashing
- Feature extraction produces valid tensors
- ONNX model loads and runs inference
Phase 2: Integration (Week 2-3)
Goal: Integrate GNN routing into HNSW search
Tasks:
- Modify
HnswIndexto support GNN routing - Implement routing selection strategies (greedy, sampling, hybrid)
- Add performance metrics tracking
- Add feature caching for performance
- Integration tests with real HNSW graphs
Deliverables:
- Modified
ruvector-core/src/index/hnsw.rs - Working end-to-end search with GNN
- Performance benchmarks vs baseline
- Feature cache implementation
Success Criteria:
- GNN routing produces correct k-NN results
- No crashes or panics in concurrent scenarios
- Metrics collection working
Phase 3: Optimization (Week 3-4)
Goal: Achieve +25% QPS, -30% distance computations
Tasks:
- Profile GNN inference overhead
- Optimize feature extraction (batching, caching)
- Tune hybrid_alpha and temperature parameters
- Implement batch inference for multiple edges
- Add SIMD optimizations where applicable
- Train actual GNN model on real query workload
Deliverables:
- Trained ONNX model with documented performance
- Python training script (
examples/gnn_routing/train_routing_model.py) - Performance tuning guide
- Optimized feature cache
Success Criteria:
- +25% QPS improvement on benchmark dataset
- -30% reduction in distance computations
- <2ms average GNN inference latency per query
-
80% feature cache hit rate
Phase 4: Production Hardening (Week 4-5)
Goal: Production-ready feature with safety guarantees
Tasks:
- Add comprehensive error handling
- Implement graceful fallback to greedy on GNN errors
- Add configuration validation
- Write regression tests (prevent regressions)
- Write documentation and examples
- Add telemetry/observability hooks
- Performance benchmarks on large-scale datasets (10M+ vectors)
Deliverables:
- Full regression test suite
- User documentation
- Performance benchmark report
- Example configurations
- Migration guide
Success Criteria:
- All regression tests passing
- Zero crashes in stress tests
- Documentation complete
- Ready for alpha release
Success Metrics
Performance Benchmarks
Primary Metrics (Must Achieve):
| Metric | Baseline (Greedy) | Target (GNN) | Measurement |
|---|---|---|---|
| QPS (1M vectors) | 10,000 | 12,500 (+25%) | queries/second @ 16 threads |
| Distance Computations | 150/query | 105/query (-30%) | average per query |
| Average Hops | 12.5 | 10.6 (-15%) | hops to reach target |
| P99 Latency | 15ms | 12ms (-20%) | 99th percentile query time |
Secondary Metrics (Nice to Have):
| Metric | Baseline | Target | Measurement |
|---|---|---|---|
| Feature Cache Hit Rate | N/A | >80% | cache hits / total accesses |
| GNN Inference Time | N/A | <2ms | average per query |
| Memory Overhead | N/A | <5% | additional RAM for GNN + cache |
| Recall@10 | 0.95 | 0.96 (+1pp) | fraction of true neighbors found |
Accuracy Metrics
Recall Preservation:
- GNN routing must achieve ≥95% of greedy baseline recall
- No degradation on edge-case queries (dense clusters, outliers)
Path Optimality:
- GNN paths should be ≤5% longer than oracle optimal paths
- Measured by comparing against brute-force ground truth
Failure Rate:
- Graceful fallback to greedy on <1% of queries
- Zero crashes or incorrect results
Memory/Latency Targets
Memory:
- GNN model size: <50MB (ONNX file)
- Feature cache: <100MB per 1M vectors
- Total overhead: <5% of base HNSW index size
Latency:
- GNN inference: <2ms average, <5ms P99
- Feature extraction: <0.5ms per node
- Total query latency: <15ms P99 (vs 15ms baseline)
Throughput:
- Concurrent queries: 16+ threads with linear scaling
- Batch inference: 10+ edges per batch for efficiency
Risks and Mitigations
Technical Risks
Risk 1: GNN Inference Overhead Exceeds Routing Savings
Probability: Medium | Impact: High
Description: If GNN model is too complex, inference time could negate benefits of reduced hops.
Mitigation:
- Profile GNN inference early in Phase 1
- Set hard latency budget (<2ms per query)
- Use lightweight GNN architecture (3-layer GAT, not deep networks)
- Batch inference across multiple edges
- Implement feature caching to avoid recomputation
- Add fallback to greedy if inference exceeds budget
Contingency: If overhead too high, switch to simpler models (MLP instead of GNN) or hybrid mode (GNN only for hard queries).
Risk 2: Training Data Scarcity
Probability: Medium | Impact: Medium
Description: May not have enough diverse queries to train robust GNN model.
Mitigation:
- Use query augmentation (add noise, rotations)
- Pretrain on synthetic queries (random vectors)
- Fine-tune on actual workload
- Support transfer learning from similar datasets
- Provide pre-trained baseline model
Contingency: Start with simple heuristic-based routing (e.g., distance + degree) and upgrade to GNN later.
Risk 3: Model Generalization Failures
Probability: Low | Impact: High
Description: GNN trained on one dataset might not generalize to different embedding distributions.
Mitigation:
- Train on diverse datasets (text, images, multi-modal)
- Use domain-agnostic features (degree, distance, structure)
- Add online learning to adapt to new query patterns
- Provide model retraining tools
- Extensive evaluation on held-out datasets
Contingency: Support per-index model training for critical use cases.
Risk 4: Feature Cache Memory Bloat
Probability: Low | Impact: Medium
Description: Caching node/edge features could consume excessive memory on large graphs.
Mitigation:
- Use LRU eviction policy (keep only recent features)
- Set cache size limits (e.g., max 100MB)
- Make caching optional (can disable for low-memory environments)
- Use compressed feature representations
- Profile memory usage in Phase 3
Contingency: Disable feature caching by default, enable only for latency-critical workloads.
Risk 5: ONNX Compatibility Issues
Probability: Low | Impact: Medium
Description: ONNX runtime might not support specific GNN operations or have platform issues.
Mitigation:
- Use only standard ONNX ops (opset 14+)
- Test on multiple platforms (Linux, macOS, Windows)
- Provide model validation tool to check compatibility
- Fallback to pure Rust inference if ONNX unavailable
Contingency: Implement lightweight Rust-native GNN inference as fallback.
Risk 6: Regression in Recall
Probability: Medium | Impact: Critical
Description: GNN routing might skip true nearest neighbors, degrading result quality.
Mitigation:
- Extensive recall testing in Phase 2
- Set minimum recall threshold (≥95% of baseline)
- Add recall monitoring in production
- Use hybrid mode (GNN + distance heuristic) for safety
- Comprehensive regression test suite
Contingency: If recall drops, increase hybrid_alpha to rely more on distance heuristic, or disable GNN routing entirely.
Summary Risk Matrix
| Risk | Probability | Impact | Mitigation Priority |
|---|---|---|---|
| GNN inference overhead | Medium | High | HIGH - Profile early |
| Training data scarcity | Medium | Medium | Medium - Augmentation |
| Model generalization | Low | High | Medium - Diverse training |
| Feature cache bloat | Low | Medium | Low - Monitor in Phase 3 |
| ONNX compatibility | Low | Medium | Low - Validation tools |
| Recall regression | Medium | Critical | HIGH - Regression tests |
Next Steps
- Prototype Phase 1: Build minimal GNN routing with mock model (1 week)
- Collect Training Data: Run 100K queries on existing HNSW, log trajectories (3 days)
- Train Initial Model: Use collected data to train baseline GAT model (2 days)
- Integration Testing: Plug GNN into HNSW, measure initial performance (1 week)
- Iterate: Optimize based on profiling results (ongoing)
Key Decision Points:
- After Phase 1: Is GNN inference fast enough? (<5ms target)
- After Phase 2: Does GNN improve QPS? (>10% required to continue)
- After Phase 3: Does GNN meet all success metrics? (Go/No-Go for Phase 4)