Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
412
vendor/ruvector/crates/ruvector-attention/src/graph/dual_space.rs
vendored
Normal file
412
vendor/ruvector/crates/ruvector-attention/src/graph/dual_space.rs
vendored
Normal file
@@ -0,0 +1,412 @@
|
||||
//! Dual-space attention combining Euclidean and Hyperbolic geometries
|
||||
//!
|
||||
//! This module implements attention that operates in both Euclidean and hyperbolic
|
||||
//! spaces, combining their complementary properties:
|
||||
//! - Euclidean: Good for flat, local structure
|
||||
//! - Hyperbolic: Good for hierarchical, tree-like structure
|
||||
|
||||
use crate::error::{AttentionError, AttentionResult};
|
||||
use crate::hyperbolic::project_to_ball;
|
||||
use crate::traits::Attention;
|
||||
use crate::utils::stable_softmax;
|
||||
|
||||
/// Compute Poincaré distance between two points
|
||||
fn poincare_dist(u: &[f32], v: &[f32], curvature: f32) -> f32 {
|
||||
let c = curvature.abs();
|
||||
let sqrt_c = c.sqrt();
|
||||
|
||||
let diff_sq: f32 = u.iter().zip(v.iter()).map(|(a, b)| (a - b).powi(2)).sum();
|
||||
let norm_u_sq: f32 = u.iter().map(|x| x * x).sum();
|
||||
let norm_v_sq: f32 = v.iter().map(|x| x * x).sum();
|
||||
|
||||
let denom = (1.0 - c * norm_u_sq).max(1e-7) * (1.0 - c * norm_v_sq).max(1e-7);
|
||||
let arg = 1.0 + 2.0 * c * diff_sq / denom;
|
||||
|
||||
(1.0 / sqrt_c) * arg.max(1.0).acosh()
|
||||
}
|
||||
|
||||
/// Configuration for dual-space attention
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DualSpaceConfig {
|
||||
pub dim: usize,
|
||||
pub curvature: f32,
|
||||
pub euclidean_weight: f32,
|
||||
pub hyperbolic_weight: f32,
|
||||
pub learn_weights: bool,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl Default for DualSpaceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dim: 256,
|
||||
curvature: 1.0,
|
||||
euclidean_weight: 0.5,
|
||||
hyperbolic_weight: 0.5,
|
||||
learn_weights: false,
|
||||
temperature: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DualSpaceConfig {
|
||||
pub fn builder() -> DualSpaceConfigBuilder {
|
||||
DualSpaceConfigBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DualSpaceConfigBuilder {
|
||||
config: DualSpaceConfig,
|
||||
}
|
||||
|
||||
impl DualSpaceConfigBuilder {
|
||||
pub fn dim(mut self, d: usize) -> Self {
|
||||
self.config.dim = d;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn curvature(mut self, c: f32) -> Self {
|
||||
self.config.curvature = c;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn euclidean_weight(mut self, w: f32) -> Self {
|
||||
self.config.euclidean_weight = w;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn hyperbolic_weight(mut self, w: f32) -> Self {
|
||||
self.config.hyperbolic_weight = w;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn temperature(mut self, t: f32) -> Self {
|
||||
self.config.temperature = t;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> DualSpaceConfig {
|
||||
self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Dual-space attention layer
|
||||
pub struct DualSpaceAttention {
|
||||
config: DualSpaceConfig,
|
||||
scale: f32,
|
||||
/// Linear projection for Euclidean space
|
||||
w_euclidean: Vec<f32>,
|
||||
/// Linear projection for hyperbolic space
|
||||
w_hyperbolic: Vec<f32>,
|
||||
/// Output projection
|
||||
w_out: Vec<f32>,
|
||||
}
|
||||
|
||||
impl DualSpaceAttention {
|
||||
pub fn new(config: DualSpaceConfig) -> Self {
|
||||
let dim = config.dim;
|
||||
let scale = 1.0 / (dim as f32).sqrt();
|
||||
|
||||
// Xavier initialization
|
||||
let w_scale = (2.0 / (dim + dim) as f32).sqrt();
|
||||
let mut seed = 42u64;
|
||||
let mut rand = || {
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
((seed as f32) / (u64::MAX as f32) - 0.5) * 2.0 * w_scale
|
||||
};
|
||||
|
||||
let w_euclidean: Vec<f32> = (0..dim * dim).map(|_| rand()).collect();
|
||||
let w_hyperbolic: Vec<f32> = (0..dim * dim).map(|_| rand()).collect();
|
||||
let w_out: Vec<f32> = (0..dim * dim).map(|_| rand()).collect();
|
||||
|
||||
Self {
|
||||
config,
|
||||
scale,
|
||||
w_euclidean,
|
||||
w_hyperbolic,
|
||||
w_out,
|
||||
}
|
||||
}
|
||||
|
||||
/// Project to Euclidean representation
|
||||
fn to_euclidean(&self, x: &[f32]) -> Vec<f32> {
|
||||
let dim = self.config.dim;
|
||||
(0..dim)
|
||||
.map(|i| {
|
||||
x.iter()
|
||||
.enumerate()
|
||||
.map(|(j, &xj)| xj * self.w_euclidean[i * dim + j])
|
||||
.sum()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Project to hyperbolic representation (Poincaré ball)
|
||||
fn to_hyperbolic(&self, x: &[f32]) -> Vec<f32> {
|
||||
let dim = self.config.dim;
|
||||
let projected: Vec<f32> = (0..dim)
|
||||
.map(|i| {
|
||||
x.iter()
|
||||
.enumerate()
|
||||
.map(|(j, &xj)| xj * self.w_hyperbolic[i * dim + j])
|
||||
.sum()
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Project to ball with curvature
|
||||
project_to_ball(&projected, self.config.curvature, 1e-5)
|
||||
}
|
||||
|
||||
/// Compute Euclidean similarity (dot product)
|
||||
fn euclidean_similarity(&self, q: &[f32], k: &[f32]) -> f32 {
|
||||
q.iter().zip(k.iter()).map(|(a, b)| a * b).sum::<f32>() * self.scale
|
||||
}
|
||||
|
||||
/// Compute hyperbolic similarity (negative Poincaré distance)
|
||||
fn hyperbolic_similarity(&self, q: &[f32], k: &[f32]) -> f32 {
|
||||
-poincare_dist(q, k, self.config.curvature)
|
||||
}
|
||||
|
||||
/// Output projection
|
||||
fn project_output(&self, x: &[f32]) -> Vec<f32> {
|
||||
let dim = self.config.dim;
|
||||
(0..dim)
|
||||
.map(|i| {
|
||||
x.iter()
|
||||
.enumerate()
|
||||
.map(|(j, &xj)| xj * self.w_out[i * dim + j])
|
||||
.sum()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the contribution weights for analysis
|
||||
pub fn get_space_contributions(&self, query: &[f32], keys: &[&[f32]]) -> (Vec<f32>, Vec<f32>) {
|
||||
let q_euc = self.to_euclidean(query);
|
||||
let q_hyp = self.to_hyperbolic(query);
|
||||
|
||||
let euc_scores: Vec<f32> = keys
|
||||
.iter()
|
||||
.map(|k| {
|
||||
let k_euc = self.to_euclidean(k);
|
||||
self.euclidean_similarity(&q_euc, &k_euc)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let hyp_scores: Vec<f32> = keys
|
||||
.iter()
|
||||
.map(|k| {
|
||||
let k_hyp = self.to_hyperbolic(k);
|
||||
self.hyperbolic_similarity(&q_hyp, &k_hyp)
|
||||
})
|
||||
.collect();
|
||||
|
||||
(euc_scores, hyp_scores)
|
||||
}
|
||||
}
|
||||
|
||||
impl Attention for DualSpaceAttention {
|
||||
fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if keys.is_empty() {
|
||||
return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
|
||||
}
|
||||
if query.len() != self.config.dim {
|
||||
return Err(AttentionError::DimensionMismatch {
|
||||
expected: self.config.dim,
|
||||
actual: query.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let n = keys.len();
|
||||
let value_dim = values[0].len();
|
||||
let temp = self.config.temperature;
|
||||
|
||||
// Project query to both spaces
|
||||
let q_euc = self.to_euclidean(query);
|
||||
let q_hyp = self.to_hyperbolic(query);
|
||||
|
||||
// Compute combined scores
|
||||
let mut combined_scores = Vec::with_capacity(n);
|
||||
|
||||
for key in keys.iter() {
|
||||
let k_euc = self.to_euclidean(key);
|
||||
let k_hyp = self.to_hyperbolic(key);
|
||||
|
||||
let euc_score = self.euclidean_similarity(&q_euc, &k_euc);
|
||||
let hyp_score = self.hyperbolic_similarity(&q_hyp, &k_hyp);
|
||||
|
||||
// Weighted combination
|
||||
let combined = (self.config.euclidean_weight * euc_score
|
||||
+ self.config.hyperbolic_weight * hyp_score)
|
||||
/ temp;
|
||||
|
||||
combined_scores.push(combined);
|
||||
}
|
||||
|
||||
// Softmax over combined scores
|
||||
let weights = stable_softmax(&combined_scores);
|
||||
|
||||
// Weighted sum of values
|
||||
let mut output = vec![0.0f32; value_dim];
|
||||
for (w, v) in weights.iter().zip(values.iter()) {
|
||||
for (o, &vi) in output.iter_mut().zip(v.iter()) {
|
||||
*o += w * vi;
|
||||
}
|
||||
}
|
||||
|
||||
// Output projection
|
||||
if value_dim == self.config.dim {
|
||||
Ok(self.project_output(&output))
|
||||
} else {
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_with_mask(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
mask: Option<&[bool]>,
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if let Some(m) = mask {
|
||||
let filtered: Vec<(usize, bool)> = m
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.filter(|(_, keep)| *keep)
|
||||
.collect();
|
||||
let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
|
||||
let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
|
||||
self.compute(query, &filtered_keys, &filtered_values)
|
||||
} else {
|
||||
self.compute(query, keys, values)
|
||||
}
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
self.config.dim
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dual_space_basic() {
|
||||
let config = DualSpaceConfig::builder()
|
||||
.dim(64)
|
||||
.curvature(1.0)
|
||||
.euclidean_weight(0.5)
|
||||
.hyperbolic_weight(0.5)
|
||||
.build();
|
||||
|
||||
let attn = DualSpaceAttention::new(config);
|
||||
|
||||
let query = vec![0.1; 64];
|
||||
let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.1; 64]).collect();
|
||||
let values: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 64]).collect();
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
assert_eq!(result.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_dominant() {
|
||||
let config = DualSpaceConfig::builder()
|
||||
.dim(32)
|
||||
.euclidean_weight(1.0)
|
||||
.hyperbolic_weight(0.0)
|
||||
.build();
|
||||
|
||||
let attn = DualSpaceAttention::new(config);
|
||||
|
||||
let query = vec![0.5; 32];
|
||||
let keys: Vec<Vec<f32>> = vec![vec![0.3; 32]; 5];
|
||||
let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
assert_eq!(result.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperbolic_dominant() {
|
||||
let config = DualSpaceConfig::builder()
|
||||
.dim(32)
|
||||
.curvature(0.5)
|
||||
.euclidean_weight(0.0)
|
||||
.hyperbolic_weight(1.0)
|
||||
.build();
|
||||
|
||||
let attn = DualSpaceAttention::new(config);
|
||||
|
||||
let query = vec![0.1; 32]; // Small values for Poincaré ball
|
||||
let keys: Vec<Vec<f32>> = vec![vec![0.1; 32]; 5];
|
||||
let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
assert_eq!(result.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_space_contributions() {
|
||||
let config = DualSpaceConfig::builder()
|
||||
.dim(16)
|
||||
.euclidean_weight(0.5)
|
||||
.hyperbolic_weight(0.5)
|
||||
.build();
|
||||
|
||||
let attn = DualSpaceAttention::new(config);
|
||||
|
||||
let query = vec![0.2; 16];
|
||||
let keys: Vec<Vec<f32>> = vec![vec![0.2; 16]; 3];
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
|
||||
let (euc_scores, hyp_scores) = attn.get_space_contributions(&query, &keys_refs);
|
||||
|
||||
assert_eq!(euc_scores.len(), 3);
|
||||
assert_eq!(hyp_scores.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temperature_scaling() {
|
||||
let config_low_temp = DualSpaceConfig::builder().dim(16).temperature(0.5).build();
|
||||
|
||||
let config_high_temp = DualSpaceConfig::builder().dim(16).temperature(2.0).build();
|
||||
|
||||
let attn_low = DualSpaceAttention::new(config_low_temp);
|
||||
let attn_high = DualSpaceAttention::new(config_high_temp);
|
||||
|
||||
let query = vec![0.5; 16];
|
||||
let keys: Vec<Vec<f32>> = vec![vec![0.8; 16], vec![0.2; 16]];
|
||||
let values: Vec<Vec<f32>> = vec![vec![1.0; 16], vec![0.0; 16]];
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result_low = attn_low.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
let result_high = attn_high.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
|
||||
// Low temperature should be more peaked (closer to [1,0,0...])
|
||||
// High temperature should be more uniform
|
||||
// We just verify both compute successfully
|
||||
assert_eq!(result_low.len(), 16);
|
||||
assert_eq!(result_high.len(), 16);
|
||||
}
|
||||
}
|
||||
394
vendor/ruvector/crates/ruvector-attention/src/graph/edge_featured.rs
vendored
Normal file
394
vendor/ruvector/crates/ruvector-attention/src/graph/edge_featured.rs
vendored
Normal file
@@ -0,0 +1,394 @@
|
||||
//! Edge-featured graph attention (GATv2 style)
|
||||
//!
|
||||
//! Extends standard graph attention with edge feature integration.
|
||||
|
||||
use crate::error::{AttentionError, AttentionResult};
|
||||
use crate::traits::Attention;
|
||||
use crate::utils::stable_softmax;
|
||||
|
||||
/// Configuration for edge-featured attention
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EdgeFeaturedConfig {
|
||||
pub node_dim: usize,
|
||||
pub edge_dim: usize,
|
||||
pub num_heads: usize,
|
||||
pub dropout: f32,
|
||||
pub concat_heads: bool,
|
||||
pub add_self_loops: bool,
|
||||
pub negative_slope: f32, // LeakyReLU slope
|
||||
}
|
||||
|
||||
impl Default for EdgeFeaturedConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
node_dim: 256,
|
||||
edge_dim: 64,
|
||||
num_heads: 4,
|
||||
dropout: 0.0,
|
||||
concat_heads: true,
|
||||
add_self_loops: true,
|
||||
negative_slope: 0.2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EdgeFeaturedConfig {
|
||||
pub fn builder() -> EdgeFeaturedConfigBuilder {
|
||||
EdgeFeaturedConfigBuilder::default()
|
||||
}
|
||||
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.node_dim / self.num_heads
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct EdgeFeaturedConfigBuilder {
|
||||
config: EdgeFeaturedConfig,
|
||||
}
|
||||
|
||||
impl EdgeFeaturedConfigBuilder {
|
||||
pub fn node_dim(mut self, d: usize) -> Self {
|
||||
self.config.node_dim = d;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn edge_dim(mut self, d: usize) -> Self {
|
||||
self.config.edge_dim = d;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn num_heads(mut self, n: usize) -> Self {
|
||||
self.config.num_heads = n;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn dropout(mut self, d: f32) -> Self {
|
||||
self.config.dropout = d;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn concat_heads(mut self, c: bool) -> Self {
|
||||
self.config.concat_heads = c;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn negative_slope(mut self, s: f32) -> Self {
|
||||
self.config.negative_slope = s;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> EdgeFeaturedConfig {
|
||||
self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Edge-featured graph attention layer
|
||||
pub struct EdgeFeaturedAttention {
|
||||
config: EdgeFeaturedConfig,
|
||||
// Weight matrices (would be learnable in training)
|
||||
w_node: Vec<f32>, // [num_heads, head_dim, node_dim]
|
||||
w_edge: Vec<f32>, // [num_heads, head_dim, edge_dim]
|
||||
a_src: Vec<f32>, // [num_heads, head_dim]
|
||||
a_dst: Vec<f32>, // [num_heads, head_dim]
|
||||
a_edge: Vec<f32>, // [num_heads, head_dim]
|
||||
}
|
||||
|
||||
impl EdgeFeaturedAttention {
|
||||
pub fn new(config: EdgeFeaturedConfig) -> Self {
|
||||
let head_dim = config.head_dim();
|
||||
let num_heads = config.num_heads;
|
||||
|
||||
// Xavier initialization
|
||||
let node_scale = (2.0 / (config.node_dim + head_dim) as f32).sqrt();
|
||||
let edge_scale = (2.0 / (config.edge_dim + head_dim) as f32).sqrt();
|
||||
let attn_scale = (1.0 / head_dim as f32).sqrt();
|
||||
|
||||
let mut seed = 42u64;
|
||||
let mut rand = || {
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
(seed as f32) / (u64::MAX as f32) - 0.5
|
||||
};
|
||||
|
||||
let w_node: Vec<f32> = (0..num_heads * head_dim * config.node_dim)
|
||||
.map(|_| rand() * 2.0 * node_scale)
|
||||
.collect();
|
||||
|
||||
let w_edge: Vec<f32> = (0..num_heads * head_dim * config.edge_dim)
|
||||
.map(|_| rand() * 2.0 * edge_scale)
|
||||
.collect();
|
||||
|
||||
let a_src: Vec<f32> = (0..num_heads * head_dim)
|
||||
.map(|_| rand() * 2.0 * attn_scale)
|
||||
.collect();
|
||||
|
||||
let a_dst: Vec<f32> = (0..num_heads * head_dim)
|
||||
.map(|_| rand() * 2.0 * attn_scale)
|
||||
.collect();
|
||||
|
||||
let a_edge: Vec<f32> = (0..num_heads * head_dim)
|
||||
.map(|_| rand() * 2.0 * attn_scale)
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
config,
|
||||
w_node,
|
||||
w_edge,
|
||||
a_src,
|
||||
a_dst,
|
||||
a_edge,
|
||||
}
|
||||
}
|
||||
|
||||
/// Transform node features for a specific head
|
||||
fn transform_node(&self, node: &[f32], head: usize) -> Vec<f32> {
|
||||
let head_dim = self.config.head_dim();
|
||||
let node_dim = self.config.node_dim;
|
||||
|
||||
(0..head_dim)
|
||||
.map(|i| {
|
||||
node.iter()
|
||||
.enumerate()
|
||||
.map(|(j, &nj)| nj * self.w_node[head * head_dim * node_dim + i * node_dim + j])
|
||||
.sum()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Transform edge features for a specific head
|
||||
fn transform_edge(&self, edge: &[f32], head: usize) -> Vec<f32> {
|
||||
let head_dim = self.config.head_dim();
|
||||
let edge_dim = self.config.edge_dim;
|
||||
|
||||
(0..head_dim)
|
||||
.map(|i| {
|
||||
edge.iter()
|
||||
.enumerate()
|
||||
.map(|(j, &ej)| ej * self.w_edge[head * head_dim * edge_dim + i * edge_dim + j])
|
||||
.sum()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute attention coefficient with LeakyReLU
|
||||
fn attention_coeff(&self, src: &[f32], dst: &[f32], edge: &[f32], head: usize) -> f32 {
|
||||
let head_dim = self.config.head_dim();
|
||||
|
||||
let mut score = 0.0f32;
|
||||
for i in 0..head_dim {
|
||||
let offset = head * head_dim + i;
|
||||
score += src[i] * self.a_src[offset];
|
||||
score += dst[i] * self.a_dst[offset];
|
||||
score += edge[i] * self.a_edge[offset];
|
||||
}
|
||||
|
||||
// LeakyReLU
|
||||
if score < 0.0 {
|
||||
self.config.negative_slope * score
|
||||
} else {
|
||||
score
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EdgeFeaturedAttention {
|
||||
/// Compute attention with explicit edge features
|
||||
pub fn compute_with_edges(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
edges: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if keys.len() != edges.len() {
|
||||
return Err(AttentionError::InvalidConfig(
|
||||
"Keys and edges must have same length".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let num_heads = self.config.num_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let n = keys.len();
|
||||
|
||||
// Transform query once per head
|
||||
let query_transformed: Vec<Vec<f32>> = (0..num_heads)
|
||||
.map(|h| self.transform_node(query, h))
|
||||
.collect();
|
||||
|
||||
// Compute per-head outputs
|
||||
let mut head_outputs: Vec<Vec<f32>> = Vec::with_capacity(num_heads);
|
||||
|
||||
for h in 0..num_heads {
|
||||
// Transform all keys and edges
|
||||
let keys_t: Vec<Vec<f32>> = keys.iter().map(|k| self.transform_node(k, h)).collect();
|
||||
let edges_t: Vec<Vec<f32>> = edges.iter().map(|e| self.transform_edge(e, h)).collect();
|
||||
|
||||
// Compute attention coefficients
|
||||
let coeffs: Vec<f32> = (0..n)
|
||||
.map(|i| self.attention_coeff(&query_transformed[h], &keys_t[i], &edges_t[i], h))
|
||||
.collect();
|
||||
|
||||
// Softmax
|
||||
let weights = stable_softmax(&coeffs);
|
||||
|
||||
// Weighted sum of values
|
||||
let mut head_out = vec![0.0f32; head_dim];
|
||||
for (i, &w) in weights.iter().enumerate() {
|
||||
let value_t = self.transform_node(values[i], h);
|
||||
for (j, &vj) in value_t.iter().enumerate() {
|
||||
head_out[j] += w * vj;
|
||||
}
|
||||
}
|
||||
|
||||
head_outputs.push(head_out);
|
||||
}
|
||||
|
||||
// Concatenate or average heads
|
||||
if self.config.concat_heads {
|
||||
Ok(head_outputs.into_iter().flatten().collect())
|
||||
} else {
|
||||
let mut output = vec![0.0f32; head_dim];
|
||||
for head_out in &head_outputs {
|
||||
for (i, &v) in head_out.iter().enumerate() {
|
||||
output[i] += v / num_heads as f32;
|
||||
}
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the edge feature dimension
|
||||
pub fn edge_dim(&self) -> usize {
|
||||
self.config.edge_dim
|
||||
}
|
||||
}
|
||||
|
||||
impl Attention for EdgeFeaturedAttention {
|
||||
fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if keys.is_empty() {
|
||||
return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
|
||||
}
|
||||
if query.len() != self.config.node_dim {
|
||||
return Err(AttentionError::DimensionMismatch {
|
||||
expected: self.config.node_dim,
|
||||
actual: query.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Use zero edge features for basic attention
|
||||
let zero_edge = vec![0.0f32; self.config.edge_dim];
|
||||
let edges: Vec<&[f32]> = (0..keys.len()).map(|_| zero_edge.as_slice()).collect();
|
||||
|
||||
self.compute_with_edges(query, keys, values, &edges)
|
||||
}
|
||||
|
||||
fn compute_with_mask(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
mask: Option<&[bool]>,
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
// Apply mask by filtering keys/values
|
||||
if let Some(m) = mask {
|
||||
let filtered: Vec<(usize, bool)> = m
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.filter(|(_, keep)| *keep)
|
||||
.collect();
|
||||
let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
|
||||
let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
|
||||
self.compute(query, &filtered_keys, &filtered_values)
|
||||
} else {
|
||||
self.compute(query, keys, values)
|
||||
}
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
if self.config.concat_heads {
|
||||
self.config.node_dim
|
||||
} else {
|
||||
self.config.head_dim()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_edge_featured_attention() {
|
||||
let config = EdgeFeaturedConfig::builder()
|
||||
.node_dim(64)
|
||||
.edge_dim(16)
|
||||
.num_heads(4)
|
||||
.build();
|
||||
|
||||
let attn = EdgeFeaturedAttention::new(config);
|
||||
|
||||
let query = vec![0.5; 64];
|
||||
let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.3; 64]).collect();
|
||||
let values: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 64]).collect();
|
||||
let edges: Vec<Vec<f32>> = (0..10).map(|_| vec![0.2; 16]).collect();
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
let edges_refs: Vec<&[f32]> = edges.iter().map(|e| e.as_slice()).collect();
|
||||
|
||||
let result = attn
|
||||
.compute_with_edges(&query, &keys_refs, &values_refs, &edges_refs)
|
||||
.unwrap();
|
||||
assert_eq!(result.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_without_edges() {
|
||||
let config = EdgeFeaturedConfig::builder()
|
||||
.node_dim(32)
|
||||
.edge_dim(8)
|
||||
.num_heads(2)
|
||||
.build();
|
||||
|
||||
let attn = EdgeFeaturedAttention::new(config);
|
||||
|
||||
let query = vec![0.5; 32];
|
||||
let keys: Vec<Vec<f32>> = vec![vec![0.3; 32]; 5];
|
||||
let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
assert_eq!(result.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu() {
|
||||
let config = EdgeFeaturedConfig::builder()
|
||||
.node_dim(16)
|
||||
.edge_dim(4)
|
||||
.num_heads(1)
|
||||
.negative_slope(0.2)
|
||||
.build();
|
||||
|
||||
let attn = EdgeFeaturedAttention::new(config);
|
||||
|
||||
// Just verify it computes without error
|
||||
let query = vec![-1.0; 16];
|
||||
let keys: Vec<Vec<f32>> = vec![vec![-0.5; 16]; 3];
|
||||
let values: Vec<Vec<f32>> = vec![vec![1.0; 16]; 3];
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
assert_eq!(result.len(), 16);
|
||||
}
|
||||
}
|
||||
14
vendor/ruvector/crates/ruvector-attention/src/graph/mod.rs
vendored
Normal file
14
vendor/ruvector/crates/ruvector-attention/src/graph/mod.rs
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
//! Graph attention mechanisms for GNN applications
|
||||
//!
|
||||
//! This module provides graph-specific attention implementations:
|
||||
//! - Edge-featured attention (GAT with edge features)
|
||||
//! - Rotary position embeddings for graphs (RoPE)
|
||||
//! - Dual-space attention (Euclidean + Hyperbolic)
|
||||
|
||||
pub mod dual_space;
|
||||
pub mod edge_featured;
|
||||
pub mod rope;
|
||||
|
||||
pub use dual_space::{DualSpaceAttention, DualSpaceConfig};
|
||||
pub use edge_featured::{EdgeFeaturedAttention, EdgeFeaturedConfig};
|
||||
pub use rope::{GraphRoPE, RoPEConfig};
|
||||
318
vendor/ruvector/crates/ruvector-attention/src/graph/rope.rs
vendored
Normal file
318
vendor/ruvector/crates/ruvector-attention/src/graph/rope.rs
vendored
Normal file
@@ -0,0 +1,318 @@
|
||||
//! Rotary Position Embeddings (RoPE) for Graph Attention
|
||||
//!
|
||||
//! Adapts RoPE for graph structures where positions are defined by graph topology
|
||||
//! (e.g., hop distance, shortest path length, or learned positional encodings).
|
||||
|
||||
use crate::error::{AttentionError, AttentionResult};
|
||||
use crate::traits::Attention;
|
||||
use crate::utils::stable_softmax;
|
||||
|
||||
/// Configuration for Graph RoPE
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RoPEConfig {
|
||||
pub dim: usize,
|
||||
pub base: f32,
|
||||
pub max_position: usize,
|
||||
pub scaling_factor: f32,
|
||||
}
|
||||
|
||||
impl Default for RoPEConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dim: 256,
|
||||
base: 10000.0,
|
||||
max_position: 512,
|
||||
scaling_factor: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RoPEConfig {
|
||||
pub fn builder() -> RoPEConfigBuilder {
|
||||
RoPEConfigBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct RoPEConfigBuilder {
|
||||
config: RoPEConfig,
|
||||
}
|
||||
|
||||
impl RoPEConfigBuilder {
|
||||
pub fn dim(mut self, d: usize) -> Self {
|
||||
self.config.dim = d;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn base(mut self, b: f32) -> Self {
|
||||
self.config.base = b;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_position(mut self, m: usize) -> Self {
|
||||
self.config.max_position = m;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn scaling_factor(mut self, s: f32) -> Self {
|
||||
self.config.scaling_factor = s;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> RoPEConfig {
|
||||
self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Graph attention with Rotary Position Embeddings
|
||||
pub struct GraphRoPE {
|
||||
config: RoPEConfig,
|
||||
/// Precomputed cos/sin tables: [max_position, dim]
|
||||
cos_cache: Vec<f32>,
|
||||
sin_cache: Vec<f32>,
|
||||
scale: f32,
|
||||
}
|
||||
|
||||
impl GraphRoPE {
|
||||
pub fn new(config: RoPEConfig) -> Self {
|
||||
let dim = config.dim;
|
||||
let max_pos = config.max_position;
|
||||
let base = config.base;
|
||||
let scaling = config.scaling_factor;
|
||||
|
||||
// Compute frequency bands
|
||||
let half_dim = dim / 2;
|
||||
let inv_freq: Vec<f32> = (0..half_dim)
|
||||
.map(|i| 1.0 / (base.powf(2.0 * i as f32 / dim as f32)))
|
||||
.collect();
|
||||
|
||||
// Precompute cos/sin for all positions
|
||||
let mut cos_cache = Vec::with_capacity(max_pos * dim);
|
||||
let mut sin_cache = Vec::with_capacity(max_pos * dim);
|
||||
|
||||
for pos in 0..max_pos {
|
||||
let scaled_pos = pos as f32 / scaling;
|
||||
for i in 0..half_dim {
|
||||
let theta = scaled_pos * inv_freq[i];
|
||||
cos_cache.push(theta.cos());
|
||||
sin_cache.push(theta.sin());
|
||||
}
|
||||
// Duplicate for both halves (interleaved format)
|
||||
for i in 0..half_dim {
|
||||
let theta = scaled_pos * inv_freq[i];
|
||||
cos_cache.push(theta.cos());
|
||||
sin_cache.push(theta.sin());
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
scale: 1.0 / (dim as f32).sqrt(),
|
||||
config,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply rotary embedding to a vector at given position
|
||||
pub fn apply_rotary(&self, x: &[f32], position: usize) -> Vec<f32> {
|
||||
let dim = self.config.dim;
|
||||
let half = dim / 2;
|
||||
let pos = position.min(self.config.max_position - 1);
|
||||
let offset = pos * dim;
|
||||
|
||||
let mut result = vec![0.0f32; dim];
|
||||
|
||||
// Apply rotation to first half
|
||||
for i in 0..half {
|
||||
let cos = self.cos_cache[offset + i];
|
||||
let sin = self.sin_cache[offset + i];
|
||||
result[i] = x[i] * cos - x[half + i] * sin;
|
||||
result[half + i] = x[i] * sin + x[half + i] * cos;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Compute attention with positional encoding based on graph distances
|
||||
pub fn compute_with_positions(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
query_pos: usize,
|
||||
key_positions: &[usize],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if keys.is_empty() {
|
||||
return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
|
||||
}
|
||||
if keys.len() != key_positions.len() {
|
||||
return Err(AttentionError::InvalidConfig(
|
||||
"Keys and positions must have same length".to_string(),
|
||||
));
|
||||
}
|
||||
if query.len() != self.config.dim {
|
||||
return Err(AttentionError::DimensionMismatch {
|
||||
expected: self.config.dim,
|
||||
actual: query.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Apply rotary to query
|
||||
let q_rot = self.apply_rotary(query, query_pos);
|
||||
|
||||
// Compute attention scores with rotary keys
|
||||
let scores: Vec<f32> = keys
|
||||
.iter()
|
||||
.zip(key_positions.iter())
|
||||
.map(|(key, &pos)| {
|
||||
let k_rot = self.apply_rotary(key, pos);
|
||||
q_rot
|
||||
.iter()
|
||||
.zip(k_rot.iter())
|
||||
.map(|(q, k)| q * k)
|
||||
.sum::<f32>()
|
||||
* self.scale
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Softmax
|
||||
let weights = stable_softmax(&scores);
|
||||
|
||||
// Weighted sum
|
||||
let value_dim = values[0].len();
|
||||
let mut output = vec![0.0f32; value_dim];
|
||||
for (w, v) in weights.iter().zip(values.iter()) {
|
||||
for (o, &vi) in output.iter_mut().zip(v.iter()) {
|
||||
*o += w * vi;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Get relative position for graph distance
|
||||
/// Converts graph hop distance to position index
|
||||
pub fn distance_to_position(distance: usize, max_distance: usize) -> usize {
|
||||
// Bucketize distances logarithmically for larger graphs
|
||||
if distance <= 8 {
|
||||
distance
|
||||
} else {
|
||||
let log_dist = (distance as f32).log2().ceil() as usize;
|
||||
8 + log_dist.min(max_distance - 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Attention for GraphRoPE {
|
||||
fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
// Default: use sequential positions (0, 1, 2, ...)
|
||||
let query_pos = 0;
|
||||
let key_positions: Vec<usize> = (0..keys.len()).collect();
|
||||
self.compute_with_positions(query, keys, values, query_pos, &key_positions)
|
||||
}
|
||||
|
||||
fn compute_with_mask(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
mask: Option<&[bool]>,
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if let Some(m) = mask {
|
||||
let filtered: Vec<(usize, bool)> = m
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.filter(|(_, keep)| *keep)
|
||||
.collect();
|
||||
let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
|
||||
let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
|
||||
self.compute(query, &filtered_keys, &filtered_values)
|
||||
} else {
|
||||
self.compute(query, keys, values)
|
||||
}
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
self.config.dim
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rope_basic() {
|
||||
let config = RoPEConfig::builder().dim(64).max_position(100).build();
|
||||
|
||||
let rope = GraphRoPE::new(config);
|
||||
|
||||
let query = vec![0.5; 64];
|
||||
let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.3; 64]).collect();
|
||||
let values: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 64]).collect();
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = rope.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
assert_eq!(result.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rope_with_positions() {
|
||||
let config = RoPEConfig::builder().dim(32).max_position(50).build();
|
||||
|
||||
let rope = GraphRoPE::new(config);
|
||||
|
||||
let query = vec![0.5; 32];
|
||||
let keys: Vec<Vec<f32>> = vec![vec![0.3; 32]; 5];
|
||||
let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
// Graph distances as positions
|
||||
let key_positions = vec![1, 2, 3, 2, 4];
|
||||
|
||||
let result = rope
|
||||
.compute_with_positions(&query, &keys_refs, &values_refs, 0, &key_positions)
|
||||
.unwrap();
|
||||
assert_eq!(result.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotary_embedding() {
|
||||
let config = RoPEConfig::builder().dim(16).max_position(10).build();
|
||||
|
||||
let rope = GraphRoPE::new(config);
|
||||
|
||||
let x = vec![1.0; 16];
|
||||
|
||||
// Rotary should preserve norm approximately
|
||||
let rotated = rope.apply_rotary(&x, 5);
|
||||
let norm_orig: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
|
||||
let norm_rot: f32 = rotated.iter().map(|v| v * v).sum::<f32>().sqrt();
|
||||
|
||||
assert!((norm_orig - norm_rot).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_to_position() {
|
||||
// Direct mapping for small distances
|
||||
assert_eq!(GraphRoPE::distance_to_position(0, 20), 0);
|
||||
assert_eq!(GraphRoPE::distance_to_position(5, 20), 5);
|
||||
assert_eq!(GraphRoPE::distance_to_position(8, 20), 8);
|
||||
|
||||
// Logarithmic for larger distances
|
||||
let pos_16 = GraphRoPE::distance_to_position(16, 20);
|
||||
let pos_32 = GraphRoPE::distance_to_position(32, 20);
|
||||
assert!(pos_16 > 8);
|
||||
assert!(pos_32 > pos_16);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user