Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,293 @@
//! GPU Configuration for RuVector ONNX Embeddings
//!
//! Provides configuration options for GPU acceleration including
//! device selection, memory limits, and performance tuning.
use serde::{Deserialize, Serialize};
/// GPU execution mode
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum GpuMode {
/// Automatically select best available backend
#[default]
Auto,
/// Force WebGPU backend
WebGpu,
/// Force CUDA-WASM transpiled backend
CudaWasm,
/// CPU-only (disable GPU)
CpuOnly,
}
/// Power preference for GPU device selection
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum PowerPreference {
/// Prefer low power consumption (integrated GPU)
LowPower,
/// Prefer high performance (discrete GPU)
#[default]
HighPerformance,
/// No preference
None,
}
/// GPU acceleration configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuConfig {
/// GPU execution mode
pub mode: GpuMode,
/// Power preference for device selection
pub power_preference: PowerPreference,
/// Maximum GPU memory usage (bytes, 0 = unlimited)
pub max_memory: u64,
/// Workgroup size for compute shaders (0 = auto)
pub workgroup_size: u32,
/// Enable async GPU operations
pub async_compute: bool,
/// Minimum batch size to use GPU (smaller batches use CPU)
pub min_batch_size: usize,
/// Minimum vector dimension to use GPU
pub min_dimension: usize,
/// Enable shader caching
pub cache_shaders: bool,
/// Enable profiling and timing
pub enable_profiling: bool,
/// Fallback to CPU on GPU error
pub fallback_to_cpu: bool,
/// Device index (for multi-GPU systems)
pub device_index: u32,
}
impl Default for GpuConfig {
fn default() -> Self {
Self {
mode: GpuMode::Auto,
power_preference: PowerPreference::HighPerformance,
max_memory: 0, // unlimited
workgroup_size: 256,
async_compute: true,
min_batch_size: 16,
min_dimension: 128,
cache_shaders: true,
enable_profiling: false,
fallback_to_cpu: true,
device_index: 0,
}
}
}
impl GpuConfig {
/// Create configuration with automatic settings
pub fn auto() -> Self {
Self::default()
}
/// Create configuration for high performance
pub fn high_performance() -> Self {
Self {
mode: GpuMode::Auto,
power_preference: PowerPreference::HighPerformance,
workgroup_size: 512,
async_compute: true,
min_batch_size: 8,
min_dimension: 64,
..Default::default()
}
}
/// Create configuration for low power usage
pub fn low_power() -> Self {
Self {
mode: GpuMode::Auto,
power_preference: PowerPreference::LowPower,
workgroup_size: 128,
async_compute: false,
min_batch_size: 32,
min_dimension: 256,
..Default::default()
}
}
/// Create CPU-only configuration
pub fn cpu_only() -> Self {
Self {
mode: GpuMode::CpuOnly,
..Default::default()
}
}
/// Create WebGPU-specific configuration
pub fn webgpu() -> Self {
Self {
mode: GpuMode::WebGpu,
..Default::default()
}
}
/// Create CUDA-WASM specific configuration
#[cfg(feature = "cuda-wasm")]
pub fn cuda_wasm() -> Self {
Self {
mode: GpuMode::CudaWasm,
workgroup_size: 256,
..Default::default()
}
}
// Builder methods
/// Set GPU mode
pub fn with_mode(mut self, mode: GpuMode) -> Self {
self.mode = mode;
self
}
/// Set power preference
pub fn with_power_preference(mut self, pref: PowerPreference) -> Self {
self.power_preference = pref;
self
}
/// Set maximum memory
pub fn with_max_memory(mut self, bytes: u64) -> Self {
self.max_memory = bytes;
self
}
/// Set workgroup size
pub fn with_workgroup_size(mut self, size: u32) -> Self {
self.workgroup_size = size;
self
}
/// Set minimum batch size for GPU usage
pub fn with_min_batch_size(mut self, size: usize) -> Self {
self.min_batch_size = size;
self
}
/// Set minimum dimension for GPU usage
pub fn with_min_dimension(mut self, dim: usize) -> Self {
self.min_dimension = dim;
self
}
/// Enable or disable profiling
pub fn with_profiling(mut self, enable: bool) -> Self {
self.enable_profiling = enable;
self
}
/// Enable or disable CPU fallback
pub fn with_fallback(mut self, enable: bool) -> Self {
self.fallback_to_cpu = enable;
self
}
/// Set device index
pub fn with_device(mut self, index: u32) -> Self {
self.device_index = index;
self
}
/// Check if GPU should be used for given workload
pub fn should_use_gpu(&self, batch_size: usize, dimension: usize) -> bool {
self.mode != GpuMode::CpuOnly
&& batch_size >= self.min_batch_size
&& dimension >= self.min_dimension
}
}
/// GPU memory statistics
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GpuMemoryStats {
/// Total GPU memory (bytes)
pub total: u64,
/// Used GPU memory (bytes)
pub used: u64,
/// Free GPU memory (bytes)
pub free: u64,
/// Peak usage (bytes)
pub peak: u64,
}
impl GpuMemoryStats {
/// Get usage percentage
pub fn usage_percent(&self) -> f32 {
if self.total > 0 {
(self.used as f32 / self.total as f32) * 100.0
} else {
0.0
}
}
}
/// GPU profiling data
#[allow(dead_code)]
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GpuProfilingData {
/// Total operations executed
pub operations: u64,
/// Total GPU time (microseconds)
pub gpu_time_us: u64,
/// Total CPU time (microseconds)
pub cpu_time_us: u64,
/// GPU speedup over CPU
pub speedup: f32,
/// Memory transfers (bytes)
pub memory_transferred: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = GpuConfig::default();
assert_eq!(config.mode, GpuMode::Auto);
assert_eq!(config.power_preference, PowerPreference::HighPerformance);
assert!(config.fallback_to_cpu);
}
#[test]
fn test_should_use_gpu() {
let config = GpuConfig::default()
.with_min_batch_size(16)
.with_min_dimension(128);
assert!(!config.should_use_gpu(8, 384)); // batch too small
assert!(!config.should_use_gpu(32, 64)); // dimension too small
assert!(config.should_use_gpu(32, 384)); // both ok
}
#[test]
fn test_cpu_only() {
let config = GpuConfig::cpu_only();
assert!(!config.should_use_gpu(1000, 1000));
}
#[test]
fn test_builder() {
let config = GpuConfig::auto()
.with_mode(GpuMode::WebGpu)
.with_max_memory(1024 * 1024 * 1024)
.with_workgroup_size(512)
.with_profiling(true);
assert_eq!(config.mode, GpuMode::WebGpu);
assert_eq!(config.max_memory, 1024 * 1024 * 1024);
assert_eq!(config.workgroup_size, 512);
assert!(config.enable_profiling);
}
}

View File

