Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
308
vendor/ruvector/crates/ruvector-attention-wasm/src/attention.rs
vendored
Normal file
308
vendor/ruvector/crates/ruvector-attention-wasm/src/attention.rs
vendored
Normal file
@@ -0,0 +1,308 @@
|
||||
use ruvector_attention::{
|
||||
attention::{MultiHeadAttention, ScaledDotProductAttention},
|
||||
hyperbolic::{HyperbolicAttention, HyperbolicAttentionConfig},
|
||||
moe::{MoEAttention, MoEConfig},
|
||||
sparse::{FlashAttention, LinearAttention, LocalGlobalAttention},
|
||||
traits::Attention,
|
||||
};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Compute scaled dot-product attention
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector as Float32Array
|
||||
/// * `keys` - Array of key vectors
|
||||
/// * `values` - Array of value vectors
|
||||
/// * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
|
||||
#[wasm_bindgen]
|
||||
pub fn scaled_dot_attention(
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
scale: Option<f32>,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse keys: {}", e)))?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse values: {}", e)))?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let attention = ScaledDotProductAttention::new(query.len());
|
||||
attention
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Multi-head attention mechanism
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMultiHeadAttention {
|
||||
inner: MultiHeadAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMultiHeadAttention {
|
||||
/// Create a new multi-head attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_heads` - Number of attention heads
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_heads: usize) -> Result<WasmMultiHeadAttention, JsError> {
|
||||
if dim % num_heads != 0 {
|
||||
return Err(JsError::new(&format!(
|
||||
"Dimension {} must be divisible by number of heads {}",
|
||||
dim, num_heads
|
||||
)));
|
||||
}
|
||||
Ok(Self {
|
||||
inner: MultiHeadAttention::new(dim, num_heads),
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute multi-head attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the number of heads
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.inner.num_heads()
|
||||
}
|
||||
|
||||
/// Get the dimension
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn dim(&self) -> usize {
|
||||
self.inner.dim()
|
||||
}
|
||||
}
|
||||
|
||||
/// Hyperbolic attention mechanism
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmHyperbolicAttention {
|
||||
inner: HyperbolicAttention,
|
||||
curvature_value: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmHyperbolicAttention {
|
||||
/// Create a new hyperbolic attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `curvature` - Hyperbolic curvature parameter
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, curvature: f32) -> WasmHyperbolicAttention {
|
||||
let config = HyperbolicAttentionConfig {
|
||||
dim,
|
||||
curvature,
|
||||
..Default::default()
|
||||
};
|
||||
Self {
|
||||
inner: HyperbolicAttention::new(config),
|
||||
curvature_value: curvature,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute hyperbolic attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the curvature
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn curvature(&self) -> f32 {
|
||||
self.curvature_value
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear attention (Performer-style)
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLinearAttention {
|
||||
inner: LinearAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLinearAttention {
|
||||
/// Create a new linear attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_features` - Number of random features
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_features: usize) -> WasmLinearAttention {
|
||||
Self {
|
||||
inner: LinearAttention::new(dim, num_features),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute linear attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Flash attention mechanism
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmFlashAttention {
|
||||
inner: FlashAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmFlashAttention {
|
||||
/// Create a new flash attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `block_size` - Block size for tiling
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, block_size: usize) -> WasmFlashAttention {
|
||||
Self {
|
||||
inner: FlashAttention::new(dim, block_size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute flash attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Local-global attention mechanism
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLocalGlobalAttention {
|
||||
inner: LocalGlobalAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLocalGlobalAttention {
|
||||
/// Create a new local-global attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `local_window` - Size of local attention window
|
||||
/// * `global_tokens` - Number of global attention tokens
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, local_window: usize, global_tokens: usize) -> WasmLocalGlobalAttention {
|
||||
Self {
|
||||
inner: LocalGlobalAttention::new(dim, local_window, global_tokens),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute local-global attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Mixture of Experts (MoE) attention
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMoEAttention {
|
||||
inner: MoEAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMoEAttention {
|
||||
/// Create a new MoE attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_experts` - Number of expert attention mechanisms
|
||||
/// * `top_k` - Number of experts to use per query
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_experts: usize, top_k: usize) -> WasmMoEAttention {
|
||||
let config = MoEConfig::builder()
|
||||
.dim(dim)
|
||||
.num_experts(num_experts)
|
||||
.top_k(top_k)
|
||||
.build();
|
||||
Self {
|
||||
inner: MoEAttention::new(config),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute MoE attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
33
vendor/ruvector/crates/ruvector-attention-wasm/src/lib.rs
vendored
Normal file
33
vendor/ruvector/crates/ruvector-attention-wasm/src/lib.rs
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
pub mod attention;
|
||||
pub mod training;
|
||||
pub mod utils;
|
||||
|
||||
/// Initialize the WASM module with panic hook
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
/// Get the version of the ruvector-attention-wasm crate
|
||||
#[wasm_bindgen]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Get information about available attention mechanisms
|
||||
#[wasm_bindgen]
|
||||
pub fn available_mechanisms() -> JsValue {
|
||||
let mechanisms = vec![
|
||||
"scaled_dot_product",
|
||||
"multi_head",
|
||||
"hyperbolic",
|
||||
"linear",
|
||||
"flash",
|
||||
"local_global",
|
||||
"moe",
|
||||
];
|
||||
serde_wasm_bindgen::to_value(&mechanisms).unwrap()
|
||||
}
|
||||
238
vendor/ruvector/crates/ruvector-attention-wasm/src/training.rs
vendored
Normal file
238
vendor/ruvector/crates/ruvector-attention-wasm/src/training.rs
vendored
Normal file
@@ -0,0 +1,238 @@
|
||||
use ruvector_attention::training::{Adam, AdamW, InfoNCELoss, Loss, Optimizer, SGD};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// InfoNCE contrastive loss for training
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmInfoNCELoss {
|
||||
inner: InfoNCELoss,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmInfoNCELoss {
|
||||
/// Create a new InfoNCE loss instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `temperature` - Temperature parameter for softmax
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(temperature: f32) -> WasmInfoNCELoss {
|
||||
Self {
|
||||
inner: InfoNCELoss::new(temperature),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute InfoNCE loss
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `anchor` - Anchor embedding
|
||||
/// * `positive` - Positive example embedding
|
||||
/// * `negatives` - Array of negative example embeddings
|
||||
pub fn compute(
|
||||
&self,
|
||||
anchor: &[f32],
|
||||
positive: &[f32],
|
||||
negatives: JsValue,
|
||||
) -> Result<f32, JsError> {
|
||||
let negatives_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(negatives)?;
|
||||
let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
|
||||
|
||||
Ok(self.inner.compute(anchor, positive, &negatives_refs))
|
||||
}
|
||||
}
|
||||
|
||||
/// Adam optimizer
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmAdam {
|
||||
inner: Adam,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmAdam {
|
||||
/// Create a new Adam optimizer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `param_count` - Number of parameters
|
||||
/// * `learning_rate` - Learning rate
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(param_count: usize, learning_rate: f32) -> WasmAdam {
|
||||
Self {
|
||||
inner: Adam::new(param_count, learning_rate),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform optimization step
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `params` - Current parameter values (will be updated in-place)
|
||||
/// * `gradients` - Gradient values
|
||||
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
|
||||
self.inner.step(params, gradients);
|
||||
}
|
||||
|
||||
/// Reset optimizer state
|
||||
pub fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
/// Get current learning rate
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn learning_rate(&self) -> f32 {
|
||||
self.inner.learning_rate()
|
||||
}
|
||||
|
||||
/// Set learning rate
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_learning_rate(&mut self, lr: f32) {
|
||||
self.inner.set_learning_rate(lr);
|
||||
}
|
||||
}
|
||||
|
||||
/// AdamW optimizer (Adam with decoupled weight decay)
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmAdamW {
|
||||
inner: AdamW,
|
||||
wd: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmAdamW {
|
||||
/// Create a new AdamW optimizer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `param_count` - Number of parameters
|
||||
/// * `learning_rate` - Learning rate
|
||||
/// * `weight_decay` - Weight decay coefficient
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(param_count: usize, learning_rate: f32, weight_decay: f32) -> WasmAdamW {
|
||||
let optimizer = AdamW::new(param_count, learning_rate).with_weight_decay(weight_decay);
|
||||
Self {
|
||||
inner: optimizer,
|
||||
wd: weight_decay,
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform optimization step with weight decay
|
||||
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
|
||||
self.inner.step(params, gradients);
|
||||
}
|
||||
|
||||
/// Reset optimizer state
|
||||
pub fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
/// Get current learning rate
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn learning_rate(&self) -> f32 {
|
||||
self.inner.learning_rate()
|
||||
}
|
||||
|
||||
/// Set learning rate
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_learning_rate(&mut self, lr: f32) {
|
||||
self.inner.set_learning_rate(lr);
|
||||
}
|
||||
|
||||
/// Get weight decay
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn weight_decay(&self) -> f32 {
|
||||
self.wd
|
||||
}
|
||||
}
|
||||
|
||||
/// SGD optimizer with momentum
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmSGD {
|
||||
inner: SGD,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmSGD {
|
||||
/// Create a new SGD optimizer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `param_count` - Number of parameters
|
||||
/// * `learning_rate` - Learning rate
|
||||
/// * `momentum` - Momentum coefficient (default: 0)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(param_count: usize, learning_rate: f32, momentum: Option<f32>) -> WasmSGD {
|
||||
let mut optimizer = SGD::new(param_count, learning_rate);
|
||||
if let Some(m) = momentum {
|
||||
optimizer = optimizer.with_momentum(m);
|
||||
}
|
||||
Self { inner: optimizer }
|
||||
}
|
||||
|
||||
/// Perform optimization step
|
||||
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
|
||||
self.inner.step(params, gradients);
|
||||
}
|
||||
|
||||
/// Reset optimizer state
|
||||
pub fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
/// Get current learning rate
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn learning_rate(&self) -> f32 {
|
||||
self.inner.learning_rate()
|
||||
}
|
||||
|
||||
/// Set learning rate
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_learning_rate(&mut self, lr: f32) {
|
||||
self.inner.set_learning_rate(lr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Learning rate scheduler
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLRScheduler {
|
||||
initial_lr: f32,
|
||||
current_step: usize,
|
||||
warmup_steps: usize,
|
||||
total_steps: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLRScheduler {
|
||||
/// Create a new learning rate scheduler with warmup and cosine decay
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `initial_lr` - Initial learning rate
|
||||
/// * `warmup_steps` - Number of warmup steps
|
||||
/// * `total_steps` - Total training steps
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(initial_lr: f32, warmup_steps: usize, total_steps: usize) -> WasmLRScheduler {
|
||||
Self {
|
||||
initial_lr,
|
||||
current_step: 0,
|
||||
warmup_steps,
|
||||
total_steps,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get learning rate for current step
|
||||
pub fn get_lr(&self) -> f32 {
|
||||
if self.current_step < self.warmup_steps {
|
||||
// Linear warmup
|
||||
self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
|
||||
} else {
|
||||
// Cosine decay
|
||||
let progress = (self.current_step - self.warmup_steps) as f32
|
||||
/ (self.total_steps - self.warmup_steps) as f32;
|
||||
let cosine = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
|
||||
self.initial_lr * cosine
|
||||
}
|
||||
}
|
||||
|
||||
/// Advance to next step
|
||||
pub fn step(&mut self) {
|
||||
self.current_step += 1;
|
||||
}
|
||||
|
||||
/// Reset scheduler
|
||||
pub fn reset(&mut self) {
|
||||
self.current_step = 0;
|
||||
}
|
||||
}
|
||||
201
vendor/ruvector/crates/ruvector-attention-wasm/src/utils.rs
vendored
Normal file
201
vendor/ruvector/crates/ruvector-attention-wasm/src/utils.rs
vendored
Normal file
@@ -0,0 +1,201 @@
|
||||
use wasm_bindgen::prelude::*;
|
||||
use web_sys::console;
|
||||
|
||||
/// Log a message to the browser console
|
||||
#[wasm_bindgen]
|
||||
pub fn log(message: &str) {
|
||||
console::log_1(&message.into());
|
||||
}
|
||||
|
||||
/// Log an error to the browser console
|
||||
#[wasm_bindgen]
|
||||
pub fn log_error(message: &str) {
|
||||
console::error_1(&message.into());
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
#[wasm_bindgen]
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, JsError> {
|
||||
if a.len() != b.len() {
|
||||
return Err(JsError::new("Vectors must have same length"));
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
return Err(JsError::new("Cannot compute similarity for zero vector"));
|
||||
}
|
||||
|
||||
Ok(dot / (norm_a * norm_b))
|
||||
}
|
||||
|
||||
/// Compute L2 norm of a vector
|
||||
#[wasm_bindgen]
|
||||
pub fn l2_norm(vec: &[f32]) -> f32 {
|
||||
vec.iter().map(|x| x * x).sum::<f32>().sqrt()
|
||||
}
|
||||
|
||||
/// Normalize a vector to unit length
|
||||
#[wasm_bindgen]
|
||||
pub fn normalize(vec: &mut [f32]) -> Result<(), JsError> {
|
||||
let norm = l2_norm(vec);
|
||||
if norm == 0.0 {
|
||||
return Err(JsError::new("Cannot normalize zero vector"));
|
||||
}
|
||||
|
||||
for x in vec.iter_mut() {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute softmax of a vector
|
||||
#[wasm_bindgen]
|
||||
pub fn softmax(vec: &mut [f32]) {
|
||||
// Subtract max for numerical stability
|
||||
let max = vec.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
// Compute exp and sum
|
||||
let mut sum = 0.0;
|
||||
for x in vec.iter_mut() {
|
||||
*x = (*x - max).exp();
|
||||
sum += *x;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
for x in vec.iter_mut() {
|
||||
*x /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute attention weights from scores
|
||||
#[wasm_bindgen]
|
||||
pub fn attention_weights(scores: &mut [f32], temperature: Option<f32>) {
|
||||
let temp = temperature.unwrap_or(1.0);
|
||||
|
||||
// Scale by temperature
|
||||
for score in scores.iter_mut() {
|
||||
*score /= temp;
|
||||
}
|
||||
|
||||
// Apply softmax
|
||||
softmax(scores);
|
||||
}
|
||||
|
||||
/// Batch normalize vectors
|
||||
#[wasm_bindgen]
|
||||
pub fn batch_normalize(vectors: JsValue, epsilon: Option<f32>) -> Result<Vec<f32>, JsError> {
|
||||
let eps = epsilon.unwrap_or(1e-8);
|
||||
let mut vecs: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(vectors)?;
|
||||
|
||||
if vecs.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let dim = vecs[0].len();
|
||||
let batch_size = vecs.len();
|
||||
|
||||
// Compute mean
|
||||
let mut mean = vec![0.0; dim];
|
||||
for vec in &vecs {
|
||||
for (i, &val) in vec.iter().enumerate() {
|
||||
mean[i] += val;
|
||||
}
|
||||
}
|
||||
for m in &mut mean {
|
||||
*m /= batch_size as f32;
|
||||
}
|
||||
|
||||
// Compute variance
|
||||
let mut variance = vec![0.0; dim];
|
||||
for vec in &vecs {
|
||||
for (i, &val) in vec.iter().enumerate() {
|
||||
let diff = val - mean[i];
|
||||
variance[i] += diff * diff;
|
||||
}
|
||||
}
|
||||
for v in &mut variance {
|
||||
*v /= batch_size as f32;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
for vec in &mut vecs {
|
||||
for (i, val) in vec.iter_mut().enumerate() {
|
||||
*val = (*val - mean[i]) / (variance[i] + eps).sqrt();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vecs.into_iter().flatten().collect())
|
||||
}
|
||||
|
||||
/// Generate random orthogonal matrix (for initialization)
|
||||
#[wasm_bindgen]
|
||||
pub fn random_orthogonal_matrix(dim: usize) -> Vec<f32> {
|
||||
use js_sys::Math;
|
||||
|
||||
let mut matrix = vec![0.0; dim * dim];
|
||||
|
||||
// Generate random matrix
|
||||
for i in 0..dim {
|
||||
for j in 0..dim {
|
||||
matrix[i * dim + j] = (Math::random() as f32 - 0.5) * 2.0;
|
||||
}
|
||||
}
|
||||
|
||||
// QR decomposition (simplified Gram-Schmidt)
|
||||
for i in 0..dim {
|
||||
// Normalize column i
|
||||
let mut norm = 0.0;
|
||||
for j in 0..dim {
|
||||
let val = matrix[j * dim + i];
|
||||
norm += val * val;
|
||||
}
|
||||
norm = norm.sqrt();
|
||||
|
||||
for j in 0..dim {
|
||||
matrix[j * dim + i] /= norm;
|
||||
}
|
||||
|
||||
// Orthogonalize remaining columns
|
||||
for k in (i + 1)..dim {
|
||||
let mut dot = 0.0;
|
||||
for j in 0..dim {
|
||||
dot += matrix[j * dim + i] * matrix[j * dim + k];
|
||||
}
|
||||
|
||||
for j in 0..dim {
|
||||
matrix[j * dim + k] -= dot * matrix[j * dim + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
matrix
|
||||
}
|
||||
|
||||
/// Compute pairwise distances between vectors
|
||||
#[wasm_bindgen]
|
||||
pub fn pairwise_distances(vectors: JsValue) -> Result<Vec<f32>, JsError> {
|
||||
let vecs: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(vectors)?;
|
||||
let n = vecs.len();
|
||||
let mut distances = vec![0.0; n * n];
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
if i == j {
|
||||
distances[i * n + j] = 0.0;
|
||||
} else {
|
||||
let mut dist = 0.0;
|
||||
for k in 0..vecs[i].len() {
|
||||
let diff = vecs[i][k] - vecs[j][k];
|
||||
dist += diff * diff;
|
||||
}
|
||||
distances[i * n + j] = dist.sqrt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(distances)
|
||||
}
|
||||
Reference in New Issue
Block a user