68 KiB
68 KiB
RuVector Attention Architecture
Document Information
- Phase: Architecture (SPARC Methodology)
- Date: 2025-11-30
- Status: Draft
- Version: 1.0.0
Executive Summary
This document defines the complete system architecture for ruvector-attention, a high-performance attention mechanism library designed to integrate seamlessly with the RuVector ecosystem. The architecture supports multiple attention paradigms (standard, geometric, sparse, graph-aware) with platform-agnostic bindings (Rust, WASM, Node.js) and SIMD-optimized computation.
1. High-Level Architecture
1.1 System Overview
┌─────────────────────────────────────────────────────────────────────────┐
│ ruvector-attention │
│ Unified Attention Framework │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────────────────────────────────────────────────────────┐ │
│ │ Attention Layer │ │
│ ├──────────────────────────────────────────────────────────────────┤ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌────────┐ │ │
│ │ │ Standard │ │ Geometric │ │ Sparse │ │ Graph │ │ │
│ │ │ Attention │ │ Attention │ │ Attention │ │ Attn │ │ │
│ │ │ │ │ │ │ │ │ │ │ │
│ │ │ • Scaled │ │ • Hyperbolic│ │ • Windowed │ │ • GAT │ │ │
│ │ │ • Dot-Prod │ │ • Riemannian│ │ • Strided │ │ • GIN │ │ │
│ │ │ • Multi-Head│ │ • Manifold │ │ • Sparse │ │ • SAGE │ │ │
│ │ │ • Cross │ │ Distance │ │ Masks │ │ │ │ │
│ │ └─────────────┘ └─────────────┘ └─────────────┘ └────────┘ │ │
│ │ │ │ │ │ │ │
│ └─────────┼────────────────┼────────────────┼──────────────┼────────┘ │
│ │ │ │ │ │
│ ┌─────────┴────────────────┴────────────────┴──────────────┴────────┐ │
│ │ Unified Attention Trait │ │
│ │ • forward(query, keys, values) -> output │ │
│ │ • forward_batch(queries, keys, values) -> outputs │ │
│ │ • get_attention_weights() -> weights │ │
│ │ • backward(grad_output) -> grad_query, grad_keys, grad_values │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌──────────────┬────────────┴────────────┬──────────────┐ │
│ │ │ │ │ │
│ │ ┌───────────┴────────┐ ┌────────────┴──────────┐ │ │
│ │ │ Compute Kernels │ │ Memory Management │ │ │
│ │ ├────────────────────┤ ├───────────────────────┤ │ │
│ │ │ • SIMD Operations │ │ • Tensor Pool │ │ │
│ │ │ • Batch MatMul │ │ • Gradient Cache │ │ │
│ │ │ • Softmax │ │ • Attention Cache │ │ │
│ │ │ • Layer Norm │ │ • Memory Mapping │ │ │
│ │ └────────────────────┘ └───────────────────────┘ │ │
│ │ │ │
│ │ ┌────────────────────┐ ┌───────────────────────┐ │ │
│ │ │ Training Support │ │ Compression │ │ │
│ │ ├────────────────────┤ ├───────────────────────┤ │ │
│ │ │ • Gradient Comp │ │ • Quantization │ │ │
│ │ │ • Backpropagation │ │ • Pruning │ │ │
│ │ │ • Optimizer Hooks │ │ • Knowledge Distill │ │ │
│ │ │ • EWC Integration │ │ • Low-Rank Approx │ │ │
│ │ └────────────────────┘ └───────────────────────┘ │ │
│ └──────────────────────────────────────────────────────┘ │
├─────────────────────────────────────────────────────────────────────────┤
│ Platform Bindings │
├─────────────────────────────────────────────────────────────────────────┤
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Rust │ │ WASM │ │ NAPI-RS │ │ CLI │ │
│ │ SDK │ │ Bindings │ │ Bindings │ │ Interface │ │
│ ├─────────────┤ ├─────────────┤ ├─────────────┤ ├─────────────┤ │
│ │ • Native │ │ • Browser │ │ • Node.js │ │ • Benchmark │ │
│ │ Types │ │ Compat │ │ Async │ │ • Test │ │
│ │ • Zero-Copy │ │ • JS Arrays │ │ • Buffer │ │ • Visualize │ │
│ │ • Traits │ │ • Workers │ │ Transfer │ │ • Export │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ │
└─────────────────────────────────────────────────────────────────────────┘
│
┌────────────┴────────────┐
│ │
┌──────────┴─────────┐ ┌──────────┴─────────┐
│ ruvector-core │ │ ruvector-gnn │
│ • Vector Storage │ │ • GNN Layers │
│ • HNSW Index │ │ • Training │
│ • SIMD Utils │ │ • Compression │
└────────────────────┘ └────────────────────┘
│
┌──────────┴─────────┐
│ ruvector-graph │
│ • Graph Storage │
│ • Traversal │
│ • Cypher Query │
└────────────────────┘
1.2 Design Principles
- Modularity: Each attention mechanism is a self-contained module with clear interfaces
- Performance: SIMD-first design with hardware-optimized kernels
- Extensibility: Trait-based architecture allows custom attention implementations
- Platform Agnostic: Core logic separated from platform-specific bindings
- Memory Efficiency: Aggressive caching and memory pooling for minimal allocations
- Type Safety: Leverages Rust's type system for compile-time guarantees
- Interoperability: First-class integration with existing RuVector crates
2. Core Traits and Interfaces
2.1 Base Attention Trait
/// Base attention mechanism trait that all attention variants implement
pub trait Attention: Send + Sync {
/// Compute attention output for a single query
///
/// # Arguments
/// * `query` - Query vector (dimension D)
/// * `keys` - Key vectors (N vectors of dimension D)
/// * `values` - Value vectors (N vectors of dimension V)
///
/// # Returns
/// Output vector of dimension V
fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32>;
/// Batch processing for multiple queries
///
/// # Arguments
/// * `queries` - Batch of query vectors (B x D)
/// * `keys` - Key vectors (N x D)
/// * `values` - Value vectors (N x V)
///
/// # Returns
/// Batch of output vectors (B x V)
fn forward_batch(
&self,
queries: &[Vec<f32>],
keys: &[Vec<f32>],
values: &[Vec<f32>],
) -> Vec<Vec<f32>>;
/// Retrieve last computed attention weights (for visualization/debugging)
fn get_attention_weights(&self) -> Option<Vec<Vec<f32>>>;
/// Reset internal state (for stateful attention mechanisms)
fn reset(&mut self) {}
/// Get configuration metadata
fn config(&self) -> AttentionConfig;
}
/// Configuration for attention mechanisms
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionConfig {
pub embed_dim: usize,
pub num_heads: Option<usize>,
pub dropout_rate: Option<f32>,
pub attention_type: AttentionType,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum AttentionType {
Standard,
MultiHead,
Geometric,
Sparse,
Graph,
}
2.2 Graph-Aware Attention
/// Attention mechanism that incorporates graph structure
pub trait GraphAttention: Attention {
/// Forward pass with edge information
///
/// # Arguments
/// * `query` - Query vector
/// * `keys` - Key vectors (nodes)
/// * `values` - Value vectors (node features)
/// * `edge_index` - Edge connectivity [(src, dst), ...]
/// * `edge_attr` - Edge attributes (optional)
fn forward_with_edges(
&self,
query: &[f32],
keys: &[Vec<f32>],
values: &[Vec<f32>],
edge_index: &[(usize, usize)],
edge_attr: Option<&[Vec<f32>]>,
) -> Vec<f32>;
/// Compute attention only over graph neighborhood
fn forward_neighborhood(
&self,
node_idx: usize,
node_features: &[Vec<f32>],
edge_index: &[(usize, usize)],
) -> Vec<f32>;
/// Multi-hop attention aggregation
fn forward_k_hop(
&self,
node_idx: usize,
node_features: &[Vec<f32>],
edge_index: &[(usize, usize)],
k: usize,
) -> Vec<f32>;
}
2.3 Geometric Attention
/// Attention in non-Euclidean spaces (hyperbolic, spherical, etc.)
pub trait GeometricAttention: Attention {
type Space: ManifoldSpace;
/// Forward pass with manifold-aware distance computation
fn forward_geometric(
&self,
query: &[f32],
keys: &[Vec<f32>],
values: &[Vec<f32>],
) -> Vec<f32>;
/// Compute attention weights using manifold distance
fn compute_manifold_weights(
&self,
query: &[f32],
keys: &[Vec<f32>],
) -> Vec<f32>;
/// Project to manifold tangent space
fn project_tangent(&self, point: &[f32], base: &[f32]) -> Vec<f32>;
/// Exponential map for manifold operations
fn exp_map(&self, base: &[f32], tangent: &[f32]) -> Vec<f32>;
}
/// Trait for different manifold spaces
pub trait ManifoldSpace: Send + Sync {
/// Compute distance between two points on the manifold
fn distance(&self, x: &[f32], y: &[f32]) -> f32;
/// Inner product in tangent space
fn inner_product(&self, base: &[f32], u: &[f32], v: &[f32]) -> f32;
/// Manifold curvature parameter
fn curvature(&self) -> f32;
}
/// Hyperbolic space (Poincaré ball model)
pub struct HyperbolicSpace {
curvature: f32,
}
/// Spherical space
pub struct SphericalSpace {
radius: f32,
}
/// Product manifold (combination of different spaces)
pub struct ProductSpace {
components: Vec<Box<dyn ManifoldSpace>>,
}
2.4 Sparse Attention
/// Attention with sparse patterns for long sequences
pub trait SparseAttention: Attention {
/// Forward with explicit sparsity mask
fn forward_sparse(
&self,
query: &[f32],
keys: &[Vec<f32>],
values: &[Vec<f32>],
mask: &SparseMask,
) -> Vec<f32>;
/// Windowed (local) attention
fn forward_windowed(
&self,
queries: &[Vec<f32>],
keys: &[Vec<f32>],
values: &[Vec<f32>],
window_size: usize,
) -> Vec<Vec<f32>>;
/// Strided attention (Longformer-style)
fn forward_strided(
&self,
queries: &[Vec<f32>],
keys: &[Vec<f32>],
values: &[Vec<f32>],
window_size: usize,
stride: usize,
) -> Vec<Vec<f32>>;
/// Top-k sparse attention
fn forward_topk(
&self,
query: &[f32],
keys: &[Vec<f32>],
values: &[Vec<f32>],
k: usize,
) -> Vec<f32>;
}
/// Sparse attention mask representation
#[derive(Debug, Clone)]
pub enum SparseMask {
/// Dense boolean mask
Dense(Vec<Vec<bool>>),
/// Sparse coordinate format [(query_idx, key_idx)]
COO(Vec<(usize, usize)>),
/// Block-sparse pattern
BlockSparse {
block_size: usize,
active_blocks: Vec<(usize, usize)>,
},
/// Windowed pattern
Windowed {
window_size: usize,
global_tokens: Vec<usize>,
},
}
2.5 Trainable Attention
/// Attention mechanism with gradient computation for training
pub trait TrainableAttention: Attention {
/// Backward pass for gradient computation
///
/// # Arguments
/// * `grad_output` - Gradient w.r.t. output
///
/// # Returns
/// (grad_query, grad_keys, grad_values, grad_params)
fn backward(
&mut self,
grad_output: &[f32],
) -> (Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>, Vec<f32>);
/// Update parameters with optimizer
fn update_parameters(&mut self, optimizer: &mut dyn Optimizer);
/// Get trainable parameters
fn parameters(&self) -> Vec<&[f32]>;
/// Get mutable trainable parameters
fn parameters_mut(&mut self) -> Vec<&mut [f32]>;
/// Number of trainable parameters
fn num_parameters(&self) -> usize;
/// Save checkpoint
fn save_checkpoint(&self, path: &str) -> Result<(), AttentionError>;
/// Load checkpoint
fn load_checkpoint(&mut self, path: &str) -> Result<(), AttentionError>;
}
3. Module Dependencies Graph
graph TB
subgraph "Core Modules"
TRAITS[traits.rs<br/>Base Traits]
CONFIG[config.rs<br/>Configuration]
ERRORS[error.rs<br/>Error Types]
end
subgraph "Attention Implementations"
STANDARD[standard.rs<br/>Scaled Dot-Product]
MULTIHEAD[multihead.rs<br/>Multi-Head Attention]
GEOMETRIC[geometric.rs<br/>Geometric Attention]
SPARSE[sparse.rs<br/>Sparse Patterns]
GRAPH[graph.rs<br/>Graph Attention]
end
subgraph "Compute Kernels"
MATMUL[matmul.rs<br/>Matrix Multiplication]
SOFTMAX[softmax.rs<br/>Softmax/Normalize]
SIMD[simd.rs<br/>SIMD Operations]
DISTANCE[distance.rs<br/>Distance Metrics]
end
subgraph "Memory Management"
TENSOR[tensor.rs<br/>Tensor Pool]
CACHE[cache.rs<br/>Attention Cache]
MMAP[mmap.rs<br/>Memory Mapping]
end
subgraph "Training Support"
GRADIENT[gradient.rs<br/>Gradient Comp]
OPTIMIZER[optimizer.rs<br/>Optimizers]
EWC[ewc.rs<br/>Continual Learning]
end
subgraph "Compression"
QUANT[quantize.rs<br/>Quantization]
PRUNE[prune.rs<br/>Pruning]
DISTILL[distill.rs<br/>Distillation]
end
subgraph "Platform Bindings"
WASM[wasm.rs<br/>WASM Bindings]
NAPI[napi.rs<br/>Node.js Bindings]
CLI[cli.rs<br/>CLI Interface]
end
subgraph "Integration"
CORE_INT[core_integration.rs<br/>RuVector Core]
GNN_INT[gnn_integration.rs<br/>RuVector GNN]
GRAPH_INT[graph_integration.rs<br/>RuVector Graph]
end
%% Dependencies
STANDARD --> TRAITS
MULTIHEAD --> TRAITS
GEOMETRIC --> TRAITS
SPARSE --> TRAITS
GRAPH --> TRAITS
STANDARD --> MATMUL
STANDARD --> SOFTMAX
MULTIHEAD --> MATMUL
MULTIHEAD --> SOFTMAX
GEOMETRIC --> DISTANCE
SPARSE --> MATMUL
MATMUL --> SIMD
SOFTMAX --> SIMD
DISTANCE --> SIMD
STANDARD --> TENSOR
STANDARD --> CACHE
MULTIHEAD --> TENSOR
GEOMETRIC --> TENSOR
GRADIENT --> STANDARD
GRADIENT --> MULTIHEAD
OPTIMIZER --> GRADIENT
QUANT --> STANDARD
PRUNE --> MULTIHEAD
DISTILL --> GRADIENT
WASM --> STANDARD
WASM --> MULTIHEAD
NAPI --> STANDARD
NAPI --> MULTIHEAD
CLI --> STANDARD
CORE_INT --> TRAITS
GNN_INT --> GRAPH
GRAPH_INT --> GRAPH
ERRORS --> TRAITS
CONFIG --> TRAITS
style TRAITS fill:#e1f5ff
style STANDARD fill:#fff4e1
style MULTIHEAD fill:#fff4e1
style SIMD fill:#e8f5e9
style WASM fill:#f3e5f5
style NAPI fill:#f3e5f5
3.1 Dependency Matrix
| Module | Core Dependencies | Optional Dependencies |
|---|---|---|
traits |
serde, ndarray |
- |
standard |
traits, matmul, softmax |
cache, simd |
multihead |
traits, matmul, softmax |
cache, simd |
geometric |
traits, distance |
simd |
sparse |
traits, matmul |
cache |
graph |
traits, ruvector-graph |
ruvector-gnn |
matmul |
ndarray |
simd |
simd |
simsimd |
- |
wasm |
standard, wasm-bindgen |
multihead, sparse |
napi |
standard, napi-rs |
multihead, graph |
4. Data Flow Architecture
4.1 Forward Pass Pipeline
sequenceDiagram
participant User
participant Attention
participant MatMul
participant Softmax
participant SIMD
participant Cache
User->>Attention: forward(query, keys, values)
Attention->>Cache: check_cache(query_hash)
alt Cache Hit
Cache-->>Attention: cached_result
Attention-->>User: result
else Cache Miss
Attention->>MatMul: compute_qk_similarity(query, keys)
MatMul->>SIMD: batch_dot_product(query, keys)
SIMD-->>MatMul: similarity_scores
MatMul-->>Attention: scores
Attention->>Softmax: normalize(scores)
Softmax->>SIMD: fast_exp(scores)
SIMD-->>Softmax: exp_scores
Softmax-->>Attention: attention_weights
Attention->>MatMul: weighted_sum(attention_weights, values)
MatMul->>SIMD: weighted_aggregate(weights, values)
SIMD-->>MatMul: output
MatMul-->>Attention: result
Attention->>Cache: store(query_hash, result)
Attention-->>User: result
end
4.2 Backward Pass (Training)
sequenceDiagram
participant Trainer
participant Attention
participant Gradient
participant MatMul
participant Optimizer
Trainer->>Attention: backward(grad_output)
Attention->>Gradient: compute_grad_values(grad_output, weights)
Gradient->>MatMul: matmul_transpose(weights^T, grad_output)
MatMul-->>Gradient: grad_values
Gradient-->>Attention: grad_values
Attention->>Gradient: compute_grad_weights(grad_output, values)
Gradient->>MatMul: outer_product(grad_output, values^T)
MatMul-->>Gradient: grad_weights_raw
Gradient->>Gradient: softmax_backward(grad_weights_raw)
Gradient-->>Attention: grad_weights
Attention->>Gradient: compute_grad_query(grad_weights, keys)
Gradient->>MatMul: matmul(grad_weights, keys)
MatMul-->>Gradient: grad_query
Gradient-->>Attention: grad_query
Attention->>Gradient: compute_grad_keys(grad_weights, query)
Gradient->>MatMul: outer_product(grad_weights^T, query)
MatMul-->>Gradient: grad_keys
Gradient-->>Attention: grad_keys
Attention-->>Trainer: (grad_query, grad_keys, grad_values)
Trainer->>Optimizer: update_parameters(gradients)
Optimizer-->>Attention: updated_parameters
4.3 Graph Attention Flow
graph LR
A[Node Features] --> B[Gather Neighbors]
B --> C[Edge Attention]
C --> D[Attention Weights]
D --> E[Aggregate Messages]
E --> F[Update Node]
F --> G[Layer Norm]
G --> H[Output Features]
I[Edge Index] --> B
J[Edge Attributes] --> C
style A fill:#e3f2fd
style H fill:#c8e6c9
style D fill:#fff9c4
4.4 Sparse Attention Flow
graph TB
A[Input Sequence] --> B{Sparsity Pattern}
B -->|Windowed| C[Local Windows]
B -->|Strided| D[Strided Pattern]
B -->|Top-K| E[Top-K Selection]
B -->|Block| F[Block Sparse]
C --> G[Sparse Matmul]
D --> G
E --> G
F --> G
G --> H[Softmax]
H --> I[Sparse Weighted Sum]
I --> J[Output]
style B fill:#ffe0b2
style G fill:#f8bbd0
5. Integration with Existing Crates
5.1 RuVector Core Integration
// File: src/integration/core_integration.rs
use ruvector_core::{VectorDatabase, HnswIndex, Distance};
use crate::traits::Attention;
/// Bridge between attention mechanisms and vector database
pub struct AttentionVectorBridge {
db: VectorDatabase,
attention: Box<dyn Attention>,
}
impl AttentionVectorBridge {
/// Attention-based vector search
pub fn attention_search(
&self,
query: &[f32],
k: usize,
) -> Result<Vec<(u64, f32)>, Error> {
// 1. Get candidate vectors from HNSW
let candidates = self.db.search(query, k * 10)?;
// 2. Extract candidate vectors
let candidate_vecs: Vec<Vec<f32>> = candidates
.iter()
.map(|(id, _)| self.db.get(*id).unwrap())
.collect();
// 3. Use attention for re-ranking
let attention_scores = self.attention.forward(
query,
&candidate_vecs,
&candidate_vecs,
);
// 4. Return top-k by attention score
let mut results: Vec<_> = candidates
.into_iter()
.zip(attention_scores)
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results.truncate(k);
Ok(results.into_iter().map(|(id_score, attn)| (id_score.0, attn)).collect())
}
/// Batch attention search with SIMD
pub fn batch_attention_search(
&self,
queries: &[Vec<f32>],
k: usize,
) -> Result<Vec<Vec<(u64, f32)>>, Error> {
queries
.par_iter()
.map(|q| self.attention_search(q, k))
.collect()
}
}
/// Use attention for learned distance metric
pub struct AttentionDistance {
attention: Box<dyn Attention>,
}
impl Distance for AttentionDistance {
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
// Use attention mechanism as a learned distance
let score = self.attention.forward(a, &[b.to_vec()], &[b.to_vec()]);
// Convert to distance (higher attention = lower distance)
1.0 / (score[0] + 1e-6)
}
}
5.2 RuVector GNN Integration
// File: src/integration/gnn_integration.rs
use ruvector_gnn::{RuvectorLayer, training::Optimizer};
use crate::traits::{GraphAttention, TrainableAttention};
/// GNN layer using custom attention mechanism
pub struct AttentionGNNLayer {
attention: Box<dyn GraphAttention + TrainableAttention>,
normalization: LayerNorm,
feedforward: Linear,
}
impl AttentionGNNLayer {
/// Forward pass through GNN layer with attention
pub fn forward(
&self,
node_features: &[Vec<f32>],
edge_index: &[(usize, usize)],
) -> Vec<Vec<f32>> {
let mut outputs = Vec::with_capacity(node_features.len());
for node_idx in 0..node_features.len() {
// 1. Gather neighbor features via edges
let neighbors: Vec<usize> = edge_index
.iter()
.filter(|(src, _)| *src == node_idx)
.map(|(_, dst)| *dst)
.collect();
let neighbor_features: Vec<Vec<f32>> = neighbors
.iter()
.map(|&idx| node_features[idx].clone())
.collect();
if neighbor_features.is_empty() {
outputs.push(node_features[node_idx].clone());
continue;
}
// 2. Apply graph attention
let attended = self.attention.forward_neighborhood(
node_idx,
node_features,
edge_index,
);
// 3. Layer normalization
let normalized = self.normalization.forward(&attended);
// 4. Feedforward network
let output = self.feedforward.forward(&normalized);
outputs.push(output);
}
outputs
}
/// Training step with gradient computation
pub fn train_step(
&mut self,
node_features: &[Vec<f32>],
edge_index: &[(usize, usize)],
targets: &[Vec<f32>],
optimizer: &mut dyn Optimizer,
) -> f32 {
// Forward pass
let predictions = self.forward(node_features, edge_index);
// Compute loss (MSE for regression)
let loss = predictions
.iter()
.zip(targets)
.map(|(pred, target)| {
pred.iter()
.zip(target)
.map(|(p, t)| (p - t).powi(2))
.sum::<f32>()
})
.sum::<f32>()
/ predictions.len() as f32;
// Backward pass
let grad_predictions: Vec<Vec<f32>> = predictions
.iter()
.zip(targets)
.map(|(pred, target)| {
pred.iter()
.zip(target)
.map(|(p, t)| 2.0 * (p - t) / predictions.len() as f32)
.collect()
})
.collect();
// Backpropagate through attention
for (grad, node_idx) in grad_predictions.iter().zip(0..) {
self.attention.backward(grad);
}
// Update parameters
self.attention.update_parameters(optimizer);
loss
}
}
/// Integration with RuVector GNN's existing training infrastructure
impl From<Box<dyn GraphAttention>> for RuvectorLayer {
fn from(attention: Box<dyn GraphAttention>) -> Self {
// Convert attention mechanism to RuvectorLayer format
// This allows using attention in existing GNN pipelines
unimplemented!("Adapter implementation")
}
}
5.3 RuVector Graph Integration
// File: src/integration/graph_integration.rs
use ruvector_graph::{Graph, Node, Edge, CypherQuery};
use crate::traits::GraphAttention;
/// Query graph database with attention-based traversal
pub struct AttentionGraphQuery {
graph: Graph,
attention: Box<dyn GraphAttention>,
}
impl AttentionGraphQuery {
/// Execute Cypher query with attention-based node selection
pub fn execute_with_attention(
&self,
query: &CypherQuery,
start_node: u64,
) -> Result<Vec<Node>, Error> {
// 1. Parse Cypher query to get traversal pattern
let pattern = query.parse_pattern()?;
// 2. Start from initial node
let mut current_nodes = vec![self.graph.get_node(start_node)?];
// 3. For each step in pattern, use attention to select next nodes
for step in pattern.steps {
let mut next_nodes = Vec::new();
for node in ¤t_nodes {
// Get neighbors
let neighbors = self.graph.get_neighbors(node.id)?;
if neighbors.is_empty() {
continue;
}
// Extract neighbor features
let neighbor_features: Vec<Vec<f32>> = neighbors
.iter()
.map(|n| n.embedding.clone())
.collect();
// Use attention to score neighbors
let attention_output = self.attention.forward(
&node.embedding,
&neighbor_features,
&neighbor_features,
);
// Select top-k neighbors by attention score
let k = step.limit.unwrap_or(10);
let mut scored_neighbors: Vec<_> = neighbors
.into_iter()
.zip(attention_output)
.collect();
scored_neighbors.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scored_neighbors.truncate(k);
next_nodes.extend(scored_neighbors.into_iter().map(|(n, _)| n));
}
current_nodes = next_nodes;
}
Ok(current_nodes)
}
/// Graph embedding with attention aggregation
pub fn compute_graph_embedding(
&self,
subgraph: &[u64],
) -> Result<Vec<f32>, Error> {
// Aggregate subgraph nodes using attention
let node_features: Vec<Vec<f32>> = subgraph
.iter()
.map(|&id| self.graph.get_node(id)?.embedding.clone())
.collect::<Result<_, _>>()?;
// Create synthetic "graph query" as average of nodes
let query: Vec<f32> = node_features
.iter()
.fold(vec![0.0; node_features[0].len()], |mut acc, node| {
for (a, &n) in acc.iter_mut().zip(node) {
*a += n;
}
acc
})
.into_iter()
.map(|x| x / node_features.len() as f32)
.collect();
// Use attention to create graph-level embedding
Ok(self.attention.forward(&query, &node_features, &node_features))
}
}
6. Platform Architecture
6.1 Feature Flags
# Cargo.toml feature configuration
[features]
default = ["std", "simd"]
# Standard library support
std = []
# SIMD optimizations
simd = ["simsimd"]
# Platform-specific features
wasm = ["wasm-bindgen", "js-sys", "web-sys"]
napi = ["napi", "napi-derive"]
# Attention variants
standard = []
multihead = ["standard"]
geometric = ["manifold-math"]
sparse = []
graph = ["ruvector-graph"]
# All attention mechanisms
all-attention = ["standard", "multihead", "geometric", "sparse", "graph"]
# Training support
training = ["gradient", "optimizer"]
gradient = []
optimizer = ["training"]
# Compression
compression = ["quantization", "pruning"]
quantization = []
pruning = []
# Memory management
mmap = ["memmap2"]
cache = ["lru"]
# Full feature set
full = ["all-attention", "training", "compression", "simd", "cache"]
6.2 WASM Architecture
// File: src/platform/wasm.rs
use wasm_bindgen::prelude::*;
use js_sys::{Array, Float32Array};
use web_sys::console;
#[wasm_bindgen]
pub struct WasmAttention {
inner: Box<dyn Attention>,
}
#[wasm_bindgen]
impl WasmAttention {
/// Create new attention mechanism
#[wasm_bindgen(constructor)]
pub fn new(config: JsValue) -> Result<WasmAttention, JsValue> {
// Parse config from JavaScript
let config: AttentionConfig = serde_wasm_bindgen::from_value(config)
.map_err(|e| JsValue::from_str(&format!("Config parse error: {}", e)))?;
// Create appropriate attention type
let attention: Box<dyn Attention> = match config.attention_type {
AttentionType::Standard => Box::new(StandardAttention::new(config)),
AttentionType::MultiHead => Box::new(MultiHeadAttention::new(config)),
AttentionType::Sparse => Box::new(SparseAttention::new(config)),
_ => return Err(JsValue::from_str("Unsupported attention type for WASM")),
};
Ok(WasmAttention { inner: attention })
}
/// Forward pass with JS arrays
#[wasm_bindgen]
pub fn forward(
&self,
query: Float32Array,
keys: Array,
values: Array,
) -> Result<Float32Array, JsValue> {
// Convert JavaScript arrays to Rust vectors
let query_vec: Vec<f32> = query.to_vec();
let keys_vec: Vec<Vec<f32>> = (0..keys.length())
.map(|i| {
let arr = Float32Array::from(keys.get(i));
arr.to_vec()
})
.collect();
let values_vec: Vec<Vec<f32>> = (0..values.length())
.map(|i| {
let arr = Float32Array::from(values.get(i));
arr.to_vec()
})
.collect();
// Perform attention computation
let result = self.inner.forward(&query_vec, &keys_vec, &values_vec);
// Convert back to JavaScript
Ok(Float32Array::from(&result[..]))
}
/// Batch processing
#[wasm_bindgen]
pub fn forward_batch(
&self,
queries: Array,
keys: Array,
values: Array,
) -> Result<Array, JsValue> {
let queries_vec: Vec<Vec<f32>> = (0..queries.length())
.map(|i| Float32Array::from(queries.get(i)).to_vec())
.collect();
let keys_vec: Vec<Vec<f32>> = (0..keys.length())
.map(|i| Float32Array::from(keys.get(i)).to_vec())
.collect();
let values_vec: Vec<Vec<f32>> = (0..values.length())
.map(|i| Float32Array::from(values.get(i)).to_vec())
.collect();
let results = self.inner.forward_batch(&queries_vec, &keys_vec, &values_vec);
let js_results = Array::new();
for result in results {
js_results.push(&Float32Array::from(&result[..]));
}
Ok(js_results)
}
/// Get attention weights for visualization
#[wasm_bindgen]
pub fn get_attention_weights(&self) -> Option<Array> {
self.inner.get_attention_weights().map(|weights| {
let js_weights = Array::new();
for row in weights {
js_weights.push(&Float32Array::from(&row[..]));
}
js_weights
})
}
}
// Memory management for WASM
#[wasm_bindgen]
pub struct WasmMemoryPool {
tensors: Vec<Vec<f32>>,
}
#[wasm_bindgen]
impl WasmMemoryPool {
#[wasm_bindgen(constructor)]
pub fn new(capacity: usize) -> Self {
Self {
tensors: Vec::with_capacity(capacity),
}
}
/// Pre-allocate tensors to avoid runtime allocation
pub fn preallocate(&mut self, size: usize, count: usize) {
for _ in 0..count {
self.tensors.push(vec![0.0; size]);
}
}
}
WASM Limitations & Solutions:
| Limitation | Solution |
|---|---|
| No SIMD | Use manual loop unrolling and vectorization |
| No threads | Use Web Workers for parallelism |
| No mmap | Pure in-memory storage only |
| Memory limit | Aggressive pooling and deallocation |
| No HNSW (mmap) | Use in-memory index or disable graph features |
6.3 NAPI-RS Architecture
// File: src/platform/napi.rs
use napi::{
bindgen_prelude::*,
JsObject, JsNumber, JsString,
};
use napi_derive::napi;
#[napi]
pub struct NapiAttention {
inner: Box<dyn Attention>,
runtime: tokio::runtime::Runtime,
}
#[napi]
impl NapiAttention {
/// Create new attention mechanism
#[napi(constructor)]
pub fn new(config: JsObject) -> Result<Self> {
let embed_dim: u32 = config.get_named("embedDim")?;
let num_heads: Option<u32> = config.get_named("numHeads").ok();
let attention_config = AttentionConfig {
embed_dim: embed_dim as usize,
num_heads: num_heads.map(|h| h as usize),
dropout_rate: config.get_named("dropoutRate").ok(),
attention_type: AttentionType::Standard, // TODO: parse from config
};
let attention = Box::new(StandardAttention::new(attention_config));
let runtime = tokio::runtime::Runtime::new()
.map_err(|e| Error::from_reason(format!("Runtime init failed: {}", e)))?;
Ok(Self { inner: attention, runtime })
}
/// Synchronous forward pass
#[napi]
pub fn forward(
&self,
query: Float32Array,
keys: Vec<Float32Array>,
values: Vec<Float32Array>,
) -> Result<Float32Array> {
let query_vec = query.to_vec();
let keys_vec: Vec<Vec<f32>> = keys.iter().map(|k| k.to_vec()).collect();
let values_vec: Vec<Vec<f32>> = values.iter().map(|v| v.to_vec()).collect();
let result = self.inner.forward(&query_vec, &keys_vec, &values_vec);
Ok(Float32Array::new(result))
}
/// Async batch processing (non-blocking)
#[napi]
pub async fn forward_batch_async(
&self,
queries: Vec<Float32Array>,
keys: Vec<Float32Array>,
values: Vec<Float32Array>,
) -> Result<Vec<Float32Array>> {
// Convert to Rust types
let queries_vec: Vec<Vec<f32>> = queries.iter().map(|q| q.to_vec()).collect();
let keys_vec: Vec<Vec<f32>> = keys.iter().map(|k| k.to_vec()).collect();
let values_vec: Vec<Vec<f32>> = values.iter().map(|v| v.to_vec()).collect();
// Spawn blocking task for CPU-intensive work
let inner = self.inner.clone(); // Need to make Attention cloneable or use Arc
let result = tokio::task::spawn_blocking(move || {
inner.forward_batch(&queries_vec, &keys_vec, &values_vec)
})
.await
.map_err(|e| Error::from_reason(format!("Task failed: {}", e)))?;
// Convert back to JS
Ok(result.into_iter().map(Float32Array::new).collect())
}
/// Stream processing with Node.js streams
#[napi]
pub fn create_stream(&self, env: Env) -> Result<JsObject> {
// Create a Transform stream that processes attention in chunks
let stream = env.create_object()?;
// Implementation would integrate with Node.js stream API
Ok(stream)
}
}
/// Buffer operations for zero-copy transfers
#[napi]
pub struct AttentionBuffer {
data: Vec<f32>,
}
#[napi]
impl AttentionBuffer {
/// Create buffer from Node.js Buffer (zero-copy if possible)
#[napi(factory)]
pub fn from_buffer(buffer: Buffer) -> Result<Self> {
// Try to avoid copy by using buffer directly
let data = buffer.as_ref().to_vec(); // TODO: true zero-copy
Ok(Self { data })
}
/// Export as Node.js Buffer
#[napi]
pub fn to_buffer(&self, env: Env) -> Result<Buffer> {
env.create_buffer_with_data(self.data.clone()) // TODO: avoid clone
}
}
NAPI-RS Advantages:
- Async/Await: Non-blocking operations for Node.js event loop
- Buffer Sharing: Near zero-copy data transfer
- TypeScript Definitions: Auto-generated .d.ts files
- Worker Threads: Offload CPU-intensive work
- Stream API: Integration with Node.js streams
6.4 CLI Architecture
// File: src/platform/cli.rs
use clap::{Parser, Subcommand};
use std::path::PathBuf;
#[derive(Parser)]
#[command(name = "ruvector-attention")]
#[command(about = "Attention mechanisms for RuVector", long_about = None)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
/// Run attention benchmark
Bench {
/// Input dimension
#[arg(short, long, default_value_t = 128)]
dim: usize,
/// Number of keys/values
#[arg(short, long, default_value_t = 1000)]
num_keys: usize,
/// Attention type (standard, multihead, geometric, sparse)
#[arg(short, long, default_value = "standard")]
attention_type: String,
/// Number of iterations
#[arg(short, long, default_value_t = 100)]
iterations: usize,
},
/// Test attention correctness
Test {
/// Test suite (all, standard, multihead, geometric, sparse, graph)
#[arg(short, long, default_value = "all")]
suite: String,
/// Verbose output
#[arg(short, long)]
verbose: bool,
},
/// Visualize attention weights
Visualize {
/// Input file (queries, keys, values)
#[arg(short, long)]
input: PathBuf,
/// Output file (attention matrix visualization)
#[arg(short, long)]
output: PathBuf,
/// Attention type
#[arg(short, long, default_value = "standard")]
attention_type: String,
},
/// Export attention model
Export {
/// Model checkpoint path
#[arg(short, long)]
checkpoint: PathBuf,
/// Export format (onnx, rust, json)
#[arg(short, long, default_value = "onnx")]
format: String,
/// Output path
#[arg(short, long)]
output: PathBuf,
},
}
pub fn run_cli() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
match cli.command {
Commands::Bench { dim, num_keys, attention_type, iterations } => {
run_benchmark(dim, num_keys, &attention_type, iterations)?;
}
Commands::Test { suite, verbose } => {
run_tests(&suite, verbose)?;
}
Commands::Visualize { input, output, attention_type } => {
visualize_attention(&input, &output, &attention_type)?;
}
Commands::Export { checkpoint, format, output } => {
export_model(&checkpoint, &format, &output)?;
}
}
Ok(())
}
fn run_benchmark(
dim: usize,
num_keys: usize,
attention_type: &str,
iterations: usize,
) -> Result<(), Box<dyn std::error::Error>> {
use std::time::Instant;
// Create attention mechanism
let attention = create_attention(attention_type, dim)?;
// Generate random data
let query = vec![0.5; dim];
let keys: Vec<Vec<f32>> = (0..num_keys).map(|_| vec![0.5; dim]).collect();
let values: Vec<Vec<f32>> = (0..num_keys).map(|_| vec![0.5; dim]).collect();
// Warmup
for _ in 0..10 {
attention.forward(&query, &keys, &values);
}
// Benchmark
let start = Instant::now();
for _ in 0..iterations {
attention.forward(&query, &keys, &values);
}
let duration = start.elapsed();
let avg_ms = duration.as_secs_f64() * 1000.0 / iterations as f64;
let throughput = iterations as f64 / duration.as_secs_f64();
println!("=== Attention Benchmark ===");
println!("Type: {}", attention_type);
println!("Dimension: {}", dim);
println!("Keys/Values: {}", num_keys);
println!("Iterations: {}", iterations);
println!("Average: {:.3} ms", avg_ms);
println!("Throughput: {:.1} ops/sec", throughput);
Ok(())
}
7. Performance Architecture
7.1 SIMD Optimization Strategy
// File: src/kernels/simd.rs
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
/// SIMD-optimized dot product
#[inline]
pub fn simd_dot_product(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
#[cfg(target_feature = "avx2")]
unsafe {
simd_dot_product_avx2(a, b)
}
#[cfg(all(target_feature = "sse2", not(target_feature = "avx2")))]
unsafe {
simd_dot_product_sse2(a, b)
}
#[cfg(not(any(target_feature = "avx2", target_feature = "sse2")))]
{
scalar_dot_product(a, b)
}
}
#[cfg(target_feature = "avx2")]
#[target_feature(enable = "avx2")]
unsafe fn simd_dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(offset));
let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
let prod = _mm256_mul_ps(va, vb);
sum = _mm256_add_ps(sum, prod);
}
// Horizontal sum
let mut result = [0f32; 8];
_mm256_storeu_ps(result.as_mut_ptr(), sum);
let mut total = result.iter().sum::<f32>();
// Handle remainder
for i in (chunks * 8)..len {
total += a[i] * b[i];
}
total
}
/// SIMD-optimized batch matrix multiplication
pub fn simd_batch_matmul(
a: &[Vec<f32>], // [batch, m, k]
b: &[Vec<f32>], // [k, n]
) -> Vec<Vec<f32>> {
use rayon::prelude::*;
a.par_iter()
.map(|row| {
b.iter()
.map(|col| simd_dot_product(row, col))
.collect()
})
.collect()
}
/// SIMD-optimized softmax
pub fn simd_softmax(x: &mut [f32]) {
// Find max for numerical stability
let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
// Exp(x - max)
for val in x.iter_mut() {
*val = (*val - max_val).exp();
}
// Normalize
let sum: f32 = x.iter().sum();
for val in x.iter_mut() {
*val /= sum;
}
}
/// SIMD-optimized layer normalization
pub fn simd_layer_norm(x: &mut [f32], eps: f32) {
let n = x.len() as f32;
// Compute mean
let mean = x.iter().sum::<f32>() / n;
// Compute variance
let variance = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
let std_dev = (variance + eps).sqrt();
// Normalize
for val in x.iter_mut() {
*val = (*val - mean) / std_dev;
}
}
7.2 Cache Optimization
// File: src/cache.rs
use lru::LruCache;
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
/// Hash query for cache lookup
fn hash_query(query: &[f32]) -> u64 {
let mut hasher = DefaultHasher::new();
// Quantize to 2 decimal places for fuzzy matching
for &val in query {
let quantized = (val * 100.0).round() as i32;
quantized.hash(&mut hasher);
}
hasher.finish()
}
pub struct AttentionCache {
cache: LruCache<u64, Vec<f32>>,
hits: usize,
misses: usize,
}
impl AttentionCache {
pub fn new(capacity: usize) -> Self {
Self {
cache: LruCache::new(capacity.try_into().unwrap()),
hits: 0,
misses: 0,
}
}
pub fn get(&mut self, query: &[f32]) -> Option<&Vec<f32>> {
let hash = hash_query(query);
match self.cache.get(&hash) {
Some(result) => {
self.hits += 1;
Some(result)
}
None => {
self.misses += 1;
None
}
}
}
pub fn put(&mut self, query: &[f32], result: Vec<f32>) {
let hash = hash_query(query);
self.cache.put(hash, result);
}
pub fn hit_rate(&self) -> f64 {
if self.hits + self.misses == 0 {
return 0.0;
}
self.hits as f64 / (self.hits + self.misses) as f64
}
}
7.3 Memory Pooling
// File: src/memory/pool.rs
use parking_lot::Mutex;
use std::sync::Arc;
pub struct TensorPool {
pools: Vec<Arc<Mutex<Vec<Vec<f32>>>>>,
}
impl TensorPool {
pub fn new() -> Self {
// Create pools for common sizes
let sizes = vec![128, 256, 512, 768, 1024, 2048];
let pools = sizes
.into_iter()
.map(|size| {
let pool: Vec<Vec<f32>> = (0..16)
.map(|_| vec![0.0; size])
.collect();
Arc::new(Mutex::new(pool))
})
.collect();
Self { pools }
}
pub fn acquire(&self, size: usize) -> Option<Vec<f32>> {
// Find appropriate pool
let pool_idx = self.pools.iter().position(|p| {
// Check if pool matches size
let guard = p.lock();
guard.first().map_or(false, |v| v.len() >= size)
})?;
self.pools[pool_idx].lock().pop()
}
pub fn release(&self, mut tensor: Vec<f32>) {
let size = tensor.len();
// Find appropriate pool
if let Some(pool) = self.pools.iter().find(|p| {
let guard = p.lock();
guard.first().map_or(false, |v| v.len() == size)
}) {
// Clear tensor
tensor.fill(0.0);
// Return to pool
pool.lock().push(tensor);
}
}
}
7.4 Parallel Execution
// File: src/parallel.rs
use rayon::prelude::*;
/// Parallel batch attention with work stealing
pub fn parallel_batch_attention<A: Attention + Sync>(
attention: &A,
queries: &[Vec<f32>],
keys: &[Vec<f32>],
values: &[Vec<f32>],
) -> Vec<Vec<f32>> {
queries
.par_iter()
.map(|query| attention.forward(query, keys, values))
.collect()
}
/// Parallel multi-head attention computation
pub fn parallel_multihead_forward(
queries: &[Vec<f32>],
keys: &[Vec<f32>],
values: &[Vec<f32>],
num_heads: usize,
) -> Vec<Vec<f32>> {
// Split into heads and process in parallel
(0..num_heads)
.into_par_iter()
.flat_map(|head| {
let head_dim = queries[0].len() / num_heads;
let start = head * head_dim;
let end = start + head_dim;
queries.iter().map(move |q| {
// Extract head-specific portion
q[start..end].to_vec()
})
})
.collect()
}
7.5 Performance Targets
| Operation | Target | SIMD Speedup | Cache Hit Rate |
|---|---|---|---|
| Dot Product (128-dim) | < 10 μs | 4-8x | N/A |
| Attention (1K keys) | < 500 μs | 2-4x | 60-80% |
| Multi-head (8 heads) | < 2 ms | 3-6x | 50-70% |
| Batch (32 queries) | < 50 ms | 8-16x | 40-60% |
| Graph attention (1K nodes) | < 5 ms | 2-4x | 30-50% |
8. Security Architecture
8.1 Input Validation
// File: src/validation.rs
/// Validate input tensors for attention computation
pub fn validate_attention_inputs(
query: &[f32],
keys: &[Vec<f32>],
values: &[Vec<f32>],
) -> Result<(), AttentionError> {
// Check dimensions
if query.is_empty() {
return Err(AttentionError::InvalidInput("Query is empty".into()));
}
if keys.is_empty() {
return Err(AttentionError::InvalidInput("Keys are empty".into()));
}
if values.is_empty() {
return Err(AttentionError::InvalidInput("Values are empty".into()));
}
if keys.len() != values.len() {
return Err(AttentionError::DimensionMismatch {
expected: keys.len(),
actual: values.len(),
});
}
// Check key dimensions match query
let query_dim = query.len();
for (i, key) in keys.iter().enumerate() {
if key.len() != query_dim {
return Err(AttentionError::DimensionMismatch {
expected: query_dim,
actual: key.len(),
});
}
}
// Check for NaN/Inf
for &val in query {
if !val.is_finite() {
return Err(AttentionError::InvalidInput("Query contains NaN/Inf".into()));
}
}
// Limit maximum sequence length to prevent DoS
const MAX_SEQUENCE_LENGTH: usize = 100_000;
if keys.len() > MAX_SEQUENCE_LENGTH {
return Err(AttentionError::InputTooLarge {
max: MAX_SEQUENCE_LENGTH,
actual: keys.len(),
});
}
Ok(())
}
8.2 Memory Safety
// File: src/safety.rs
/// Safe tensor operations with bounds checking
pub struct SafeTensor {
data: Vec<f32>,
shape: Vec<usize>,
}
impl SafeTensor {
pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Result<Self, AttentionError> {
let expected_size: usize = shape.iter().product();
if data.len() != expected_size {
return Err(AttentionError::DimensionMismatch {
expected: expected_size,
actual: data.len(),
});
}
Ok(Self { data, shape })
}
pub fn get(&self, indices: &[usize]) -> Result<f32, AttentionError> {
if indices.len() != self.shape.len() {
return Err(AttentionError::InvalidIndex);
}
let flat_index = self.compute_flat_index(indices)?;
self.data.get(flat_index)
.copied()
.ok_or(AttentionError::IndexOutOfBounds)
}
fn compute_flat_index(&self, indices: &[usize]) -> Result<usize, AttentionError> {
let mut flat = 0;
let mut stride = 1;
for (i, &idx) in indices.iter().enumerate().rev() {
if idx >= self.shape[i] {
return Err(AttentionError::IndexOutOfBounds);
}
flat += idx * stride;
stride *= self.shape[i];
}
Ok(flat)
}
}
8.3 Numerical Stability
// File: src/numerical.rs
/// Numerically stable softmax computation
pub fn stable_softmax(x: &[f32]) -> Vec<f32> {
// Find max for numerical stability
let max_val = x.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
// Compute exp(x - max)
let exp_vals: Vec<f32> = x.iter()
.map(|&val| (val - max_val).exp())
.collect();
// Normalize
let sum: f32 = exp_vals.iter().sum();
if sum == 0.0 || !sum.is_finite() {
// Fallback to uniform distribution
vec![1.0 / x.len() as f32; x.len()]
} else {
exp_vals.iter().map(|&val| val / sum).collect()
}
}
/// Safe division with epsilon for numerical stability
pub fn safe_div(a: f32, b: f32, eps: f32) -> f32 {
a / (b + eps)
}
/// Clamp values to prevent overflow
pub fn clamp_finite(x: f32, min: f32, max: f32) -> f32 {
if !x.is_finite() {
return 0.0;
}
x.clamp(min, max)
}
9. Deployment Architecture
9.1 Crate Structure
crates/ruvector-attention/
├── Cargo.toml
├── README.md
├── LICENSE
├── benches/
│ ├── attention_bench.rs
│ ├── simd_bench.rs
│ └── comparison_bench.rs
├── examples/
│ ├── basic_attention.rs
│ ├── graph_attention.rs
│ ├── geometric_attention.rs
│ └── training_example.rs
├── src/
│ ├── lib.rs
│ ├── config.rs
│ ├── error.rs
│ │
│ ├── traits/
│ │ ├── mod.rs
│ │ ├── attention.rs
│ │ ├── graph_attention.rs
│ │ ├── geometric_attention.rs
│ │ └── sparse_attention.rs
│ │
│ ├── attention/
│ │ ├── mod.rs
│ │ ├── standard.rs
│ │ ├── multihead.rs
│ │ ├── geometric.rs
│ │ ├── sparse.rs
│ │ └── graph.rs
│ │
│ ├── kernels/
│ │ ├── mod.rs
│ │ ├── matmul.rs
│ │ ├── softmax.rs
│ │ ├── simd.rs
│ │ └── distance.rs
│ │
│ ├── memory/
│ │ ├── mod.rs
│ │ ├── tensor.rs
│ │ ├── cache.rs
│ │ └── pool.rs
│ │
│ ├── training/
│ │ ├── mod.rs
│ │ ├── gradient.rs
│ │ ├── optimizer.rs
│ │ └── ewc.rs
│ │
│ ├── compression/
│ │ ├── mod.rs
│ │ ├── quantize.rs
│ │ ├── prune.rs
│ │ └── distill.rs
│ │
│ ├── integration/
│ │ ├── mod.rs
│ │ ├── core_integration.rs
│ │ ├── gnn_integration.rs
│ │ └── graph_integration.rs
│ │
│ ├── platform/
│ │ ├── mod.rs
│ │ ├── wasm.rs
│ │ ├── napi.rs
│ │ └── cli.rs
│ │
│ ├── manifold/
│ │ ├── mod.rs
│ │ ├── hyperbolic.rs
│ │ ├── spherical.rs
│ │ └── product.rs
│ │
│ └── utils/
│ ├── mod.rs
│ ├── validation.rs
│ ├── safety.rs
│ └── numerical.rs
│
└── tests/
├── integration_tests.rs
├── attention_tests.rs
├── geometric_tests.rs
└── correctness_tests.rs
9.2 Build Configuration
# Cargo.toml
[package]
name = "ruvector-attention"
version = "0.1.0"
edition = "2021"
rust-version = "1.77"
license = "MIT"
authors = ["Ruvector Team"]
repository = "https://github.com/ruvnet/ruvector"
description = "High-performance attention mechanisms for RuVector"
[dependencies]
# RuVector dependencies
ruvector-core = { version = "0.1", path = "../ruvector-core", optional = true }
ruvector-gnn = { version = "0.1", path = "../ruvector-gnn", optional = true }
ruvector-graph = { version = "0.1", path = "../ruvector-graph", optional = true }
# Math and numerics
ndarray = { workspace = true }
rand = { workspace = true }
rand_distr = { workspace = true }
# Performance
simsimd = { workspace = true, optional = true }
rayon = { workspace = true }
parking_lot = { workspace = true }
# Serialization
serde = { workspace = true }
serde_json = { workspace = true }
# Error handling
thiserror = { workspace = true }
anyhow = { workspace = true }
# Caching
lru = { version = "0.12", optional = true }
# Memory mapping
memmap2 = { workspace = true, optional = true }
# Platform-specific
napi = { workspace = true, optional = true }
napi-derive = { workspace = true, optional = true }
wasm-bindgen = { workspace = true, optional = true }
js-sys = { workspace = true, optional = true }
web-sys = { workspace = true, optional = true }
# CLI
clap = { workspace = true, optional = true }
[dev-dependencies]
criterion = { workspace = true }
proptest = { workspace = true }
tempfile = "3.10"
[features]
default = ["std", "simd"]
std = []
simd = ["simsimd"]
# Attention types
standard = []
multihead = ["standard"]
geometric = []
sparse = []
graph = ["ruvector-graph"]
all-attention = ["standard", "multihead", "geometric", "sparse", "graph"]
# Integration
core-integration = ["ruvector-core"]
gnn-integration = ["ruvector-gnn", "graph"]
graph-integration = ["ruvector-graph"]
# Training
training = []
# Compression
compression = []
# Memory
cache = ["lru"]
mmap = ["memmap2"]
# Platforms
wasm = ["wasm-bindgen", "js-sys", "web-sys"]
napi = ["dep:napi", "dep:napi-derive"]
cli = ["clap"]
# Full feature set
full = [
"all-attention",
"core-integration",
"gnn-integration",
"graph-integration",
"training",
"compression",
"cache",
"simd",
]
[[bench]]
name = "attention_bench"
harness = false
[lib]
crate-type = ["rlib"]
bench = false
10. Testing Architecture
10.1 Test Strategy
// File: tests/attention_tests.rs
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
/// Unit test: basic attention forward pass
#[test]
fn test_standard_attention_forward() {
let config = AttentionConfig {
embed_dim: 128,
num_heads: None,
dropout_rate: None,
attention_type: AttentionType::Standard,
};
let attention = StandardAttention::new(config);
let query = vec![0.5; 128];
let keys = vec![vec![0.3; 128]; 10];
let values = vec![vec![0.7; 128]; 10];
let result = attention.forward(&query, &keys, &values);
assert_eq!(result.len(), 128);
assert!(result.iter().all(|&x| x.is_finite()));
}
/// Property test: attention output dimension
proptest! {
#[test]
fn test_attention_output_dimension(
query_dim in 16usize..512,
value_dim in 16usize..512,
num_keys in 1usize..100,
) {
let query = vec![0.5; query_dim];
let keys = vec![vec![0.3; query_dim]; num_keys];
let values = vec![vec![0.7; value_dim]; num_keys];
let attention = create_standard_attention(query_dim);
let result = attention.forward(&query, &keys, &values);
prop_assert_eq!(result.len(), value_dim);
}
}
/// Property test: attention weights sum to 1
proptest! {
#[test]
fn test_attention_weights_normalize(
query_dim in 16usize..128,
num_keys in 1usize..50,
) {
let query = vec![0.5; query_dim];
let keys = vec![vec![0.3; query_dim]; num_keys];
let values = vec![vec![0.7; query_dim]; num_keys];
let attention = create_standard_attention(query_dim);
attention.forward(&query, &keys, &values);
if let Some(weights) = attention.get_attention_weights() {
for row in weights {
let sum: f32 = row.iter().sum();
prop_assert!((sum - 1.0).abs() < 1e-5);
}
}
}
}
/// Integration test: graph attention with edge index
#[test]
fn test_graph_attention_with_edges() {
let node_features = vec![
vec![0.1; 64],
vec![0.2; 64],
vec![0.3; 64],
];
let edge_index = vec![
(0, 1),
(1, 2),
(2, 0),
];
let attention = create_graph_attention(64);
let result = attention.forward_with_edges(
&node_features[0],
&node_features,
&node_features,
&edge_index,
None,
);
assert_eq!(result.len(), 64);
assert!(result.iter().all(|&x| x.is_finite()));
}
/// Benchmark comparison test
#[test]
fn test_simd_vs_scalar_performance() {
use std::time::Instant;
let size = 1000;
let a = vec![0.5; size];
let b = vec![0.3; size];
// Scalar version
let start = Instant::now();
let scalar_result = scalar_dot_product(&a, &b);
let scalar_time = start.elapsed();
// SIMD version
let start = Instant::now();
let simd_result = simd_dot_product(&a, &b);
let simd_time = start.elapsed();
// Results should be approximately equal
assert!((scalar_result - simd_result).abs() < 1e-3);
// SIMD should be faster (relaxed for CI)
println!("Scalar: {:?}, SIMD: {:?}", scalar_time, simd_time);
}
}
10.2 Correctness Tests
/// Test numerical stability with extreme values
#[test]
fn test_numerical_stability() {
let query = vec![f32::MAX / 2.0; 64];
let keys = vec![vec![f32::MAX / 2.0; 64]; 10];
let values = vec![vec![0.5; 64]; 10];
let attention = create_standard_attention(64);
let result = attention.forward(&query, &keys, &values);
// Should not produce NaN or Inf
assert!(result.iter().all(|&x| x.is_finite()));
}
/// Test gradient computation correctness
#[test]
fn test_gradient_computation() {
let mut attention = create_trainable_attention(64);
let query = vec![0.5; 64];
let keys = vec![vec![0.3; 64]; 10];
let values = vec![vec![0.7; 64]; 10];
// Forward pass
let output = attention.forward(&query, &keys, &values);
// Backward pass
let grad_output = vec![1.0; 64];
let (grad_query, grad_keys, grad_values, _) = attention.backward(&grad_output);
// Gradients should be finite
assert!(grad_query.iter().all(|&x| x.is_finite()));
assert!(grad_keys.iter().all(|row| row.iter().all(|&x| x.is_finite())));
assert!(grad_values.iter().all(|row| row.iter().all(|&x| x.is_finite())));
// Numerical gradient check (finite differences)
let eps = 1e-4;
// ... implementation
}
11. Documentation Architecture
11.1 API Documentation
//! # RuVector Attention
//!
//! High-performance attention mechanisms for the RuVector ecosystem.
//!
//! ## Features
//!
//! - **Multiple Attention Types**: Standard, multi-head, geometric, sparse, graph
//! - **SIMD Optimization**: Hardware-accelerated computation
//! - **Platform Agnostic**: Rust, WASM, Node.js bindings
//! - **Training Support**: Gradient computation and optimization
//! - **Memory Efficient**: Caching and pooling
//!
//! ## Quick Start
//!
//! ```rust
//! use ruvector_attention::{StandardAttention, AttentionConfig, Attention};
//!
//! let config = AttentionConfig {
//! embed_dim: 128,
//! num_heads: None,
//! dropout_rate: None,
//! attention_type: AttentionType::Standard,
//! };
//!
//! let attention = StandardAttention::new(config);
//!
//! let query = vec![0.5; 128];
//! let keys = vec![vec![0.3; 128]; 10];
//! let values = vec![vec![0.7; 128]; 10];
//!
//! let output = attention.forward(&query, &keys, &values);
//! ```
//!
//! ## Architecture
//!
//! See [architecture documentation](../docs/latent-space/implementation-plans/02-architecture.md)
//! for complete system design.
11.2 Example Documentation
Create comprehensive examples in examples/:
basic_attention.rs- Simple attention usagemultihead_attention.rs- Multi-head setupgraph_attention.rs- Graph neural network integrationgeometric_attention.rs- Hyperbolic/spherical spacestraining_loop.rs- End-to-end trainingwasm_usage.rs- WASM binding examplenode_usage.js- Node.js binding example
12. Migration Path
12.1 Phase 1: Core Implementation
- Implement base traits (
Attention,AttentionConfig) - Create standard attention mechanism
- Add SIMD kernels (matmul, softmax)
- Implement basic caching
12.2 Phase 2: Advanced Attention
- Multi-head attention
- Geometric attention (hyperbolic space)
- Sparse attention patterns
- Graph attention
12.3 Phase 3: Platform Bindings
- WASM bindings
- NAPI-RS bindings
- CLI tool
12.4 Phase 4: Integration
- RuVector Core integration
- RuVector GNN integration
- RuVector Graph integration
12.5 Phase 5: Optimization
- Performance tuning
- Memory optimization
- Benchmark suite
- Documentation
13. Success Criteria
13.1 Functional Requirements
- ✅ All attention types implemented
- ✅ Platform bindings functional
- ✅ Integration tests passing
- ✅ Examples working
13.2 Performance Requirements
- ✅ SIMD speedup >2x vs scalar
- ✅ Cache hit rate >60%
- ✅ Memory overhead <20%
- ✅ Throughput >100 ops/sec (standard benchmark)
13.3 Quality Requirements
- ✅ Code coverage >80%
- ✅ Documentation coverage 100%
- ✅ Zero unsafe violations
- ✅ All clippy warnings resolved
Appendix A: Related Work
- Transformers: Vaswani et al., "Attention Is All You Need" (2017)
- Graph Attention Networks: Veličković et al., "Graph Attention Networks" (2018)
- Geometric Deep Learning: Bronstein et al., "Geometric Deep Learning" (2021)
- Sparse Transformers: Child et al., "Generating Long Sequences with Sparse Transformers" (2019)
Appendix B: References
- RuVector Core:
/crates/ruvector-core - RuVector GNN:
/crates/ruvector-gnn - RuVector Graph:
/crates/ruvector-graph - SIMD Library:
simsimd - NAPI-RS:
napi-rs.dev
Document Status: ✅ Complete Next Phase: 03-refinement.md (TDD Implementation) Review: Architecture review scheduled for implementation phase