Files
wifi-densepose/docs/research/latent-space/implementation-plans/agents/04-graph-attention.md
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

39 KiB
Raw Blame History

Agent 4: Graph Attention Implementations

Overview

This agent implements graph-aware attention mechanisms that leverage both structural and latent space information from HNSW indices. The implementations bridge traditional GNN attention with modern transformer-style mechanisms adapted for graph structures.

Architecture Components

1. EdgeFeaturedAttention

2. GraphRoPE (Rotary Position Embeddings for Graphs)

3. DualSpaceAttention (Cross-Attention between Graph and Latent Space)


1. EdgeFeaturedAttention

Integrates edge features directly into attention computation, extending GAT to handle rich edge information stored in HNSW connections.

Design Principles

  • Edge-Aware Scoring: Attention coefficients incorporate edge features
  • LeakyReLU Activation: Standard GAT activation for gradient flow
  • HNSW Integration: Leverages edge weights and multi-level connections

Implementation

use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
use std::collections::HashMap;

/// Edge-Featured Attention Mechanism
/// Computes attention over graph neighbors with edge feature integration
pub struct EdgeFeaturedAttention {
    /// Query projection weights [hidden_dim, hidden_dim]
    w_query: Array2<f32>,

    /// Key projection weights [hidden_dim, hidden_dim]
    w_key: Array2<f32>,

    /// Value projection weights [hidden_dim, hidden_dim]
    w_value: Array2<f32>,

    /// Edge feature projection [edge_dim, hidden_dim]
    w_edge: Array2<f32>,

    /// Attention scoring vector [2 * hidden_dim + hidden_dim]
    /// Concatenates [query || key || edge_features]
    a: Array1<f32>,

    /// LeakyReLU negative slope
    negative_slope: f32,

    /// Hidden dimension size
    hidden_dim: usize,

    /// Edge feature dimension
    edge_dim: usize,
}

