Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,356 @@
//! Curriculum learning for attention training
//!
//! Provides schedulers for progressive training difficulty.
/// Decay type for temperature/parameter annealing
#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub enum DecayType {
#[default]
Linear,
Exponential,
Cosine,
Step,
}
/// Curriculum learning stage
#[derive(Clone, Debug)]
pub struct CurriculumStage {
pub name: String,
pub difficulty: f32, // 0.0 = easy, 1.0 = hard
pub duration: usize, // Steps in this stage
pub temperature: f32, // Softmax temperature
pub negative_count: usize, // Number of negatives
}
impl CurriculumStage {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
difficulty: 0.5,
duration: 1000,
temperature: 1.0,
negative_count: 10,
}
}
pub fn difficulty(mut self, d: f32) -> Self {
self.difficulty = d.clamp(0.0, 1.0);
self
}
pub fn duration(mut self, d: usize) -> Self {
self.duration = d;
self
}
pub fn temperature(mut self, t: f32) -> Self {
self.temperature = t.max(0.01);
self
}
pub fn negative_count(mut self, n: usize) -> Self {
self.negative_count = n.max(1);
self
}
}
/// Curriculum scheduler for progressive training
pub struct CurriculumScheduler {
stages: Vec<CurriculumStage>,
current_stage: usize,
steps_in_stage: usize,
total_steps: usize,
}
impl CurriculumScheduler {
pub fn new() -> Self {
Self {
stages: Vec::new(),
current_stage: 0,
steps_in_stage: 0,
total_steps: 0,
}
}
/// Add a stage to the curriculum
pub fn add_stage(mut self, stage: CurriculumStage) -> Self {
self.stages.push(stage);
self
}
/// Build a default easy-to-hard curriculum
pub fn default_curriculum(total_steps: usize) -> Self {
let stage_duration = total_steps / 4;
Self::new()
.add_stage(
CurriculumStage::new("warm_up")
.difficulty(0.1)
.duration(stage_duration)
.temperature(2.0)
.negative_count(5),
)
.add_stage(
CurriculumStage::new("easy")
.difficulty(0.3)
.duration(stage_duration)
.temperature(1.0)
.negative_count(10),
)
.add_stage(
CurriculumStage::new("medium")
.difficulty(0.6)
.duration(stage_duration)
.temperature(0.5)
.negative_count(20),
)
.add_stage(
CurriculumStage::new("hard")
.difficulty(1.0)
.duration(stage_duration)
.temperature(0.1)
.negative_count(50),
)
}
/// Get current stage
pub fn current_stage(&self) -> Option<&CurriculumStage> {
self.stages.get(self.current_stage)
}
/// Advance one step and return current stage
pub fn step(&mut self) -> Option<&CurriculumStage> {
if self.stages.is_empty() {
return None;
}
self.steps_in_stage += 1;
self.total_steps += 1;
// Check if we should advance to next stage
if let Some(stage) = self.stages.get(self.current_stage) {
if self.steps_in_stage >= stage.duration && self.current_stage < self.stages.len() - 1 {
self.current_stage += 1;
self.steps_in_stage = 0;
}
}
self.current_stage()
}
/// Get current difficulty (0.0 to 1.0)
pub fn difficulty(&self) -> f32 {
self.current_stage().map(|s| s.difficulty).unwrap_or(1.0)
}
/// Get current temperature
pub fn temperature(&self) -> f32 {
self.current_stage().map(|s| s.temperature).unwrap_or(1.0)
}
/// Get current negative count
pub fn negative_count(&self) -> usize {
self.current_stage().map(|s| s.negative_count).unwrap_or(10)
}
/// Check if training is complete
pub fn is_complete(&self) -> bool {
if self.stages.is_empty() {
return true;
}
self.current_stage >= self.stages.len() - 1
&& self.steps_in_stage >= self.stages.last().map(|s| s.duration).unwrap_or(0)
}
/// Get progress (0.0 to 1.0)
pub fn progress(&self) -> f32 {
let total_duration: usize = self.stages.iter().map(|s| s.duration).sum();
if total_duration == 0 {
return 1.0;
}
self.total_steps as f32 / total_duration as f32
}
/// Reset curriculum
pub fn reset(&mut self) {
self.current_stage = 0;
self.steps_in_stage = 0;
self.total_steps = 0;
}
}
impl Default for CurriculumScheduler {
fn default() -> Self {
Self::new()
}
}
/// Temperature annealing scheduler
pub struct TemperatureAnnealing {
initial_temp: f32,
final_temp: f32,
total_steps: usize,
current_step: usize,
decay_type: DecayType,
step_size: usize, // For step decay
}
impl TemperatureAnnealing {
pub fn new(initial: f32, final_temp: f32, steps: usize) -> Self {
Self {
initial_temp: initial,
final_temp: final_temp,
total_steps: steps,
current_step: 0,
decay_type: DecayType::Linear,
step_size: steps / 10,
}
}
pub fn with_decay(mut self, decay: DecayType) -> Self {
self.decay_type = decay;
self
}
pub fn with_step_size(mut self, size: usize) -> Self {
self.step_size = size;
self
}
/// Get current temperature and advance
pub fn step(&mut self) -> f32 {
let temp = self.get_temp();
self.current_step += 1;
temp
}
/// Get current temperature without advancing
pub fn get_temp(&self) -> f32 {
if self.current_step >= self.total_steps {
return self.final_temp;
}
let progress = self.current_step as f32 / self.total_steps as f32;
let range = self.initial_temp - self.final_temp;
match self.decay_type {
DecayType::Linear => self.initial_temp - range * progress,
DecayType::Exponential => {
let decay_rate =
(self.final_temp / self.initial_temp).ln() / self.total_steps as f32;
self.initial_temp * (decay_rate * self.current_step as f32).exp()
}
DecayType::Cosine => {
self.final_temp + 0.5 * range * (1.0 + (std::f32::consts::PI * progress).cos())
}
DecayType::Step => {
let num_steps = self.current_step / self.step_size.max(1);
let step_decay =
range * num_steps as f32 / (self.total_steps / self.step_size.max(1)) as f32;
(self.initial_temp - step_decay).max(self.final_temp)
}
}
}
/// Reset annealing
pub fn reset(&mut self) {
self.current_step = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_curriculum_stages() {
let mut curriculum = CurriculumScheduler::new()
.add_stage(CurriculumStage::new("easy").duration(10).difficulty(0.2))
.add_stage(CurriculumStage::new("hard").duration(10).difficulty(0.8));
assert_eq!(curriculum.current_stage().unwrap().name, "easy");
assert!((curriculum.difficulty() - 0.2).abs() < 1e-5);
// Progress through first stage
for _ in 0..10 {
curriculum.step();
}
assert_eq!(curriculum.current_stage().unwrap().name, "hard");
assert!((curriculum.difficulty() - 0.8).abs() < 1e-5);
}
#[test]
fn test_default_curriculum() {
let mut curriculum = CurriculumScheduler::default_curriculum(400);
assert_eq!(curriculum.stages.len(), 4);
assert_eq!(curriculum.current_stage().unwrap().name, "warm_up");
// Progress to end
for _ in 0..400 {
curriculum.step();
}
assert!(curriculum.is_complete());
}
#[test]
fn test_temperature_linear() {
let mut annealing = TemperatureAnnealing::new(1.0, 0.1, 100);
let temp_start = annealing.step();
assert!((temp_start - 1.0).abs() < 0.1);
for _ in 0..99 {
annealing.step();
}
let temp_end = annealing.get_temp();
assert!((temp_end - 0.1).abs() < 0.1);
}
#[test]
fn test_temperature_cosine() {
let mut annealing = TemperatureAnnealing::new(1.0, 0.0, 100).with_decay(DecayType::Cosine);
// Halfway should be approximately middle value
for _ in 0..50 {
annealing.step();
}
let temp_mid = annealing.get_temp();
assert!(temp_mid > 0.4 && temp_mid < 0.6);
}
#[test]
fn test_temperature_step() {
let mut annealing = TemperatureAnnealing::new(1.0, 0.0, 100)
.with_decay(DecayType::Step)
.with_step_size(25);
let temp_0 = annealing.get_temp();
for _ in 0..25 {
annealing.step();
}
let temp_25 = annealing.get_temp();
// Should have dropped
assert!(temp_25 < temp_0);
}
#[test]
fn test_curriculum_progress() {
let mut curriculum = CurriculumScheduler::new()
.add_stage(CurriculumStage::new("stage1").duration(50))
.add_stage(CurriculumStage::new("stage2").duration(50));
assert!((curriculum.progress() - 0.0).abs() < 1e-5);
for _ in 0..50 {
curriculum.step();
}
assert!((curriculum.progress() - 0.5).abs() < 0.05);
}
}

View File

@@ -0,0 +1,359 @@
//! Loss functions for attention-based learning
//!
//! Includes contrastive losses optimized for representation learning.
/// Reduction method for loss computation
#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub enum Reduction {
#[default]
Mean,
Sum,
None,
}
/// Loss trait for attention training
pub trait Loss: Send + Sync {
/// Compute loss value
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32;
/// Compute loss with gradients for anchor
fn compute_with_gradients(
&self,
anchor: &[f32],
positive: &[f32],
negatives: &[&[f32]],
) -> (f32, Vec<f32>);
}
/// InfoNCE contrastive loss
///
/// L = -log(exp(sim(a,p)/τ) / Σexp(sim(a,n)/τ))
pub struct InfoNCELoss {
temperature: f32,
}
impl InfoNCELoss {
pub fn new(temperature: f32) -> Self {
Self {
temperature: temperature.max(0.01),
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
dot / (norm_a * norm_b)
}
}
impl Loss for InfoNCELoss {
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
let pos_sim = Self::cosine_similarity(anchor, positive) / self.temperature;
let neg_sims: Vec<f32> = negatives
.iter()
.map(|n| Self::cosine_similarity(anchor, n) / self.temperature)
.collect();
// Stable log-sum-exp
let max_sim = neg_sims
.iter()
.copied()
.chain(std::iter::once(pos_sim))
.fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 =
neg_sims.iter().map(|s| (s - max_sim).exp()).sum::<f32>() + (pos_sim - max_sim).exp();
let log_sum_exp = max_sim + sum_exp.ln();
log_sum_exp - pos_sim
}
fn compute_with_gradients(
&self,
anchor: &[f32],
positive: &[f32],
negatives: &[&[f32]],
) -> (f32, Vec<f32>) {
let dim = anchor.len();
let pos_sim = Self::cosine_similarity(anchor, positive) / self.temperature;
let neg_sims: Vec<f32> = negatives
.iter()
.map(|n| Self::cosine_similarity(anchor, n) / self.temperature)
.collect();
// Compute softmax weights
let max_sim = neg_sims
.iter()
.copied()
.chain(std::iter::once(pos_sim))
.fold(f32::NEG_INFINITY, f32::max);
let pos_exp = (pos_sim - max_sim).exp();
let neg_exps: Vec<f32> = neg_sims.iter().map(|s| (s - max_sim).exp()).collect();
let total_exp: f32 = pos_exp + neg_exps.iter().sum::<f32>();
let pos_weight = pos_exp / total_exp;
let neg_weights: Vec<f32> = neg_exps.iter().map(|e| e / total_exp).collect();
// Loss value
let loss = -(pos_weight.ln());
// Gradient with respect to anchor
// ∂L/∂anchor = (p_pos - 1) * ∂sim(a,p)/∂a + Σ p_neg_i * ∂sim(a,n_i)/∂a
let norm_a: f32 = anchor.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
let norm_p: f32 = positive.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
let mut gradients = vec![0.0f32; dim];
// Gradient from positive
let dot_ap: f32 = anchor.iter().zip(positive.iter()).map(|(a, p)| a * p).sum();
for i in 0..dim {
let d_sim = (positive[i] / (norm_a * norm_p))
- (anchor[i] * dot_ap / (norm_a.powi(3) * norm_p));
gradients[i] += (pos_weight - 1.0) * d_sim / self.temperature;
}
// Gradient from negatives
for (neg, &weight) in negatives.iter().zip(neg_weights.iter()) {
let norm_n: f32 = neg.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
let dot_an: f32 = anchor.iter().zip(neg.iter()).map(|(a, n)| a * n).sum();
for i in 0..dim {
let d_sim =
(neg[i] / (norm_a * norm_n)) - (anchor[i] * dot_an / (norm_a.powi(3) * norm_n));
gradients[i] += weight * d_sim / self.temperature;
}
}
(loss, gradients)
}
}
/// Local contrastive loss for neighborhood preservation
pub struct LocalContrastiveLoss {
margin: f32,
reduction: Reduction,
}
impl LocalContrastiveLoss {
pub fn new(margin: f32) -> Self {
Self {
margin,
reduction: Reduction::Mean,
}
}
pub fn with_reduction(mut self, reduction: Reduction) -> Self {
self.reduction = reduction;
self
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
}
impl Loss for LocalContrastiveLoss {
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
let d_pos = Self::euclidean_distance(anchor, positive);
let losses: Vec<f32> = negatives
.iter()
.map(|neg| {
let d_neg = Self::euclidean_distance(anchor, neg);
(d_pos - d_neg + self.margin).max(0.0)
})
.collect();
match self.reduction {
Reduction::Mean => losses.iter().sum::<f32>() / losses.len().max(1) as f32,
Reduction::Sum => losses.iter().sum(),
Reduction::None => losses.first().copied().unwrap_or(0.0),
}
}
fn compute_with_gradients(
&self,
anchor: &[f32],
positive: &[f32],
negatives: &[&[f32]],
) -> (f32, Vec<f32>) {
let dim = anchor.len();
let d_pos = Self::euclidean_distance(anchor, positive);
let mut total_loss = 0.0f32;
let mut gradients = vec![0.0f32; dim];
let mut active_count = 0;
for neg in negatives.iter() {
let d_neg = Self::euclidean_distance(anchor, neg);
let margin_loss = d_pos - d_neg + self.margin;
if margin_loss > 0.0 {
total_loss += margin_loss;
active_count += 1;
// Gradient: ∂L/∂a = (a - p)/d_pos - (a - n)/d_neg
for i in 0..dim {
if d_pos > 1e-8 {
gradients[i] += (anchor[i] - positive[i]) / d_pos;
}
if d_neg > 1e-8 {
gradients[i] -= (anchor[i] - neg[i]) / d_neg;
}
}
}
}
let loss = match self.reduction {
Reduction::Mean if active_count > 0 => {
gradients.iter_mut().for_each(|g| *g /= active_count as f32);
total_loss / active_count as f32
}
Reduction::Sum => total_loss,
_ => total_loss / negatives.len().max(1) as f32,
};
(loss, gradients)
}
}
/// Spectral regularization for smooth representations
pub struct SpectralRegularization {
weight: f32,
}
impl SpectralRegularization {
pub fn new(weight: f32) -> Self {
Self { weight }
}
/// Compute spectral norm regularization for a batch of embeddings
pub fn compute_batch(&self, embeddings: &[&[f32]]) -> f32 {
if embeddings.is_empty() {
return 0.0;
}
let dim = embeddings[0].len();
let n = embeddings.len();
// Compute covariance matrix diagonal approximation
let mut var_sum = 0.0f32;
for d in 0..dim {
let mean: f32 = embeddings.iter().map(|e| e[d]).sum::<f32>() / n as f32;
let var: f32 = embeddings
.iter()
.map(|e| (e[d] - mean).powi(2))
.sum::<f32>()
/ n as f32;
var_sum += var;
}
// Regularization: encourage uniform variance across dimensions
let avg_var = var_sum / dim as f32;
let var_of_var: f32 = {
let mut sum = 0.0;
for d in 0..dim {
let mean: f32 = embeddings.iter().map(|e| e[d]).sum::<f32>() / n as f32;
let var: f32 = embeddings
.iter()
.map(|e| (e[d] - mean).powi(2))
.sum::<f32>()
/ n as f32;
sum += (var - avg_var).powi(2);
}
sum / dim as f32
};
self.weight * var_of_var
}
}
impl Loss for SpectralRegularization {
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
let mut all_embeddings: Vec<&[f32]> = Vec::with_capacity(2 + negatives.len());
all_embeddings.push(anchor);
all_embeddings.push(positive);
all_embeddings.extend(negatives.iter().copied());
self.compute_batch(&all_embeddings)
}
fn compute_with_gradients(
&self,
anchor: &[f32],
positive: &[f32],
negatives: &[&[f32]],
) -> (f32, Vec<f32>) {
let loss = self.compute(anchor, positive, negatives);
// Simplified: no gradient for spectral reg (typically used as auxiliary)
let gradients = vec![0.0f32; anchor.len()];
(loss, gradients)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_infonce_loss() {
let loss = InfoNCELoss::new(0.07);
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
let negatives: Vec<Vec<f32>> = vec![vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0]];
let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
let loss_val = loss.compute(&anchor, &positive, &neg_refs);
assert!(loss_val >= 0.0);
}
#[test]
fn test_infonce_gradients() {
let loss = InfoNCELoss::new(0.1);
let anchor = vec![0.5; 64];
let positive = vec![0.6; 64];
let negatives: Vec<Vec<f32>> = vec![vec![0.1; 64]; 5];
let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
let (loss_val, grads) = loss.compute_with_gradients(&anchor, &positive, &neg_refs);
assert!(loss_val >= 0.0);
assert_eq!(grads.len(), 64);
}
#[test]
fn test_local_contrastive() {
let loss = LocalContrastiveLoss::new(1.0);
let anchor = vec![0.0, 0.0];
let positive = vec![0.1, 0.0]; // Close
let negatives: Vec<Vec<f32>> = vec![vec![2.0, 0.0], vec![0.0, 2.0]]; // Far
let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
let loss_val = loss.compute(&anchor, &positive, &neg_refs);
assert!(loss_val >= 0.0);
}
#[test]
fn test_spectral_regularization() {
let reg = SpectralRegularization::new(0.01);
let embeddings: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.1; 32]).collect();
let emb_refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
let loss_val = reg.compute_batch(&emb_refs);
assert!(loss_val >= 0.0);
}
}