@@ -0,0 +1,298 @@
//! GPU Acceleration Module for RuVector ONNX Embeddings
//!
//! This module provides optional GPU acceleration using cuda-wasm for:
//! - Pooling operations
//! - Similarity computations
//! - Batch vector operations
//!
//! ## Architecture
//!
//! ```text
//! ┌─────────────────────────────────────────────────────────────────┐
//! │ GPU Acceleration Layer │
//! ├─────────────────────────────────────────────────────────────────┤
//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
//! │ │ GpuBackend │ -> │ Shaders │ -> │ WebGPU Runtime │ │
//! │ │ (Trait) │ │ (WGSL) │ │ (wgpu) │ │
//! │ └─────────────┘ └─────────────┘ └─────────────────────┘ │
//! │ │ │ │
//! │ v v │
//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
//! │ │ GpuPooler │ │ GpuSimilar │ │ GpuVectorOps │ │
//! │ │ │ │ │ │ │ │
//! │ └─────────────┘ └─────────────┘ └─────────────────────┘ │
//! └─────────────────────────────────────────────────────────────────┘
//! ```
//!
//! ## Feature Flags
//!
//! - `gpu`: Enable GPU acceleration (WebGPU backend)
//! - `cuda-wasm`: Enable CUDA-WASM transpilation support
//!
//! ## Usage
//!
//! ```rust,ignore
//! use ruvector_onnx_embeddings::gpu::{GpuAccelerator, GpuConfig};
//!
//! // Create GPU accelerator with auto-detection
//! let gpu = GpuAccelerator::new(GpuConfig::auto()).await?;
//!
//! // GPU-accelerated similarity search
//! let similarities = gpu.batch_cosine_similarity(&query, &candidates)?;
//!
//! // GPU-accelerated pooling
//! let pooled = gpu.mean_pool(&token_embeddings, &attention_mask)?;
//! ```
mod backend;
mod config;
mod operations;
mod shaders;
#[cfg(test)]
mod tests;
pub use backend::{GpuBackend, GpuDevice, GpuInfo};
pub use config::{GpuConfig, GpuMode, PowerPreference};
pub use operations::{
GpuPooler, GpuSimilarity, GpuVectorOps,
batch_cosine_similarity_gpu, batch_dot_product_gpu, batch_euclidean_gpu,
};
pub use shaders::ShaderRegistry;
use crate::Result;
use std::sync::Arc;
/// GPU Accelerator - Main entry point for GPU operations
///
/// Provides unified access to GPU-accelerated operations with automatic
/// fallback to CPU when GPU is unavailable.
pub struct GpuAccelerator {
backend: Arc<dyn GpuBackend>,
config: GpuConfig,
pooler: GpuPooler,
similarity: GpuSimilarity,
vector_ops: GpuVectorOps,
}
impl GpuAccelerator {
/// Create a new GPU accelerator with the given configuration
pub async fn new(config: GpuConfig) -> Result<Self> {
let backend: Arc<dyn GpuBackend> = Arc::from(backend::create_backend(&config).await?);
let shader_registry = ShaderRegistry::new();
let mut pooler = GpuPooler::new(backend.as_ref(), &shader_registry)?;
let mut similarity = GpuSimilarity::new(backend.as_ref(), &shader_registry)?;
let mut vector_ops = GpuVectorOps::new(backend.as_ref(), &shader_registry)?;
// Wire up the backend to all components for GPU dispatch
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
{
pooler.set_backend(Arc::clone(&backend));
similarity.set_backend(Arc::clone(&backend));
vector_ops.set_backend(Arc::clone(&backend));
}
Ok(Self {
backend,
config,
pooler,
similarity,
vector_ops,
})
}
/// Create with automatic configuration
pub async fn auto() -> Result<Self> {
Self::new(GpuConfig::auto()).await
}
/// Check if GPU acceleration is available
pub fn is_available(&self) -> bool {
self.backend.is_available()
}
/// Get GPU device information
pub fn device_info(&self) -> GpuInfo {
self.backend.device_info()
}
/// Get the current configuration
pub fn config(&self) -> &GpuConfig {
&self.config
}
// ==================== Pooling Operations ====================
/// Mean pooling over token embeddings (GPU-accelerated)
pub fn mean_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
self.pooler.mean_pool(
token_embeddings,
attention_mask,
batch_size,
seq_length,
hidden_size,
)
}
/// CLS token pooling (GPU-accelerated)
pub fn cls_pool(
&self,
token_embeddings: &[f32],
batch_size: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
self.pooler.cls_pool(token_embeddings, batch_size, hidden_size)
}
/// Max pooling over token embeddings (GPU-accelerated)
pub fn max_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
self.pooler.max_pool(
token_embeddings,
attention_mask,
batch_size,
seq_length,
hidden_size,
)
}
// ==================== Similarity Operations ====================
/// Batch cosine similarity (GPU-accelerated)
pub fn batch_cosine_similarity(
&self,
query: &[f32],
candidates: &[&[f32]],
) -> Result<Vec<f32>> {
self.similarity.batch_cosine(query, candidates)
}
/// Batch dot product (GPU-accelerated)
pub fn batch_dot_product(
&self,
query: &[f32],
candidates: &[&[f32]],
) -> Result<Vec<f32>> {
self.similarity.batch_dot_product(query, candidates)
}
/// Batch Euclidean distance (GPU-accelerated)
pub fn batch_euclidean_distance(
&self,
query: &[f32],
candidates: &[&[f32]],
) -> Result<Vec<f32>> {
self.similarity.batch_euclidean(query, candidates)
}
/// Find top-k most similar vectors (GPU-accelerated)
pub fn top_k_similar(
&self,
query: &[f32],
candidates: &[&[f32]],
k: usize,
) -> Result<Vec<(usize, f32)>> {
self.similarity.top_k(query, candidates, k)
}
// ==================== Vector Operations ====================
/// L2 normalize vectors (GPU-accelerated)
pub fn normalize_batch(&self, vectors: &mut [f32], dimension: usize) -> Result<()> {
self.vector_ops.normalize_batch(vectors, dimension)
}
/// Matrix-vector multiplication (GPU-accelerated)
pub fn matmul(
&self,
matrix: &[f32],
vector: &[f32],
rows: usize,
cols: usize,
) -> Result<Vec<f32>> {
self.vector_ops.matmul(matrix, vector, rows, cols)
}
/// Batch vector addition (GPU-accelerated)
pub fn batch_add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
self.vector_ops.batch_add(a, b)
}
/// Batch vector scaling (GPU-accelerated)
pub fn batch_scale(&self, vectors: &mut [f32], scale: f32) -> Result<()> {
self.vector_ops.batch_scale(vectors, scale)
}
}
/// Convenience function to check GPU availability without creating accelerator
pub async fn is_gpu_available() -> bool {
backend::probe_gpu().await
}
/// Get GPU device info without full initialization
pub async fn get_gpu_info() -> Option<GpuInfo> {
backend::get_device_info().await
}
/// Fallback wrapper that tries GPU first, then CPU
pub struct HybridAccelerator {
gpu: Option<GpuAccelerator>,
use_gpu: bool,
}
impl HybridAccelerator {
/// Create hybrid accelerator with GPU if available
pub async fn new() -> Self {
let gpu = GpuAccelerator::auto().await.ok();
let use_gpu = gpu.is_some();
Self { gpu, use_gpu }
}
/// Check if GPU is being used
pub fn using_gpu(&self) -> bool {
self.use_gpu && self.gpu.is_some()
}
/// Disable GPU (use CPU only)
pub fn disable_gpu(&mut self) {
self.use_gpu = false;
}
/// Enable GPU if available
pub fn enable_gpu(&mut self) {
self.use_gpu = self.gpu.is_some();
}
/// Batch cosine similarity with automatic backend selection
pub fn batch_cosine_similarity(
&self,
query: &[f32],
candidates: &[Vec<f32>],
) -> Vec<f32> {
if self.use_gpu {
if let Some(ref gpu) = self.gpu {
let refs: Vec<&[f32]> = candidates.iter().map(|v| v.as_slice()).collect();
if let Ok(result) = gpu.batch_cosine_similarity(query, &refs) {
return result;
}
}
}
// CPU fallback
crate::pooling::batch_cosine_similarity(query, candidates)
}
}

View File