impl EdgeFeaturedAttention {
    /// Create new EdgeFeaturedAttention layer
    pub fn new(hidden_dim: usize, edge_dim: usize, negative_slope: f32) -> Self {
        // Initialize with Xavier/Glorot uniform initialization
        let bound_h = (6.0 / (hidden_dim as f32 * 2.0)).sqrt();
        let bound_e = (6.0 / (edge_dim as f32 + hidden_dim as f32)).sqrt();
        let bound_a = (6.0 / (3.0 * hidden_dim as f32)).sqrt();

        Self {
            w_query: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound_h - bound_h
            }),
            w_key: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound_h - bound_h
            }),
            w_value: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound_h - bound_h
            }),
            w_edge: Array2::from_shape_fn((edge_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound_e - bound_e
            }),
            a: Array1::from_shape_fn(3 * hidden_dim, |_| {
                rand::random::<f32>() * 2.0 * bound_a - bound_a
            }),
            negative_slope,
            hidden_dim,
            edge_dim,
        }
    }

    /// Apply LeakyReLU activation
    #[inline]
    fn leaky_relu(&self, x: f32) -> f32 {
        if x >= 0.0 {
            x
        } else {
            self.negative_slope * x
        }
    }

    /// Compute attention coefficients for a single node
    ///
    /// # HNSW Integration Points:
    /// - `neighbor_ids`: Retrieved from HNSW.get_neighbors(node_id, layer)
    /// - `edge_features`: Stored in HNSW edge metadata or computed from distance
    /// - `layer`: HNSW layer level affects neighbor selection
    fn compute_attention_scores(
        &self,
        query_node: ArrayView1<f32>,           // [hidden_dim]
        neighbor_features: ArrayView2<f32>,   // [num_neighbors, hidden_dim]
        edge_features: ArrayView2<f32>,       // [num_neighbors, edge_dim]
    ) -> Array1<f32> {
        let num_neighbors = neighbor_features.nrows();

        // Project query, keys, and values
        let query = query_node.dot(&self.w_query);  // [hidden_dim]
        let keys = neighbor_features.dot(&self.w_key.t());  // [num_neighbors, hidden_dim]
        let edges_proj = edge_features.dot(&self.w_edge.t());  // [num_neighbors, hidden_dim]

        // Compute attention logits for each neighbor
        let mut logits = Array1::zeros(num_neighbors);

        for i in 0..num_neighbors {
            // Concatenate [query || key_i || edge_i]
            let mut concat = Array1::zeros(3 * self.hidden_dim);
            concat.slice_mut(s![0..self.hidden_dim]).assign(&query);
            concat.slice_mut(s![self.hidden_dim..2*self.hidden_dim])
                .assign(&keys.row(i));
            concat.slice_mut(s![2*self.hidden_dim..])
                .assign(&edges_proj.row(i));

            // Compute attention score: a^T * LeakyReLU(concat)
            let score: f32 = concat.iter()
                .zip(self.a.iter())
                .map(|(x, a)| self.leaky_relu(*x) * a)
                .sum();

            logits[i] = score;
        }

        // Apply softmax
        let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
        let exp_logits: Array1<f32> = logits.mapv(|x| (x - max_logit).exp());
        let sum_exp = exp_logits.sum();

        exp_logits / sum_exp
    }

    /// Forward pass: compute attended features for all nodes
    ///
    /// # HNSW Integration:
    /// ```rust
    /// // Pseudo-code for HNSW integration:
    /// for node_id in graph.nodes() {
    ///     // Get neighbors from HNSW at specific layer
    ///     let neighbors = hnsw.get_neighbors(node_id, layer);
    ///
    ///     // Extract edge features from HNSW metadata
    ///     let edge_feats = neighbors.iter().map(|&n| {
    ///         hnsw.get_edge_features(node_id, n, layer)
    ///     }).collect();
    ///
    ///     // Compute attention
    ///     let attended = self.forward_single(
    ///         node_features.row(node_id),
    ///         neighbor_features,
    ///         edge_feats
    ///     );
    /// }
    /// ```
    pub fn forward(
        &self,
        node_features: ArrayView2<f32>,       // [num_nodes, hidden_dim]
        adjacency: &HashMap<usize, Vec<usize>>,  // node_id -> neighbor_ids
        edge_features_map: &HashMap<(usize, usize), Array1<f32>>,  // (src, dst) -> edge_feat
    ) -> Array2<f32> {
        let num_nodes = node_features.nrows();
        let mut output = Array2::zeros((num_nodes, self.hidden_dim));

        for node_id in 0..num_nodes {
            if let Some(neighbors) = adjacency.get(&node_id) {
                if neighbors.is_empty() {
                    // No neighbors, apply self-loop
                    output.row_mut(node_id).assign(&node_features.row(node_id));
                    continue;
                }

                // Gather neighbor features
                let num_neighbors = neighbors.len();
                let mut neighbor_feats = Array2::zeros((num_neighbors, self.hidden_dim));
                let mut edge_feats = Array2::zeros((num_neighbors, self.edge_dim));

                for (i, &neighbor_id) in neighbors.iter().enumerate() {
                    neighbor_feats.row_mut(i).assign(&node_features.row(neighbor_id));

                    if let Some(edge_feat) = edge_features_map.get(&(node_id, neighbor_id)) {
                        edge_feats.row_mut(i).assign(edge_feat);
                    }
                }

                // Compute attention scores
                let attention_weights = self.compute_attention_scores(
                    node_features.row(node_id),
                    neighbor_feats.view(),
                    edge_feats.view(),
                );

                // Project neighbor features to values
                let values = neighbor_feats.dot(&self.w_value.t());  // [num_neighbors, hidden_dim]

                // Weighted sum of values
                let mut attended = Array1::zeros(self.hidden_dim);
                for i in 0..num_neighbors {
                    attended = attended + &(values.row(i).to_owned() * attention_weights[i]);
                }

                output.row_mut(node_id).assign(&attended);
            } else {
                // Isolated node
                output.row_mut(node_id).assign(&node_features.row(node_id));
            }
        }

        output
    }

    /// Forward pass for a single node (useful for online inference)
    pub fn forward_single(
        &self,
        query_node: ArrayView1<f32>,
        neighbor_features: ArrayView2<f32>,
        edge_features: ArrayView2<f32>,
    ) -> Array1<f32> {
        let attention_weights = self.compute_attention_scores(
            query_node,
            neighbor_features,
            edge_features,
        );

        let values = neighbor_features.dot(&self.w_value.t());

        let mut attended = Array1::zeros(self.hidden_dim);
        for i in 0..values.nrows() {
            attended = attended + &(values.row(i).to_owned() * attention_weights[i]);
        }

        attended
    }
}

// HNSW Integration Helper
pub struct HNSWEdgeFeatureExtractor {
    /// Extract edge features from HNSW metadata
    /// Features can include:
    /// - Edge weight (inverse of distance)
    /// - Layer level
    /// - Neighbor degree
    /// - Edge directionality
}

impl HNSWEdgeFeatureExtractor {
    pub fn extract_features(
        &self,
        src_id: usize,
        dst_id: usize,
        distance: f32,
        layer: usize,
        dst_degree: usize,
    ) -> Array1<f32> {
        // Example edge features [edge_dim = 4]
        Array1::from_vec(vec![
            1.0 / (distance + 1e-6),  // Edge weight (inverse distance)
            layer as f32,              // HNSW layer
            (dst_degree as f32).ln(),  // Log degree of neighbor
            1.0,                       // Bias term
        ])
    }
}

Usage Example

// Initialize attention layer
let attention = EdgeFeaturedAttention::new(
    128,    // hidden_dim
    4,      // edge_dim
    0.2,    // negative_slope for LeakyReLU
);