View File

@@ -0,0 +1,351 @@
//! Hard negative mining strategies
//!
//! Provides various methods for selecting informative negative samples.
/// Mining strategy enumeration
#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub enum MiningStrategy {
#[default]
Random,
HardNegative,
SemiHard,
DistanceWeighted,
}
/// Trait for negative sample mining
pub trait NegativeMiner: Send + Sync {
/// Mine negatives for an anchor from a candidate pool
fn mine(
&self,
anchor: &[f32],
positive: &[f32],
candidates: &[&[f32]],
num_negatives: usize,
) -> Vec<usize>;
/// Get mining strategy
fn strategy(&self) -> MiningStrategy;
}
/// Hard negative miner that selects closest negatives
pub struct HardNegativeMiner {
strategy: MiningStrategy,
margin: f32,
temperature: f32,
}
impl HardNegativeMiner {
pub fn new(strategy: MiningStrategy) -> Self {
Self {
strategy,
margin: 0.1,
temperature: 1.0,
}
}
pub fn with_margin(mut self, margin: f32) -> Self {
self.margin = margin;
self
}
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = temp;
self
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
dot / (norm_a * norm_b)
}
/// Select random indices
fn random_selection(num_candidates: usize, num_select: usize, seed: u64) -> Vec<usize> {
let mut indices: Vec<usize> = (0..num_candidates).collect();
let mut current_seed = seed;
// Fisher-Yates shuffle
for i in (1..indices.len()).rev() {
current_seed = current_seed
.wrapping_mul(6364136223846793005)
.wrapping_add(1);
let j = (current_seed as usize) % (i + 1);
indices.swap(i, j);
}
indices.truncate(num_select.min(num_candidates));
indices
}
/// Select hardest negatives (closest to anchor)
fn hard_negative_selection(
&self,
anchor: &[f32],
candidates: &[&[f32]],
num_select: usize,
) -> Vec<usize> {
let mut indexed_sims: Vec<(usize, f32)> = candidates
.iter()
.enumerate()
.map(|(i, c)| (i, Self::cosine_similarity(anchor, c)))
.collect();
// Sort by similarity descending (higher sim = harder negative)
indexed_sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed_sims
.into_iter()
.take(num_select.min(candidates.len()))
.map(|(i, _)| i)
.collect()
}
/// Select semi-hard negatives (within margin of positive)
fn semi_hard_selection(
&self,
anchor: &[f32],
positive: &[f32],
candidates: &[&[f32]],
num_select: usize,
) -> Vec<usize> {
let d_pos = Self::euclidean_distance(anchor, positive);
let mut semi_hard: Vec<(usize, f32)> = candidates
.iter()
.enumerate()
.filter_map(|(i, c)| {
let d_neg = Self::euclidean_distance(anchor, c);
// Semi-hard: d_pos < d_neg < d_pos + margin
if d_neg > d_pos && d_neg < d_pos + self.margin {
Some((i, d_neg))
} else {
None
}
})
.collect();
// Sort by distance (prefer harder ones)
semi_hard.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut result: Vec<usize> = semi_hard.into_iter().map(|(i, _)| i).collect();
// If not enough semi-hard, fill with hard negatives
if result.len() < num_select {
let hard = self.hard_negative_selection(anchor, candidates, num_select - result.len());
for idx in hard {
if !result.contains(&idx) {
result.push(idx);
}
}
}
result.truncate(num_select);
result
}
/// Distance-weighted sampling
fn distance_weighted_selection(
&self,
anchor: &[f32],
candidates: &[&[f32]],
num_select: usize,
) -> Vec<usize> {
if candidates.is_empty() {
return vec![];
}
// Compute weights based on similarity (closer = higher weight)
let sims: Vec<f32> = candidates
.iter()
.map(|c| Self::cosine_similarity(anchor, c) / self.temperature)
.collect();
// Softmax weights
let max_sim = sims.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_sims: Vec<f32> = sims.iter().map(|s| (s - max_sim).exp()).collect();
let sum_exp: f32 = exp_sims.iter().sum();
let probs: Vec<f32> = exp_sims.iter().map(|e| e / sum_exp).collect();
// Sample without replacement using the probabilities
let mut remaining: Vec<(usize, f32)> = probs.into_iter().enumerate().collect();
let mut selected = Vec::with_capacity(num_select);
let mut seed = 42u64;
while selected.len() < num_select && !remaining.is_empty() {
// Random value
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let r = (seed as f32) / (u64::MAX as f32);
// Select based on cumulative probability
let total: f32 = remaining.iter().map(|(_, p)| p).sum();
let mut cumsum = 0.0;
let mut select_idx = 0;
for (i, (_, p)) in remaining.iter().enumerate() {
cumsum += p / total;
if r < cumsum {
select_idx = i;
break;
}
}
let (orig_idx, _) = remaining.remove(select_idx);
selected.push(orig_idx);
}
selected
}
}
impl NegativeMiner for HardNegativeMiner {
fn mine(
&self,
anchor: &[f32],
positive: &[f32],
candidates: &[&[f32]],
num_negatives: usize,
) -> Vec<usize> {
match self.strategy {
MiningStrategy::Random => Self::random_selection(candidates.len(), num_negatives, 42),
MiningStrategy::HardNegative => {
self.hard_negative_selection(anchor, candidates, num_negatives)
}
MiningStrategy::SemiHard => {
self.semi_hard_selection(anchor, positive, candidates, num_negatives)
}
MiningStrategy::DistanceWeighted => {
self.distance_weighted_selection(anchor, candidates, num_negatives)
}
}
}
fn strategy(&self) -> MiningStrategy {
self.strategy
}
}
/// In-batch negative mining (uses other batch items as negatives)
pub struct InBatchMiner {
exclude_positive: bool,
}
impl InBatchMiner {
pub fn new() -> Self {
Self {
exclude_positive: true,
}
}
pub fn include_positive(mut self) -> Self {
self.exclude_positive = false;
self
}
/// Get negative indices from a batch for a given anchor index
pub fn get_negatives(
&self,
anchor_idx: usize,
positive_idx: usize,
batch_size: usize,
) -> Vec<usize> {
(0..batch_size)
.filter(|&i| i != anchor_idx && (!self.exclude_positive || i != positive_idx))
.collect()
}
}
impl Default for InBatchMiner {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_random_mining() {
let miner = HardNegativeMiner::new(MiningStrategy::Random);
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
let candidates: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32 * 0.05; 3]).collect();
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
let selected = miner.mine(&anchor, &positive, &cand_refs, 5);
assert_eq!(selected.len(), 5);
}
#[test]
fn test_hard_negative_mining() {
let miner = HardNegativeMiner::new(MiningStrategy::HardNegative);
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
// Create candidates with varying similarity to anchor
let candidates: Vec<Vec<f32>> = vec![
vec![0.9, 0.1, 0.0], // Similar to anchor
vec![0.5, 0.5, 0.0], // Medium
vec![0.0, 1.0, 0.0], // Different
vec![0.0, 0.0, 1.0], // Different
];
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
// Should select the most similar ones first
assert!(selected.contains(&0)); // Most similar
}
#[test]
fn test_semi_hard_mining() {
let miner = HardNegativeMiner::new(MiningStrategy::SemiHard).with_margin(1.0);
let anchor = vec![0.0, 0.0];
let positive = vec![0.5, 0.0]; // Distance 0.5
let candidates: Vec<Vec<f32>> = vec![
vec![0.3, 0.0], // Too easy (d = 0.3 < 0.5)
vec![0.7, 0.0], // Semi-hard (0.5 < 0.7 < 1.5)
vec![1.0, 0.0], // Semi-hard
vec![3.0, 0.0], // Too hard (d = 3.0 > 1.5)
];
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
assert!(!selected.is_empty());
}
#[test]
fn test_distance_weighted() {
let miner = HardNegativeMiner::new(MiningStrategy::DistanceWeighted).with_temperature(0.5);
let anchor = vec![1.0, 0.0];
let positive = vec![0.9, 0.1];
let candidates: Vec<Vec<f32>> = (0..10).map(|i| vec![0.1 * i as f32; 2]).collect();
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
let selected = miner.mine(&anchor, &positive, &cand_refs, 3);
assert_eq!(selected.len(), 3);
}
#[test]
fn test_in_batch_miner() {
let miner = InBatchMiner::new();
let negatives = miner.get_negatives(2, 5, 10);
assert!(!negatives.contains(&2)); // Exclude anchor
assert!(!negatives.contains(&5)); // Exclude positive
assert_eq!(negatives.len(), 8);
}
}

