Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
356
vendor/ruvector/crates/ruvector-attention/src/training/curriculum.rs
vendored
Normal file
356
vendor/ruvector/crates/ruvector-attention/src/training/curriculum.rs
vendored
Normal file
@@ -0,0 +1,356 @@
|
||||
//! Curriculum learning for attention training
|
||||
//!
|
||||
//! Provides schedulers for progressive training difficulty.
|
||||
|
||||
/// Decay type for temperature/parameter annealing
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq)]
|
||||
pub enum DecayType {
|
||||
#[default]
|
||||
Linear,
|
||||
Exponential,
|
||||
Cosine,
|
||||
Step,
|
||||
}
|
||||
|
||||
/// Curriculum learning stage
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CurriculumStage {
|
||||
pub name: String,
|
||||
pub difficulty: f32, // 0.0 = easy, 1.0 = hard
|
||||
pub duration: usize, // Steps in this stage
|
||||
pub temperature: f32, // Softmax temperature
|
||||
pub negative_count: usize, // Number of negatives
|
||||
}
|
||||
|
||||
impl CurriculumStage {
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
difficulty: 0.5,
|
||||
duration: 1000,
|
||||
temperature: 1.0,
|
||||
negative_count: 10,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn difficulty(mut self, d: f32) -> Self {
|
||||
self.difficulty = d.clamp(0.0, 1.0);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn duration(mut self, d: usize) -> Self {
|
||||
self.duration = d;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn temperature(mut self, t: f32) -> Self {
|
||||
self.temperature = t.max(0.01);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn negative_count(mut self, n: usize) -> Self {
|
||||
self.negative_count = n.max(1);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Curriculum scheduler for progressive training
|
||||
pub struct CurriculumScheduler {
|
||||
stages: Vec<CurriculumStage>,
|
||||
current_stage: usize,
|
||||
steps_in_stage: usize,
|
||||
total_steps: usize,
|
||||
}
|
||||
|
||||
impl CurriculumScheduler {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
stages: Vec::new(),
|
||||
current_stage: 0,
|
||||
steps_in_stage: 0,
|
||||
total_steps: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a stage to the curriculum
|
||||
pub fn add_stage(mut self, stage: CurriculumStage) -> Self {
|
||||
self.stages.push(stage);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build a default easy-to-hard curriculum
|
||||
pub fn default_curriculum(total_steps: usize) -> Self {
|
||||
let stage_duration = total_steps / 4;
|
||||
|
||||
Self::new()
|
||||
.add_stage(
|
||||
CurriculumStage::new("warm_up")
|
||||
.difficulty(0.1)
|
||||
.duration(stage_duration)
|
||||
.temperature(2.0)
|
||||
.negative_count(5),
|
||||
)
|
||||
.add_stage(
|
||||
CurriculumStage::new("easy")
|
||||
.difficulty(0.3)
|
||||
.duration(stage_duration)
|
||||
.temperature(1.0)
|
||||
.negative_count(10),
|
||||
)
|
||||
.add_stage(
|
||||
CurriculumStage::new("medium")
|
||||
.difficulty(0.6)
|
||||
.duration(stage_duration)
|
||||
.temperature(0.5)
|
||||
.negative_count(20),
|
||||
)
|
||||
.add_stage(
|
||||
CurriculumStage::new("hard")
|
||||
.difficulty(1.0)
|
||||
.duration(stage_duration)
|
||||
.temperature(0.1)
|
||||
.negative_count(50),
|
||||
)
|
||||
}
|
||||
|
||||
/// Get current stage
|
||||
pub fn current_stage(&self) -> Option<&CurriculumStage> {
|
||||
self.stages.get(self.current_stage)
|
||||
}
|
||||
|
||||
/// Advance one step and return current stage
|
||||
pub fn step(&mut self) -> Option<&CurriculumStage> {
|
||||
if self.stages.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.steps_in_stage += 1;
|
||||
self.total_steps += 1;
|
||||
|
||||
// Check if we should advance to next stage
|
||||
if let Some(stage) = self.stages.get(self.current_stage) {
|
||||
if self.steps_in_stage >= stage.duration && self.current_stage < self.stages.len() - 1 {
|
||||
self.current_stage += 1;
|
||||
self.steps_in_stage = 0;
|
||||
}
|
||||
}
|
||||
|
||||
self.current_stage()
|
||||
}
|
||||
|
||||
/// Get current difficulty (0.0 to 1.0)
|
||||
pub fn difficulty(&self) -> f32 {
|
||||
self.current_stage().map(|s| s.difficulty).unwrap_or(1.0)
|
||||
}
|
||||
|
||||
/// Get current temperature
|
||||
pub fn temperature(&self) -> f32 {
|
||||
self.current_stage().map(|s| s.temperature).unwrap_or(1.0)
|
||||
}
|
||||
|
||||
/// Get current negative count
|
||||
pub fn negative_count(&self) -> usize {
|
||||
self.current_stage().map(|s| s.negative_count).unwrap_or(10)
|
||||
}
|
||||
|
||||
/// Check if training is complete
|
||||
pub fn is_complete(&self) -> bool {
|
||||
if self.stages.is_empty() {
|
||||
return true;
|
||||
}
|
||||
self.current_stage >= self.stages.len() - 1
|
||||
&& self.steps_in_stage >= self.stages.last().map(|s| s.duration).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Get progress (0.0 to 1.0)
|
||||
pub fn progress(&self) -> f32 {
|
||||
let total_duration: usize = self.stages.iter().map(|s| s.duration).sum();
|
||||
if total_duration == 0 {
|
||||
return 1.0;
|
||||
}
|
||||
self.total_steps as f32 / total_duration as f32
|
||||
}
|
||||
|
||||
/// Reset curriculum
|
||||
pub fn reset(&mut self) {
|
||||
self.current_stage = 0;
|
||||
self.steps_in_stage = 0;
|
||||
self.total_steps = 0;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CurriculumScheduler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Temperature annealing scheduler
|
||||
pub struct TemperatureAnnealing {
|
||||
initial_temp: f32,
|
||||
final_temp: f32,
|
||||
total_steps: usize,
|
||||
current_step: usize,
|
||||
decay_type: DecayType,
|
||||
step_size: usize, // For step decay
|
||||
}
|
||||
|
||||
impl TemperatureAnnealing {
|
||||
pub fn new(initial: f32, final_temp: f32, steps: usize) -> Self {
|
||||
Self {
|
||||
initial_temp: initial,
|
||||
final_temp: final_temp,
|
||||
total_steps: steps,
|
||||
current_step: 0,
|
||||
decay_type: DecayType::Linear,
|
||||
step_size: steps / 10,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_decay(mut self, decay: DecayType) -> Self {
|
||||
self.decay_type = decay;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_step_size(mut self, size: usize) -> Self {
|
||||
self.step_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Get current temperature and advance
|
||||
pub fn step(&mut self) -> f32 {
|
||||
let temp = self.get_temp();
|
||||
self.current_step += 1;
|
||||
temp
|
||||
}
|
||||
|
||||
/// Get current temperature without advancing
|
||||
pub fn get_temp(&self) -> f32 {
|
||||
if self.current_step >= self.total_steps {
|
||||
return self.final_temp;
|
||||
}
|
||||
|
||||
let progress = self.current_step as f32 / self.total_steps as f32;
|
||||
let range = self.initial_temp - self.final_temp;
|
||||
|
||||
match self.decay_type {
|
||||
DecayType::Linear => self.initial_temp - range * progress,
|
||||
DecayType::Exponential => {
|
||||
let decay_rate =
|
||||
(self.final_temp / self.initial_temp).ln() / self.total_steps as f32;
|
||||
self.initial_temp * (decay_rate * self.current_step as f32).exp()
|
||||
}
|
||||
DecayType::Cosine => {
|
||||
self.final_temp + 0.5 * range * (1.0 + (std::f32::consts::PI * progress).cos())
|
||||
}
|
||||
DecayType::Step => {
|
||||
let num_steps = self.current_step / self.step_size.max(1);
|
||||
let step_decay =
|
||||
range * num_steps as f32 / (self.total_steps / self.step_size.max(1)) as f32;
|
||||
(self.initial_temp - step_decay).max(self.final_temp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset annealing
|
||||
pub fn reset(&mut self) {
|
||||
self.current_step = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_curriculum_stages() {
|
||||
let mut curriculum = CurriculumScheduler::new()
|
||||
.add_stage(CurriculumStage::new("easy").duration(10).difficulty(0.2))
|
||||
.add_stage(CurriculumStage::new("hard").duration(10).difficulty(0.8));
|
||||
|
||||
assert_eq!(curriculum.current_stage().unwrap().name, "easy");
|
||||
assert!((curriculum.difficulty() - 0.2).abs() < 1e-5);
|
||||
|
||||
// Progress through first stage
|
||||
for _ in 0..10 {
|
||||
curriculum.step();
|
||||
}
|
||||
|
||||
assert_eq!(curriculum.current_stage().unwrap().name, "hard");
|
||||
assert!((curriculum.difficulty() - 0.8).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_curriculum() {
|
||||
let mut curriculum = CurriculumScheduler::default_curriculum(400);
|
||||
|
||||
assert_eq!(curriculum.stages.len(), 4);
|
||||
assert_eq!(curriculum.current_stage().unwrap().name, "warm_up");
|
||||
|
||||
// Progress to end
|
||||
for _ in 0..400 {
|
||||
curriculum.step();
|
||||
}
|
||||
|
||||
assert!(curriculum.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temperature_linear() {
|
||||
let mut annealing = TemperatureAnnealing::new(1.0, 0.1, 100);
|
||||
|
||||
let temp_start = annealing.step();
|
||||
assert!((temp_start - 1.0).abs() < 0.1);
|
||||
|
||||
for _ in 0..99 {
|
||||
annealing.step();
|
||||
}
|
||||
|
||||
let temp_end = annealing.get_temp();
|
||||
assert!((temp_end - 0.1).abs() < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temperature_cosine() {
|
||||
let mut annealing = TemperatureAnnealing::new(1.0, 0.0, 100).with_decay(DecayType::Cosine);
|
||||
|
||||
// Halfway should be approximately middle value
|
||||
for _ in 0..50 {
|
||||
annealing.step();
|
||||
}
|
||||
|
||||
let temp_mid = annealing.get_temp();
|
||||
assert!(temp_mid > 0.4 && temp_mid < 0.6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temperature_step() {
|
||||
let mut annealing = TemperatureAnnealing::new(1.0, 0.0, 100)
|
||||
.with_decay(DecayType::Step)
|
||||
.with_step_size(25);
|
||||
|
||||
let temp_0 = annealing.get_temp();
|
||||
for _ in 0..25 {
|
||||
annealing.step();
|
||||
}
|
||||
let temp_25 = annealing.get_temp();
|
||||
|
||||
// Should have dropped
|
||||
assert!(temp_25 < temp_0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_curriculum_progress() {
|
||||
let mut curriculum = CurriculumScheduler::new()
|
||||
.add_stage(CurriculumStage::new("stage1").duration(50))
|
||||
.add_stage(CurriculumStage::new("stage2").duration(50));
|
||||
|
||||
assert!((curriculum.progress() - 0.0).abs() < 1e-5);
|
||||
|
||||
for _ in 0..50 {
|
||||
curriculum.step();
|
||||
}
|
||||
|
||||
assert!((curriculum.progress() - 0.5).abs() < 0.05);
|
||||
}
|
||||
}
|
||||
359
vendor/ruvector/crates/ruvector-attention/src/training/loss.rs
vendored
Normal file
359
vendor/ruvector/crates/ruvector-attention/src/training/loss.rs
vendored
Normal file
@@ -0,0 +1,359 @@
|
||||
//! Loss functions for attention-based learning
|
||||
//!
|
||||
//! Includes contrastive losses optimized for representation learning.
|
||||
|
||||
/// Reduction method for loss computation
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq)]
|
||||
pub enum Reduction {
|
||||
#[default]
|
||||
Mean,
|
||||
Sum,
|
||||
None,
|
||||
}
|
||||
|
||||
/// Loss trait for attention training
|
||||
pub trait Loss: Send + Sync {
|
||||
/// Compute loss value
|
||||
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32;
|
||||
|
||||
/// Compute loss with gradients for anchor
|
||||
fn compute_with_gradients(
|
||||
&self,
|
||||
anchor: &[f32],
|
||||
positive: &[f32],
|
||||
negatives: &[&[f32]],
|
||||
) -> (f32, Vec<f32>);
|
||||
}
|
||||
|
||||
/// InfoNCE contrastive loss
|
||||
///
|
||||
/// L = -log(exp(sim(a,p)/τ) / Σexp(sim(a,n)/τ))
|
||||
pub struct InfoNCELoss {
|
||||
temperature: f32,
|
||||
}
|
||||
|
||||
impl InfoNCELoss {
|
||||
pub fn new(temperature: f32) -> Self {
|
||||
Self {
|
||||
temperature: temperature.max(0.01),
|
||||
}
|
||||
}
|
||||
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
|
||||
dot / (norm_a * norm_b)
|
||||
}
|
||||
}
|
||||
|
||||
impl Loss for InfoNCELoss {
|
||||
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
|
||||
let pos_sim = Self::cosine_similarity(anchor, positive) / self.temperature;
|
||||
|
||||
let neg_sims: Vec<f32> = negatives
|
||||
.iter()
|
||||
.map(|n| Self::cosine_similarity(anchor, n) / self.temperature)
|
||||
.collect();
|
||||
|
||||
// Stable log-sum-exp
|
||||
let max_sim = neg_sims
|
||||
.iter()
|
||||
.copied()
|
||||
.chain(std::iter::once(pos_sim))
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
let sum_exp: f32 =
|
||||
neg_sims.iter().map(|s| (s - max_sim).exp()).sum::<f32>() + (pos_sim - max_sim).exp();
|
||||
|
||||
let log_sum_exp = max_sim + sum_exp.ln();
|
||||
|
||||
log_sum_exp - pos_sim
|
||||
}
|
||||
|
||||
fn compute_with_gradients(
|
||||
&self,
|
||||
anchor: &[f32],
|
||||
positive: &[f32],
|
||||
negatives: &[&[f32]],
|
||||
) -> (f32, Vec<f32>) {
|
||||
let dim = anchor.len();
|
||||
let pos_sim = Self::cosine_similarity(anchor, positive) / self.temperature;
|
||||
|
||||
let neg_sims: Vec<f32> = negatives
|
||||
.iter()
|
||||
.map(|n| Self::cosine_similarity(anchor, n) / self.temperature)
|
||||
.collect();
|
||||
|
||||
// Compute softmax weights
|
||||
let max_sim = neg_sims
|
||||
.iter()
|
||||
.copied()
|
||||
.chain(std::iter::once(pos_sim))
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
let pos_exp = (pos_sim - max_sim).exp();
|
||||
let neg_exps: Vec<f32> = neg_sims.iter().map(|s| (s - max_sim).exp()).collect();
|
||||
let total_exp: f32 = pos_exp + neg_exps.iter().sum::<f32>();
|
||||
|
||||
let pos_weight = pos_exp / total_exp;
|
||||
let neg_weights: Vec<f32> = neg_exps.iter().map(|e| e / total_exp).collect();
|
||||
|
||||
// Loss value
|
||||
let loss = -(pos_weight.ln());
|
||||
|
||||
// Gradient with respect to anchor
|
||||
// ∂L/∂anchor = (p_pos - 1) * ∂sim(a,p)/∂a + Σ p_neg_i * ∂sim(a,n_i)/∂a
|
||||
let norm_a: f32 = anchor.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
|
||||
let norm_p: f32 = positive.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
|
||||
|
||||
let mut gradients = vec![0.0f32; dim];
|
||||
|
||||
// Gradient from positive
|
||||
let dot_ap: f32 = anchor.iter().zip(positive.iter()).map(|(a, p)| a * p).sum();
|
||||
for i in 0..dim {
|
||||
let d_sim = (positive[i] / (norm_a * norm_p))
|
||||
- (anchor[i] * dot_ap / (norm_a.powi(3) * norm_p));
|
||||
gradients[i] += (pos_weight - 1.0) * d_sim / self.temperature;
|
||||
}
|
||||
|
||||
// Gradient from negatives
|
||||
for (neg, &weight) in negatives.iter().zip(neg_weights.iter()) {
|
||||
let norm_n: f32 = neg.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
|
||||
let dot_an: f32 = anchor.iter().zip(neg.iter()).map(|(a, n)| a * n).sum();
|
||||
|
||||
for i in 0..dim {
|
||||
let d_sim =
|
||||
(neg[i] / (norm_a * norm_n)) - (anchor[i] * dot_an / (norm_a.powi(3) * norm_n));
|
||||
gradients[i] += weight * d_sim / self.temperature;
|
||||
}
|
||||
}
|
||||
|
||||
(loss, gradients)
|
||||
}
|
||||
}
|
||||
|
||||
/// Local contrastive loss for neighborhood preservation
|
||||
pub struct LocalContrastiveLoss {
|
||||
margin: f32,
|
||||
reduction: Reduction,
|
||||
}
|
||||
|
||||
impl LocalContrastiveLoss {
|
||||
pub fn new(margin: f32) -> Self {
|
||||
Self {
|
||||
margin,
|
||||
reduction: Reduction::Mean,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_reduction(mut self, reduction: Reduction) -> Self {
|
||||
self.reduction = reduction;
|
||||
self
|
||||
}
|
||||
|
||||
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl Loss for LocalContrastiveLoss {
|
||||
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
|
||||
let d_pos = Self::euclidean_distance(anchor, positive);
|
||||
|
||||
let losses: Vec<f32> = negatives
|
||||
.iter()
|
||||
.map(|neg| {
|
||||
let d_neg = Self::euclidean_distance(anchor, neg);
|
||||
(d_pos - d_neg + self.margin).max(0.0)
|
||||
})
|
||||
.collect();
|
||||
|
||||
match self.reduction {
|
||||
Reduction::Mean => losses.iter().sum::<f32>() / losses.len().max(1) as f32,
|
||||
Reduction::Sum => losses.iter().sum(),
|
||||
Reduction::None => losses.first().copied().unwrap_or(0.0),
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_with_gradients(
|
||||
&self,
|
||||
anchor: &[f32],
|
||||
positive: &[f32],
|
||||
negatives: &[&[f32]],
|
||||
) -> (f32, Vec<f32>) {
|
||||
let dim = anchor.len();
|
||||
let d_pos = Self::euclidean_distance(anchor, positive);
|
||||
|
||||
let mut total_loss = 0.0f32;
|
||||
let mut gradients = vec![0.0f32; dim];
|
||||
let mut active_count = 0;
|
||||
|
||||
for neg in negatives.iter() {
|
||||
let d_neg = Self::euclidean_distance(anchor, neg);
|
||||
let margin_loss = d_pos - d_neg + self.margin;
|
||||
|
||||
if margin_loss > 0.0 {
|
||||
total_loss += margin_loss;
|
||||
active_count += 1;
|
||||
|
||||
// Gradient: ∂L/∂a = (a - p)/d_pos - (a - n)/d_neg
|
||||
for i in 0..dim {
|
||||
if d_pos > 1e-8 {
|
||||
gradients[i] += (anchor[i] - positive[i]) / d_pos;
|
||||
}
|
||||
if d_neg > 1e-8 {
|
||||
gradients[i] -= (anchor[i] - neg[i]) / d_neg;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let loss = match self.reduction {
|
||||
Reduction::Mean if active_count > 0 => {
|
||||
gradients.iter_mut().for_each(|g| *g /= active_count as f32);
|
||||
total_loss / active_count as f32
|
||||
}
|
||||
Reduction::Sum => total_loss,
|
||||
_ => total_loss / negatives.len().max(1) as f32,
|
||||
};
|
||||
|
||||
(loss, gradients)
|
||||
}
|
||||
}
|
||||
|
||||
/// Spectral regularization for smooth representations
|
||||
pub struct SpectralRegularization {
|
||||
weight: f32,
|
||||
}
|
||||
|
||||
impl SpectralRegularization {
|
||||
pub fn new(weight: f32) -> Self {
|
||||
Self { weight }
|
||||
}
|
||||
|
||||
/// Compute spectral norm regularization for a batch of embeddings
|
||||
pub fn compute_batch(&self, embeddings: &[&[f32]]) -> f32 {
|
||||
if embeddings.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dim = embeddings[0].len();
|
||||
let n = embeddings.len();
|
||||
|
||||
// Compute covariance matrix diagonal approximation
|
||||
let mut var_sum = 0.0f32;
|
||||
|
||||
for d in 0..dim {
|
||||
let mean: f32 = embeddings.iter().map(|e| e[d]).sum::<f32>() / n as f32;
|
||||
let var: f32 = embeddings
|
||||
.iter()
|
||||
.map(|e| (e[d] - mean).powi(2))
|
||||
.sum::<f32>()
|
||||
/ n as f32;
|
||||
var_sum += var;
|
||||
}
|
||||
|
||||
// Regularization: encourage uniform variance across dimensions
|
||||
let avg_var = var_sum / dim as f32;
|
||||
let var_of_var: f32 = {
|
||||
let mut sum = 0.0;
|
||||
for d in 0..dim {
|
||||
let mean: f32 = embeddings.iter().map(|e| e[d]).sum::<f32>() / n as f32;
|
||||
let var: f32 = embeddings
|
||||
.iter()
|
||||
.map(|e| (e[d] - mean).powi(2))
|
||||
.sum::<f32>()
|
||||
/ n as f32;
|
||||
sum += (var - avg_var).powi(2);
|
||||
}
|
||||
sum / dim as f32
|
||||
};
|
||||
|
||||
self.weight * var_of_var
|
||||
}
|
||||
}
|
||||
|
||||
impl Loss for SpectralRegularization {
|
||||
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
|
||||
let mut all_embeddings: Vec<&[f32]> = Vec::with_capacity(2 + negatives.len());
|
||||
all_embeddings.push(anchor);
|
||||
all_embeddings.push(positive);
|
||||
all_embeddings.extend(negatives.iter().copied());
|
||||
|
||||
self.compute_batch(&all_embeddings)
|
||||
}
|
||||
|
||||
fn compute_with_gradients(
|
||||
&self,
|
||||
anchor: &[f32],
|
||||
positive: &[f32],
|
||||
negatives: &[&[f32]],
|
||||
) -> (f32, Vec<f32>) {
|
||||
let loss = self.compute(anchor, positive, negatives);
|
||||
// Simplified: no gradient for spectral reg (typically used as auxiliary)
|
||||
let gradients = vec![0.0f32; anchor.len()];
|
||||
(loss, gradients)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_infonce_loss() {
|
||||
let loss = InfoNCELoss::new(0.07);
|
||||
|
||||
let anchor = vec![1.0, 0.0, 0.0];
|
||||
let positive = vec![0.9, 0.1, 0.0];
|
||||
let negatives: Vec<Vec<f32>> = vec![vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0]];
|
||||
let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
|
||||
|
||||
let loss_val = loss.compute(&anchor, &positive, &neg_refs);
|
||||
assert!(loss_val >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_infonce_gradients() {
|
||||
let loss = InfoNCELoss::new(0.1);
|
||||
|
||||
let anchor = vec![0.5; 64];
|
||||
let positive = vec![0.6; 64];
|
||||
let negatives: Vec<Vec<f32>> = vec![vec![0.1; 64]; 5];
|
||||
let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
|
||||
|
||||
let (loss_val, grads) = loss.compute_with_gradients(&anchor, &positive, &neg_refs);
|
||||
|
||||
assert!(loss_val >= 0.0);
|
||||
assert_eq!(grads.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_local_contrastive() {
|
||||
let loss = LocalContrastiveLoss::new(1.0);
|
||||
|
||||
let anchor = vec![0.0, 0.0];
|
||||
let positive = vec![0.1, 0.0]; // Close
|
||||
let negatives: Vec<Vec<f32>> = vec![vec![2.0, 0.0], vec![0.0, 2.0]]; // Far
|
||||
let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
|
||||
|
||||
let loss_val = loss.compute(&anchor, &positive, &neg_refs);
|
||||
assert!(loss_val >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spectral_regularization() {
|
||||
let reg = SpectralRegularization::new(0.01);
|
||||
|
||||
let embeddings: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.1; 32]).collect();
|
||||
let emb_refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
|
||||
|
||||
let loss_val = reg.compute_batch(&emb_refs);
|
||||
assert!(loss_val >= 0.0);
|
||||
}
|
||||
}
|
||||
351
vendor/ruvector/crates/ruvector-attention/src/training/mining.rs
vendored
Normal file
351
vendor/ruvector/crates/ruvector-attention/src/training/mining.rs
vendored
Normal file
@@ -0,0 +1,351 @@
|
||||
//! Hard negative mining strategies
|
||||
//!
|
||||
//! Provides various methods for selecting informative negative samples.
|
||||
|
||||
/// Mining strategy enumeration
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq)]
|
||||
pub enum MiningStrategy {
|
||||
#[default]
|
||||
Random,
|
||||
HardNegative,
|
||||
SemiHard,
|
||||
DistanceWeighted,
|
||||
}
|
||||
|
||||
/// Trait for negative sample mining
|
||||
pub trait NegativeMiner: Send + Sync {
|
||||
/// Mine negatives for an anchor from a candidate pool
|
||||
fn mine(
|
||||
&self,
|
||||
anchor: &[f32],
|
||||
positive: &[f32],
|
||||
candidates: &[&[f32]],
|
||||
num_negatives: usize,
|
||||
) -> Vec<usize>;
|
||||
|
||||
/// Get mining strategy
|
||||
fn strategy(&self) -> MiningStrategy;
|
||||
}
|
||||
|
||||
/// Hard negative miner that selects closest negatives
|
||||
pub struct HardNegativeMiner {
|
||||
strategy: MiningStrategy,
|
||||
margin: f32,
|
||||
temperature: f32,
|
||||
}
|
||||
|
||||
impl HardNegativeMiner {
|
||||
pub fn new(strategy: MiningStrategy) -> Self {
|
||||
Self {
|
||||
strategy,
|
||||
margin: 0.1,
|
||||
temperature: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_margin(mut self, margin: f32) -> Self {
|
||||
self.margin = margin;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temp: f32) -> Self {
|
||||
self.temperature = temp;
|
||||
self
|
||||
}
|
||||
|
||||
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
|
||||
dot / (norm_a * norm_b)
|
||||
}
|
||||
|
||||
/// Select random indices
|
||||
fn random_selection(num_candidates: usize, num_select: usize, seed: u64) -> Vec<usize> {
|
||||
let mut indices: Vec<usize> = (0..num_candidates).collect();
|
||||
let mut current_seed = seed;
|
||||
|
||||
// Fisher-Yates shuffle
|
||||
for i in (1..indices.len()).rev() {
|
||||
current_seed = current_seed
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1);
|
||||
let j = (current_seed as usize) % (i + 1);
|
||||
indices.swap(i, j);
|
||||
}
|
||||
|
||||
indices.truncate(num_select.min(num_candidates));
|
||||
indices
|
||||
}
|
||||
|
||||
/// Select hardest negatives (closest to anchor)
|
||||
fn hard_negative_selection(
|
||||
&self,
|
||||
anchor: &[f32],
|
||||
candidates: &[&[f32]],
|
||||
num_select: usize,
|
||||
) -> Vec<usize> {
|
||||
let mut indexed_sims: Vec<(usize, f32)> = candidates
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, c)| (i, Self::cosine_similarity(anchor, c)))
|
||||
.collect();
|
||||
|
||||
// Sort by similarity descending (higher sim = harder negative)
|
||||
indexed_sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
indexed_sims
|
||||
.into_iter()
|
||||
.take(num_select.min(candidates.len()))
|
||||
.map(|(i, _)| i)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Select semi-hard negatives (within margin of positive)
|
||||
fn semi_hard_selection(
|
||||
&self,
|
||||
anchor: &[f32],
|
||||
positive: &[f32],
|
||||
candidates: &[&[f32]],
|
||||
num_select: usize,
|
||||
) -> Vec<usize> {
|
||||
let d_pos = Self::euclidean_distance(anchor, positive);
|
||||
|
||||
let mut semi_hard: Vec<(usize, f32)> = candidates
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, c)| {
|
||||
let d_neg = Self::euclidean_distance(anchor, c);
|
||||
// Semi-hard: d_pos < d_neg < d_pos + margin
|
||||
if d_neg > d_pos && d_neg < d_pos + self.margin {
|
||||
Some((i, d_neg))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by distance (prefer harder ones)
|
||||
semi_hard.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
let mut result: Vec<usize> = semi_hard.into_iter().map(|(i, _)| i).collect();
|
||||
|
||||
// If not enough semi-hard, fill with hard negatives
|
||||
if result.len() < num_select {
|
||||
let hard = self.hard_negative_selection(anchor, candidates, num_select - result.len());
|
||||
for idx in hard {
|
||||
if !result.contains(&idx) {
|
||||
result.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.truncate(num_select);
|
||||
result
|
||||
}
|
||||
|
||||
/// Distance-weighted sampling
|
||||
fn distance_weighted_selection(
|
||||
&self,
|
||||
anchor: &[f32],
|
||||
candidates: &[&[f32]],
|
||||
num_select: usize,
|
||||
) -> Vec<usize> {
|
||||
if candidates.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// Compute weights based on similarity (closer = higher weight)
|
||||
let sims: Vec<f32> = candidates
|
||||
.iter()
|
||||
.map(|c| Self::cosine_similarity(anchor, c) / self.temperature)
|
||||
.collect();
|
||||
|
||||
// Softmax weights
|
||||
let max_sim = sims.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sims: Vec<f32> = sims.iter().map(|s| (s - max_sim).exp()).collect();
|
||||
let sum_exp: f32 = exp_sims.iter().sum();
|
||||
let probs: Vec<f32> = exp_sims.iter().map(|e| e / sum_exp).collect();
|
||||
|
||||
// Sample without replacement using the probabilities
|
||||
let mut remaining: Vec<(usize, f32)> = probs.into_iter().enumerate().collect();
|
||||
let mut selected = Vec::with_capacity(num_select);
|
||||
let mut seed = 42u64;
|
||||
|
||||
while selected.len() < num_select && !remaining.is_empty() {
|
||||
// Random value
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let r = (seed as f32) / (u64::MAX as f32);
|
||||
|
||||
// Select based on cumulative probability
|
||||
let total: f32 = remaining.iter().map(|(_, p)| p).sum();
|
||||
let mut cumsum = 0.0;
|
||||
let mut select_idx = 0;
|
||||
|
||||
for (i, (_, p)) in remaining.iter().enumerate() {
|
||||
cumsum += p / total;
|
||||
if r < cumsum {
|
||||
select_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let (orig_idx, _) = remaining.remove(select_idx);
|
||||
selected.push(orig_idx);
|
||||
}
|
||||
|
||||
selected
|
||||
}
|
||||
}
|
||||
|
||||
impl NegativeMiner for HardNegativeMiner {
|
||||
fn mine(
|
||||
&self,
|
||||
anchor: &[f32],
|
||||
positive: &[f32],
|
||||
candidates: &[&[f32]],
|
||||
num_negatives: usize,
|
||||
) -> Vec<usize> {
|
||||
match self.strategy {
|
||||
MiningStrategy::Random => Self::random_selection(candidates.len(), num_negatives, 42),
|
||||
MiningStrategy::HardNegative => {
|
||||
self.hard_negative_selection(anchor, candidates, num_negatives)
|
||||
}
|
||||
MiningStrategy::SemiHard => {
|
||||
self.semi_hard_selection(anchor, positive, candidates, num_negatives)
|
||||
}
|
||||
MiningStrategy::DistanceWeighted => {
|
||||
self.distance_weighted_selection(anchor, candidates, num_negatives)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn strategy(&self) -> MiningStrategy {
|
||||
self.strategy
|
||||
}
|
||||
}
|
||||
|
||||
/// In-batch negative mining (uses other batch items as negatives)
|
||||
pub struct InBatchMiner {
|
||||
exclude_positive: bool,
|
||||
}
|
||||
|
||||
impl InBatchMiner {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
exclude_positive: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn include_positive(mut self) -> Self {
|
||||
self.exclude_positive = false;
|
||||
self
|
||||
}
|
||||
|
||||
/// Get negative indices from a batch for a given anchor index
|
||||
pub fn get_negatives(
|
||||
&self,
|
||||
anchor_idx: usize,
|
||||
positive_idx: usize,
|
||||
batch_size: usize,
|
||||
) -> Vec<usize> {
|
||||
(0..batch_size)
|
||||
.filter(|&i| i != anchor_idx && (!self.exclude_positive || i != positive_idx))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InBatchMiner {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_random_mining() {
|
||||
let miner = HardNegativeMiner::new(MiningStrategy::Random);
|
||||
|
||||
let anchor = vec![1.0, 0.0, 0.0];
|
||||
let positive = vec![0.9, 0.1, 0.0];
|
||||
let candidates: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32 * 0.05; 3]).collect();
|
||||
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
|
||||
|
||||
let selected = miner.mine(&anchor, &positive, &cand_refs, 5);
|
||||
assert_eq!(selected.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hard_negative_mining() {
|
||||
let miner = HardNegativeMiner::new(MiningStrategy::HardNegative);
|
||||
|
||||
let anchor = vec![1.0, 0.0, 0.0];
|
||||
let positive = vec![0.9, 0.1, 0.0];
|
||||
// Create candidates with varying similarity to anchor
|
||||
let candidates: Vec<Vec<f32>> = vec![
|
||||
vec![0.9, 0.1, 0.0], // Similar to anchor
|
||||
vec![0.5, 0.5, 0.0], // Medium
|
||||
vec![0.0, 1.0, 0.0], // Different
|
||||
vec![0.0, 0.0, 1.0], // Different
|
||||
];
|
||||
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
|
||||
|
||||
let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
|
||||
|
||||
// Should select the most similar ones first
|
||||
assert!(selected.contains(&0)); // Most similar
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_semi_hard_mining() {
|
||||
let miner = HardNegativeMiner::new(MiningStrategy::SemiHard).with_margin(1.0);
|
||||
|
||||
let anchor = vec![0.0, 0.0];
|
||||
let positive = vec![0.5, 0.0]; // Distance 0.5
|
||||
let candidates: Vec<Vec<f32>> = vec![
|
||||
vec![0.3, 0.0], // Too easy (d = 0.3 < 0.5)
|
||||
vec![0.7, 0.0], // Semi-hard (0.5 < 0.7 < 1.5)
|
||||
vec![1.0, 0.0], // Semi-hard
|
||||
vec![3.0, 0.0], // Too hard (d = 3.0 > 1.5)
|
||||
];
|
||||
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
|
||||
|
||||
let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
|
||||
assert!(!selected.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_weighted() {
|
||||
let miner = HardNegativeMiner::new(MiningStrategy::DistanceWeighted).with_temperature(0.5);
|
||||
|
||||
let anchor = vec![1.0, 0.0];
|
||||
let positive = vec![0.9, 0.1];
|
||||
let candidates: Vec<Vec<f32>> = (0..10).map(|i| vec![0.1 * i as f32; 2]).collect();
|
||||
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
|
||||
|
||||
let selected = miner.mine(&anchor, &positive, &cand_refs, 3);
|
||||
assert_eq!(selected.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_in_batch_miner() {
|
||||
let miner = InBatchMiner::new();
|
||||
|
||||
let negatives = miner.get_negatives(2, 5, 10);
|
||||
|
||||
assert!(!negatives.contains(&2)); // Exclude anchor
|
||||
assert!(!negatives.contains(&5)); // Exclude positive
|
||||
assert_eq!(negatives.len(), 8);
|
||||
}
|
||||
}
|
||||
42
vendor/ruvector/crates/ruvector-attention/src/training/mod.rs
vendored
Normal file
42
vendor/ruvector/crates/ruvector-attention/src/training/mod.rs
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
//! Training utilities for attention-based graph neural networks
|
||||
//!
|
||||
//! This module provides training infrastructure including:
|
||||
//! - Loss functions (InfoNCE, contrastive, spectral regularization)
|
||||
//! - Optimizers (SGD, Adam, AdamW)
|
||||
//! - Curriculum learning schedulers
|
||||
//! - Hard negative mining strategies
|
||||
|
||||
pub mod curriculum;
|
||||
pub mod loss;
|
||||
pub mod mining;
|
||||
pub mod optimizer;
|
||||
|
||||
pub use curriculum::{CurriculumScheduler, CurriculumStage, DecayType, TemperatureAnnealing};
|
||||
pub use loss::{InfoNCELoss, LocalContrastiveLoss, Loss, Reduction, SpectralRegularization};
|
||||
pub use mining::{HardNegativeMiner, MiningStrategy, NegativeMiner};
|
||||
pub use optimizer::{Adam, AdamW, Optimizer, SGD};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_training_components_integration() {
|
||||
// Test optimizer with loss
|
||||
let mut optimizer = Adam::new(128, 0.001);
|
||||
let loss = InfoNCELoss::new(0.07);
|
||||
|
||||
let mut params = vec![0.5; 128];
|
||||
let anchor = vec![1.0; 128];
|
||||
let positive = vec![0.9; 128];
|
||||
let negatives: Vec<Vec<f32>> = (0..5).map(|_| vec![0.1; 128]).collect();
|
||||
let neg_refs: Vec<&[f32]> = negatives.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let (loss_val, gradients) = loss.compute_with_gradients(&anchor, &positive, &neg_refs);
|
||||
|
||||
assert!(loss_val >= 0.0);
|
||||
assert_eq!(gradients.len(), anchor.len());
|
||||
|
||||
optimizer.step(&mut params, &gradients);
|
||||
}
|
||||
}
|
||||
400
vendor/ruvector/crates/ruvector-attention/src/training/optimizer.rs
vendored
Normal file
400
vendor/ruvector/crates/ruvector-attention/src/training/optimizer.rs
vendored
Normal file
@@ -0,0 +1,400 @@
|
||||
//! Optimizers for attention training
|
||||
//!
|
||||
//! Provides standard optimizers with momentum and adaptive learning rates.
|
||||
|
||||
/// Optimizer trait for parameter updates
|
||||
pub trait Optimizer: Send + Sync {
|
||||
/// Update parameters using gradients
|
||||
fn step(&mut self, params: &mut [f32], gradients: &[f32]);
|
||||
|
||||
/// Reset optimizer state
|
||||
fn reset(&mut self);
|
||||
|
||||
/// Get current learning rate
|
||||
fn learning_rate(&self) -> f32;
|
||||
|
||||
/// Set learning rate
|
||||
fn set_learning_rate(&mut self, lr: f32);
|
||||
}
|
||||
|
||||
/// Stochastic Gradient Descent with momentum
|
||||
pub struct SGD {
|
||||
lr: f32,
|
||||
momentum: f32,
|
||||
weight_decay: f32,
|
||||
velocity: Vec<f32>,
|
||||
nesterov: bool,
|
||||
}
|
||||
|
||||
impl SGD {
|
||||
pub fn new(dim: usize, lr: f32) -> Self {
|
||||
Self {
|
||||
lr,
|
||||
momentum: 0.0,
|
||||
weight_decay: 0.0,
|
||||
velocity: vec![0.0; dim],
|
||||
nesterov: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_momentum(mut self, momentum: f32) -> Self {
|
||||
self.momentum = momentum;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_weight_decay(mut self, wd: f32) -> Self {
|
||||
self.weight_decay = wd;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_nesterov(mut self, nesterov: bool) -> Self {
|
||||
self.nesterov = nesterov;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Optimizer for SGD {
|
||||
fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
|
||||
if self.velocity.len() != params.len() {
|
||||
self.velocity = vec![0.0; params.len()];
|
||||
}
|
||||
|
||||
for i in 0..params.len() {
|
||||
let mut g = gradients[i];
|
||||
|
||||
// Weight decay
|
||||
if self.weight_decay > 0.0 {
|
||||
g += self.weight_decay * params[i];
|
||||
}
|
||||
|
||||
// Update velocity
|
||||
self.velocity[i] = self.momentum * self.velocity[i] + g;
|
||||
|
||||
// Update parameters
|
||||
if self.nesterov {
|
||||
params[i] -= self.lr * (g + self.momentum * self.velocity[i]);
|
||||
} else {
|
||||
params[i] -= self.lr * self.velocity[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.velocity.fill(0.0);
|
||||
}
|
||||
|
||||
fn learning_rate(&self) -> f32 {
|
||||
self.lr
|
||||
}
|
||||
|
||||
fn set_learning_rate(&mut self, lr: f32) {
|
||||
self.lr = lr;
|
||||
}
|
||||
}
|
||||
|
||||
/// Adam optimizer with bias correction
|
||||
pub struct Adam {
|
||||
lr: f32,
|
||||
beta1: f32,
|
||||
beta2: f32,
|
||||
epsilon: f32,
|
||||
weight_decay: f32,
|
||||
m: Vec<f32>, // First moment
|
||||
v: Vec<f32>, // Second moment
|
||||
t: usize, // Timestep
|
||||
}
|
||||
|
||||
impl Adam {
|
||||
pub fn new(dim: usize, lr: f32) -> Self {
|
||||
Self {
|
||||
lr,
|
||||
beta1: 0.9,
|
||||
beta2: 0.999,
|
||||
epsilon: 1e-8,
|
||||
weight_decay: 0.0,
|
||||
m: vec![0.0; dim],
|
||||
v: vec![0.0; dim],
|
||||
t: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
|
||||
self.beta1 = beta1;
|
||||
self.beta2 = beta2;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_epsilon(mut self, eps: f32) -> Self {
|
||||
self.epsilon = eps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_weight_decay(mut self, wd: f32) -> Self {
|
||||
self.weight_decay = wd;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Optimizer for Adam {
|
||||
fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
|
||||
if self.m.len() != params.len() {
|
||||
self.m = vec![0.0; params.len()];
|
||||
self.v = vec![0.0; params.len()];
|
||||
}
|
||||
|
||||
self.t += 1;
|
||||
let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
|
||||
let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
|
||||
|
||||
for i in 0..params.len() {
|
||||
let g = gradients[i];
|
||||
|
||||
// Update moments
|
||||
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
|
||||
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
|
||||
|
||||
// Bias-corrected estimates
|
||||
let m_hat = self.m[i] / bias_correction1;
|
||||
let v_hat = self.v[i] / bias_correction2;
|
||||
|
||||
// Update with optional weight decay
|
||||
let update = m_hat / (v_hat.sqrt() + self.epsilon);
|
||||
params[i] -= self.lr * (update + self.weight_decay * params[i]);
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.m.fill(0.0);
|
||||
self.v.fill(0.0);
|
||||
self.t = 0;
|
||||
}
|
||||
|
||||
fn learning_rate(&self) -> f32 {
|
||||
self.lr
|
||||
}
|
||||
|
||||
fn set_learning_rate(&mut self, lr: f32) {
|
||||
self.lr = lr;
|
||||
}
|
||||
}
|
||||
|
||||
/// AdamW optimizer (decoupled weight decay)
|
||||
pub struct AdamW {
|
||||
inner: Adam,
|
||||
weight_decay: f32,
|
||||
}
|
||||
|
||||
impl AdamW {
|
||||
pub fn new(dim: usize, lr: f32) -> Self {
|
||||
Self {
|
||||
inner: Adam::new(dim, lr),
|
||||
weight_decay: 0.01,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_weight_decay(mut self, wd: f32) -> Self {
|
||||
self.weight_decay = wd;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
|
||||
self.inner = self.inner.with_betas(beta1, beta2);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Optimizer for AdamW {
|
||||
fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
|
||||
if self.inner.m.len() != params.len() {
|
||||
self.inner.m = vec![0.0; params.len()];
|
||||
self.inner.v = vec![0.0; params.len()];
|
||||
}
|
||||
|
||||
self.inner.t += 1;
|
||||
let bias_correction1 = 1.0 - self.inner.beta1.powi(self.inner.t as i32);
|
||||
let bias_correction2 = 1.0 - self.inner.beta2.powi(self.inner.t as i32);
|
||||
|
||||
for i in 0..params.len() {
|
||||
let g = gradients[i];
|
||||
|
||||
// Update moments
|
||||
self.inner.m[i] = self.inner.beta1 * self.inner.m[i] + (1.0 - self.inner.beta1) * g;
|
||||
self.inner.v[i] = self.inner.beta2 * self.inner.v[i] + (1.0 - self.inner.beta2) * g * g;
|
||||
|
||||
// Bias-corrected estimates
|
||||
let m_hat = self.inner.m[i] / bias_correction1;
|
||||
let v_hat = self.inner.v[i] / bias_correction2;
|
||||
|
||||
// Decoupled weight decay (applied to params directly, not through gradient)
|
||||
params[i] *= 1.0 - self.inner.lr * self.weight_decay;
|
||||
|
||||
// Adam update
|
||||
params[i] -= self.inner.lr * m_hat / (v_hat.sqrt() + self.inner.epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
fn learning_rate(&self) -> f32 {
|
||||
self.inner.lr
|
||||
}
|
||||
|
||||
fn set_learning_rate(&mut self, lr: f32) {
|
||||
self.inner.lr = lr;
|
||||
}
|
||||
}
|
||||
|
||||
/// Learning rate scheduler
|
||||
pub struct LearningRateScheduler {
|
||||
initial_lr: f32,
|
||||
warmup_steps: usize,
|
||||
decay_steps: usize,
|
||||
min_lr: f32,
|
||||
current_step: usize,
|
||||
}
|
||||
|
||||
impl LearningRateScheduler {
|
||||
pub fn new(initial_lr: f32) -> Self {
|
||||
Self {
|
||||
initial_lr,
|
||||
warmup_steps: 0,
|
||||
decay_steps: 100000,
|
||||
min_lr: 1e-7,
|
||||
current_step: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_warmup(mut self, steps: usize) -> Self {
|
||||
self.warmup_steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_decay(mut self, steps: usize) -> Self {
|
||||
self.decay_steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_min_lr(mut self, min_lr: f32) -> Self {
|
||||
self.min_lr = min_lr;
|
||||
self
|
||||
}
|
||||
|
||||
/// Get current learning rate and advance step
|
||||
pub fn step(&mut self) -> f32 {
|
||||
let lr = self.get_lr();
|
||||
self.current_step += 1;
|
||||
lr
|
||||
}
|
||||
|
||||
/// Get learning rate without advancing
|
||||
pub fn get_lr(&self) -> f32 {
|
||||
if self.current_step < self.warmup_steps {
|
||||
// Linear warmup
|
||||
self.initial_lr * (self.current_step + 1) as f32 / self.warmup_steps as f32
|
||||
} else {
|
||||
// Cosine decay
|
||||
let progress = (self.current_step - self.warmup_steps) as f32 / self.decay_steps as f32;
|
||||
let decay = 0.5 * (1.0 + (std::f32::consts::PI * progress.min(1.0)).cos());
|
||||
self.min_lr + (self.initial_lr - self.min_lr) * decay
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset scheduler
|
||||
pub fn reset(&mut self) {
|
||||
self.current_step = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sgd() {
|
||||
let mut opt = SGD::new(4, 0.1);
|
||||
let mut params = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let gradients = vec![0.1, 0.2, 0.3, 0.4];
|
||||
|
||||
opt.step(&mut params, &gradients);
|
||||
|
||||
assert!(params[0] < 1.0);
|
||||
assert!(params[1] < 2.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sgd_momentum() {
|
||||
let mut opt = SGD::new(4, 0.1).with_momentum(0.9);
|
||||
let mut params = vec![1.0; 4];
|
||||
let gradients = vec![1.0; 4];
|
||||
|
||||
// Multiple steps should accumulate momentum
|
||||
for _ in 0..5 {
|
||||
opt.step(&mut params, &gradients);
|
||||
}
|
||||
|
||||
assert!(params[0] < 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adam() {
|
||||
let mut opt = Adam::new(64, 0.001);
|
||||
let mut params = vec![0.5; 64];
|
||||
let gradients = vec![0.1; 64];
|
||||
|
||||
for _ in 0..100 {
|
||||
opt.step(&mut params, &gradients);
|
||||
}
|
||||
|
||||
// Should have moved toward 0
|
||||
assert!(params[0] < 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adamw() {
|
||||
let mut opt = AdamW::new(32, 0.001).with_weight_decay(0.01);
|
||||
let mut params = vec![1.0; 32];
|
||||
let gradients = vec![0.0; 32]; // No gradient, only weight decay
|
||||
|
||||
for _ in 0..100 {
|
||||
opt.step(&mut params, &gradients);
|
||||
}
|
||||
|
||||
// Weight decay should shrink params
|
||||
assert!(params[0] < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lr_scheduler_warmup() {
|
||||
let mut scheduler = LearningRateScheduler::new(0.001).with_warmup(100);
|
||||
|
||||
let lr_start = scheduler.step();
|
||||
assert!(lr_start < 0.001); // Still warming up
|
||||
|
||||
for _ in 0..99 {
|
||||
scheduler.step();
|
||||
}
|
||||
|
||||
let lr_end_warmup = scheduler.get_lr();
|
||||
assert!((lr_end_warmup - 0.001).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lr_scheduler_decay() {
|
||||
let mut scheduler = LearningRateScheduler::new(0.001)
|
||||
.with_warmup(0)
|
||||
.with_decay(100)
|
||||
.with_min_lr(0.0001);
|
||||
|
||||
let lr_start = scheduler.step();
|
||||
assert!((lr_start - 0.001).abs() < 1e-5);
|
||||
|
||||
for _ in 0..100 {
|
||||
scheduler.step();
|
||||
}
|
||||
|
||||
let lr_end = scheduler.get_lr();
|
||||
assert!((lr_end - 0.0001).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user