// HNSW Integration Example:
// 1. Query HNSW for neighbors
// let neighbors = hnsw.get_neighbors(node_id, layer);
//
// 2. Extract edge features
// let edge_extractor = HNSWEdgeFeatureExtractor::new();
// for &neighbor in neighbors {
//     let distance = hnsw.distance(node_id, neighbor);
//     let edge_feat = edge_extractor.extract_features(
//         node_id, neighbor, distance, layer, hnsw.degree(neighbor)
//     );
// }
//
// 3. Apply attention
// let output = attention.forward(node_features, adjacency, edge_features);

2. GraphRoPE (Rotary Position Embeddings for Graphs)

Adapts RoPE from transformers to encode structural positions in graphs using HNSW distances and layer information.

Design Principles

  • Distance-Based Rotation: Rotation angles based on graph distance
  • Layer-Aware Encoding: Different frequencies per HNSW layer
  • Relative Positioning: Encodes relative structural positions

Implementation

use ndarray::{Array1, Array2, ArrayView1, s};
use std::f32::consts::PI;

/// Graph Rotary Position Embeddings
/// Encodes graph structural positions via rotation in embedding space
pub struct GraphRoPE {
    /// Dimension of embeddings
    dim: usize,

    /// Base frequency for rotations
    base: f32,

    /// Maximum distance to encode
    max_distance: f32,

    /// Number of HNSW layers to support
    num_layers: usize,

    /// Precomputed frequency bands [dim/2]
    inv_freq: Array1<f32>,
}

impl GraphRoPE {
    /// Create new GraphRoPE encoder
    ///
    /// # Arguments
    /// * `dim` - Embedding dimension (must be even)
    /// * `base` - Base frequency (default: 10000.0 like in transformers)
    /// * `max_distance` - Maximum graph distance to encode
    /// * `num_layers` - Number of HNSW layers
    pub fn new(dim: usize, base: f32, max_distance: f32, num_layers: usize) -> Self {
        assert!(dim % 2 == 0, "Dimension must be even for RoPE");

        // Compute inverse frequencies: θ_i = base^(-2i/d) for i in [0, d/2)
        let inv_freq = Array1::from_shape_fn(dim / 2, |i| {
            1.0 / base.powf(2.0 * i as f32 / dim as f32)
        });

        Self {
            dim,
            base,
            max_distance,
            num_layers,
            inv_freq,
        }
    }

    /// Compute rotation matrix for a given distance and layer
    ///
    /// # HNSW Integration:
    /// - `distance`: Graph distance or HNSW distance metric
    /// - `layer`: HNSW layer level (0 = bottom, higher = more abstract)
    ///
    /// Returns: Rotation matrix [dim, dim] (block diagonal with 2x2 rotation blocks)
    fn compute_rotation_matrix(&self, distance: f32, layer: usize) -> Array2<f32> {
        // Layer-dependent frequency scaling
        // Higher layers = lower frequencies = coarser position encoding
        let layer_scale = 1.0 / (1.0 + layer as f32);

        // Normalize distance
        let normalized_dist = (distance / self.max_distance).min(1.0);

        let mut rotation = Array2::eye(self.dim);

        // Create 2D rotation blocks
        for i in 0..self.dim / 2 {
            let freq = self.inv_freq[i] * layer_scale;
            let theta = normalized_dist * freq;

            let cos_theta = theta.cos();
            let sin_theta = theta.sin();

            // 2x2 rotation block
            let idx1 = 2 * i;
            let idx2 = 2 * i + 1;

            rotation[[idx1, idx1]] = cos_theta;
            rotation[[idx1, idx2]] = -sin_theta;
            rotation[[idx2, idx1]] = sin_theta;
            rotation[[idx2, idx2]] = cos_theta;
        }

        rotation
    }

    /// Apply rotary position encoding to embeddings
    ///
    /// # Arguments
    /// * `embeddings` - Input embeddings [batch_size, dim]
    /// * `distances` - Graph distances for each embedding [batch_size]
    /// * `layers` - HNSW layer for each embedding [batch_size]
    ///
    /// # HNSW Integration Example:
    /// ```rust
    /// // For each node, compute distance from query node
    /// let distances = nodes.iter().map(|&node_id| {
    ///     hnsw.shortest_path_distance(query_id, node_id)
    /// }).collect();
    ///
    /// // Use HNSW layer information
    /// let layers = nodes.iter().map(|&node_id| {
    ///     hnsw.get_node_layer(node_id)
    /// }).collect();
    ///
    /// let rotated = rope.apply_rotation(embeddings, &distances, &layers);
    /// ```
    pub fn apply_rotation(
        &self,
        embeddings: ArrayView2<f32>,  // [batch_size, dim]
        distances: &[f32],             // [batch_size]
        layers: &[usize],              // [batch_size]
    ) -> Array2<f32> {
        let batch_size = embeddings.nrows();
        assert_eq!(batch_size, distances.len());
        assert_eq!(batch_size, layers.len());

        let mut output = Array2::zeros((batch_size, self.dim));

        for i in 0..batch_size {
            let rotation = self.compute_rotation_matrix(distances[i], layers[i]);
            let rotated = rotation.dot(&embeddings.row(i));
            output.row_mut(i).assign(&rotated);
        }

        output
    }

