398 lines
11 KiB
Rust
398 lines
11 KiB
Rust
//! Pooling strategies for combining token embeddings into sentence embeddings
|
|
|
|
use crate::config::PoolingStrategy;
|
|
use rayon::prelude::*;
|
|
use tracing::{debug, instrument};
|
|
|
|
/// Pooler for combining token embeddings
|
|
#[derive(Debug, Clone)]
|
|
pub struct Pooler {
|
|
strategy: PoolingStrategy,
|
|
normalize: bool,
|
|
}
|
|
|
|
impl Pooler {
|
|
/// Create a new pooler with the given strategy
|
|
pub fn new(strategy: PoolingStrategy, normalize: bool) -> Self {
|
|
Self { strategy, normalize }
|
|
}
|
|
|
|
/// Pool token embeddings into sentence embeddings
|
|
///
|
|
/// # Arguments
|
|
/// * `token_embeddings` - Token embeddings for each sequence [batch][seq_len * hidden]
|
|
/// * `attention_mask` - Attention mask for each sequence [batch][seq_len]
|
|
/// * `seq_length` - Sequence length
|
|
/// * `hidden_size` - Hidden dimension size
|
|
#[instrument(skip_all, fields(batch_size = token_embeddings.len(), strategy = ?self.strategy))]
|
|
pub fn pool(
|
|
&self,
|
|
token_embeddings: &[Vec<f32>],
|
|
attention_mask: &[Vec<i64>],
|
|
seq_length: usize,
|
|
hidden_size: usize,
|
|
) -> Vec<Vec<f32>> {
|
|
debug!(
|
|
"Pooling {} sequences with strategy {:?}",
|
|
token_embeddings.len(),
|
|
self.strategy
|
|
);
|
|
|
|
let embeddings: Vec<Vec<f32>> = token_embeddings
|
|
.par_iter()
|
|
.zip(attention_mask.par_iter())
|
|
.map(|(tokens, mask)| {
|
|
self.pool_single(tokens, mask, seq_length, hidden_size)
|
|
})
|
|
.collect();
|
|
|
|
if self.normalize {
|
|
embeddings
|
|
.into_par_iter()
|
|
.map(|emb| Self::normalize_vector(&emb))
|
|
.collect()
|
|
} else {
|
|
embeddings
|
|
}
|
|
}
|
|
|
|
/// Pool a single sequence
|
|
fn pool_single(
|
|
&self,
|
|
token_embeddings: &[f32],
|
|
attention_mask: &[i64],
|
|
seq_length: usize,
|
|
hidden_size: usize,
|
|
) -> Vec<f32> {
|
|
match self.strategy {
|
|
PoolingStrategy::Mean => {
|
|
self.mean_pool(token_embeddings, attention_mask, seq_length, hidden_size)
|
|
}
|
|
PoolingStrategy::Cls => {
|
|
self.cls_pool(token_embeddings, hidden_size)
|
|
}
|
|
PoolingStrategy::Max => {
|
|
self.max_pool(token_embeddings, attention_mask, seq_length, hidden_size)
|
|
}
|
|
PoolingStrategy::MeanSqrtLen => {
|
|
self.mean_sqrt_len_pool(token_embeddings, attention_mask, seq_length, hidden_size)
|
|
}
|
|
PoolingStrategy::LastToken => {
|
|
self.last_token_pool(token_embeddings, attention_mask, seq_length, hidden_size)
|
|
}
|
|
PoolingStrategy::WeightedMean => {
|
|
self.weighted_mean_pool(token_embeddings, attention_mask, seq_length, hidden_size)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Mean pooling over all tokens (weighted by attention mask)
|
|
fn mean_pool(
|
|
&self,
|
|
token_embeddings: &[f32],
|
|
attention_mask: &[i64],
|
|
seq_length: usize,
|
|
hidden_size: usize,
|
|
) -> Vec<f32> {
|
|
let mut result = vec![0.0f32; hidden_size];
|
|
let mut count = 0.0f32;
|
|
|
|
for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) {
|
|
if mask == 1 {
|
|
let start = i * hidden_size;
|
|
let end = start + hidden_size;
|
|
for (j, val) in token_embeddings[start..end].iter().enumerate() {
|
|
result[j] += val;
|
|
}
|
|
count += 1.0;
|
|
}
|
|
}
|
|
|
|
if count > 0.0 {
|
|
for val in &mut result {
|
|
*val /= count;
|
|
}
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
/// CLS token pooling (first token)
|
|
fn cls_pool(&self, token_embeddings: &[f32], hidden_size: usize) -> Vec<f32> {
|
|
token_embeddings[..hidden_size].to_vec()
|
|
}
|
|
|
|
/// Max pooling over all tokens
|
|
fn max_pool(
|
|
&self,
|
|
token_embeddings: &[f32],
|
|
attention_mask: &[i64],
|
|
seq_length: usize,
|
|
hidden_size: usize,
|
|
) -> Vec<f32> {
|
|
let mut result = vec![f32::NEG_INFINITY; hidden_size];
|
|
|
|
for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) {
|
|
if mask == 1 {
|
|
let start = i * hidden_size;
|
|
let end = start + hidden_size;
|
|
for (j, val) in token_embeddings[start..end].iter().enumerate() {
|
|
if *val > result[j] {
|
|
result[j] = *val;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Replace -inf with 0 for empty sequences
|
|
for val in &mut result {
|
|
if val.is_infinite() {
|
|
*val = 0.0;
|
|
}
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
/// Mean pooling with sqrt(length) scaling
|
|
fn mean_sqrt_len_pool(
|
|
&self,
|
|
token_embeddings: &[f32],
|
|
attention_mask: &[i64],
|
|
seq_length: usize,
|
|
hidden_size: usize,
|
|
) -> Vec<f32> {
|
|
let mut result = self.mean_pool(token_embeddings, attention_mask, seq_length, hidden_size);
|
|
let length: f32 = attention_mask.iter().filter(|&&m| m == 1).count() as f32;
|
|
|
|
if length > 0.0 {
|
|
let scale = length.sqrt();
|
|
for val in &mut result {
|
|
*val *= scale;
|
|
}
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
/// Last token pooling (for decoder models)
|
|
fn last_token_pool(
|
|
&self,
|
|
token_embeddings: &[f32],
|
|
attention_mask: &[i64],
|
|
_seq_length: usize,
|
|
hidden_size: usize,
|
|
) -> Vec<f32> {
|
|
// Find last non-padding token
|
|
let last_idx = attention_mask
|
|
.iter()
|
|
.rposition(|&m| m == 1)
|
|
.unwrap_or(0);
|
|
|
|
let start = last_idx * hidden_size;
|
|
let end = start + hidden_size;
|
|
|
|
if end <= token_embeddings.len() {
|
|
token_embeddings[start..end].to_vec()
|
|
} else {
|
|
self.cls_pool(token_embeddings, hidden_size)
|
|
}
|
|
}
|
|
|
|
/// Weighted mean pooling based on position
|
|
fn weighted_mean_pool(
|
|
&self,
|
|
token_embeddings: &[f32],
|
|
attention_mask: &[i64],
|
|
seq_length: usize,
|
|
hidden_size: usize,
|
|
) -> Vec<f32> {
|
|
let mut result = vec![0.0f32; hidden_size];
|
|
let mut total_weight = 0.0f32;
|
|
|
|
for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) {
|
|
if mask == 1 {
|
|
// Weight decreases with position (more weight to early tokens)
|
|
let weight = 1.0 / (i + 1) as f32;
|
|
let start = i * hidden_size;
|
|
let end = start + hidden_size;
|
|
|
|
for (j, val) in token_embeddings[start..end].iter().enumerate() {
|
|
result[j] += val * weight;
|
|
}
|
|
total_weight += weight;
|
|
}
|
|
}
|
|
|
|
if total_weight > 0.0 {
|
|
for val in &mut result {
|
|
*val /= total_weight;
|
|
}
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
/// L2 normalize a vector
|
|
pub fn normalize_vector(vec: &[f32]) -> Vec<f32> {
|
|
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
|
|
if norm > 1e-12 {
|
|
vec.iter().map(|x| x / norm).collect()
|
|
} else {
|
|
vec.to_vec()
|
|
}
|
|
}
|
|
|
|
/// Compute cosine similarity between two vectors (SIMD-optimized)
|
|
#[cfg(feature = "simsimd")]
|
|
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
|
use simsimd::SpatialSimilarity;
|
|
f32::cosine(a, b).unwrap_or(0.0) as f32
|
|
}
|
|
|
|
#[cfg(not(feature = "simsimd"))]
|
|
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
|
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
|
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
|
|
if norm_a > 1e-12 && norm_b > 1e-12 {
|
|
dot / (norm_a * norm_b)
|
|
} else {
|
|
0.0
|
|
}
|
|
}
|
|
|
|
/// Compute dot product between two vectors (SIMD-optimized)
|
|
#[cfg(feature = "simsimd")]
|
|
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
|
|
use simsimd::SpatialSimilarity;
|
|
f32::dot(a, b).unwrap_or(0.0) as f32
|
|
}
|
|
|
|
#[cfg(not(feature = "simsimd"))]
|
|
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
|
|
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
|
|
}
|
|
|
|
/// Compute Euclidean distance between two vectors (SIMD-optimized)
|
|
#[cfg(feature = "simsimd")]
|
|
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
use simsimd::SpatialSimilarity;
|
|
(f32::sqeuclidean(a, b).unwrap_or(0.0) as f32).sqrt()
|
|
}
|
|
|
|
#[cfg(not(feature = "simsimd"))]
|
|
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
a.iter()
|
|
.zip(b.iter())
|
|
.map(|(x, y)| (x - y).powi(2))
|
|
.sum::<f32>()
|
|
.sqrt()
|
|
}
|
|
}
|
|
|
|
impl Default for Pooler {
|
|
fn default() -> Self {
|
|
Self::new(PoolingStrategy::Mean, true)
|
|
}
|
|
}
|
|
|
|
/// Batch distance computation using ndarray
|
|
pub fn batch_cosine_similarity(
|
|
query: &[f32],
|
|
candidates: &[Vec<f32>],
|
|
) -> Vec<f32> {
|
|
candidates
|
|
.par_iter()
|
|
.map(|c| Pooler::cosine_similarity(query, c))
|
|
.collect()
|
|
}
|
|
|
|
/// Find top-k most similar vectors
|
|
pub fn top_k_similar(
|
|
query: &[f32],
|
|
candidates: &[Vec<f32>],
|
|
k: usize,
|
|
) -> Vec<(usize, f32)> {
|
|
let mut scores: Vec<(usize, f32)> = candidates
|
|
.par_iter()
|
|
.enumerate()
|
|
.map(|(i, c)| (i, Pooler::cosine_similarity(query, c)))
|
|
.collect();
|
|
|
|
// Sort by score descending
|
|
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
scores.truncate(k);
|
|
scores
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_normalize_vector() {
|
|
let vec = vec![3.0, 4.0];
|
|
let normalized = Pooler::normalize_vector(&vec);
|
|
|
|
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
assert!((norm - 1.0).abs() < 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_similarity() {
|
|
let a = vec![1.0, 0.0, 0.0];
|
|
let b = vec![1.0, 0.0, 0.0];
|
|
let c = vec![0.0, 1.0, 0.0];
|
|
|
|
assert!((Pooler::cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
|
|
assert!((Pooler::cosine_similarity(&a, &c)).abs() < 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_mean_pooling() {
|
|
let pooler = Pooler::new(PoolingStrategy::Mean, false);
|
|
|
|
// 2 tokens, 3 dimensions
|
|
let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
let mask = vec![1i64, 1];
|
|
|
|
let result = pooler.pool_single(&embeddings, &mask, 2, 3);
|
|
|
|
assert_eq!(result.len(), 3);
|
|
assert!((result[0] - 2.5).abs() < 1e-6);
|
|
assert!((result[1] - 3.5).abs() < 1e-6);
|
|
assert!((result[2] - 4.5).abs() < 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cls_pooling() {
|
|
let pooler = Pooler::new(PoolingStrategy::Cls, false);
|
|
|
|
let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
let mask = vec![1i64, 1];
|
|
|
|
let result = pooler.pool_single(&embeddings, &mask, 2, 3);
|
|
|
|
assert_eq!(result, vec![1.0, 2.0, 3.0]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_top_k_similar() {
|
|
let query = vec![1.0, 0.0, 0.0];
|
|
let candidates = vec![
|
|
vec![1.0, 0.0, 0.0],
|
|
vec![0.0, 1.0, 0.0],
|
|
vec![0.707, 0.707, 0.0],
|
|
];
|
|
|
|
let results = top_k_similar(&query, &candidates, 2);
|
|
|
|
assert_eq!(results.len(), 2);
|
|
assert_eq!(results[0].0, 0); // Most similar
|
|
}
|
|
}
|