View File

@@ -0,0 +1,42 @@
//! Training utilities for attention-based graph neural networks
//!
//! This module provides training infrastructure including:
//! - Loss functions (InfoNCE, contrastive, spectral regularization)
//! - Optimizers (SGD, Adam, AdamW)
//! - Curriculum learning schedulers
//! - Hard negative mining strategies
pub mod curriculum;
pub mod loss;
pub mod mining;
pub mod optimizer;
pub use curriculum::{CurriculumScheduler, CurriculumStage, DecayType, TemperatureAnnealing};
pub use loss::{InfoNCELoss, LocalContrastiveLoss, Loss, Reduction, SpectralRegularization};
pub use mining::{HardNegativeMiner, MiningStrategy, NegativeMiner};
pub use optimizer::{Adam, AdamW, Optimizer, SGD};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_training_components_integration() {
// Test optimizer with loss
let mut optimizer = Adam::new(128, 0.001);
let loss = InfoNCELoss::new(0.07);
let mut params = vec![0.5; 128];
let anchor = vec![1.0; 128];
let positive = vec![0.9; 128];
let negatives: Vec<Vec<f32>> = (0..5).map(|_| vec![0.1; 128]).collect();
let neg_refs: Vec<&[f32]> = negatives.iter().map(|v| v.as_slice()).collect();
let (loss_val, gradients) = loss.compute_with_gradients(&anchor, &positive, &neg_refs);
assert!(loss_val >= 0.0);
assert_eq!(gradients.len(), anchor.len());
optimizer.step(&mut params, &gradients);
}
}