    /// Apply rotation to a single embedding
    pub fn apply_rotation_single(
        &self,
        embedding: ArrayView1<f32>,
        distance: f32,
        layer: usize,
    ) -> Array1<f32> {
        let rotation = self.compute_rotation_matrix(distance, layer);
        rotation.dot(&embedding)
    }

    /// Compute relative rotary embeddings between two nodes
    /// This encodes the relative position in graph space
    ///
    /// # HNSW Integration:
    /// ```rust
    /// let query_emb = node_embeddings.row(query_id);
    /// let key_emb = node_embeddings.row(key_id);
    ///
    /// // Get HNSW distance
    /// let distance = hnsw.distance(query_id, key_id);
    /// let layer = hnsw.get_common_layer(query_id, key_id);
    ///
    /// let (rotated_q, rotated_k) = rope.apply_relative_rotation(
    ///     query_emb, key_emb, distance, layer
    /// );
    ///
    /// // Compute attention with relative position encoding
    /// let score = rotated_q.dot(&rotated_k);
    /// ```
    pub fn apply_relative_rotation(
        &self,
        query_emb: ArrayView1<f32>,
        key_emb: ArrayView1<f32>,
        distance: f32,
        layer: usize,
    ) -> (Array1<f32>, Array1<f32>) {
        let rotation = self.compute_rotation_matrix(distance, layer);

        // Apply rotation to both query and key
        let rotated_query = rotation.dot(&query_emb);
        let rotated_key = rotation.dot(&key_emb);

        (rotated_query, rotated_key)
    }

    /// Encode distances as sinusoidal features (alternative to rotation)
    /// Useful for edge features or distance embeddings
    pub fn encode_distance(&self, distance: f32, layer: usize) -> Array1<f32> {
        let layer_scale = 1.0 / (1.0 + layer as f32);
        let normalized_dist = (distance / self.max_distance).min(1.0);

        let mut encoding = Array1::zeros(self.dim);

        for i in 0..self.dim / 2 {
            let freq = self.inv_freq[i] * layer_scale;
            let angle = normalized_dist * freq;

            encoding[2 * i] = angle.sin();
            encoding[2 * i + 1] = angle.cos();
        }

        encoding
    }
}

/// HNSW-Aware Distance Computer
pub struct HNSWDistanceComputer {
    /// Compute graph distances using HNSW structure
}

impl HNSWDistanceComputer {
    /// Compute shortest path distance using HNSW layers
    /// Higher layers provide shortcuts for faster distance computation
    pub fn shortest_path_distance(
        &self,
        hnsw: &dyn HNSWInterface,
        source: usize,
        target: usize,
    ) -> f32 {
        // Start from highest layer for efficiency
        let max_layer = hnsw.get_max_level();

        // BFS with layer-aware traversal
        // Implementation would use HNSW's hierarchical structure
        // to compute distances efficiently

        // Placeholder implementation
        hnsw.distance(source, target)
    }

    /// Get the highest common layer between two nodes
    pub fn get_common_layer(
        &self,
        hnsw: &dyn HNSWInterface,
        node1: usize,
        node2: usize,
    ) -> usize {
        let layer1 = hnsw.get_node_layer(node1);
        let layer2 = hnsw.get_node_layer(node2);
        layer1.min(layer2)
    }
}

// Trait for HNSW interface (abstraction for integration)
pub trait HNSWInterface {
    fn distance(&self, id1: usize, id2: usize) -> f32;
    fn get_max_level(&self) -> usize;
    fn get_node_layer(&self, id: usize) -> usize;
    fn get_neighbors(&self, id: usize, layer: usize) -> Vec<usize>;
}

Usage Example

// Initialize GraphRoPE
let rope = GraphRoPE::new(
    128,      // dim
    10000.0,  // base (same as transformer RoPE)
    20.0,     // max_distance (graph hops)
    8,        // num_layers (HNSW layers)
);

// HNSW Integration:
// 1. Compute distances from query node
// let distance_computer = HNSWDistanceComputer::new();
// let distances: Vec<f32> = nodes.iter().map(|&node_id| {
//     distance_computer.shortest_path_distance(&hnsw, query_id, node_id)
// }).collect();
//
// 2. Get layer information
// let layers: Vec<usize> = nodes.iter().map(|&node_id| {
//     hnsw.get_node_layer(node_id)
// }).collect();
//
// 3. Apply rotary embeddings
// let rotated_embeddings = rope.apply_rotation(
//     embeddings.view(),
//     &distances,
//     &layers,
// );

