//! Pooling strategies for combining token embeddings into sentence embeddings use crate::config::PoolingStrategy; use rayon::prelude::*; use tracing::{debug, instrument}; /// Pooler for combining token embeddings #[derive(Debug, Clone)] pub struct Pooler { strategy: PoolingStrategy, normalize: bool, } impl Pooler { /// Create a new pooler with the given strategy pub fn new(strategy: PoolingStrategy, normalize: bool) -> Self { Self { strategy, normalize } } /// Pool token embeddings into sentence embeddings /// /// # Arguments /// * `token_embeddings` - Token embeddings for each sequence [batch][seq_len * hidden] /// * `attention_mask` - Attention mask for each sequence [batch][seq_len] /// * `seq_length` - Sequence length /// * `hidden_size` - Hidden dimension size #[instrument(skip_all, fields(batch_size = token_embeddings.len(), strategy = ?self.strategy))] pub fn pool( &self, token_embeddings: &[Vec], attention_mask: &[Vec], seq_length: usize, hidden_size: usize, ) -> Vec> { debug!( "Pooling {} sequences with strategy {:?}", token_embeddings.len(), self.strategy ); let embeddings: Vec> = token_embeddings .par_iter() .zip(attention_mask.par_iter()) .map(|(tokens, mask)| { self.pool_single(tokens, mask, seq_length, hidden_size) }) .collect(); if self.normalize { embeddings .into_par_iter() .map(|emb| Self::normalize_vector(&emb)) .collect() } else { embeddings } } /// Pool a single sequence fn pool_single( &self, token_embeddings: &[f32], attention_mask: &[i64], seq_length: usize, hidden_size: usize, ) -> Vec { match self.strategy { PoolingStrategy::Mean => { self.mean_pool(token_embeddings, attention_mask, seq_length, hidden_size) } PoolingStrategy::Cls => { self.cls_pool(token_embeddings, hidden_size) } PoolingStrategy::Max => { self.max_pool(token_embeddings, attention_mask, seq_length, hidden_size) } PoolingStrategy::MeanSqrtLen => { self.mean_sqrt_len_pool(token_embeddings, attention_mask, seq_length, hidden_size) } PoolingStrategy::LastToken => { self.last_token_pool(token_embeddings, attention_mask, seq_length, hidden_size) } PoolingStrategy::WeightedMean => { self.weighted_mean_pool(token_embeddings, attention_mask, seq_length, hidden_size) } } } /// Mean pooling over all tokens (weighted by attention mask) fn mean_pool( &self, token_embeddings: &[f32], attention_mask: &[i64], seq_length: usize, hidden_size: usize, ) -> Vec { let mut result = vec![0.0f32; hidden_size]; let mut count = 0.0f32; for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) { if mask == 1 { let start = i * hidden_size; let end = start + hidden_size; for (j, val) in token_embeddings[start..end].iter().enumerate() { result[j] += val; } count += 1.0; } } if count > 0.0 { for val in &mut result { *val /= count; } } result } /// CLS token pooling (first token) fn cls_pool(&self, token_embeddings: &[f32], hidden_size: usize) -> Vec { token_embeddings[..hidden_size].to_vec() } /// Max pooling over all tokens fn max_pool( &self, token_embeddings: &[f32], attention_mask: &[i64], seq_length: usize, hidden_size: usize, ) -> Vec { let mut result = vec![f32::NEG_INFINITY; hidden_size]; for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) { if mask == 1 { let start = i * hidden_size; let end = start + hidden_size; for (j, val) in token_embeddings[start..end].iter().enumerate() { if *val > result[j] { result[j] = *val; } } } } // Replace -inf with 0 for empty sequences for val in &mut result { if val.is_infinite() { *val = 0.0; } } result } /// Mean pooling with sqrt(length) scaling fn mean_sqrt_len_pool( &self, token_embeddings: &[f32], attention_mask: &[i64], seq_length: usize, hidden_size: usize, ) -> Vec { let mut result = self.mean_pool(token_embeddings, attention_mask, seq_length, hidden_size); let length: f32 = attention_mask.iter().filter(|&&m| m == 1).count() as f32; if length > 0.0 { let scale = length.sqrt(); for val in &mut result { *val *= scale; } } result } /// Last token pooling (for decoder models) fn last_token_pool( &self, token_embeddings: &[f32], attention_mask: &[i64], _seq_length: usize, hidden_size: usize, ) -> Vec { // Find last non-padding token let last_idx = attention_mask .iter() .rposition(|&m| m == 1) .unwrap_or(0); let start = last_idx * hidden_size; let end = start + hidden_size; if end <= token_embeddings.len() { token_embeddings[start..end].to_vec() } else { self.cls_pool(token_embeddings, hidden_size) } } /// Weighted mean pooling based on position fn weighted_mean_pool( &self, token_embeddings: &[f32], attention_mask: &[i64], seq_length: usize, hidden_size: usize, ) -> Vec { let mut result = vec![0.0f32; hidden_size]; let mut total_weight = 0.0f32; for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) { if mask == 1 { // Weight decreases with position (more weight to early tokens) let weight = 1.0 / (i + 1) as f32; let start = i * hidden_size; let end = start + hidden_size; for (j, val) in token_embeddings[start..end].iter().enumerate() { result[j] += val * weight; } total_weight += weight; } } if total_weight > 0.0 { for val in &mut result { *val /= total_weight; } } result } /// L2 normalize a vector pub fn normalize_vector(vec: &[f32]) -> Vec { let norm: f32 = vec.iter().map(|x| x * x).sum::().sqrt(); if norm > 1e-12 { vec.iter().map(|x| x / norm).collect() } else { vec.to_vec() } } /// Compute cosine similarity between two vectors (SIMD-optimized) #[cfg(feature = "simsimd")] pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { use simsimd::SpatialSimilarity; f32::cosine(a, b).unwrap_or(0.0) as f32 } #[cfg(not(feature = "simsimd"))] pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); let norm_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); let norm_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); if norm_a > 1e-12 && norm_b > 1e-12 { dot / (norm_a * norm_b) } else { 0.0 } } /// Compute dot product between two vectors (SIMD-optimized) #[cfg(feature = "simsimd")] pub fn dot_product(a: &[f32], b: &[f32]) -> f32 { use simsimd::SpatialSimilarity; f32::dot(a, b).unwrap_or(0.0) as f32 } #[cfg(not(feature = "simsimd"))] pub fn dot_product(a: &[f32], b: &[f32]) -> f32 { a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() } /// Compute Euclidean distance between two vectors (SIMD-optimized) #[cfg(feature = "simsimd")] pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 { use simsimd::SpatialSimilarity; (f32::sqeuclidean(a, b).unwrap_or(0.0) as f32).sqrt() } #[cfg(not(feature = "simsimd"))] pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 { a.iter() .zip(b.iter()) .map(|(x, y)| (x - y).powi(2)) .sum::() .sqrt() } } impl Default for Pooler { fn default() -> Self { Self::new(PoolingStrategy::Mean, true) } } /// Batch distance computation using ndarray pub fn batch_cosine_similarity( query: &[f32], candidates: &[Vec], ) -> Vec { candidates .par_iter() .map(|c| Pooler::cosine_similarity(query, c)) .collect() } /// Find top-k most similar vectors pub fn top_k_similar( query: &[f32], candidates: &[Vec], k: usize, ) -> Vec<(usize, f32)> { let mut scores: Vec<(usize, f32)> = candidates .par_iter() .enumerate() .map(|(i, c)| (i, Pooler::cosine_similarity(query, c))) .collect(); // Sort by score descending scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); scores.truncate(k); scores } #[cfg(test)] mod tests { use super::*; #[test] fn test_normalize_vector() { let vec = vec![3.0, 4.0]; let normalized = Pooler::normalize_vector(&vec); let norm: f32 = normalized.iter().map(|x| x * x).sum::().sqrt(); assert!((norm - 1.0).abs() < 1e-6); } #[test] fn test_cosine_similarity() { let a = vec![1.0, 0.0, 0.0]; let b = vec![1.0, 0.0, 0.0]; let c = vec![0.0, 1.0, 0.0]; assert!((Pooler::cosine_similarity(&a, &b) - 1.0).abs() < 1e-6); assert!((Pooler::cosine_similarity(&a, &c)).abs() < 1e-6); } #[test] fn test_mean_pooling() { let pooler = Pooler::new(PoolingStrategy::Mean, false); // 2 tokens, 3 dimensions let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let mask = vec![1i64, 1]; let result = pooler.pool_single(&embeddings, &mask, 2, 3); assert_eq!(result.len(), 3); assert!((result[0] - 2.5).abs() < 1e-6); assert!((result[1] - 3.5).abs() < 1e-6); assert!((result[2] - 4.5).abs() < 1e-6); } #[test] fn test_cls_pooling() { let pooler = Pooler::new(PoolingStrategy::Cls, false); let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let mask = vec![1i64, 1]; let result = pooler.pool_single(&embeddings, &mask, 2, 3); assert_eq!(result, vec![1.0, 2.0, 3.0]); } #[test] fn test_top_k_similar() { let query = vec![1.0, 0.0, 0.0]; let candidates = vec![ vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.707, 0.707, 0.0], ]; let results = top_k_similar(&query, &candidates, 2); assert_eq!(results.len(), 2); assert_eq!(results[0].0, 0); // Most similar } }