View File

@@ -0,0 +1,400 @@
//! Optimizers for attention training
//!
//! Provides standard optimizers with momentum and adaptive learning rates.
/// Optimizer trait for parameter updates
pub trait Optimizer: Send + Sync {
/// Update parameters using gradients
fn step(&mut self, params: &mut [f32], gradients: &[f32]);
/// Reset optimizer state
fn reset(&mut self);
/// Get current learning rate
fn learning_rate(&self) -> f32;
/// Set learning rate
fn set_learning_rate(&mut self, lr: f32);
}
/// Stochastic Gradient Descent with momentum
pub struct SGD {
lr: f32,
momentum: f32,
weight_decay: f32,
velocity: Vec<f32>,
nesterov: bool,
}
impl SGD {
pub fn new(dim: usize, lr: f32) -> Self {
Self {
lr,
momentum: 0.0,
weight_decay: 0.0,
velocity: vec![0.0; dim],
nesterov: false,
}
}
pub fn with_momentum(mut self, momentum: f32) -> Self {
self.momentum = momentum;
self
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
pub fn with_nesterov(mut self, nesterov: bool) -> Self {
self.nesterov = nesterov;
self
}
}
impl Optimizer for SGD {
fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
if self.velocity.len() != params.len() {
self.velocity = vec![0.0; params.len()];
}
for i in 0..params.len() {
let mut g = gradients[i];
// Weight decay
if self.weight_decay > 0.0 {
g += self.weight_decay * params[i];
}
// Update velocity
self.velocity[i] = self.momentum * self.velocity[i] + g;
// Update parameters
if self.nesterov {
params[i] -= self.lr * (g + self.momentum * self.velocity[i]);
} else {
params[i] -= self.lr * self.velocity[i];
}
}
}
fn reset(&mut self) {
self.velocity.fill(0.0);
}
fn learning_rate(&self) -> f32 {
self.lr
}
fn set_learning_rate(&mut self, lr: f32) {
self.lr = lr;
}
}
/// Adam optimizer with bias correction
pub struct Adam {
lr: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
weight_decay: f32,
m: Vec<f32>, // First moment
v: Vec<f32>, // Second moment
t: usize, // Timestep
}
impl Adam {
pub fn new(dim: usize, lr: f32) -> Self {
Self {
lr,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.0,
m: vec![0.0; dim],
v: vec![0.0; dim],
t: 0,
}
}
pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
self.beta1 = beta1;
self.beta2 = beta2;
self
}
pub fn with_epsilon(mut self, eps: f32) -> Self {
self.epsilon = eps;
self
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
}
impl Optimizer for Adam {
fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
if self.m.len() != params.len() {
self.m = vec![0.0; params.len()];
self.v = vec![0.0; params.len()];
}
self.t += 1;
let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
for i in 0..params.len() {
let g = gradients[i];
// Update moments
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
// Bias-corrected estimates
let m_hat = self.m[i] / bias_correction1;
let v_hat = self.v[i] / bias_correction2;
// Update with optional weight decay
let update = m_hat / (v_hat.sqrt() + self.epsilon);
params[i] -= self.lr * (update + self.weight_decay * params[i]);
}
}
fn reset(&mut self) {
self.m.fill(0.0);
self.v.fill(0.0);
self.t = 0;
}
fn learning_rate(&self) -> f32 {
self.lr
}
fn set_learning_rate(&mut self, lr: f32) {
self.lr = lr;
}
}
/// AdamW optimizer (decoupled weight decay)
pub struct AdamW {
inner: Adam,
weight_decay: f32,
}
impl AdamW {
pub fn new(dim: usize, lr: f32) -> Self {
Self {
inner: Adam::new(dim, lr),
weight_decay: 0.01,
}
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
self.inner = self.inner.with_betas(beta1, beta2);
self
}
}
impl Optimizer for AdamW {
fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
if self.inner.m.len() != params.len() {
self.inner.m = vec![0.0; params.len()];
self.inner.v = vec![0.0; params.len()];
}
self.inner.t += 1;
let bias_correction1 = 1.0 - self.inner.beta1.powi(self.inner.t as i32);
let bias_correction2 = 1.0 - self.inner.beta2.powi(self.inner.t as i32);
for i in 0..params.len() {
let g = gradients[i];
// Update moments
self.inner.m[i] = self.inner.beta1 * self.inner.m[i] + (1.0 - self.inner.beta1) * g;
self.inner.v[i] = self.inner.beta2 * self.inner.v[i] + (1.0 - self.inner.beta2) * g * g;
// Bias-corrected estimates
let m_hat = self.inner.m[i] / bias_correction1;
let v_hat = self.inner.v[i] / bias_correction2;
// Decoupled weight decay (applied to params directly, not through gradient)
params[i] *= 1.0 - self.inner.lr * self.weight_decay;
// Adam update
params[i] -= self.inner.lr * m_hat / (v_hat.sqrt() + self.inner.epsilon);
}
}
fn reset(&mut self) {
self.inner.reset();
}
fn learning_rate(&self) -> f32 {
self.inner.lr
}
fn set_learning_rate(&mut self, lr: f32) {
self.inner.lr = lr;
}
}
/// Learning rate scheduler
pub struct LearningRateScheduler {
initial_lr: f32,
warmup_steps: usize,
decay_steps: usize,
min_lr: f32,
current_step: usize,
}
impl LearningRateScheduler {
pub fn new(initial_lr: f32) -> Self {
Self {
initial_lr,
warmup_steps: 0,
decay_steps: 100000,
min_lr: 1e-7,
current_step: 0,
}
}
pub fn with_warmup(mut self, steps: usize) -> Self {
self.warmup_steps = steps;
self
}
pub fn with_decay(mut self, steps: usize) -> Self {
self.decay_steps = steps;
self
}
pub fn with_min_lr(mut self, min_lr: f32) -> Self {
self.min_lr = min_lr;
self
}
/// Get current learning rate and advance step
pub fn step(&mut self) -> f32 {
let lr = self.get_lr();
self.current_step += 1;
lr
}
/// Get learning rate without advancing
pub fn get_lr(&self) -> f32 {
if self.current_step < self.warmup_steps {
// Linear warmup
self.initial_lr * (self.current_step + 1) as f32 / self.warmup_steps as f32
} else {
// Cosine decay
let progress = (self.current_step - self.warmup_steps) as f32 / self.decay_steps as f32;
let decay = 0.5 * (1.0 + (std::f32::consts::PI * progress.min(1.0)).cos());
self.min_lr + (self.initial_lr - self.min_lr) * decay
}
}
/// Reset scheduler
pub fn reset(&mut self) {
self.current_step = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sgd() {
let mut opt = SGD::new(4, 0.1);
let mut params = vec![1.0, 2.0, 3.0, 4.0];
let gradients = vec![0.1, 0.2, 0.3, 0.4];
opt.step(&mut params, &gradients);
assert!(params[0] < 1.0);
assert!(params[1] < 2.0);
}
#[test]
fn test_sgd_momentum() {
let mut opt = SGD::new(4, 0.1).with_momentum(0.9);
let mut params = vec![1.0; 4];
let gradients = vec![1.0; 4];
// Multiple steps should accumulate momentum
for _ in 0..5 {
opt.step(&mut params, &gradients);
}
assert!(params[0] < 0.0);
}
#[test]
fn test_adam() {
let mut opt = Adam::new(64, 0.001);
let mut params = vec![0.5; 64];
let gradients = vec![0.1; 64];
for _ in 0..100 {
opt.step(&mut params, &gradients);
}
// Should have moved toward 0
assert!(params[0] < 0.5);
}
#[test]
fn test_adamw() {
let mut opt = AdamW::new(32, 0.001).with_weight_decay(0.01);
let mut params = vec![1.0; 32];
let gradients = vec![0.0; 32]; // No gradient, only weight decay
for _ in 0..100 {
opt.step(&mut params, &gradients);
}
// Weight decay should shrink params
assert!(params[0] < 1.0);
}
#[test]
fn test_lr_scheduler_warmup() {
let mut scheduler = LearningRateScheduler::new(0.001).with_warmup(100);
let lr_start = scheduler.step();
assert!(lr_start < 0.001); // Still warming up
for _ in 0..99 {
scheduler.step();
}
let lr_end_warmup = scheduler.get_lr();
assert!((lr_end_warmup - 0.001).abs() < 1e-5);
}
#[test]
fn test_lr_scheduler_decay() {
let mut scheduler = LearningRateScheduler::new(0.001)
.with_warmup(0)
.with_decay(100)
.with_min_lr(0.0001);
let lr_start = scheduler.step();
assert!((lr_start - 0.001).abs() < 1e-5);
for _ in 0..100 {
scheduler.step();
}
let lr_end = scheduler.get_lr();
assert!((lr_end - 0.0001).abs() < 1e-5);
}
}