// For attention computation with relative positions:
// let (rotated_q, rotated_k) = rope.apply_relative_rotation(
//     query_embedding,
//     key_embedding,
//     distance,
//     layer,
// );
// let attention_score = rotated_q.dot(&rotated_k) / (dim as f32).sqrt();

3. DualSpaceAttention (Cross-Attention)

Performs cross-attention between graph-space neighbors (from original graph structure) and latent-space neighbors (from HNSW index), fusing both structural and semantic information.

Design Principles

  • Dual Neighbor Sets: Graph neighbors (structure) + Latent neighbors (semantics)
  • Cross-Attention Fusion: Attend across both spaces simultaneously
  • HNSW Latent Search: Leverage HNSW for efficient semantic neighbor retrieval

Implementation

use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, s};
use std::collections::{HashMap, HashSet};

/// Dual-Space Cross-Attention
/// Attends to both graph-structure neighbors and latent-space neighbors
pub struct DualSpaceAttention {
    /// Graph-space attention head
    graph_attention: GraphAttentionHead,

    /// Latent-space attention head
    latent_attention: LatentAttentionHead,

    /// Cross-attention fusion layer
    fusion: CrossAttentionFusion,

    /// Hidden dimension
    hidden_dim: usize,

    /// Number of latent neighbors to retrieve from HNSW
    k_latent: usize,
}

impl DualSpaceAttention {
    pub fn new(hidden_dim: usize, k_latent: usize) -> Self {
        Self {
            graph_attention: GraphAttentionHead::new(hidden_dim),
            latent_attention: LatentAttentionHead::new(hidden_dim),
            fusion: CrossAttentionFusion::new(hidden_dim),
            hidden_dim,
            k_latent,
        }
    }

    /// Forward pass with dual-space attention
    ///
    /// # HNSW Integration:
    /// ```rust
    /// // 1. Graph neighbors (from original graph structure)
    /// let graph_neighbors = graph.get_neighbors(node_id);
    ///
    /// // 2. Latent neighbors (from HNSW semantic search)
    /// let latent_neighbors = hnsw.search(
    ///     node_embeddings.row(node_id),
    ///     k_latent,
    ///     layer
    /// );
    ///
    /// // 3. Apply dual-space attention
    /// let output = dual_attention.forward(
    ///     node_id,
    ///     &node_embeddings,
    ///     &graph_neighbors,
    ///     &latent_neighbors,
    /// );
    /// ```
    pub fn forward(
        &self,
        node_id: usize,
        node_embeddings: ArrayView2<f32>,        // [num_nodes, hidden_dim]
        graph_neighbors: &[usize],                // From graph structure
        latent_neighbors: &[(usize, f32)],        // From HNSW (id, distance)
    ) -> Array1<f32> {
        // Query embedding
        let query = node_embeddings.row(node_id);

        // 1. Graph-space attention
        let graph_context = self.graph_attention.attend(
            query,
            graph_neighbors,
            node_embeddings,
        );

        // 2. Latent-space attention
        let latent_context = self.latent_attention.attend(
            query,
            latent_neighbors,
            node_embeddings,
        );

        // 3. Cross-attention fusion
        let fused = self.fusion.fuse(
            query,
            graph_context.view(),
            latent_context.view(),
        );

        fused
    }

    /// Find latent neighbors using HNSW
    ///
    /// # HNSW Integration:
    /// This is a critical integration point where we query HNSW index
    /// to find semantically similar nodes in the latent space
    pub fn find_latent_neighbors(
        &self,
        hnsw: &dyn HNSWInterface,
        query_embedding: ArrayView1<f32>,
        k: usize,
        layer: usize,
    ) -> Vec<(usize, f32)> {
        // HNSW search for k nearest neighbors in latent space
        // Returns: Vec<(node_id, distance)>

        // Pseudo-code for HNSW integration:
        // let results = hnsw.search_layer(
        //     query_embedding,
        //     k,
        //     layer,
        //     ef_search=50  // Search parameter
        // );

        // Placeholder implementation
        vec![]
    }

    /// Batch forward pass for multiple nodes
    pub fn forward_batch(
        &self,
        node_embeddings: ArrayView2<f32>,
        graph_adjacency: &HashMap<usize, Vec<usize>>,
        hnsw: &dyn HNSWInterface,
        layer: usize,
    ) -> Array2<f32> {
        let num_nodes = node_embeddings.nrows();
        let mut output = Array2::zeros((num_nodes, self.hidden_dim));

        for node_id in 0..num_nodes {
            // Get graph neighbors
            let graph_neighbors = graph_adjacency
                .get(&node_id)
                .map(|v| v.as_slice())
                .unwrap_or(&[]);

            // Get latent neighbors from HNSW
            let latent_neighbors = self.find_latent_neighbors(
                hnsw,
                node_embeddings.row(node_id),
                self.k_latent,
                layer,
            );

            // Apply dual-space attention
            let node_output = self.forward(
                node_id,
                node_embeddings,
                graph_neighbors,
                &latent_neighbors,
            );

            output.row_mut(node_id).assign(&node_output);
        }

        output
    }
}