@@ -0,0 +1,934 @@
//! GPU-Accelerated Operations
//!
//! High-level GPU operations for embeddings with automatic fallback to CPU.
use crate::{EmbeddingError, Result};
use super::backend::{GpuBackend, BufferUsage};
use super::shaders::ShaderRegistry;
use rayon::prelude::*;
use std::sync::Arc;
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
use bytemuck;
// ==================== GPU Pooler ====================
/// GPU-accelerated pooling operations
pub struct GpuPooler {
use_gpu: bool,
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: Option<Arc<dyn GpuBackend>>,
}
impl GpuPooler {
/// Create new GPU pooler
pub fn new(backend: &dyn GpuBackend, _shaders: &ShaderRegistry) -> Result<Self> {
let use_gpu = backend.is_available() && backend.device_info().supports_compute;
Ok(Self {
use_gpu,
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: None, // Will be set by GpuAccelerator
})
}
/// Set the backend for GPU operations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
pub fn set_backend(&mut self, backend: Arc<dyn GpuBackend>) {
self.backend = Some(backend);
}
/// Mean pooling (GPU or CPU fallback)
pub fn mean_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
// GPU implementation requires minimum batch size for efficiency
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && batch_size >= 8 && self.backend.is_some() {
return self.mean_pool_gpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size);
}
Ok(self.mean_pool_cpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size))
}
/// CLS pooling (GPU or CPU fallback)
pub fn cls_pool(
&self,
token_embeddings: &[f32],
batch_size: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
// CLS pooling is simple copy, CPU is often faster
Ok(self.cls_pool_cpu(token_embeddings, batch_size, hidden_size))
}
/// Max pooling (GPU or CPU fallback)
pub fn max_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && batch_size >= 8 && self.backend.is_some() {
return self.max_pool_gpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size);
}
Ok(self.max_pool_cpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size))
}
// GPU implementations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn mean_pool_gpu(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "mean_pool".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
// Create buffers
let token_buf = backend.create_buffer(
(token_embeddings.len() * 4) as u64,
BufferUsage::Storage,
)?;
let mask_buf = backend.create_buffer(
(attention_mask.len() * 8) as u64,
BufferUsage::Storage,
)?;
let output_buf = backend.create_buffer(
(batch_size * hidden_size * 4) as u64,
BufferUsage::Storage,
)?;
// Create params buffer (batch_size, seq_length, hidden_size)
let params: [u32; 3] = [batch_size as u32, seq_length as u32, hidden_size as u32];
let params_buf = backend.create_buffer(16, BufferUsage::Uniform)?; // 16 bytes aligned
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&token_buf, bytemuck::cast_slice(token_embeddings))?;
backend.write_buffer(&mask_buf, bytemuck::cast_slice(attention_mask))?;
// Create pipeline with mean pool shader
let shader = super::shaders::MEAN_POOL_SHADER;
let pipeline = backend.create_pipeline(shader, "mean_pool", [64, 1, 1])?;
// Dispatch with params buffer as 4th binding
let total_outputs = batch_size * hidden_size;
let workgroups = [total_outputs.div_ceil(64) as u32, 1, 1];
backend.dispatch(&pipeline, &[&token_buf, &mask_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (batch_size * hidden_size * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(token_buf)?;
backend.release_buffer(mask_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn max_pool_gpu(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "max_pool".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
// Create buffers
let token_buf = backend.create_buffer(
(token_embeddings.len() * 4) as u64,
BufferUsage::Storage,
)?;
let mask_buf = backend.create_buffer(
(attention_mask.len() * 8) as u64,
BufferUsage::Storage,
)?;
let output_buf = backend.create_buffer(
(batch_size * hidden_size * 4) as u64,
BufferUsage::Storage,
)?;
// Create params buffer (batch_size, seq_length, hidden_size)
let params: [u32; 3] = [batch_size as u32, seq_length as u32, hidden_size as u32];
let params_buf = backend.create_buffer(16, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&token_buf, bytemuck::cast_slice(token_embeddings))?;
backend.write_buffer(&mask_buf, bytemuck::cast_slice(attention_mask))?;
// Create pipeline with max pool shader
let shader = super::shaders::MAX_POOL_SHADER;
let pipeline = backend.create_pipeline(shader, "max_pool", [64, 1, 1])?;
// Dispatch with params buffer as 4th binding
let total_outputs = batch_size * hidden_size;
let workgroups = [total_outputs.div_ceil(64) as u32, 1, 1];
backend.dispatch(&pipeline, &[&token_buf, &mask_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (batch_size * hidden_size * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(token_buf)?;
backend.release_buffer(mask_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
// CPU implementations
fn mean_pool_cpu(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Vec<f32> {
let mut output = vec![0.0f32; batch_size * hidden_size];
output
.par_chunks_mut(hidden_size)
.enumerate()
.for_each(|(batch_idx, out_chunk)| {
let tokens_base = batch_idx * seq_length * hidden_size;
let mask_base = batch_idx * seq_length;
let mut count = 0.0f32;
for seq_idx in 0..seq_length {
if attention_mask[mask_base + seq_idx] == 1 {
let start = tokens_base + seq_idx * hidden_size;
for (j, out_val) in out_chunk.iter_mut().enumerate() {
*out_val += token_embeddings[start + j];
}
count += 1.0;
}
}
if count > 0.0 {
for val in out_chunk.iter_mut() {
*val /= count;
}
}
});
output
}
fn cls_pool_cpu(
&self,
token_embeddings: &[f32],
batch_size: usize,
hidden_size: usize,
) -> Vec<f32> {
let seq_length = token_embeddings.len() / (batch_size * hidden_size);
let mut output = vec![0.0f32; batch_size * hidden_size];
for batch_idx in 0..batch_size {
let src_start = batch_idx * seq_length * hidden_size;
let dst_start = batch_idx * hidden_size;
output[dst_start..dst_start + hidden_size]
.copy_from_slice(&token_embeddings[src_start..src_start + hidden_size]);
}
output
}
fn max_pool_cpu(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Vec<f32> {
let mut output = vec![f32::NEG_INFINITY; batch_size * hidden_size];
output
.par_chunks_mut(hidden_size)
.enumerate()
.for_each(|(batch_idx, out_chunk)| {
let tokens_base = batch_idx * seq_length * hidden_size;
let mask_base = batch_idx * seq_length;
for seq_idx in 0..seq_length {
if attention_mask[mask_base + seq_idx] == 1 {
let start = tokens_base + seq_idx * hidden_size;
for (j, out_val) in out_chunk.iter_mut().enumerate() {
let val = token_embeddings[start + j];
if val > *out_val {
*out_val = val;
}
}
}
}
// Replace -inf with 0
for val in out_chunk.iter_mut() {
if val.is_infinite() {
*val = 0.0;
}
}
});
output
}
}
// ==================== GPU Similarity ====================
/// GPU-accelerated similarity computations
pub struct GpuSimilarity {
use_gpu: bool,
min_candidates: usize,
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: Option<Arc<dyn GpuBackend>>,
}
impl GpuSimilarity {
/// Create new GPU similarity calculator
pub fn new(backend: &dyn GpuBackend, _shaders: &ShaderRegistry) -> Result<Self> {
Ok(Self {
use_gpu: backend.is_available() && backend.device_info().supports_compute,
min_candidates: 64, // Minimum candidates to use GPU
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: None,
})
}
/// Set the backend for GPU operations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
pub fn set_backend(&mut self, backend: Arc<dyn GpuBackend>) {
self.backend = Some(backend);
}
/// Batch cosine similarity
pub fn batch_cosine(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && candidates.len() >= self.min_candidates && self.backend.is_some() {
return self.batch_cosine_gpu(query, candidates);
}
Ok(self.batch_cosine_cpu(query, candidates))
}
/// Batch dot product
pub fn batch_dot_product(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && candidates.len() >= self.min_candidates && self.backend.is_some() {
return self.batch_dot_product_gpu(query, candidates);
}
Ok(self.batch_dot_product_cpu(query, candidates))
}
/// Batch Euclidean distance
pub fn batch_euclidean(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && candidates.len() >= self.min_candidates && self.backend.is_some() {
return self.batch_euclidean_gpu(query, candidates);
}
Ok(self.batch_euclidean_cpu(query, candidates))
}
/// Find top-k most similar
pub fn top_k(&self, query: &[f32], candidates: &[&[f32]], k: usize) -> Result<Vec<(usize, f32)>> {
let similarities = self.batch_cosine(query, candidates)?;
let mut indexed: Vec<(usize, f32)> = similarities.into_iter().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(k);
Ok(indexed)
}
// GPU implementations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn batch_cosine_gpu(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "batch_cosine".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
let dimension = query.len();
let num_candidates = candidates.len();
// Flatten candidates into contiguous buffer
let candidates_flat: Vec<f32> = candidates.iter().flat_map(|c| c.iter().copied()).collect();
// Create buffers
let query_buf = backend.create_buffer((dimension * 4) as u64, BufferUsage::Storage)?;
let candidates_buf = backend.create_buffer((candidates_flat.len() * 4) as u64, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((num_candidates * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (dimension, num_candidates)
let params: [u32; 2] = [dimension as u32, num_candidates as u32];
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&query_buf, bytemuck::cast_slice(query))?;
backend.write_buffer(&candidates_buf, bytemuck::cast_slice(&candidates_flat))?;
// Create pipeline with batch cosine shader
let shader = super::shaders::BATCH_COSINE_SIMILARITY_SHADER;
let pipeline = backend.create_pipeline(shader, "batch_cosine_similarity", [256, 1, 1])?;
// Dispatch with params buffer as 4th binding
let workgroups = [num_candidates.div_ceil(256) as u32, 1, 1];
backend.dispatch(&pipeline, &[&query_buf, &candidates_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (num_candidates * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(query_buf)?;
backend.release_buffer(candidates_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn batch_dot_product_gpu(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "batch_dot_product".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
let dimension = query.len();
let num_candidates = candidates.len();
// Flatten candidates into contiguous buffer
let candidates_flat: Vec<f32> = candidates.iter().flat_map(|c| c.iter().copied()).collect();
// Create buffers
let query_buf = backend.create_buffer((dimension * 4) as u64, BufferUsage::Storage)?;
let candidates_buf = backend.create_buffer((candidates_flat.len() * 4) as u64, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((num_candidates * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (dimension, num_candidates)
let params: [u32; 2] = [dimension as u32, num_candidates as u32];
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&query_buf, bytemuck::cast_slice(query))?;
backend.write_buffer(&candidates_buf, bytemuck::cast_slice(&candidates_flat))?;
// Create pipeline
let shader = super::shaders::DOT_PRODUCT_SHADER;
let pipeline = backend.create_pipeline(shader, "dot_product", [256, 1, 1])?;
// Dispatch with params buffer as 4th binding
let workgroups = [num_candidates.div_ceil(256) as u32, 1, 1];
backend.dispatch(&pipeline, &[&query_buf, &candidates_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (num_candidates * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(query_buf)?;
backend.release_buffer(candidates_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn batch_euclidean_gpu(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "batch_euclidean".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
let dimension = query.len();
let num_candidates = candidates.len();
// Flatten candidates into contiguous buffer
let candidates_flat: Vec<f32> = candidates.iter().flat_map(|c| c.iter().copied()).collect();
// Create buffers
let query_buf = backend.create_buffer((dimension * 4) as u64, BufferUsage::Storage)?;
let candidates_buf = backend.create_buffer((candidates_flat.len() * 4) as u64, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((num_candidates * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (dimension, num_candidates)
let params: [u32; 2] = [dimension as u32, num_candidates as u32];
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&query_buf, bytemuck::cast_slice(query))?;
backend.write_buffer(&candidates_buf, bytemuck::cast_slice(&candidates_flat))?;
// Create pipeline
let shader = super::shaders::EUCLIDEAN_DISTANCE_SHADER;
let pipeline = backend.create_pipeline(shader, "euclidean_distance", [256, 1, 1])?;
// Dispatch with params buffer as 4th binding
let workgroups = [num_candidates.div_ceil(256) as u32, 1, 1];
backend.dispatch(&pipeline, &[&query_buf, &candidates_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (num_candidates * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(query_buf)?;
backend.release_buffer(candidates_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
// CPU implementations
fn batch_cosine_cpu(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| cosine_similarity_cpu(query, c))
.collect()
}
fn batch_dot_product_cpu(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| dot_product_cpu(query, c))
.collect()
}
fn batch_euclidean_cpu(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| euclidean_distance_cpu(query, c))
.collect()
}
}
// ==================== GPU Vector Operations ====================
/// GPU-accelerated vector operations
pub struct GpuVectorOps {
use_gpu: bool,
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: Option<Arc<dyn GpuBackend>>,
}
impl GpuVectorOps {
/// Create new GPU vector operations
pub fn new(backend: &dyn GpuBackend, _shaders: &ShaderRegistry) -> Result<Self> {
Ok(Self {
use_gpu: backend.is_available() && backend.device_info().supports_compute,
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: None,
})
}
/// Set the backend for GPU operations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
pub fn set_backend(&mut self, backend: Arc<dyn GpuBackend>) {
self.backend = Some(backend);
}
/// L2 normalize batch of vectors
pub fn normalize_batch(&self, vectors: &mut [f32], dimension: usize) -> Result<()> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && vectors.len() >= dimension * 64 && self.backend.is_some() {
return self.normalize_batch_gpu(vectors, dimension);
}
self.normalize_batch_cpu(vectors, dimension);
Ok(())
}
/// Matrix-vector multiplication
pub fn matmul(&self, matrix: &[f32], vector: &[f32], rows: usize, cols: usize) -> Result<Vec<f32>> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && rows >= 64 && self.backend.is_some() {
return self.matmul_gpu(matrix, vector, rows, cols);
}
Ok(self.matmul_cpu(matrix, vector, rows, cols))
}
/// Batch vector addition
pub fn batch_add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
if a.len() != b.len() {
return Err(EmbeddingError::dimension_mismatch(a.len(), b.len()));
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && a.len() >= 1024 && self.backend.is_some() {
return self.batch_add_gpu(a, b);
}
Ok(a.par_iter().zip(b.par_iter()).map(|(x, y)| x + y).collect())
}
/// Batch vector scaling
pub fn batch_scale(&self, vectors: &mut [f32], scale: f32) -> Result<()> {
vectors.par_iter_mut().for_each(|v| *v *= scale);
Ok(())
}
// GPU implementations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn normalize_batch_gpu(&self, vectors: &mut [f32], dimension: usize) -> Result<()> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "normalize_batch".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
let num_vectors = vectors.len() / dimension;
// Create buffers (input, dummy, output, params)
let input_buf = backend.create_buffer((vectors.len() * 4) as u64, BufferUsage::Storage)?;
let dummy_buf = backend.create_buffer(4, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((vectors.len() * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (dimension, num_vectors)
let params: [u32; 2] = [dimension as u32, num_vectors as u32];
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&input_buf, bytemuck::cast_slice(vectors))?;
// Create pipeline
let shader = super::shaders::L2_NORMALIZE_SHADER;
let pipeline = backend.create_pipeline(shader, "l2_normalize", [256, 1, 1])?;
// Dispatch with 4 bindings
let workgroups = [num_vectors.div_ceil(256) as u32, 1, 1];
backend.dispatch(&pipeline, &[&input_buf, &dummy_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (vectors.len() * 4) as u64)?;
let output: &[f32] = bytemuck::cast_slice(&output_bytes);
vectors.copy_from_slice(output);
// Cleanup
backend.release_buffer(input_buf)?;
backend.release_buffer(dummy_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(())
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn matmul_gpu(&self, matrix: &[f32], vector: &[f32], rows: usize, cols: usize) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "matmul".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
// Create buffers
let mat_buf = backend.create_buffer((matrix.len() * 4) as u64, BufferUsage::Storage)?;
let vec_buf = backend.create_buffer((vector.len() * 4) as u64, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((rows * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (rows, cols)
let params: [u32; 2] = [rows as u32, cols as u32];
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&mat_buf, bytemuck::cast_slice(matrix))?;
backend.write_buffer(&vec_buf, bytemuck::cast_slice(vector))?;
// Create pipeline
let shader = super::shaders::MATMUL_SHADER;
let pipeline = backend.create_pipeline(shader, "matmul", [16, 16, 1])?;
// Dispatch with params buffer as 4th binding
let workgroups = [rows.div_ceil(16) as u32, 1, 1];
backend.dispatch(&pipeline, &[&mat_buf, &vec_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (rows * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(mat_buf)?;
backend.release_buffer(vec_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn batch_add_gpu(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "batch_add".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
// Create buffers
let buf_a = backend.create_buffer((a.len() * 4) as u64, BufferUsage::Storage)?;
let buf_b = backend.create_buffer((b.len() * 4) as u64, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((a.len() * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (length)
let params: [u32; 1] = [a.len() as u32];
let params_buf = backend.create_buffer(4, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&buf_a, bytemuck::cast_slice(a))?;
backend.write_buffer(&buf_b, bytemuck::cast_slice(b))?;
// Create pipeline
let shader = super::shaders::VECTOR_ADD_SHADER;
let pipeline = backend.create_pipeline(shader, "vector_add", [256, 1, 1])?;
// Dispatch with params buffer as 4th binding
let workgroups = [a.len().div_ceil(256) as u32, 1, 1];
backend.dispatch(&pipeline, &[&buf_a, &buf_b, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (a.len() * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(buf_a)?;
backend.release_buffer(buf_b)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
// CPU implementations
fn normalize_batch_cpu(&self, vectors: &mut [f32], dimension: usize) {
vectors
.par_chunks_mut(dimension)
.for_each(|chunk| {
let norm: f32 = chunk.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-12 {
for val in chunk.iter_mut() {
*val /= norm;
}
}
});
}
fn matmul_cpu(&self, matrix: &[f32], vector: &[f32], rows: usize, cols: usize) -> Vec<f32> {
let mut result = vec![0.0f32; rows];
result
.par_iter_mut()
.enumerate()
.for_each(|(row, out)| {
let row_start = row * cols;
*out = matrix[row_start..row_start + cols]
.iter()
.zip(vector.iter())
.map(|(m, v)| m * v)
.sum();
});
result
}
}
// ==================== Standalone Functions ====================
/// Batch cosine similarity (GPU-accelerated if available)
pub fn batch_cosine_similarity_gpu(query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| cosine_similarity_cpu(query, c))
.collect()
}
/// Batch dot product (GPU-accelerated if available)
pub fn batch_dot_product_gpu(query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| dot_product_cpu(query, c))
.collect()
}
/// Batch Euclidean distance (GPU-accelerated if available)
pub fn batch_euclidean_gpu(query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| euclidean_distance_cpu(query, c))
.collect()
}
// ==================== CPU Helper Functions ====================
#[inline]
fn cosine_similarity_cpu(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 1e-12 && norm_b > 1e-12 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
#[inline]
fn dot_product_cpu(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
fn euclidean_distance_cpu(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity_cpu(&a, &b) - 1.0).abs() < 1e-6);
assert!(cosine_similarity_cpu(&a, &c).abs() < 1e-6);
}
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!((dot_product_cpu(&a, &b) - 32.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0];
assert!((euclidean_distance_cpu(&a, &b) - 5.0).abs() < 1e-6);
}
#[test]
fn test_batch_cosine() {
let query = vec![1.0, 0.0, 0.0];
let candidates: Vec<&[f32]> = vec![
&[1.0, 0.0, 0.0][..],
&[0.0, 1.0, 0.0][..],
&[0.707, 0.707, 0.0][..],
];
let results = batch_cosine_similarity_gpu(&query, &candidates);
assert_eq!(results.len(), 3);
assert!((results[0] - 1.0).abs() < 1e-6);
assert!(results[1].abs() < 1e-6);
}
#[test]
fn test_mean_pool_cpu() {
let pooler = GpuPooler {
use_gpu: false,
#[cfg(feature = "gpu")]
backend: None,
};
// batch=2, seq=2, hidden=3
let tokens = vec![
1.0, 2.0, 3.0, // batch 0, seq 0
4.0, 5.0, 6.0, // batch 0, seq 1
7.0, 8.0, 9.0, // batch 1, seq 0
10.0, 11.0, 12.0, // batch 1, seq 1
];
let mask = vec![1i64, 1, 1, 1];
let result = pooler.mean_pool_cpu(&tokens, &mask, 2, 2, 3);
assert_eq!(result.len(), 6);
// Batch 0: mean of [1,2,3] and [4,5,6] = [2.5, 3.5, 4.5]
assert!((result[0] - 2.5).abs() < 1e-6);
assert!((result[1] - 3.5).abs() < 1e-6);
assert!((result[2] - 4.5).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,613 @@
//! GPU Compute Shaders for RuVector Operations
//!
//! WGSL (WebGPU Shading Language) implementations for:
//! - Pooling operations
//! - Similarity computations
//! - Vector normalization
//! - Matrix operations
use std::collections::HashMap;
/// Shader registry for managing compute shaders
#[derive(Debug)]
pub struct ShaderRegistry {
shaders: HashMap<String, ShaderModule>,
}
/// Shader module information
#[derive(Debug, Clone)]
pub struct ShaderModule {
/// Shader name
pub name: String,
/// WGSL source code
pub source: String,
/// Entry point function
pub entry_point: String,
/// Default workgroup size
pub workgroup_size: [u32; 3],
}
impl ShaderRegistry {
/// Create new registry with built-in shaders
pub fn new() -> Self {
let mut shaders = HashMap::new();
// Register all built-in shaders
for shader in Self::builtin_shaders() {
shaders.insert(shader.name.clone(), shader);
}
Self { shaders }
}
/// Get shader by name
pub fn get(&self, name: &str) -> Option<&ShaderModule> {
self.shaders.get(name)
}
/// Register custom shader
pub fn register(&mut self, shader: ShaderModule) {
self.shaders.insert(shader.name.clone(), shader);
}
/// List all available shaders
pub fn list(&self) -> Vec<&str> {
self.shaders.keys().map(|s| s.as_str()).collect()
}
/// Get built-in shader definitions
fn builtin_shaders() -> Vec<ShaderModule> {
vec![
// Cosine Similarity
ShaderModule {
name: "cosine_similarity".to_string(),
source: SHADER_COSINE_SIMILARITY.to_string(),
entry_point: "cosine_similarity".to_string(),
workgroup_size: [256, 1, 1],
},
// Batch Cosine Similarity
ShaderModule {
name: "batch_cosine_similarity".to_string(),
source: SHADER_BATCH_COSINE_SIMILARITY.to_string(),
entry_point: "batch_cosine_similarity".to_string(),
workgroup_size: [256, 1, 1],
},
// Dot Product
ShaderModule {
name: "dot_product".to_string(),
source: SHADER_DOT_PRODUCT.to_string(),
entry_point: "dot_product".to_string(),
workgroup_size: [256, 1, 1],
},
// Euclidean Distance
ShaderModule {
name: "euclidean_distance".to_string(),
source: SHADER_EUCLIDEAN_DISTANCE.to_string(),
entry_point: "euclidean_distance".to_string(),
workgroup_size: [256, 1, 1],
},
// L2 Normalize
ShaderModule {
name: "l2_normalize".to_string(),
source: SHADER_L2_NORMALIZE.to_string(),
entry_point: "l2_normalize".to_string(),
workgroup_size: [256, 1, 1],
},
// Mean Pooling
ShaderModule {
name: "mean_pool".to_string(),
source: SHADER_MEAN_POOL.to_string(),
entry_point: "mean_pool".to_string(),
workgroup_size: [64, 1, 1],
},
// Max Pooling
ShaderModule {
name: "max_pool".to_string(),
source: SHADER_MAX_POOL.to_string(),
entry_point: "max_pool".to_string(),
workgroup_size: [64, 1, 1],
},
// CLS Pooling
ShaderModule {
name: "cls_pool".to_string(),
source: SHADER_CLS_POOL.to_string(),
entry_point: "cls_pool".to_string(),
workgroup_size: [64, 1, 1],
},
// Matrix-Vector Multiplication
ShaderModule {
name: "matmul".to_string(),
source: SHADER_MATMUL.to_string(),
entry_point: "matmul".to_string(),
workgroup_size: [16, 16, 1],
},
// Vector Addition
ShaderModule {
name: "vector_add".to_string(),
source: SHADER_VECTOR_ADD.to_string(),
entry_point: "vector_add".to_string(),
workgroup_size: [256, 1, 1],
},
// Vector Scale
ShaderModule {
name: "vector_scale".to_string(),
source: SHADER_VECTOR_SCALE.to_string(),
entry_point: "vector_scale".to_string(),
workgroup_size: [256, 1, 1],
},
]
}
}
impl Default for ShaderRegistry {
fn default() -> Self {
Self::new()
}
}
// ==================== Shader Source Code ====================
// Public aliases for operations.rs
pub const MEAN_POOL_SHADER: &str = SHADER_MEAN_POOL;
pub const MAX_POOL_SHADER: &str = SHADER_MAX_POOL;
pub const BATCH_COSINE_SIMILARITY_SHADER: &str = SHADER_BATCH_COSINE_SIMILARITY;
pub const DOT_PRODUCT_SHADER: &str = SHADER_DOT_PRODUCT;
pub const EUCLIDEAN_DISTANCE_SHADER: &str = SHADER_EUCLIDEAN_DISTANCE;
pub const L2_NORMALIZE_SHADER: &str = SHADER_L2_NORMALIZE;
pub const MATMUL_SHADER: &str = SHADER_MATMUL;
pub const VECTOR_ADD_SHADER: &str = SHADER_VECTOR_ADD;
/// Cosine similarity between two vectors
pub const SHADER_COSINE_SIMILARITY: &str = r#"
struct Params {
dimension: u32,
count: u32,
}
@group(0) @binding(0) var<storage, read> query: array<f32>;
@group(0) @binding(1) var<storage, read> candidate: array<f32>;
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
var<workgroup> shared_dot: array<f32, 256>;
var<workgroup> shared_norm_a: array<f32, 256>;
var<workgroup> shared_norm_b: array<f32, 256>;
@compute @workgroup_size(256)
fn cosine_similarity(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let idx = gid.x;
let local_idx = lid.x;
var dot: f32 = 0.0;
var norm_a: f32 = 0.0;
var norm_b: f32 = 0.0;
// Compute partial sums
var i = local_idx;
while (i < params.dimension) {
let a = query[i];
let b = candidate[i];
dot += a * b;
norm_a += a * a;
norm_b += b * b;
i += 256u;
}
// Store in shared memory
shared_dot[local_idx] = dot;
shared_norm_a[local_idx] = norm_a;
shared_norm_b[local_idx] = norm_b;
workgroupBarrier();
// Reduction
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (local_idx < stride) {
shared_dot[local_idx] += shared_dot[local_idx + stride];
shared_norm_a[local_idx] += shared_norm_a[local_idx + stride];
shared_norm_b[local_idx] += shared_norm_b[local_idx + stride];
}
workgroupBarrier();
}
// Write result
if (local_idx == 0u) {
let norm_product = sqrt(shared_norm_a[0] * shared_norm_b[0]);
if (norm_product > 1e-12) {
result[0] = shared_dot[0] / norm_product;
} else {
result[0] = 0.0;
}
}
}
"#;
/// Batch cosine similarity - one query vs many candidates
pub const SHADER_BATCH_COSINE_SIMILARITY: &str = r#"
struct Params {
dimension: u32,
num_candidates: u32,
}
@group(0) @binding(0) var<storage, read> query: array<f32>;
@group(0) @binding(1) var<storage, read> candidates: array<f32>;
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn batch_cosine_similarity(@builtin(global_invocation_id) gid: vec3<u32>) {
let candidate_idx = gid.x;
if (candidate_idx >= params.num_candidates) {
return;
}
let base = candidate_idx * params.dimension;
var dot: f32 = 0.0;
var norm_a: f32 = 0.0;
var norm_b: f32 = 0.0;
for (var i = 0u; i < params.dimension; i++) {
let a = query[i];
let b = candidates[base + i];
dot += a * b;
norm_a += a * a;
norm_b += b * b;
}
let norm_product = sqrt(norm_a * norm_b);
if (norm_product > 1e-12) {
results[candidate_idx] = dot / norm_product;
} else {
results[candidate_idx] = 0.0;
}
}
"#;
/// Dot product computation
pub const SHADER_DOT_PRODUCT: &str = r#"
struct Params {
dimension: u32,
num_candidates: u32,
}
@group(0) @binding(0) var<storage, read> query: array<f32>;
@group(0) @binding(1) var<storage, read> candidates: array<f32>;
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn dot_product(@builtin(global_invocation_id) gid: vec3<u32>) {
let candidate_idx = gid.x;
if (candidate_idx >= params.num_candidates) {
return;
}
let base = candidate_idx * params.dimension;
var dot: f32 = 0.0;
for (var i = 0u; i < params.dimension; i++) {
dot += query[i] * candidates[base + i];
}
results[candidate_idx] = dot;
}
"#;
/// Euclidean distance computation
pub const SHADER_EUCLIDEAN_DISTANCE: &str = r#"
struct Params {
dimension: u32,
num_candidates: u32,
}
@group(0) @binding(0) var<storage, read> query: array<f32>;
@group(0) @binding(1) var<storage, read> candidates: array<f32>;
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn euclidean_distance(@builtin(global_invocation_id) gid: vec3<u32>) {
let candidate_idx = gid.x;
if (candidate_idx >= params.num_candidates) {
return;
}
let base = candidate_idx * params.dimension;
var sum_sq: f32 = 0.0;
for (var i = 0u; i < params.dimension; i++) {
let diff = query[i] - candidates[base + i];
sum_sq += diff * diff;
}
results[candidate_idx] = sqrt(sum_sq);
}
"#;
/// L2 normalization
pub const SHADER_L2_NORMALIZE: &str = r#"
struct Params {
dimension: u32,
num_vectors: u32,
}
@group(0) @binding(0) var<storage, read> input_vectors: array<f32>;
@group(0) @binding(1) var<storage, read> _dummy: array<f32>;
@group(0) @binding(2) var<storage, read_write> output_vectors: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn l2_normalize(@builtin(global_invocation_id) gid: vec3<u32>) {
let vec_idx = gid.x;
if (vec_idx >= params.num_vectors) {
return;
}
let base = vec_idx * params.dimension;
// Compute norm
var norm_sq: f32 = 0.0;
for (var i = 0u; i < params.dimension; i++) {
let val = input_vectors[base + i];
norm_sq += val * val;
}
let norm = sqrt(norm_sq);
// Normalize and write to output
if (norm > 1e-12) {
for (var i = 0u; i < params.dimension; i++) {
output_vectors[base + i] = input_vectors[base + i] / norm;
}
} else {
for (var i = 0u; i < params.dimension; i++) {
output_vectors[base + i] = input_vectors[base + i];
}
}
}
"#;
/// Mean pooling over sequence
pub const SHADER_MEAN_POOL: &str = r#"
struct Params {
batch_size: u32,
seq_length: u32,
hidden_size: u32,
}
@group(0) @binding(0) var<storage, read> tokens: array<f32>;
@group(0) @binding(1) var<storage, read> attention_mask: array<i32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(64)
fn mean_pool(@builtin(global_invocation_id) gid: vec3<u32>) {
let batch_idx = gid.x / params.hidden_size;
let hidden_idx = gid.x % params.hidden_size;
if (batch_idx >= params.batch_size) {
return;
}
let tokens_base = batch_idx * params.seq_length * params.hidden_size;
let mask_base = batch_idx * params.seq_length;
var sum: f32 = 0.0;
var count: f32 = 0.0;
for (var i = 0u; i < params.seq_length; i++) {
if (attention_mask[mask_base + i] == 1) {
sum += tokens[tokens_base + i * params.hidden_size + hidden_idx];
count += 1.0;
}
}
let out_idx = batch_idx * params.hidden_size + hidden_idx;
if (count > 0.0) {
output[out_idx] = sum / count;
} else {
output[out_idx] = 0.0;
}
}
"#;
/// Max pooling over sequence
pub const SHADER_MAX_POOL: &str = r#"
struct Params {
batch_size: u32,
seq_length: u32,
hidden_size: u32,
}
@group(0) @binding(0) var<storage, read> tokens: array<f32>;
@group(0) @binding(1) var<storage, read> attention_mask: array<i32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(64)
fn max_pool(@builtin(global_invocation_id) gid: vec3<u32>) {
let batch_idx = gid.x / params.hidden_size;
let hidden_idx = gid.x % params.hidden_size;
if (batch_idx >= params.batch_size) {
return;
}
let tokens_base = batch_idx * params.seq_length * params.hidden_size;
let mask_base = batch_idx * params.seq_length;
var max_val: f32 = -3.402823e+38; // -FLT_MAX
var found: bool = false;
for (var i = 0u; i < params.seq_length; i++) {
if (attention_mask[mask_base + i] == 1) {
let val = tokens[tokens_base + i * params.hidden_size + hidden_idx];
if (!found || val > max_val) {
max_val = val;
found = true;
}
}
}
let out_idx = batch_idx * params.hidden_size + hidden_idx;
output[out_idx] = select(0.0, max_val, found);
}
"#;
/// CLS token pooling (first token)
pub const SHADER_CLS_POOL: &str = r#"
struct Params {
batch_size: u32,
seq_length: u32,
hidden_size: u32,
}
@group(0) @binding(0) var<storage, read> tokens: array<f32>;
@group(0) @binding(1) var<storage, read> _dummy: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(64)
fn cls_pool(@builtin(global_invocation_id) gid: vec3<u32>) {
let batch_idx = gid.x / params.hidden_size;
let hidden_idx = gid.x % params.hidden_size;
if (batch_idx >= params.batch_size) {
return;
}
// CLS is first token
let tokens_base = batch_idx * params.seq_length * params.hidden_size;
let out_idx = batch_idx * params.hidden_size + hidden_idx;
output[out_idx] = tokens[tokens_base + hidden_idx];
}
"#;
/// Matrix-vector multiplication
pub const SHADER_MATMUL: &str = r#"
struct Params {
rows: u32,
cols: u32,
}
@group(0) @binding(0) var<storage, read> matrix: array<f32>;
@group(0) @binding(1) var<storage, read> vector: array<f32>;
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(16, 16)
fn matmul(@builtin(global_invocation_id) gid: vec3<u32>) {
let row = gid.x;
if (row >= params.rows) {
return;
}
var sum: f32 = 0.0;
for (var col = 0u; col < params.cols; col++) {
sum += matrix[row * params.cols + col] * vector[col];
}
result[row] = sum;
}
"#;
/// Vector addition
pub const SHADER_VECTOR_ADD: &str = r#"
struct Params {
length: u32,
}
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn vector_add(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= params.length) {
return;
}
result[idx] = a[idx] + b[idx];
}
"#;
/// Vector scaling
pub const SHADER_VECTOR_SCALE: &str = r#"
struct Params {
length: u32,
scale: f32,
}
@group(0) @binding(0) var<storage, read> input_vector: array<f32>;
@group(0) @binding(1) var<storage, read> _dummy: array<f32>;
@group(0) @binding(2) var<storage, read_write> output_vector: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn vector_scale(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= params.length) {
return;
}
output_vector[idx] = input_vector[idx] * params.scale;
}
"#;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shader_registry() {
let registry = ShaderRegistry::new();
// Check all built-in shaders are registered
assert!(registry.get("cosine_similarity").is_some());
assert!(registry.get("batch_cosine_similarity").is_some());
assert!(registry.get("dot_product").is_some());
assert!(registry.get("euclidean_distance").is_some());
assert!(registry.get("l2_normalize").is_some());
assert!(registry.get("mean_pool").is_some());
assert!(registry.get("max_pool").is_some());
assert!(registry.get("cls_pool").is_some());
assert!(registry.get("matmul").is_some());
assert!(registry.get("vector_add").is_some());
assert!(registry.get("vector_scale").is_some());
}
#[test]
fn test_shader_content() {
let registry = ShaderRegistry::new();
let cosine = registry.get("cosine_similarity").unwrap();
assert!(cosine.source.contains("@compute"));
assert!(cosine.source.contains("workgroup_size"));
assert_eq!(cosine.entry_point, "cosine_similarity");
}
#[test]
fn test_custom_shader() {
let mut registry = ShaderRegistry::new();
registry.register(ShaderModule {
name: "custom_op".to_string(),
source: "// custom shader".to_string(),
entry_point: "custom".to_string(),
workgroup_size: [128, 1, 1],
});
assert!(registry.get("custom_op").is_some());
}
}

View File

@@ -0,0 +1,424 @@
//! GPU Module Tests
//!
//! Comprehensive tests for GPU acceleration functionality.
use super::*;
use super::config::{GpuConfig, GpuMode, PowerPreference, GpuMemoryStats};
use super::backend::CpuBackend;
use super::shaders::ShaderModule;
// ==================== Configuration Tests ====================
#[test]
fn test_gpu_config_default() {
let config = GpuConfig::default();
assert_eq!(config.mode, GpuMode::Auto);
assert_eq!(config.power_preference, PowerPreference::HighPerformance);
assert_eq!(config.workgroup_size, 256);
assert!(config.fallback_to_cpu);
assert!(config.cache_shaders);
}
#[test]
fn test_gpu_config_builder() {
let config = GpuConfig::auto()
.with_mode(GpuMode::WebGpu)
.with_power_preference(PowerPreference::LowPower)
.with_workgroup_size(512)
.with_min_batch_size(32)
.with_min_dimension(256)
.with_profiling(true);
assert_eq!(config.mode, GpuMode::WebGpu);
assert_eq!(config.power_preference, PowerPreference::LowPower);
assert_eq!(config.workgroup_size, 512);
assert_eq!(config.min_batch_size, 32);
assert_eq!(config.min_dimension, 256);
assert!(config.enable_profiling);
}
#[test]
fn test_should_use_gpu() {
let config = GpuConfig::default()
.with_min_batch_size(16)
.with_min_dimension(128);
// Below minimum batch size
assert!(!config.should_use_gpu(8, 384));
// Below minimum dimension
assert!(!config.should_use_gpu(32, 64));
// Both conditions met
assert!(config.should_use_gpu(32, 384));
// CPU only mode
let cpu_config = GpuConfig::cpu_only();
assert!(!cpu_config.should_use_gpu(1000, 1000));
}
#[test]
fn test_preset_configs() {
let high_perf = GpuConfig::high_performance();
assert_eq!(high_perf.workgroup_size, 512);
assert_eq!(high_perf.min_batch_size, 8);
let low_power = GpuConfig::low_power();
assert_eq!(low_power.power_preference, PowerPreference::LowPower);
assert_eq!(low_power.workgroup_size, 128);
let cpu_only = GpuConfig::cpu_only();
assert_eq!(cpu_only.mode, GpuMode::CpuOnly);
}
// ==================== Shader Tests ====================
#[test]
fn test_shader_registry_initialization() {
let registry = ShaderRegistry::new();
let expected_shaders = vec![
"cosine_similarity",
"batch_cosine_similarity",
"dot_product",
"euclidean_distance",
"l2_normalize",
"mean_pool",
"max_pool",
"cls_pool",
"matmul",
"vector_add",
"vector_scale",
];
for name in expected_shaders {
assert!(registry.get(name).is_some(), "Missing shader: {}", name);
}
}
#[test]
fn test_shader_module_content() {
let registry = ShaderRegistry::new();
// Check cosine similarity shader
let cosine = registry.get("cosine_similarity").unwrap();
assert!(cosine.source.contains("@compute"));
assert!(cosine.source.contains("workgroup_size"));
assert!(cosine.source.contains("cosine_similarity"));
assert_eq!(cosine.entry_point, "cosine_similarity");
assert_eq!(cosine.workgroup_size, [256, 1, 1]);
// Check mean pool shader
let mean_pool = registry.get("mean_pool").unwrap();
assert!(mean_pool.source.contains("attention_mask"));
assert!(mean_pool.source.contains("hidden_size"));
assert_eq!(mean_pool.entry_point, "mean_pool");
}
#[test]
fn test_custom_shader_registration() {
let mut registry = ShaderRegistry::new();
let custom = ShaderModule {
name: "custom_kernel".to_string(),
source: "@compute @workgroup_size(64) fn custom() {}".to_string(),
entry_point: "custom".to_string(),
workgroup_size: [64, 1, 1],
};
registry.register(custom);
assert!(registry.get("custom_kernel").is_some());
let retrieved = registry.get("custom_kernel").unwrap();
assert_eq!(retrieved.entry_point, "custom");
}
// ==================== Batch Operations Tests ====================
#[test]
fn test_batch_cosine_similarity() {
let query = vec![1.0, 0.0, 0.0];
let candidates: Vec<&[f32]> = vec![
&[1.0, 0.0, 0.0][..], // similarity = 1.0
&[0.0, 1.0, 0.0][..], // similarity = 0.0
&[-1.0, 0.0, 0.0][..], // similarity = -1.0
];
let results = batch_cosine_similarity_gpu(&query, &candidates);
assert_eq!(results.len(), 3);
assert!((results[0] - 1.0).abs() < 1e-6);
assert!(results[1].abs() < 1e-6);
assert!((results[2] - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_batch_dot_product() {
let query = vec![1.0, 1.0, 1.0];
let candidates: Vec<&[f32]> = vec![
&[1.0, 1.0, 1.0][..], // dot = 3.0
&[2.0, 2.0, 2.0][..], // dot = 6.0
&[0.0, 0.0, 0.0][..], // dot = 0.0
];
let results = batch_dot_product_gpu(&query, &candidates);
assert_eq!(results.len(), 3);
assert!((results[0] - 3.0).abs() < 1e-6);
assert!((results[1] - 6.0).abs() < 1e-6);
assert!(results[2].abs() < 1e-6);
}
#[test]
fn test_batch_euclidean() {
let query = vec![0.0, 0.0, 0.0];
let candidates: Vec<&[f32]> = vec![
&[3.0, 4.0, 0.0][..], // dist = 5.0
&[1.0, 0.0, 0.0][..], // dist = 1.0
&[0.0, 0.0, 0.0][..], // dist = 0.0
];
let results = batch_euclidean_gpu(&query, &candidates);
assert_eq!(results.len(), 3);
assert!((results[0] - 5.0).abs() < 1e-6);
assert!((results[1] - 1.0).abs() < 1e-6);
assert!(results[2].abs() < 1e-6);
}
// ==================== Pooling Tests (using public API) ====================
#[test]
fn test_mean_pool_via_api() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let pooler = GpuPooler::new(&backend, &shaders).unwrap();
// batch=2, seq=2, hidden=3
let tokens = vec![
1.0, 2.0, 3.0, // batch 0, seq 0
4.0, 5.0, 6.0, // batch 0, seq 1
7.0, 8.0, 9.0, // batch 1, seq 0
10.0, 11.0, 12.0, // batch 1, seq 1
];
let mask = vec![1i64, 1, 1, 1];
let result = pooler.mean_pool(&tokens, &mask, 2, 2, 3).unwrap();
assert_eq!(result.len(), 6);
// Batch 0: mean of [1,2,3] and [4,5,6] = [2.5, 3.5, 4.5]
assert!((result[0] - 2.5).abs() < 1e-6);
assert!((result[1] - 3.5).abs() < 1e-6);
assert!((result[2] - 4.5).abs() < 1e-6);
}
#[test]
fn test_cls_pool_via_api() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let pooler = GpuPooler::new(&backend, &shaders).unwrap();
// batch=2, seq=3, hidden=4
let tokens = vec![
// Batch 0
1.0, 2.0, 3.0, 4.0, // CLS token
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
// Batch 1
10.0, 20.0, 30.0, 40.0, // CLS token
50.0, 60.0, 70.0, 80.0,
90.0, 100.0, 110.0, 120.0,
];
let result = pooler.cls_pool(&tokens, 2, 4).unwrap();
assert_eq!(result.len(), 8);
// Batch 0: first token
assert!((result[0] - 1.0).abs() < 1e-6);
assert!((result[1] - 2.0).abs() < 1e-6);
assert!((result[2] - 3.0).abs() < 1e-6);
assert!((result[3] - 4.0).abs() < 1e-6);
// Batch 1: first token
assert!((result[4] - 10.0).abs() < 1e-6);
assert!((result[5] - 20.0).abs() < 1e-6);
assert!((result[6] - 30.0).abs() < 1e-6);
assert!((result[7] - 40.0).abs() < 1e-6);
}
#[test]
fn test_max_pool_via_api() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let pooler = GpuPooler::new(&backend, &shaders).unwrap();
// batch=1, seq=3, hidden=4
let tokens = vec![
1.0, 10.0, 3.0, 4.0, // seq 0
5.0, 2.0, 7.0, 8.0, // seq 1
9.0, 6.0, 11.0, 0.0, // seq 2
];
let mask = vec![1i64, 1, 1];
let result = pooler.max_pool(&tokens, &mask, 1, 3, 4).unwrap();
assert_eq!(result.len(), 4);
// Max across all sequences for each dimension
assert!((result[0] - 9.0).abs() < 1e-6); // max(1, 5, 9)
assert!((result[1] - 10.0).abs() < 1e-6); // max(10, 2, 6)
assert!((result[2] - 11.0).abs() < 1e-6); // max(3, 7, 11)
assert!((result[3] - 8.0).abs() < 1e-6); // max(4, 8, 0)
}
// ==================== Vector Operations Tests ====================
#[test]
fn test_normalize_batch() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
let mut vectors = vec![
3.0, 4.0, 0.0, // norm = 5, normalized = [0.6, 0.8, 0]
0.0, 0.0, 5.0, // norm = 5, normalized = [0, 0, 1]
];
ops.normalize_batch(&mut vectors, 3).unwrap();
// Check first vector
assert!((vectors[0] - 0.6).abs() < 1e-6);
assert!((vectors[1] - 0.8).abs() < 1e-6);
assert!(vectors[2].abs() < 1e-6);
// Check second vector
assert!(vectors[3].abs() < 1e-6);
assert!(vectors[4].abs() < 1e-6);
assert!((vectors[5] - 1.0).abs() < 1e-6);
}
#[test]
fn test_matmul() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
// 2x3 matrix
let matrix = vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
];
// 3x1 vector
let vector = vec![1.0, 1.0, 1.0];
let result = ops.matmul(&matrix, &vector, 2, 3).unwrap();
assert_eq!(result.len(), 2);
assert!((result[0] - 6.0).abs() < 1e-6); // 1+2+3
assert!((result[1] - 15.0).abs() < 1e-6); // 4+5+6
}
#[test]
fn test_batch_add() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let result = ops.batch_add(&a, &b).unwrap();
assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
}
#[test]
fn test_batch_scale() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
let mut vectors = vec![1.0, 2.0, 3.0, 4.0];
ops.batch_scale(&mut vectors, 2.0).unwrap();
assert_eq!(vectors, vec![2.0, 4.0, 6.0, 8.0]);
}
// ==================== Integration Tests ====================
#[test]
fn test_gpu_similarity_with_backend() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let similarity = GpuSimilarity::new(&backend, &shaders).unwrap();
let query = vec![1.0, 0.0, 0.0];
let candidates: Vec<&[f32]> = vec![
&[1.0, 0.0, 0.0][..],
&[0.0, 1.0, 0.0][..],
];
let results = similarity.batch_cosine(&query, &candidates).unwrap();
assert_eq!(results.len(), 2);
assert!((results[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_top_k_similar() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let similarity = GpuSimilarity::new(&backend, &shaders).unwrap();
let query = vec![1.0, 0.0, 0.0];
let candidates: Vec<&[f32]> = vec![
&[0.0, 1.0, 0.0][..], // sim = 0
&[1.0, 0.0, 0.0][..], // sim = 1 (best)
&[0.5, 0.5, 0.0][..], // sim ≈ 0.707
&[-1.0, 0.0, 0.0][..], // sim = -1 (worst)
];
let top2 = similarity.top_k(&query, &candidates, 2).unwrap();
assert_eq!(top2.len(), 2);
assert_eq!(top2[0].0, 1); // Index of [1,0,0]
assert_eq!(top2[1].0, 2); // Index of [0.5,0.5,0]
}
// ==================== Memory Stats Tests ====================
#[test]
fn test_memory_stats() {
let stats = GpuMemoryStats {
total: 1024 * 1024 * 1024, // 1GB
used: 512 * 1024 * 1024, // 512MB
free: 512 * 1024 * 1024,
peak: 768 * 1024 * 1024,
};
assert!((stats.usage_percent() - 50.0).abs() < 0.1);
}
#[test]
fn test_empty_memory_stats() {
let stats = GpuMemoryStats::default();
assert_eq!(stats.usage_percent(), 0.0);
}
// ==================== Backend Tests ====================
#[test]
fn test_cpu_backend_info() {
let backend = CpuBackend;
assert!(backend.is_available());
let info = backend.device_info();
assert_eq!(info.backend, "CPU");
assert!(!info.supports_compute);
}