Files
wifi-densepose/vendor/ruvector/crates/ruvector-attention-wasm/src/training.rs

239 lines
6.2 KiB
Rust

use ruvector_attention::training::{Adam, AdamW, InfoNCELoss, Loss, Optimizer, SGD};
use wasm_bindgen::prelude::*;
/// InfoNCE contrastive loss for training
#[wasm_bindgen]
pub struct WasmInfoNCELoss {
inner: InfoNCELoss,
}
#[wasm_bindgen]
impl WasmInfoNCELoss {
/// Create a new InfoNCE loss instance
///
/// # Arguments
/// * `temperature` - Temperature parameter for softmax
#[wasm_bindgen(constructor)]
pub fn new(temperature: f32) -> WasmInfoNCELoss {
Self {
inner: InfoNCELoss::new(temperature),
}
}
/// Compute InfoNCE loss
///
/// # Arguments
/// * `anchor` - Anchor embedding
/// * `positive` - Positive example embedding
/// * `negatives` - Array of negative example embeddings
pub fn compute(
&self,
anchor: &[f32],
positive: &[f32],
negatives: JsValue,
) -> Result<f32, JsError> {
let negatives_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(negatives)?;
let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
Ok(self.inner.compute(anchor, positive, &negatives_refs))
}
}
/// Adam optimizer
#[wasm_bindgen]
pub struct WasmAdam {
inner: Adam,
}
#[wasm_bindgen]
impl WasmAdam {
/// Create a new Adam optimizer
///
/// # Arguments
/// * `param_count` - Number of parameters
/// * `learning_rate` - Learning rate
#[wasm_bindgen(constructor)]
pub fn new(param_count: usize, learning_rate: f32) -> WasmAdam {
Self {
inner: Adam::new(param_count, learning_rate),
}
}
/// Perform optimization step
///
/// # Arguments
/// * `params` - Current parameter values (will be updated in-place)
/// * `gradients` - Gradient values
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
self.inner.step(params, gradients);
}
/// Reset optimizer state
pub fn reset(&mut self) {
self.inner.reset();
}
/// Get current learning rate
#[wasm_bindgen(getter)]
pub fn learning_rate(&self) -> f32 {
self.inner.learning_rate()
}
/// Set learning rate
#[wasm_bindgen(setter)]
pub fn set_learning_rate(&mut self, lr: f32) {
self.inner.set_learning_rate(lr);
}
}
/// AdamW optimizer (Adam with decoupled weight decay)
#[wasm_bindgen]
pub struct WasmAdamW {
inner: AdamW,
wd: f32,
}
#[wasm_bindgen]
impl WasmAdamW {
/// Create a new AdamW optimizer
///
/// # Arguments
/// * `param_count` - Number of parameters
/// * `learning_rate` - Learning rate
/// * `weight_decay` - Weight decay coefficient
#[wasm_bindgen(constructor)]
pub fn new(param_count: usize, learning_rate: f32, weight_decay: f32) -> WasmAdamW {
let optimizer = AdamW::new(param_count, learning_rate).with_weight_decay(weight_decay);
Self {
inner: optimizer,
wd: weight_decay,
}
}
/// Perform optimization step with weight decay
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
self.inner.step(params, gradients);
}
/// Reset optimizer state
pub fn reset(&mut self) {
self.inner.reset();
}
/// Get current learning rate
#[wasm_bindgen(getter)]
pub fn learning_rate(&self) -> f32 {
self.inner.learning_rate()
}
/// Set learning rate
#[wasm_bindgen(setter)]
pub fn set_learning_rate(&mut self, lr: f32) {
self.inner.set_learning_rate(lr);
}
/// Get weight decay
#[wasm_bindgen(getter)]
pub fn weight_decay(&self) -> f32 {
self.wd
}
}
/// SGD optimizer with momentum
#[wasm_bindgen]
pub struct WasmSGD {
inner: SGD,
}
#[wasm_bindgen]
impl WasmSGD {
/// Create a new SGD optimizer
///
/// # Arguments
/// * `param_count` - Number of parameters
/// * `learning_rate` - Learning rate
/// * `momentum` - Momentum coefficient (default: 0)
#[wasm_bindgen(constructor)]
pub fn new(param_count: usize, learning_rate: f32, momentum: Option<f32>) -> WasmSGD {
let mut optimizer = SGD::new(param_count, learning_rate);
if let Some(m) = momentum {
optimizer = optimizer.with_momentum(m);
}
Self { inner: optimizer }
}
/// Perform optimization step
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
self.inner.step(params, gradients);
}
/// Reset optimizer state
pub fn reset(&mut self) {
self.inner.reset();
}
/// Get current learning rate
#[wasm_bindgen(getter)]
pub fn learning_rate(&self) -> f32 {
self.inner.learning_rate()
}
/// Set learning rate
#[wasm_bindgen(setter)]
pub fn set_learning_rate(&mut self, lr: f32) {
self.inner.set_learning_rate(lr);
}
}
/// Learning rate scheduler
#[wasm_bindgen]
pub struct WasmLRScheduler {
initial_lr: f32,
current_step: usize,
warmup_steps: usize,
total_steps: usize,
}
#[wasm_bindgen]
impl WasmLRScheduler {
/// Create a new learning rate scheduler with warmup and cosine decay
///
/// # Arguments
/// * `initial_lr` - Initial learning rate
/// * `warmup_steps` - Number of warmup steps
/// * `total_steps` - Total training steps
#[wasm_bindgen(constructor)]
pub fn new(initial_lr: f32, warmup_steps: usize, total_steps: usize) -> WasmLRScheduler {
Self {
initial_lr,
current_step: 0,
warmup_steps,
total_steps,
}
}
/// Get learning rate for current step
pub fn get_lr(&self) -> f32 {
if self.current_step < self.warmup_steps {
// Linear warmup
self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
} else {
// Cosine decay
let progress = (self.current_step - self.warmup_steps) as f32
/ (self.total_steps - self.warmup_steps) as f32;
let cosine = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
self.initial_lr * cosine
}
}
/// Advance to next step
pub fn step(&mut self) {
self.current_step += 1;
}
/// Reset scheduler
pub fn reset(&mut self) {
self.current_step = 0;
}
}