/// Graph-space attention head (attends to structural neighbors)
struct GraphAttentionHead {
    w_query: Array2<f32>,
    w_key: Array2<f32>,
    w_value: Array2<f32>,
    hidden_dim: usize,
}

impl GraphAttentionHead {
    fn new(hidden_dim: usize) -> Self {
        let bound = (6.0 / (hidden_dim as f32 * 2.0)).sqrt();

        Self {
            w_query: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound - bound
            }),
            w_key: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound - bound
            }),
            w_value: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound - bound
            }),
            hidden_dim,
        }
    }

    fn attend(
        &self,
        query: ArrayView1<f32>,
        neighbor_ids: &[usize],
        node_embeddings: ArrayView2<f32>,
    ) -> Array1<f32> {
        if neighbor_ids.is_empty() {
            return query.to_owned();
        }

        // Project query
        let q = query.dot(&self.w_query);  // [hidden_dim]

        // Gather and project keys and values
        let num_neighbors = neighbor_ids.len();
        let mut keys = Array2::zeros((num_neighbors, self.hidden_dim));
        let mut values = Array2::zeros((num_neighbors, self.hidden_dim));

        for (i, &neighbor_id) in neighbor_ids.iter().enumerate() {
            let neighbor_emb = node_embeddings.row(neighbor_id);
            keys.row_mut(i).assign(&neighbor_emb.dot(&self.w_key));
            values.row_mut(i).assign(&neighbor_emb.dot(&self.w_value));
        }

        // Compute attention scores
        let scale = 1.0 / (self.hidden_dim as f32).sqrt();
        let mut scores = Array1::zeros(num_neighbors);
        for i in 0..num_neighbors {
            scores[i] = q.dot(&keys.row(i)) * scale;
        }

        // Softmax
        let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
        let exp_scores: Array1<f32> = scores.mapv(|x| (x - max_score).exp());
        let sum_exp = exp_scores.sum();
        let attention_weights = exp_scores / sum_exp;

        // Weighted sum of values
        let mut output = Array1::zeros(self.hidden_dim);
        for i in 0..num_neighbors {
            output = output + &(values.row(i).to_owned() * attention_weights[i]);
        }

        output
    }
}

/// Latent-space attention head (attends to semantic neighbors from HNSW)
struct LatentAttentionHead {
    w_query: Array2<f32>,
    w_key: Array2<f32>,
    w_value: Array2<f32>,
    hidden_dim: usize,
}

impl LatentAttentionHead {
    fn new(hidden_dim: usize) -> Self {
        let bound = (6.0 / (hidden_dim as f32 * 2.0)).sqrt();

        Self {
            w_query: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound - bound
            }),
            w_key: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound - bound
            }),
            w_value: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound - bound
            }),
            hidden_dim,
        }
    }

    fn attend(
        &self,
        query: ArrayView1<f32>,
        latent_neighbors: &[(usize, f32)],  // (neighbor_id, distance)
        node_embeddings: ArrayView2<f32>,
    ) -> Array1<f32> {
        if latent_neighbors.is_empty() {
            return query.to_owned();
        }

        // Project query
        let q = query.dot(&self.w_query);

        let num_neighbors = latent_neighbors.len();
        let mut keys = Array2::zeros((num_neighbors, self.hidden_dim));
        let mut values = Array2::zeros((num_neighbors, self.hidden_dim));

        // Distance-weighted attention bias
        let mut distance_weights = Array1::zeros(num_neighbors);

        for (i, &(neighbor_id, distance)) in latent_neighbors.iter().enumerate() {
            let neighbor_emb = node_embeddings.row(neighbor_id);
            keys.row_mut(i).assign(&neighbor_emb.dot(&self.w_key));
            values.row_mut(i).assign(&neighbor_emb.dot(&self.w_value));

            // Convert HNSW distance to attention bias
            // Closer neighbors (smaller distance) get positive bias
            distance_weights[i] = -distance;
        }

        // Compute attention scores with distance bias
        let scale = 1.0 / (self.hidden_dim as f32).sqrt();
        let mut scores = Array1::zeros(num_neighbors);
        for i in 0..num_neighbors {
            scores[i] = q.dot(&keys.row(i)) * scale + distance_weights[i];
        }

        // Softmax
        let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
        let exp_scores: Array1<f32> = scores.mapv(|x| (x - max_score).exp());
        let sum_exp = exp_scores.sum();
        let attention_weights = exp_scores / sum_exp;

        // Weighted sum
        let mut output = Array1::zeros(self.hidden_dim);
        for i in 0..num_neighbors {
            output = output + &(values.row(i).to_owned() * attention_weights[i]);
        }

        output
    }
}

/// Cross-attention fusion layer
/// Fuses information from graph-space and latent-space contexts
struct CrossAttentionFusion {
    // Cross-attention: query=original, keys/values=graph_context
    w_graph_key: Array2<f32>,
    w_graph_value: Array2<f32>,

    // Cross-attention: query=original, keys/values=latent_context
    w_latent_key: Array2<f32>,
    w_latent_value: Array2<f32>,

    // Fusion weights
    w_fusion: Array2<f32>,

    hidden_dim: usize,
}

impl CrossAttentionFusion {
    fn new(hidden_dim: usize) -> Self {
        let bound = (6.0 / (hidden_dim as f32 * 2.0)).sqrt();

        Self {
            w_graph_key: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound - bound
            }),
            w_graph_value: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound - bound
            }),
            w_latent_key: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound - bound
            }),
            w_latent_value: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound - bound
            }),
            w_fusion: Array2::from_shape_fn((2 * hidden_dim, hidden_dim), |_| {
                rand::random::<f32>() * 2.0 * bound - bound
            }),
            hidden_dim,
        }
    }

    fn fuse(
        &self,
        query: ArrayView1<f32>,           // Original node embedding
        graph_context: ArrayView1<f32>,   // From graph-space attention
        latent_context: ArrayView1<f32>,  // From latent-space attention
    ) -> Array1<f32> {
        // Cross-attention with graph context
        let graph_key = graph_context.dot(&self.w_graph_key);
        let graph_value = graph_context.dot(&self.w_graph_value);
        let graph_score = query.dot(&graph_key) / (self.hidden_dim as f32).sqrt();
        let graph_attended = graph_value * graph_score.tanh();  // Gated by attention

        // Cross-attention with latent context
        let latent_key = latent_context.dot(&self.w_latent_key);
        let latent_value = latent_context.dot(&self.w_latent_value);
        let latent_score = query.dot(&latent_key) / (self.hidden_dim as f32).sqrt();
        let latent_attended = latent_value * latent_score.tanh();

        // Concatenate and fuse
        let mut concat = Array1::zeros(2 * self.hidden_dim);
        concat.slice_mut(s![0..self.hidden_dim]).assign(&graph_attended);
        concat.slice_mut(s![self.hidden_dim..]).assign(&latent_attended);

        let fused = concat.dot(&self.w_fusion);

        // Residual connection
        query.to_owned() + fused
    }
}

/// HNSW Search Integration Helper
pub struct HNSWLatentSearch {
    /// Search parameters
    ef_search: usize,
}

impl HNSWLatentSearch {
    pub fn new(ef_search: usize) -> Self {
        Self { ef_search }
    }

    /// Search HNSW for k nearest neighbors in latent space
    ///
    /// # Arguments
    /// * `hnsw` - HNSW index
    /// * `query_embedding` - Query vector
    /// * `k` - Number of neighbors to return
    /// * `layer` - HNSW layer to search (higher = more abstract)
    ///
    /// # Returns
    /// Vec<(node_id, distance)> sorted by distance (ascending)
    pub fn search(
        &self,
        hnsw: &dyn HNSWInterface,
        query_embedding: ArrayView1<f32>,
        k: usize,
        layer: usize,
    ) -> Vec<(usize, f32)> {
        // Pseudo-code for HNSW integration:
        //
        // 1. Start from entry point at given layer
        // let entry_point = hnsw.get_entry_point(layer);
        //
        // 2. Greedy search to find closest node
        // let mut current = entry_point;
        // loop {
        //     let neighbors = hnsw.get_neighbors(current, layer);
        //     let (closest, dist) = find_closest(neighbors, query_embedding);
        //     if dist >= current_dist { break; }
        //     current = closest;
        // }
        //
        // 3. Beam search for k neighbors
        // let mut candidates = PriorityQueue::new();
        // let mut results = PriorityQueue::new();
        // candidates.push(current, distance);
        //
        // while !candidates.is_empty() && results.len() < ef_search {
        //     let (node, dist) = candidates.pop();
        //     if dist > results.peek().dist { break; }
        //
        //     for neighbor in hnsw.get_neighbors(node, layer) {
        //         let neighbor_dist = distance(query_embedding, neighbor_embedding);
        //         if neighbor_dist < results.peek().dist {
        //             candidates.push(neighbor, neighbor_dist);
        //             results.push(neighbor, neighbor_dist);
        //             if results.len() > ef_search {
        //                 results.pop();
        //             }
        //         }
        //     }
        // }
        //
        // 4. Return top-k results
        // results.into_iter().take(k).collect()

        // Placeholder
        vec![]
    }
}

Usage Example

// Initialize dual-space attention
let dual_attention = DualSpaceAttention::new(
    128,  // hidden_dim
    16,   // k_latent neighbors
);

// HNSW Integration Example:
// 1. Get graph neighbors (from original graph)
// let graph_neighbors = graph.adjacency.get(&node_id).unwrap();
//
// 2. Search HNSW for latent neighbors
// let hnsw_search = HNSWLatentSearch::new(50);  // ef_search=50
// let latent_neighbors = hnsw_search.search(
//     &hnsw,
//     node_embeddings.row(node_id),
//     16,    // k
//     layer,
// );
//
// 3. Apply dual-space attention
// let output = dual_attention.forward(
//     node_id,
//     node_embeddings.view(),
//     graph_neighbors,
//     &latent_neighbors,
// );

// Batch processing:
// let output_embeddings = dual_attention.forward_batch(
//     node_embeddings.view(),
//     &graph_adjacency,
//     &hnsw,
//     layer,
// );

Integration Architecture

Complete Pipeline with HNSW

pub struct GraphAttentionPipeline {
    /// Edge-featured attention for local neighborhood
    edge_attention: EdgeFeaturedAttention,

    /// Graph RoPE for positional encoding
    rope: GraphRoPE,

    /// Dual-space attention for graph-latent fusion
    dual_attention: DualSpaceAttention,

    /// HNSW index for latent space search
    hnsw: Box<dyn HNSWInterface>,
}

impl GraphAttentionPipeline {
    pub fn forward(
        &mut self,
        node_features: ArrayView2<f32>,
        graph_adjacency: &HashMap<usize, Vec<usize>>,
        edge_features: &HashMap<(usize, usize), Array1<f32>>,
        layer: usize,
    ) -> Array2<f32> {
        let num_nodes = node_features.nrows();

        // 1. Apply EdgeFeaturedAttention for local context
        let local_context = self.edge_attention.forward(
            node_features,
            graph_adjacency,
            edge_features,
        );

        // 2. Compute distances for RoPE
        let query_id = 0;  // Example: use node 0 as reference
        let distances: Vec<f32> = (0..num_nodes)
            .map(|node_id| {
                self.hnsw.distance(query_id, node_id)
            })
            .collect();

        let layers: Vec<usize> = (0..num_nodes)
            .map(|node_id| self.hnsw.get_node_layer(node_id))
            .collect();

        // 3. Apply GraphRoPE positional encoding
        let positioned = self.rope.apply_rotation(
            local_context.view(),
            &distances,
            &layers,
        );

        // 4. Apply DualSpaceAttention for graph-latent fusion
        let output = self.dual_attention.forward_batch(
            positioned.view(),
            graph_adjacency,
            self.hnsw.as_ref(),
            layer,
        );

        output
    }
}

Performance Considerations

Memory Efficiency

  • Sparse Attention: Only attend to k-nearest neighbors
  • Layer-wise Processing: Process HNSW layers incrementally
  • Batch Operations: Vectorize attention computation

Computational Complexity

  • EdgeFeaturedAttention: O(|E| × d²) where |E| is number of edges
  • GraphRoPE: O(n × d²) where n is number of nodes
  • DualSpaceAttention: O(n × (k_graph + k_latent) × d²)

HNSW Integration Benefits

  1. Fast Neighbor Search: O(log n) latent neighbor retrieval
  2. Multi-Scale Structure: Layer-aware attention at different resolutions
  3. Distance Metrics: Pre-computed distances for efficient RoPE
  4. Dynamic Updates: Add new nodes without full retraining

Testing Strategy

Unit Tests

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_edge_featured_attention() {
        let attention = EdgeFeaturedAttention::new(64, 4, 0.2);
        // Test with synthetic graph
    }

    #[test]
    fn test_graph_rope() {
        let rope = GraphRoPE::new(64, 10000.0, 10.0, 4);
        // Test rotation properties
    }

    #[test]
    fn test_dual_space_attention() {
        let dual = DualSpaceAttention::new(64, 8);
        // Test with mock HNSW
    }
}

Integration Tests

  • Test with real HNSW indices
  • Validate attention distributions
  • Benchmark search performance
  • Verify gradient flow

Next Steps

  1. Implement HNSW Interface: Create concrete implementation or adapter
  2. Gradient Computation: Add backward pass for training
  3. Multi-Head Attention: Extend to multi-head versions
  4. Layer Normalization: Add normalization for stable training
  5. Benchmarking: Compare with baseline GNN attention mechanisms

References

  • GAT (Graph Attention Networks): Veličković et al., 2018
  • RoPE (Rotary Position Embeddings): Su et al., 2021
  • HNSW (Hierarchical Navigable Small World): Malkov & Yashunin, 2018
  • Cross-Attention Mechanisms: Vaswani et al., 2017 (Transformers)